use std::{ path::PathBuf, sync::{ atomic::{AtomicUsize, Ordering}, Arc, }, }; use chrono::Utc; use reqwest::Client; use tokio::sync::{broadcast, mpsc, oneshot}; use crate::{ models::{Job, JobId, JobStatus, Segment}, storage::Storage, transcriber::Transcriber, webhook, }; /// Per-job broadcast channel for SSE subscribers. pub type ProgressTx = broadcast::Sender; #[derive(Debug, Clone)] pub enum ProgressEvent { Progress(u8), Done(Box), Error(String), } /// Global registry: job_id → broadcast sender. pub type ProgressRegistry = Arc>; // ── Transcription request/response types for the blocking thread ───────────── struct TranscribeRequest { pcm: Vec, language: Option, task: String, /// Per-chunk progress callback — receives 0–100 from whisper.cpp and can /// scale/offset it before forwarding to the job's broadcast channel. on_progress: Box, reply: oneshot::Sender, String)>>, } /// Spawn the single GPU worker. /// Returns the SSE progress registry. pub fn start( job_rx: mpsc::UnboundedReceiver, storage: Arc, model_path: PathBuf, queue_depth: Arc, gpu_device: u32, ) -> ProgressRegistry { let registry: ProgressRegistry = Arc::new(dashmap::DashMap::new()); let reg_clone = Arc::clone(®istry); let (tx_req, rx_req) = std::sync::mpsc::channel::(); std::thread::Builder::new() .name("whisper-gpu".into()) .spawn(move || transcriber_thread(rx_req, model_path, gpu_device)) .expect("failed to spawn whisper-gpu thread"); tokio::spawn(run(job_rx, storage, queue_depth, reg_clone, tx_req)); registry } /// Dedicated OS thread that owns the Transcriber (non-Send) and runs inference. fn transcriber_thread( rx: std::sync::mpsc::Receiver, model_path: PathBuf, gpu_device: u32, ) { let transcriber = match Transcriber::load(&model_path, gpu_device) { Ok(t) => t, Err(e) => { tracing::error!(error = %e, "failed to load whisper model — transcriber thread exiting"); return; } }; tracing::info!(model = %model_path.display(), "GPU worker ready"); for req in rx { let on_progress = req.on_progress; let result = transcriber.transcribe( &req.pcm, req.language.as_deref(), &req.task, move |p| on_progress(p), ); let _ = req.reply.send(result); } } pub(crate) async fn run( mut job_rx: mpsc::UnboundedReceiver, storage: Arc, queue_depth: Arc, registry: ProgressRegistry, tx_req: std::sync::mpsc::Sender, ) { let http = Client::builder() .timeout(std::time::Duration::from_secs(30)) .build() .expect("failed to build reqwest client"); while let Some(job_id) = job_rx.recv().await { queue_depth.fetch_sub(1, Ordering::Relaxed); let mut job = match storage.get(&job_id).await { Ok(j) => j, Err(e) => { tracing::warn!(job_id = %job_id, error = %e, "job vanished before processing"); registry.remove(&job_id); continue; } }; if job.status == JobStatus::Cancelled { registry.remove(&job_id); continue; } job.status = JobStatus::Running; if let Err(e) = storage.save(&job).await { tracing::error!(job_id = %job_id, error = %e, "failed to persist running status"); } let progress_tx = registry .entry(job_id) .or_insert_with(|| broadcast::channel(64).0) .clone(); let audio_path = audio_path_for(&job_id); let result = process_job(&job, &audio_path, &progress_tx, &tx_req).await; let _ = tokio::fs::remove_file(&audio_path).await; match result { Ok((segments, language, duration_secs)) => { job.status = JobStatus::Done; job.segments = segments; job.language = Some(language); job.duration_secs = Some(duration_secs); job.progress = 100; job.completed_at = Some(Utc::now()); let _ = progress_tx.send(ProgressEvent::Done(Box::new(job.clone()))); } Err(e) => { let msg = e.to_string(); tracing::error!(job_id = %job_id, error = %msg, "transcription failed"); job.status = JobStatus::Failed; job.error = Some(msg.clone()); job.completed_at = Some(Utc::now()); let _ = progress_tx.send(ProgressEvent::Error(msg)); } } if let Err(e) = storage.save(&job).await { tracing::error!(job_id = %job_id, error = %e, "failed to persist final job state"); } if let Some(url) = &job.webhook_url.clone() { let http = http.clone(); let url = url.clone(); let job = job.clone(); tokio::spawn(async move { webhook::fire(&http, &url, &job).await; }); } tokio::time::sleep(std::time::Duration::from_secs(30)).await; registry.remove(&job_id); } } // ── Silence-based chunking ──────────────────────────────────────────────────── /// Target chunk length. Smaller = safer (less hallucination budget per chunk). const TARGET_CHUNK_SECS: f32 = 180.0; /// How far from the target we'll snap to a silence midpoint. const SNAP_WINDOW_SECS: f32 = 30.0; /// Silence below this level (dB) counts as a split candidate. const SILENCE_DB: &str = "-35dB"; /// Minimum silence duration to register as a candidate split. const SILENCE_DUR: &str = "0.4"; /// Detect silence periods and return the midpoint (seconds) of each. /// On any error (ffmpeg missing, binary format, etc.) returns an empty vec /// so the caller can fall back to hard cuts. async fn detect_silence_midpoints(path: &std::path::Path) -> Vec { use tokio::process::Command; let filter = format!("silencedetect=n={}:d={}", SILENCE_DB, SILENCE_DUR); let output = Command::new("ffmpeg") .args([ "-nostdin", "-i", path.to_str().unwrap_or(""), "-af", &filter, "-f", "null", "-", ]) .output() .await; let output = match output { Ok(o) => o, Err(e) => { tracing::warn!(error = %e, "silencedetect unavailable; using hard cuts"); return Vec::new(); } }; // silencedetect logs to stderr let stderr = String::from_utf8_lossy(&output.stderr); let mut starts: Vec = Vec::new(); let mut ends: Vec = Vec::new(); for line in stderr.lines() { if let Some(i) = line.find("silence_start: ") { if let Ok(t) = line[i + "silence_start: ".len()..].trim().parse::() { starts.push(t); } } else if let Some(i) = line.find("silence_end: ") { // Format: "silence_end: 12.34 | silence_duration: 0.56" let t_str = line[i + "silence_end: ".len()..] .split(" |") .next() .unwrap_or("") .trim(); if let Ok(t) = t_str.parse::() { ends.push(t); } } } let mids: Vec = starts.iter().zip(ends.iter()) .map(|(s, e)| (s + e) / 2.0) .collect(); tracing::debug!(n = mids.len(), "silence midpoints detected"); mids } /// Build cut points every `target_secs`, snapping to the nearest silence /// midpoint within `snap_window` when one exists; otherwise a hard cut. /// Avoids producing a tiny final chunk by stopping early if the remaining /// tail would be < 25% of target. fn snap_to_silence( mids: &[f32], total_secs: f32, target_secs: f32, snap_window: f32, ) -> Vec { let mut cuts: Vec = Vec::new(); let mut pos = target_secs; while pos < total_secs - target_secs * 0.25 { let prev_cut = cuts.last().copied().unwrap_or(0.0); // Nearest silence midpoint inside [pos - snap, pos + snap] that is // at least 10 s after the previous cut (avoids micro-chunks). let best = mids.iter().copied() .filter(|&t| t > prev_cut + 10.0 && (t - pos).abs() <= snap_window) .min_by(|a, b| (a - pos).abs().partial_cmp(&(b - pos).abs()).unwrap()); let cut = best.unwrap_or(pos); cuts.push(cut); pos = cut + target_secs; } cuts } /// Convert cut points into (start_secs, end_secs) chunk pairs. fn to_chunk_ranges(cuts: &[f32], total_secs: f32) -> Vec<(f32, f32)> { let mut ranges = Vec::new(); let mut start = 0.0_f32; for &cut in cuts { if cut - start >= 5.0 { ranges.push((start, cut)); start = cut; } } // Last chunk if total_secs - start >= 1.0 { ranges.push((start, total_secs)); } ranges } // ── Job processing ──────────────────────────────────────────────────────────── async fn process_job( job: &Job, audio_path: &std::path::Path, progress_tx: &ProgressTx, tx_req: &std::sync::mpsc::Sender, ) -> crate::Result<(Vec, String, f32)> { // 1. Decode full audio to 16 kHz mono PCM. let pcm = decode_audio(audio_path).await?; let total_secs = pcm.len() as f32 / 16_000.0; // 2. Detect silence from the original file (fast amplitude scan). let silence_mids = detect_silence_midpoints(audio_path).await; // 3. Build silence-snapped chunk boundaries. let cuts = snap_to_silence(&silence_mids, total_secs, TARGET_CHUNK_SECS, SNAP_WINDOW_SECS); let chunks = to_chunk_ranges(&cuts, total_secs); let n = chunks.len(); tracing::info!( total_secs, n_chunks = n, silence_points = silence_mids.len(), "audio chunked by silence" ); // 4. Transcribe each chunk, applying a time offset to all timestamps. let mut all_segments: Vec = Vec::new(); let mut language = String::new(); for (ci, (chunk_start, chunk_end)) in chunks.iter().enumerate() { let s0 = (*chunk_start * 16_000.0) as usize; let s1 = ((*chunk_end * 16_000.0) as usize).min(pcm.len()); let chunk_pcm = pcm[s0..s1].to_vec(); // Scale chunk's 0-100 progress into the job's 0-100 range. let base = (ci * 100 / n) as u8; let span = (100usize / n).max(1) as u8; let tx = progress_tx.clone(); let on_progress = Box::new(move |p: u8| { let overall = base.saturating_add(p.saturating_mul(span) / 100); let _ = tx.send(ProgressEvent::Progress(overall)); }); let (reply_tx, reply_rx) = oneshot::channel(); tx_req.send(TranscribeRequest { pcm: chunk_pcm, language: job.language.clone(), task: job.task.clone(), on_progress, reply: reply_tx, }).map_err(|_| crate::AppError::Internal("transcriber thread gone".into()))?; let (mut segs, lang) = reply_rx.await .map_err(|_| crate::AppError::Internal("transcriber thread dropped reply".into()))??; // Shift all timestamps by chunk offset. let offset = *chunk_start; for seg in &mut segs { seg.start += offset; seg.end += offset; for word in &mut seg.words { word.start += offset; word.end += offset; } } tracing::debug!( chunk = ci + 1, of = n, start = chunk_start, end = chunk_end, segs = segs.len(), "chunk done" ); all_segments.extend(segs); if language.is_empty() { language = lang; } } // Renumber segment indices across the merged output. for (i, seg) in all_segments.iter_mut().enumerate() { seg.index = i as i32; } let _ = progress_tx.send(ProgressEvent::Progress(100)); Ok((all_segments, language, total_secs)) } /// Decode any audio file to 16 kHz mono PCM f32 using ffmpeg. async fn decode_audio(path: &std::path::Path) -> crate::Result> { use tokio::process::Command; let output = Command::new("ffmpeg") .args([ "-nostdin", "-threads", "0", "-i", path.to_str().unwrap_or(""), "-f", "f32le", "-ac", "1", "-ar", "16000", "-", ]) .output() .await .map_err(|e| crate::AppError::Internal(format!("ffmpeg spawn failed: {e}")))?; if !output.status.success() { let stderr = String::from_utf8_lossy(&output.stderr); return Err(crate::AppError::Internal(format!( "ffmpeg exited with {}: {}", output.status, stderr ))); } let bytes = output.stdout; if bytes.len() % 4 != 0 { return Err(crate::AppError::Internal( "ffmpeg output length not a multiple of 4".into(), )); } Ok(bytes .chunks_exact(4) .map(|b| f32::from_le_bytes([b[0], b[1], b[2], b[3]])) .collect()) } pub fn audio_path_for(id: &JobId) -> PathBuf { let data_dir = std::env::var("DATA_DIR").unwrap_or_else(|_| "/data".into()); PathBuf::from(data_dir).join(format!("{id}.audio")) }