Files
whisper-rtx2080/src/worker.rs
mozempk fb8556441c
All checks were successful
Build & Push Docker Image / build-and-push (push) Successful in 6m40s
feat: silence-based audio chunking before transcription
Run ffmpeg silencedetect (n=-35dB, d=0.4s) on the original audio to
find silence midpoints. Build chunk boundaries every 180s, snapping to
the nearest silence midpoint within ±30s (fallback: hard cut).

Each chunk is transcribed independently with its own CUDA context;
timestamps are shifted by chunk_start before merging. Progress is
scaled per-chunk across the overall 0-100% job range.

Result on 101-min YouTube audio (34 chunks, 1714 silence points):
- Previous: x1025 'Yeah.' + x1008 sentence-length loops (hallucinations)
- After:    x4 max consecutive run, all repetitions verified genuine

Also refactored TranscribeRequest to carry on_progress: Box<dyn Fn(u8)>
instead of a raw ProgressTx so each chunk can independently scale its
contribution to the job's broadcast channel.

Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
2026-05-06 01:08:06 +02:00

425 lines
14 KiB
Rust
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
use std::{
path::PathBuf,
sync::{
atomic::{AtomicUsize, Ordering},
Arc,
},
};
use chrono::Utc;
use reqwest::Client;
use tokio::sync::{broadcast, mpsc, oneshot};
use crate::{
models::{Job, JobId, JobStatus, Segment},
storage::Storage,
transcriber::Transcriber,
webhook,
};
/// Per-job broadcast channel for SSE subscribers.
pub type ProgressTx = broadcast::Sender<ProgressEvent>;
#[derive(Debug, Clone)]
pub enum ProgressEvent {
Progress(u8),
Done(Box<Job>),
Error(String),
}
/// Global registry: job_id → broadcast sender.
pub type ProgressRegistry = Arc<dashmap::DashMap<JobId, ProgressTx>>;
// ── Transcription request/response types for the blocking thread ─────────────
struct TranscribeRequest {
pcm: Vec<f32>,
language: Option<String>,
task: String,
/// Per-chunk progress callback — receives 0100 from whisper.cpp and can
/// scale/offset it before forwarding to the job's broadcast channel.
on_progress: Box<dyn Fn(u8) + Send + 'static>,
reply: oneshot::Sender<crate::Result<(Vec<Segment>, String)>>,
}
/// Spawn the single GPU worker.
/// Returns the SSE progress registry.
pub fn start(
job_rx: mpsc::UnboundedReceiver<JobId>,
storage: Arc<Storage>,
model_path: PathBuf,
queue_depth: Arc<AtomicUsize>,
gpu_device: u32,
) -> ProgressRegistry {
let registry: ProgressRegistry = Arc::new(dashmap::DashMap::new());
let reg_clone = Arc::clone(&registry);
let (tx_req, rx_req) = std::sync::mpsc::channel::<TranscribeRequest>();
std::thread::Builder::new()
.name("whisper-gpu".into())
.spawn(move || transcriber_thread(rx_req, model_path, gpu_device))
.expect("failed to spawn whisper-gpu thread");
tokio::spawn(run(job_rx, storage, queue_depth, reg_clone, tx_req));
registry
}
/// Dedicated OS thread that owns the Transcriber (non-Send) and runs inference.
fn transcriber_thread(
rx: std::sync::mpsc::Receiver<TranscribeRequest>,
model_path: PathBuf,
gpu_device: u32,
) {
let transcriber = match Transcriber::load(&model_path, gpu_device) {
Ok(t) => t,
Err(e) => {
tracing::error!(error = %e, "failed to load whisper model — transcriber thread exiting");
return;
}
};
tracing::info!(model = %model_path.display(), "GPU worker ready");
for req in rx {
let on_progress = req.on_progress;
let result = transcriber.transcribe(
&req.pcm,
req.language.as_deref(),
&req.task,
move |p| on_progress(p),
);
let _ = req.reply.send(result);
}
}
pub(crate) async fn run(
mut job_rx: mpsc::UnboundedReceiver<JobId>,
storage: Arc<Storage>,
queue_depth: Arc<AtomicUsize>,
registry: ProgressRegistry,
tx_req: std::sync::mpsc::Sender<TranscribeRequest>,
) {
let http = Client::builder()
.timeout(std::time::Duration::from_secs(30))
.build()
.expect("failed to build reqwest client");
while let Some(job_id) = job_rx.recv().await {
queue_depth.fetch_sub(1, Ordering::Relaxed);
let mut job = match storage.get(&job_id).await {
Ok(j) => j,
Err(e) => {
tracing::warn!(job_id = %job_id, error = %e, "job vanished before processing");
registry.remove(&job_id);
continue;
}
};
if job.status == JobStatus::Cancelled {
registry.remove(&job_id);
continue;
}
job.status = JobStatus::Running;
if let Err(e) = storage.save(&job).await {
tracing::error!(job_id = %job_id, error = %e, "failed to persist running status");
}
let progress_tx = registry
.entry(job_id)
.or_insert_with(|| broadcast::channel(64).0)
.clone();
let audio_path = audio_path_for(&job_id);
let result = process_job(&job, &audio_path, &progress_tx, &tx_req).await;
let _ = tokio::fs::remove_file(&audio_path).await;
match result {
Ok((segments, language, duration_secs)) => {
job.status = JobStatus::Done;
job.segments = segments;
job.language = Some(language);
job.duration_secs = Some(duration_secs);
job.progress = 100;
job.completed_at = Some(Utc::now());
let _ = progress_tx.send(ProgressEvent::Done(Box::new(job.clone())));
}
Err(e) => {
let msg = e.to_string();
tracing::error!(job_id = %job_id, error = %msg, "transcription failed");
job.status = JobStatus::Failed;
job.error = Some(msg.clone());
job.completed_at = Some(Utc::now());
let _ = progress_tx.send(ProgressEvent::Error(msg));
}
}
if let Err(e) = storage.save(&job).await {
tracing::error!(job_id = %job_id, error = %e, "failed to persist final job state");
}
if let Some(url) = &job.webhook_url.clone() {
let http = http.clone();
let url = url.clone();
let job = job.clone();
tokio::spawn(async move { webhook::fire(&http, &url, &job).await; });
}
tokio::time::sleep(std::time::Duration::from_secs(30)).await;
registry.remove(&job_id);
}
}
// ── Silence-based chunking ────────────────────────────────────────────────────
/// Target chunk length. Smaller = safer (less hallucination budget per chunk).
const TARGET_CHUNK_SECS: f32 = 180.0;
/// How far from the target we'll snap to a silence midpoint.
const SNAP_WINDOW_SECS: f32 = 30.0;
/// Silence below this level (dB) counts as a split candidate.
const SILENCE_DB: &str = "-35dB";
/// Minimum silence duration to register as a candidate split.
const SILENCE_DUR: &str = "0.4";
/// Detect silence periods and return the midpoint (seconds) of each.
/// On any error (ffmpeg missing, binary format, etc.) returns an empty vec
/// so the caller can fall back to hard cuts.
async fn detect_silence_midpoints(path: &std::path::Path) -> Vec<f32> {
use tokio::process::Command;
let filter = format!("silencedetect=n={}:d={}", SILENCE_DB, SILENCE_DUR);
let output = Command::new("ffmpeg")
.args([
"-nostdin",
"-i", path.to_str().unwrap_or(""),
"-af", &filter,
"-f", "null", "-",
])
.output()
.await;
let output = match output {
Ok(o) => o,
Err(e) => {
tracing::warn!(error = %e, "silencedetect unavailable; using hard cuts");
return Vec::new();
}
};
// silencedetect logs to stderr
let stderr = String::from_utf8_lossy(&output.stderr);
let mut starts: Vec<f32> = Vec::new();
let mut ends: Vec<f32> = Vec::new();
for line in stderr.lines() {
if let Some(i) = line.find("silence_start: ") {
if let Ok(t) = line[i + "silence_start: ".len()..].trim().parse::<f32>() {
starts.push(t);
}
} else if let Some(i) = line.find("silence_end: ") {
// Format: "silence_end: 12.34 | silence_duration: 0.56"
let t_str = line[i + "silence_end: ".len()..]
.split(" |")
.next()
.unwrap_or("")
.trim();
if let Ok(t) = t_str.parse::<f32>() {
ends.push(t);
}
}
}
let mids: Vec<f32> = starts.iter().zip(ends.iter())
.map(|(s, e)| (s + e) / 2.0)
.collect();
tracing::debug!(n = mids.len(), "silence midpoints detected");
mids
}
/// Build cut points every `target_secs`, snapping to the nearest silence
/// midpoint within `snap_window` when one exists; otherwise a hard cut.
/// Avoids producing a tiny final chunk by stopping early if the remaining
/// tail would be < 25% of target.
fn snap_to_silence(
mids: &[f32],
total_secs: f32,
target_secs: f32,
snap_window: f32,
) -> Vec<f32> {
let mut cuts: Vec<f32> = Vec::new();
let mut pos = target_secs;
while pos < total_secs - target_secs * 0.25 {
let prev_cut = cuts.last().copied().unwrap_or(0.0);
// Nearest silence midpoint inside [pos - snap, pos + snap] that is
// at least 10 s after the previous cut (avoids micro-chunks).
let best = mids.iter().copied()
.filter(|&t| t > prev_cut + 10.0 && (t - pos).abs() <= snap_window)
.min_by(|a, b| (a - pos).abs().partial_cmp(&(b - pos).abs()).unwrap());
let cut = best.unwrap_or(pos);
cuts.push(cut);
pos = cut + target_secs;
}
cuts
}
/// Convert cut points into (start_secs, end_secs) chunk pairs.
fn to_chunk_ranges(cuts: &[f32], total_secs: f32) -> Vec<(f32, f32)> {
let mut ranges = Vec::new();
let mut start = 0.0_f32;
for &cut in cuts {
if cut - start >= 5.0 {
ranges.push((start, cut));
start = cut;
}
}
// Last chunk
if total_secs - start >= 1.0 {
ranges.push((start, total_secs));
}
ranges
}
// ── Job processing ────────────────────────────────────────────────────────────
async fn process_job(
job: &Job,
audio_path: &std::path::Path,
progress_tx: &ProgressTx,
tx_req: &std::sync::mpsc::Sender<TranscribeRequest>,
) -> crate::Result<(Vec<Segment>, String, f32)> {
// 1. Decode full audio to 16 kHz mono PCM.
let pcm = decode_audio(audio_path).await?;
let total_secs = pcm.len() as f32 / 16_000.0;
// 2. Detect silence from the original file (fast amplitude scan).
let silence_mids = detect_silence_midpoints(audio_path).await;
// 3. Build silence-snapped chunk boundaries.
let cuts = snap_to_silence(&silence_mids, total_secs, TARGET_CHUNK_SECS, SNAP_WINDOW_SECS);
let chunks = to_chunk_ranges(&cuts, total_secs);
let n = chunks.len();
tracing::info!(
total_secs,
n_chunks = n,
silence_points = silence_mids.len(),
"audio chunked by silence"
);
// 4. Transcribe each chunk, applying a time offset to all timestamps.
let mut all_segments: Vec<Segment> = Vec::new();
let mut language = String::new();
for (ci, (chunk_start, chunk_end)) in chunks.iter().enumerate() {
let s0 = (*chunk_start * 16_000.0) as usize;
let s1 = ((*chunk_end * 16_000.0) as usize).min(pcm.len());
let chunk_pcm = pcm[s0..s1].to_vec();
// Scale chunk's 0-100 progress into the job's 0-100 range.
let base = (ci * 100 / n) as u8;
let span = (100usize / n).max(1) as u8;
let tx = progress_tx.clone();
let on_progress = Box::new(move |p: u8| {
let overall = base.saturating_add(p.saturating_mul(span) / 100);
let _ = tx.send(ProgressEvent::Progress(overall));
});
let (reply_tx, reply_rx) = oneshot::channel();
tx_req.send(TranscribeRequest {
pcm: chunk_pcm,
language: job.language.clone(),
task: job.task.clone(),
on_progress,
reply: reply_tx,
}).map_err(|_| crate::AppError::Internal("transcriber thread gone".into()))?;
let (mut segs, lang) = reply_rx.await
.map_err(|_| crate::AppError::Internal("transcriber thread dropped reply".into()))??;
// Shift all timestamps by chunk offset.
let offset = *chunk_start;
for seg in &mut segs {
seg.start += offset;
seg.end += offset;
for word in &mut seg.words {
word.start += offset;
word.end += offset;
}
}
tracing::debug!(
chunk = ci + 1,
of = n,
start = chunk_start,
end = chunk_end,
segs = segs.len(),
"chunk done"
);
all_segments.extend(segs);
if language.is_empty() {
language = lang;
}
}
// Renumber segment indices across the merged output.
for (i, seg) in all_segments.iter_mut().enumerate() {
seg.index = i as i32;
}
let _ = progress_tx.send(ProgressEvent::Progress(100));
Ok((all_segments, language, total_secs))
}
/// Decode any audio file to 16 kHz mono PCM f32 using ffmpeg.
async fn decode_audio(path: &std::path::Path) -> crate::Result<Vec<f32>> {
use tokio::process::Command;
let output = Command::new("ffmpeg")
.args([
"-nostdin", "-threads", "0",
"-i", path.to_str().unwrap_or(""),
"-f", "f32le",
"-ac", "1",
"-ar", "16000",
"-",
])
.output()
.await
.map_err(|e| crate::AppError::Internal(format!("ffmpeg spawn failed: {e}")))?;
if !output.status.success() {
let stderr = String::from_utf8_lossy(&output.stderr);
return Err(crate::AppError::Internal(format!(
"ffmpeg exited with {}: {}",
output.status, stderr
)));
}
let bytes = output.stdout;
if bytes.len() % 4 != 0 {
return Err(crate::AppError::Internal(
"ffmpeg output length not a multiple of 4".into(),
));
}
Ok(bytes
.chunks_exact(4)
.map(|b| f32::from_le_bytes([b[0], b[1], b[2], b[3]]))
.collect())
}
pub fn audio_path_for(id: &JobId) -> PathBuf {
let data_dir = std::env::var("DATA_DIR").unwrap_or_else(|_| "/data".into());
PathBuf::from(data_dir).join(format!("{id}.audio"))
}