use std::{ collections::HashSet, path::PathBuf, sync::{ atomic::{AtomicUsize, Ordering}, Arc, Mutex, }, time::{Duration, Instant}, }; use chrono::Utc; use reqwest::Client; use tokio::sync::{broadcast, mpsc, oneshot, RwLock}; use crate::{ models::{Job, JobId, JobStatus, ModelEvent, ModelState, Segment}, storage::Storage, transcriber::Transcriber, webhook, AppError, }; /// Per-job broadcast channel for SSE subscribers. pub type ProgressTx = broadcast::Sender; #[derive(Debug, Clone)] pub enum ProgressEvent { /// `percent` — overall 0–100; `chunk` — 1-based; `total` — total chunks. Progress { percent: u8, chunk: usize, total: usize }, Done(Box), Error(String), } /// Global registry: job_id → broadcast sender. pub type ProgressRegistry = Arc>; // ── Worker command channel ──────────────────────────────────────────────────── /// Commands sent to the GPU worker OS thread. #[derive(Debug)] pub enum WorkerCmd { /// Request a model load. Idempotent: if already loading/ready, ignored. Load, /// Unload the model immediately and free GPU memory. Unload, /// Internal: run a transcription chunk. Transcribe(TranscribeRequest), } // ── Transcription request/response types ───────────────────────────────────── pub struct TranscribeRequest { pub pcm: Vec, pub language: Option, pub task: String, pub on_progress: Box, pub reply: oneshot::Sender, String)>>, } impl std::fmt::Debug for TranscribeRequest { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { f.debug_struct("TranscribeRequest") .field("language", &self.language) .field("task", &self.task) .finish_non_exhaustive() } } // ── Public API ──────────────────────────────────────────────────────────────── /// Spawn the single GPU worker. /// /// Returns the SSE progress registry and a command sender for the worker thread. /// The model starts **unloaded**; send `WorkerCmd::Load` or submit a job to /// trigger loading. #[allow(clippy::too_many_arguments)] pub fn start( job_rx: mpsc::UnboundedReceiver, storage: Arc, model_path: PathBuf, queue_depth: Arc, gpu_device: u32, model_state: Arc>, model_event_tx: broadcast::Sender, webhook_registry: Arc>>, idle_timeout: Duration, gpu_poll_interval: Duration, ) -> (ProgressRegistry, std::sync::mpsc::SyncSender) { let registry: ProgressRegistry = Arc::new(dashmap::DashMap::new()); let reg_clone = Arc::clone(®istry); // Bounded sync channel: capacity 8 is plenty (load/unload are rare). let (cmd_tx, cmd_rx) = std::sync::mpsc::sync_channel::(8); let cmd_tx_clone = cmd_tx.clone(); // Capture Tokio runtime handle so the OS thread can spawn async tasks. let rt_handle = tokio::runtime::Handle::current(); std::thread::Builder::new() .name("whisper-gpu".into()) .spawn(move || { transcriber_thread( cmd_rx, model_path, gpu_device, model_state, model_event_tx, webhook_registry, idle_timeout, gpu_poll_interval, rt_handle, ); }) .expect("failed to spawn whisper-gpu thread"); tokio::spawn(run(job_rx, storage, queue_depth, reg_clone, cmd_tx_clone)); (registry, cmd_tx) } // ── GPU OS thread ───────────────────────────────────────────────────────────── /// The worker OS thread that owns the `Transcriber` (non-`Send`). /// /// Uses `recv_timeout` with a 1-second tick to drive the idle timer without a /// separate thread. #[allow(clippy::too_many_arguments)] fn transcriber_thread( rx: std::sync::mpsc::Receiver, model_path: PathBuf, gpu_device: u32, model_state: Arc>, model_event_tx: broadcast::Sender, webhook_registry: Arc>>, idle_timeout: Duration, gpu_poll_interval: Duration, rt: tokio::runtime::Handle, ) { let mut transcriber: Option = None; let mut last_job = Instant::now(); loop { match rx.recv_timeout(Duration::from_secs(1)) { Ok(WorkerCmd::Load) => { if transcriber.is_some() { tracing::debug!("WorkerCmd::Load ignored — model already loaded"); continue; } transcriber = try_load_with_polling( &rx, &model_path, gpu_device, &model_state, &model_event_tx, &webhook_registry, gpu_poll_interval, &rt, ); if transcriber.is_some() { last_job = Instant::now(); } } Ok(WorkerCmd::Unload) => { do_unload(&mut transcriber, &model_state, &model_event_tx, &webhook_registry, &rt); } Ok(WorkerCmd::Transcribe(req)) => { let t = match &mut transcriber { Some(t) => t, None => { tracing::warn!("Transcribe cmd received but model is unloaded — failing job"); let _ = req.reply.send(Err(AppError::Internal( "model unloaded before job could run".into(), ))); continue; } }; let result = t.transcribe( &req.pcm, req.language.as_deref(), &req.task, move |p| (req.on_progress)(p), ); last_job = Instant::now(); let _ = req.reply.send(result); } Err(std::sync::mpsc::RecvTimeoutError::Timeout) => { if transcriber.is_some() && last_job.elapsed() >= idle_timeout { tracing::info!( elapsed_secs = last_job.elapsed().as_secs(), "idle timeout reached — unloading model" ); do_unload( &mut transcriber, &model_state, &model_event_tx, &webhook_registry, &rt, ); } } Err(std::sync::mpsc::RecvTimeoutError::Disconnected) => { tracing::info!("worker command channel closed — shutting down GPU thread"); break; } } } } /// Attempt to load the model, polling on VRAM failures. /// /// While waiting for GPU, drains `rx` so that `WorkerCmd::Unload` cancels the /// load attempt and `WorkerCmd::Transcribe` commands get a "model not ready" /// rejection. Returns `Some(Transcriber)` on success, `None` if cancelled. #[allow(clippy::too_many_arguments)] fn try_load_with_polling( rx: &std::sync::mpsc::Receiver, model_path: &PathBuf, gpu_device: u32, model_state: &Arc>, model_event_tx: &broadcast::Sender, webhook_registry: &Arc>>, gpu_poll_interval: Duration, rt: &tokio::runtime::Handle, ) -> Option { loop { set_state(model_state, ModelState::Loading); broadcast_event(model_event_tx, ModelEvent::ModelLoading); tracing::info!("loading whisper model..."); match Transcriber::load(model_path, gpu_device) { Ok(t) => { let loaded_at = Utc::now(); set_state(model_state, ModelState::Ready { loaded_at }); broadcast_event(model_event_tx, ModelEvent::ModelReady { loaded_at }); fire_webhooks(webhook_registry, ModelEvent::ModelReady { loaded_at }, rt); tracing::info!("model loaded and ready"); return Some(t); } Err(AppError::OutOfMemory(msg)) => { let (vram_needed_mb, vram_free_mb) = parse_oom_vram(&msg, gpu_device); let retry_in_secs = gpu_poll_interval.as_secs(); tracing::warn!( vram_needed_mb, vram_free_mb, retry_in_secs, "insufficient VRAM — will retry" ); set_state(model_state, ModelState::WaitingForGpu { vram_needed_mb, vram_free_mb, retry_in_secs, }); broadcast_event(model_event_tx, ModelEvent::ModelWaitingForGpu { vram_needed_mb, vram_free_mb, retry_in_secs, }); // Interruptible sleep: drain rx while waiting for gpu_poll_interval. let deadline = Instant::now() + gpu_poll_interval; loop { let remaining = deadline.saturating_duration_since(Instant::now()); if remaining.is_zero() { break; } match rx.recv_timeout(remaining.min(Duration::from_secs(1))) { Ok(WorkerCmd::Unload) => { tracing::info!("Unload received while waiting for GPU — cancelling load"); set_state(model_state, ModelState::Unloaded); broadcast_event(model_event_tx, ModelEvent::ModelUnloaded); fire_webhooks(webhook_registry, ModelEvent::ModelUnloaded, rt); return None; } Ok(WorkerCmd::Load) => {} // idempotent Ok(WorkerCmd::Transcribe(req)) => { let _ = req.reply.send(Err(AppError::ModelNotReady { state: "waiting_for_gpu".into(), retry_after_secs: retry_in_secs, })); } Err(std::sync::mpsc::RecvTimeoutError::Timeout) => {} Err(std::sync::mpsc::RecvTimeoutError::Disconnected) => return None, } } // Loop back to retry load } Err(e) => { tracing::error!(error = %e, "model load failed with non-recoverable error"); set_state(model_state, ModelState::Unloaded); return None; } } } } fn do_unload( transcriber: &mut Option, model_state: &Arc>, model_event_tx: &broadcast::Sender, webhook_registry: &Arc>>, rt: &tokio::runtime::Handle, ) { *transcriber = None; set_state(model_state, ModelState::Unloaded); broadcast_event(model_event_tx, ModelEvent::ModelUnloaded); fire_webhooks(webhook_registry, ModelEvent::ModelUnloaded, rt); tracing::info!("model unloaded — GPU memory freed"); } // ── Helpers ─────────────────────────────────────────────────────────────────── fn set_state(arc: &Arc>, state: ModelState) { *arc.blocking_write() = state; } fn broadcast_event(tx: &broadcast::Sender, event: ModelEvent) { let _ = tx.send(event); } fn fire_webhooks( registry: &Arc>>, event: ModelEvent, rt: &tokio::runtime::Handle, ) { if !event.is_webhook_event() { return; } let urls: Vec = registry .lock() .unwrap_or_else(|e| e.into_inner()) .iter() .cloned() .collect(); if urls.is_empty() { return; } let payload = match serde_json::to_string(&event) { Ok(p) => p, Err(e) => { tracing::error!(error = %e, "failed to serialize model event"); return; } }; for url in urls { let body = payload.clone(); rt.spawn(async move { let http = Client::builder() .timeout(Duration::from_secs(10)) .build() .expect("http client"); for attempt in 0..3_u32 { match http.post(&url) .header("content-type", "application/json") .body(body.clone()) .send() .await { Ok(r) if r.status().is_success() => { tracing::debug!(url, "model event webhook delivered"); return; } Ok(r) => tracing::warn!(url, status = r.status().as_u16(), "webhook non-2xx"), Err(e) => tracing::warn!(url, error = %e, attempt, "webhook delivery failed"), } if attempt < 2 { tokio::time::sleep(Duration::from_secs(2u64.pow(attempt))).await; } } tracing::error!(url, "model event webhook failed after 3 attempts"); }); } } fn parse_oom_vram(msg: &str, gpu_device: u32) -> (u64, u64) { let needed = msg .split_whitespace() .zip(msg.split_whitespace().skip(1)) .find(|(_, next)| *next == "MiB") .and_then(|(n, _)| n.parse::().ok()) .map(|v| v as u64) .unwrap_or(0); let free = std::process::Command::new("nvidia-smi") .args([ &format!("--id={gpu_device}"), "--query-gpu=memory.free", "--format=csv,noheader,nounits", ]) .output() .ok() .and_then(|o| String::from_utf8(o.stdout).ok()) .and_then(|s| s.trim().parse::().ok()) .unwrap_or(0); (needed, free) } // ── Async job runner ────────────────────────────────────────────────────────── async fn run( mut job_rx: mpsc::UnboundedReceiver, storage: Arc, queue_depth: Arc, registry: ProgressRegistry, cmd_tx: std::sync::mpsc::SyncSender, ) { let http = Client::builder() .timeout(Duration::from_secs(30)) .build() .expect("failed to build reqwest client"); while let Some(job_id) = job_rx.recv().await { queue_depth.fetch_sub(1, Ordering::Relaxed); let mut job = match storage.get(&job_id).await { Ok(j) => j, Err(e) => { tracing::warn!(job_id = %job_id, error = %e, "job vanished before processing"); registry.remove(&job_id); continue; } }; if job.status == JobStatus::Cancelled { registry.remove(&job_id); continue; } job.status = JobStatus::Running; if let Err(e) = storage.save(&job).await { tracing::error!(job_id = %job_id, error = %e, "failed to persist running status"); } let progress_tx = registry .entry(job_id) .or_insert_with(|| broadcast::channel(64).0) .clone(); let audio_path = audio_path_for(&job_id); let result = process_job(&job, &audio_path, &progress_tx, &cmd_tx, &storage).await; let _ = tokio::fs::remove_file(&audio_path).await; match result { Ok((segments, language, duration_secs)) => { job.status = JobStatus::Done; job.segments = segments; job.language = Some(language); job.duration_secs = Some(duration_secs); job.progress = 100; job.completed_at = Some(Utc::now()); let _ = progress_tx.send(ProgressEvent::Done(Box::new(job.clone()))); } Err(e) => { let msg = e.to_string(); tracing::error!(job_id = %job_id, error = %msg, "transcription failed"); job.status = JobStatus::Failed; job.error = Some(msg.clone()); job.completed_at = Some(Utc::now()); let _ = progress_tx.send(ProgressEvent::Error(msg)); } } if let Err(e) = storage.save(&job).await { tracing::error!(job_id = %job_id, error = %e, "failed to persist final job state"); } if let Some(url) = &job.webhook_url.clone() { let http = http.clone(); let url = url.clone(); let job = job.clone(); tokio::spawn(async move { webhook::fire(&http, &url, &job).await; }); } tokio::time::sleep(Duration::from_secs(30)).await; registry.remove(&job_id); } } // ── Silence-based chunking ──────────────────────────────────────────────────── const TARGET_CHUNK_SECS: f32 = 60.0; const SNAP_WINDOW_SECS: f32 = 30.0; const SILENCE_DB: &str = "-35dB"; const SILENCE_DUR: &str = "0.4"; async fn detect_silence_midpoints(path: &std::path::Path) -> Vec { use tokio::process::Command; let filter = format!("silencedetect=n={}:d={}", SILENCE_DB, SILENCE_DUR); let output = Command::new("ffmpeg") .args([ "-nostdin", "-i", path.to_str().unwrap_or(""), "-af", &filter, "-f", "null", "-", ]) .output() .await; let output = match output { Ok(o) => o, Err(e) => { tracing::warn!(error = %e, "silencedetect unavailable; using hard cuts"); return Vec::new(); } }; let stderr = String::from_utf8_lossy(&output.stderr); let mut starts: Vec = Vec::new(); let mut ends: Vec = Vec::new(); for line in stderr.lines() { if let Some(i) = line.find("silence_start: ") { if let Ok(t) = line[i + "silence_start: ".len()..].trim().parse::() { starts.push(t); } } else if let Some(i) = line.find("silence_end: ") { let t_str = line[i + "silence_end: ".len()..] .split(" |") .next() .unwrap_or("") .trim(); if let Ok(t) = t_str.parse::() { ends.push(t); } } } let mids: Vec = starts.iter().zip(ends.iter()) .map(|(s, e)| (s + e) / 2.0) .collect(); tracing::debug!(n = mids.len(), "silence midpoints detected"); mids } fn snap_to_silence( mids: &[f32], total_secs: f32, target_secs: f32, snap_window: f32, ) -> Vec { let mut cuts: Vec = Vec::new(); let mut pos = target_secs; while pos < total_secs - target_secs * 0.25 { let prev_cut = cuts.last().copied().unwrap_or(0.0); let best = mids.iter().copied() .filter(|&t| t > prev_cut + 10.0 && (t - pos).abs() <= snap_window) .min_by(|a, b| (a - pos).abs().partial_cmp(&(b - pos).abs()).unwrap()); let cut = best.unwrap_or(pos); cuts.push(cut); pos = cut + target_secs; } cuts } fn to_chunk_ranges(cuts: &[f32], total_secs: f32) -> Vec<(f32, f32)> { let mut ranges = Vec::new(); let mut start = 0.0_f32; for &cut in cuts { if cut - start >= 5.0 { ranges.push((start, cut)); start = cut; } } if total_secs - start >= 1.0 { ranges.push((start, total_secs)); } ranges } // ── Job processing ──────────────────────────────────────────────────────────── async fn process_job( job: &Job, audio_path: &std::path::Path, progress_tx: &ProgressTx, cmd_tx: &std::sync::mpsc::SyncSender, storage: &Arc, ) -> crate::Result<(Vec, String, f32)> { let pcm = decode_audio(audio_path).await?; let total_secs = pcm.len() as f32 / 16_000.0; let silence_mids = detect_silence_midpoints(audio_path).await; let cuts = snap_to_silence(&silence_mids, total_secs, TARGET_CHUNK_SECS, SNAP_WINDOW_SECS); let chunks = to_chunk_ranges(&cuts, total_secs); let n = chunks.len(); tracing::info!( total_secs, n_chunks = n, silence_points = silence_mids.len(), "audio chunked by silence" ); let mut all_segments: Vec = Vec::new(); let mut language = String::new(); for (ci, (chunk_start, chunk_end)) in chunks.iter().enumerate() { let s0 = (*chunk_start * 16_000.0) as usize; let s1 = ((*chunk_end * 16_000.0) as usize).min(pcm.len()); let mut chunk_pcm = pcm[s0..s1].to_vec(); trim_trailing_silence(&mut chunk_pcm); let base = (ci * 100 / n) as u8; let span = (100usize / n).max(1) as u8; let _ = progress_tx.send(ProgressEvent::Progress { percent: base, chunk: ci + 1, total: n, }); let mut snapshot = job.clone(); snapshot.progress = base; if let Err(e) = storage.save(&snapshot).await { tracing::warn!(error = %e, "failed to persist mid-job progress"); } let tx = progress_tx.clone(); let chunk_num = ci + 1; let on_progress = Box::new(move |p: u8| { let overall = base.saturating_add(p.saturating_mul(span) / 100); let _ = tx.send(ProgressEvent::Progress { percent: overall, chunk: chunk_num, total: n, }); }); let (reply_tx, reply_rx) = oneshot::channel(); cmd_tx.send(WorkerCmd::Transcribe(TranscribeRequest { pcm: chunk_pcm, language: job.language.clone(), task: job.task.clone(), on_progress, reply: reply_tx, })).map_err(|_| AppError::Internal("worker command channel closed".into()))?; let (mut segs, lang) = reply_rx.await .map_err(|_| AppError::Internal("transcriber thread dropped reply".into()))??; let offset = *chunk_start; for seg in &mut segs { seg.start += offset; seg.end += offset; for word in &mut seg.words { word.start += offset; word.end += offset; } } tracing::debug!( chunk = ci + 1, of = n, start = chunk_start, end = chunk_end, segs = segs.len(), "chunk done" ); all_segments.extend(segs); if language.is_empty() { language = lang; } } for (i, seg) in all_segments.iter_mut().enumerate() { seg.index = i as i32; } let _ = progress_tx.send(ProgressEvent::Progress { percent: 100, chunk: n, total: n }); Ok((all_segments, language, total_secs)) } fn trim_trailing_silence(pcm: &mut Vec) { const THRESHOLD: f32 = 0.017_8; const PADDING: usize = 8_000; if let Some(last_loud) = pcm.iter().rposition(|&s| s.abs() > THRESHOLD) { let new_len = (last_loud + 1 + PADDING).min(pcm.len()); if new_len < pcm.len() { tracing::trace!( original_samples = pcm.len(), trimmed_samples = pcm.len() - new_len, "trimmed trailing silence" ); pcm.truncate(new_len); } } } async fn decode_audio(path: &std::path::Path) -> crate::Result> { use tokio::process::Command; let output = Command::new("ffmpeg") .args([ "-nostdin", "-threads", "0", "-i", path.to_str().unwrap_or(""), "-f", "f32le", "-ac", "1", "-ar", "16000", "-", ]) .output() .await .map_err(|e| AppError::Internal(format!("ffmpeg spawn failed: {e}")))?; if !output.status.success() { let stderr = String::from_utf8_lossy(&output.stderr); return Err(AppError::Internal(format!( "ffmpeg exited with {}: {}", output.status, stderr ))); } let bytes = output.stdout; if bytes.len() % 4 != 0 { return Err(AppError::Internal( "ffmpeg output length not a multiple of 4".into(), )); } Ok(bytes .chunks_exact(4) .map(|b| f32::from_le_bytes([b[0], b[1], b[2], b[3]])) .collect()) } pub fn audio_path_for(id: &JobId) -> PathBuf { let data_dir = std::env::var("DATA_DIR").unwrap_or_else(|_| "/data".into()); PathBuf::from(data_dir).join(format!("{id}.audio")) } // ── Unit tests ──────────────────────────────────────────────────────────────── #[cfg(test)] mod tests { use super::*; #[test] fn test_snap_to_silence_uses_nearest_midpoint() { let mids = vec![55.0, 58.0, 62.0]; let cuts = snap_to_silence(&mids, 120.0, 60.0, 30.0); assert!(!cuts.is_empty()); assert!((cuts[0] - 58.0).abs() < 0.01, "expected ~58.0, got {}", cuts[0]); } #[test] fn test_snap_to_silence_hard_cut_when_no_silence() { let cuts = snap_to_silence(&[], 120.0, 60.0, 30.0); assert_eq!(cuts, vec![60.0]); } #[test] fn test_to_chunk_ranges_single_chunk() { let ranges = to_chunk_ranges(&[], 30.0); assert_eq!(ranges, vec![(0.0, 30.0)]); } #[test] fn test_to_chunk_ranges_two_chunks() { let ranges = to_chunk_ranges(&[60.0], 120.0); assert_eq!(ranges, vec![(0.0, 60.0), (60.0, 120.0)]); } #[test] fn test_trim_trailing_silence_all_silent() { let mut pcm = vec![0.0f32; 1000]; trim_trailing_silence(&mut pcm); assert_eq!(pcm.len(), 1000); } #[test] fn test_trim_trailing_silence_trims_to_padding() { let mut pcm = vec![0.0f32; 32_000]; pcm[10_000] = 1.0; trim_trailing_silence(&mut pcm); assert_eq!(pcm.len(), (10_001 + 8_000).min(32_000)); } }