Files
whisper-rtx2080/src/worker.rs
Giancarmine Salucci d8a73e150a
All checks were successful
Build & Push Docker Image / test (push) Successful in 6m2s
Build & Push Docker Image / build-and-push (push) Successful in 6m31s
fix(worker): port final segment cleanup
Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
2026-05-12 00:10:32 +02:00

1288 lines
41 KiB
Rust
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
use std::{
collections::HashSet,
path::PathBuf,
sync::{
atomic::{AtomicUsize, Ordering},
Arc, Mutex,
},
time::{Duration, Instant},
};
use chrono::Utc;
use reqwest::Client;
use tokio::sync::{broadcast, mpsc, oneshot, RwLock};
use crate::{
models::{Job, JobId, JobStatus, ModelEvent, ModelState, Segment},
storage::Storage,
transcriber::Transcriber,
webhook, AppError,
};
/// Per-job broadcast channel for SSE subscribers.
pub type ProgressTx = broadcast::Sender<ProgressEvent>;
#[derive(Debug, Clone)]
pub enum ProgressEvent {
/// `percent` — overall 0100; `chunk` — 1-based; `total` — total chunks.
Progress {
percent: u8,
chunk: usize,
total: usize,
},
Done(Box<Job>),
Error(String),
}
/// Global registry: job_id → broadcast sender.
pub type ProgressRegistry = Arc<dashmap::DashMap<JobId, ProgressTx>>;
// ── Worker command channel ────────────────────────────────────────────────────
/// Commands sent to the GPU worker OS thread.
#[derive(Debug)]
pub enum WorkerCmd {
/// Request a model load. Idempotent: if already loading/ready, ignored.
Load,
/// Unload the model immediately and free GPU memory.
Unload,
/// Internal: run a transcription chunk.
Transcribe(TranscribeRequest),
}
// ── Transcription request/response types ─────────────────────────────────────
pub struct TranscribeRequest {
pub pcm: Vec<f32>,
pub language: Option<String>,
pub task: String,
pub on_progress: Box<dyn Fn(u8) + Send + 'static>,
pub reply: oneshot::Sender<crate::Result<(Vec<Segment>, String)>>,
}
impl std::fmt::Debug for TranscribeRequest {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("TranscribeRequest")
.field("language", &self.language)
.field("task", &self.task)
.finish_non_exhaustive()
}
}
// ── Public API ────────────────────────────────────────────────────────────────
/// Spawn the single GPU worker.
///
/// Returns the SSE progress registry and a command sender for the worker thread.
/// The model starts **unloaded**; send `WorkerCmd::Load` or submit a job to
/// trigger loading.
#[allow(clippy::too_many_arguments)]
pub fn start(
job_rx: mpsc::UnboundedReceiver<JobId>,
storage: Arc<Storage>,
model_path: PathBuf,
queue_depth: Arc<AtomicUsize>,
gpu_device: u32,
model_state: Arc<RwLock<ModelState>>,
model_event_tx: broadcast::Sender<ModelEvent>,
webhook_registry: Arc<Mutex<HashSet<String>>>,
idle_timeout: Duration,
gpu_poll_interval: Duration,
) -> (ProgressRegistry, std::sync::mpsc::SyncSender<WorkerCmd>) {
let registry: ProgressRegistry = Arc::new(dashmap::DashMap::new());
let reg_clone = Arc::clone(&registry);
// Bounded sync channel: capacity 8 is plenty (load/unload are rare).
let (cmd_tx, cmd_rx) = std::sync::mpsc::sync_channel::<WorkerCmd>(8);
let cmd_tx_clone = cmd_tx.clone();
// Capture Tokio runtime handle so the OS thread can spawn async tasks.
let rt_handle = tokio::runtime::Handle::current();
std::thread::Builder::new()
.name("whisper-gpu".into())
.spawn(move || {
transcriber_thread(
cmd_rx,
model_path,
gpu_device,
model_state,
model_event_tx,
webhook_registry,
idle_timeout,
gpu_poll_interval,
rt_handle,
);
})
.expect("failed to spawn whisper-gpu thread");
tokio::spawn(run(job_rx, storage, queue_depth, reg_clone, cmd_tx_clone));
(registry, cmd_tx)
}
// ── GPU OS thread ─────────────────────────────────────────────────────────────
/// The worker OS thread that owns the `Transcriber` (non-`Send`).
///
/// Uses `recv_timeout` with a 1-second tick to drive the idle timer without a
/// separate thread.
#[allow(clippy::too_many_arguments)]
fn transcriber_thread(
rx: std::sync::mpsc::Receiver<WorkerCmd>,
model_path: PathBuf,
gpu_device: u32,
model_state: Arc<RwLock<ModelState>>,
model_event_tx: broadcast::Sender<ModelEvent>,
webhook_registry: Arc<Mutex<HashSet<String>>>,
idle_timeout: Duration,
gpu_poll_interval: Duration,
rt: tokio::runtime::Handle,
) {
let mut transcriber: Option<Transcriber> = None;
let mut last_job = Instant::now();
loop {
match rx.recv_timeout(Duration::from_secs(1)) {
Ok(WorkerCmd::Load) => {
if transcriber.is_some() {
tracing::debug!("WorkerCmd::Load ignored — model already loaded");
continue;
}
transcriber = try_load_with_polling(
&rx,
&model_path,
gpu_device,
&model_state,
&model_event_tx,
&webhook_registry,
gpu_poll_interval,
&rt,
);
if transcriber.is_some() {
last_job = Instant::now();
}
}
Ok(WorkerCmd::Unload) => {
do_unload(
&mut transcriber,
&model_state,
&model_event_tx,
&webhook_registry,
&rt,
);
}
Ok(WorkerCmd::Transcribe(req)) => {
let t = match &mut transcriber {
Some(t) => t,
None => {
tracing::warn!(
"Transcribe cmd received but model is unloaded — failing job"
);
let _ = req.reply.send(Err(AppError::Internal(
"model unloaded before job could run".into(),
)));
continue;
}
};
let result = t.transcribe(&req.pcm, req.language.as_deref(), &req.task, move |p| {
(req.on_progress)(p)
});
last_job = Instant::now();
let _ = req.reply.send(result);
}
Err(std::sync::mpsc::RecvTimeoutError::Timeout) => {
if transcriber.is_some() && last_job.elapsed() >= idle_timeout {
tracing::info!(
elapsed_secs = last_job.elapsed().as_secs(),
"idle timeout reached — unloading model"
);
do_unload(
&mut transcriber,
&model_state,
&model_event_tx,
&webhook_registry,
&rt,
);
}
}
Err(std::sync::mpsc::RecvTimeoutError::Disconnected) => {
tracing::info!("worker command channel closed — shutting down GPU thread");
break;
}
}
}
}
/// Attempt to load the model, polling on VRAM failures.
///
/// While waiting for GPU, drains `rx` so that `WorkerCmd::Unload` cancels the
/// load attempt and `WorkerCmd::Transcribe` commands get a "model not ready"
/// rejection. Returns `Some(Transcriber)` on success, `None` if cancelled.
#[allow(clippy::too_many_arguments)]
fn try_load_with_polling(
rx: &std::sync::mpsc::Receiver<WorkerCmd>,
model_path: &PathBuf,
gpu_device: u32,
model_state: &Arc<RwLock<ModelState>>,
model_event_tx: &broadcast::Sender<ModelEvent>,
webhook_registry: &Arc<Mutex<HashSet<String>>>,
gpu_poll_interval: Duration,
rt: &tokio::runtime::Handle,
) -> Option<Transcriber> {
loop {
set_state(model_state, ModelState::Loading);
broadcast_event(model_event_tx, ModelEvent::ModelLoading);
tracing::info!("loading whisper model...");
match Transcriber::load(model_path, gpu_device) {
Ok(t) => {
let loaded_at = Utc::now();
set_state(model_state, ModelState::Ready { loaded_at });
broadcast_event(model_event_tx, ModelEvent::ModelReady { loaded_at });
fire_webhooks(webhook_registry, ModelEvent::ModelReady { loaded_at }, rt);
tracing::info!("model loaded and ready");
return Some(t);
}
Err(AppError::OutOfMemory(msg)) => {
let (vram_needed_mb, vram_free_mb) = parse_oom_vram(&msg, gpu_device);
let retry_in_secs = gpu_poll_interval.as_secs();
tracing::warn!(
vram_needed_mb,
vram_free_mb,
retry_in_secs,
"insufficient VRAM — will retry"
);
set_state(
model_state,
ModelState::WaitingForGpu {
vram_needed_mb,
vram_free_mb,
retry_in_secs,
},
);
broadcast_event(
model_event_tx,
ModelEvent::ModelWaitingForGpu {
vram_needed_mb,
vram_free_mb,
retry_in_secs,
},
);
// Interruptible sleep: drain rx while waiting for gpu_poll_interval.
let deadline = Instant::now() + gpu_poll_interval;
loop {
let remaining = deadline.saturating_duration_since(Instant::now());
if remaining.is_zero() {
break;
}
match rx.recv_timeout(remaining.min(Duration::from_secs(1))) {
Ok(WorkerCmd::Unload) => {
tracing::info!(
"Unload received while waiting for GPU — cancelling load"
);
set_state(model_state, ModelState::Unloaded);
broadcast_event(model_event_tx, ModelEvent::ModelUnloaded);
fire_webhooks(webhook_registry, ModelEvent::ModelUnloaded, rt);
return None;
}
Ok(WorkerCmd::Load) => {} // idempotent
Ok(WorkerCmd::Transcribe(req)) => {
let _ = req.reply.send(Err(AppError::ModelNotReady {
state: "waiting_for_gpu".into(),
retry_after_secs: retry_in_secs,
}));
}
Err(std::sync::mpsc::RecvTimeoutError::Timeout) => {}
Err(std::sync::mpsc::RecvTimeoutError::Disconnected) => return None,
}
}
// Loop back to retry load
}
Err(e) => {
tracing::error!(error = %e, "model load failed with non-recoverable error");
set_state(model_state, ModelState::Unloaded);
broadcast_event(model_event_tx, ModelEvent::ModelUnloaded);
fire_webhooks(webhook_registry, ModelEvent::ModelUnloaded, rt);
return None;
}
}
}
}
fn do_unload(
transcriber: &mut Option<Transcriber>,
model_state: &Arc<RwLock<ModelState>>,
model_event_tx: &broadcast::Sender<ModelEvent>,
webhook_registry: &Arc<Mutex<HashSet<String>>>,
rt: &tokio::runtime::Handle,
) {
*transcriber = None;
set_state(model_state, ModelState::Unloaded);
broadcast_event(model_event_tx, ModelEvent::ModelUnloaded);
fire_webhooks(webhook_registry, ModelEvent::ModelUnloaded, rt);
tracing::info!("model unloaded — GPU memory freed");
}
// ── Helpers ───────────────────────────────────────────────────────────────────
fn set_state(arc: &Arc<RwLock<ModelState>>, state: ModelState) {
*arc.blocking_write() = state;
}
fn broadcast_event(tx: &broadcast::Sender<ModelEvent>, event: ModelEvent) {
let _ = tx.send(event);
}
fn fire_webhooks(
registry: &Arc<Mutex<HashSet<String>>>,
event: ModelEvent,
rt: &tokio::runtime::Handle,
) {
if !event.is_webhook_event() {
return;
}
let urls: Vec<String> = registry
.lock()
.unwrap_or_else(|e| e.into_inner())
.iter()
.cloned()
.collect();
if urls.is_empty() {
return;
}
let payload = match serde_json::to_string(&event) {
Ok(p) => p,
Err(e) => {
tracing::error!(error = %e, "failed to serialize model event");
return;
}
};
for url in urls {
let body = payload.clone();
rt.spawn(async move {
let http = Client::builder()
.timeout(Duration::from_secs(10))
.build()
.expect("http client");
for attempt in 0..3_u32 {
match http
.post(&url)
.header("content-type", "application/json")
.body(body.clone())
.send()
.await
{
Ok(r) if r.status().is_success() => {
tracing::debug!(url, "model event webhook delivered");
return;
}
Ok(r) => tracing::warn!(url, status = r.status().as_u16(), "webhook non-2xx"),
Err(e) => tracing::warn!(url, error = %e, attempt, "webhook delivery failed"),
}
if attempt < 2 {
tokio::time::sleep(Duration::from_secs(2u64.pow(attempt))).await;
}
}
tracing::error!(url, "model event webhook failed after 3 attempts");
});
}
}
fn parse_oom_vram(msg: &str, gpu_device: u32) -> (u64, u64) {
let needed = msg
.split_whitespace()
.zip(msg.split_whitespace().skip(1))
.find(|(_, next)| *next == "MiB")
.and_then(|(n, _)| n.parse::<f64>().ok())
.map(|v| v as u64)
.unwrap_or(0);
let free = std::process::Command::new("nvidia-smi")
.args([
&format!("--id={gpu_device}"),
"--query-gpu=memory.free",
"--format=csv,noheader,nounits",
])
.output()
.ok()
.and_then(|o| String::from_utf8(o.stdout).ok())
.and_then(|s| s.trim().parse::<u64>().ok())
.unwrap_or(0);
(needed, free)
}
// ── Async job runner ──────────────────────────────────────────────────────────
async fn run(
mut job_rx: mpsc::UnboundedReceiver<JobId>,
storage: Arc<Storage>,
queue_depth: Arc<AtomicUsize>,
registry: ProgressRegistry,
cmd_tx: std::sync::mpsc::SyncSender<WorkerCmd>,
) {
let http = Client::builder()
.timeout(Duration::from_secs(30))
.build()
.expect("failed to build reqwest client");
while let Some(job_id) = job_rx.recv().await {
queue_depth.fetch_sub(1, Ordering::Relaxed);
let mut job = match storage.get(&job_id).await {
Ok(j) => j,
Err(e) => {
tracing::warn!(job_id = %job_id, error = %e, "job vanished before processing");
registry.remove(&job_id);
continue;
}
};
if job.status == JobStatus::Cancelled {
let _ = tokio::fs::remove_file(&audio_path_for(&job_id)).await;
registry.remove(&job_id);
continue;
}
job.status = JobStatus::Running;
if let Err(e) = storage.save(&job).await {
tracing::error!(job_id = %job_id, error = %e, "failed to persist running status");
}
let progress_tx = registry
.entry(job_id)
.or_insert_with(|| broadcast::channel(64).0)
.clone();
let audio_path = audio_path_for(&job_id);
let result = process_job(&job, &audio_path, &progress_tx, &cmd_tx, &storage).await;
let _ = tokio::fs::remove_file(&audio_path).await;
// Re-read from storage: the job may have been cancelled via DELETE /jobs/:id
// while process_job() was running. If so, discard the result entirely.
let current_status = storage.get(&job_id).await.map(|j| j.status).ok();
if current_status == Some(JobStatus::Cancelled) {
tracing::info!(job_id = %job_id, "job cancelled during inference — discarding result");
registry.remove(&job_id);
continue;
}
match result {
Ok((segments, language, duration_secs)) => {
job.status = JobStatus::Done;
job.segments = segments;
job.language = Some(language);
job.duration_secs = Some(duration_secs);
job.progress = 100;
job.completed_at = Some(Utc::now());
let _ = progress_tx.send(ProgressEvent::Done(Box::new(job.clone())));
}
Err(e) => {
let msg = e.to_string();
tracing::error!(job_id = %job_id, error = %msg, "transcription failed");
job.status = JobStatus::Failed;
job.error = Some(msg.clone());
job.completed_at = Some(Utc::now());
let _ = progress_tx.send(ProgressEvent::Error(msg));
}
}
if let Err(e) = storage.save(&job).await {
tracing::error!(job_id = %job_id, error = %e, "failed to persist final job state");
}
if let Some(url) = &job.webhook_url.clone() {
let http = http.clone();
let url = url.clone();
let job = job.clone();
tokio::spawn(async move {
webhook::fire(&http, &url, &job).await;
});
}
tokio::time::sleep(Duration::from_secs(30)).await;
registry.remove(&job_id);
}
}
// ── Silence-based chunking ────────────────────────────────────────────────────
const TARGET_CHUNK_SECS: f32 = 60.0;
const SNAP_WINDOW_SECS: f32 = 30.0;
const SILENCE_DB: &str = "-35dB";
const SILENCE_DUR: &str = "0.4";
async fn detect_silence_midpoints(path: &std::path::Path) -> Vec<f32> {
use tokio::process::Command;
let filter = format!("silencedetect=n={}:d={}", SILENCE_DB, SILENCE_DUR);
let output = Command::new("ffmpeg")
.args([
"-nostdin",
"-i",
path.to_str().unwrap_or(""),
"-af",
&filter,
"-f",
"null",
"-",
])
.output()
.await;
let output = match output {
Ok(o) => o,
Err(e) => {
tracing::warn!(error = %e, "silencedetect unavailable; using hard cuts");
return Vec::new();
}
};
let stderr = String::from_utf8_lossy(&output.stderr);
let mut starts: Vec<f32> = Vec::new();
let mut ends: Vec<f32> = Vec::new();
for line in stderr.lines() {
if let Some(i) = line.find("silence_start: ") {
if let Ok(t) = line[i + "silence_start: ".len()..].trim().parse::<f32>() {
starts.push(t);
}
} else if let Some(i) = line.find("silence_end: ") {
let t_str = line[i + "silence_end: ".len()..]
.split(" |")
.next()
.unwrap_or("")
.trim();
if let Ok(t) = t_str.parse::<f32>() {
ends.push(t);
}
}
}
let mids: Vec<f32> = starts
.iter()
.zip(ends.iter())
.map(|(s, e)| (s + e) / 2.0)
.collect();
tracing::debug!(n = mids.len(), "silence midpoints detected");
mids
}
fn snap_to_silence(mids: &[f32], total_secs: f32, target_secs: f32, snap_window: f32) -> Vec<f32> {
let mut cuts: Vec<f32> = Vec::new();
let mut pos = target_secs;
while pos < total_secs - target_secs * 0.25 {
let prev_cut = cuts.last().copied().unwrap_or(0.0);
let best = mids
.iter()
.copied()
.filter(|&t| t > prev_cut + 10.0 && (t - pos).abs() <= snap_window)
.min_by(|a, b| (a - pos).abs().partial_cmp(&(b - pos).abs()).unwrap());
let cut = best.unwrap_or(pos);
cuts.push(cut);
pos = cut + target_secs;
}
cuts
}
fn to_chunk_ranges(cuts: &[f32], total_secs: f32) -> Vec<(f32, f32)> {
let mut ranges = Vec::new();
let mut start = 0.0_f32;
for &cut in cuts {
if cut - start >= 5.0 {
ranges.push((start, cut));
start = cut;
}
}
if total_secs - start >= 1.0 {
ranges.push((start, total_secs));
}
ranges
}
const MAX_CHAIN_GAP_SECS: f32 = 0.15;
const MIN_MEANINGFUL_WORDS: usize = 2;
const MIN_MEANINGFUL_CHARS: usize = 8;
const MIN_OVERLAP_WORDS: usize = 1;
const SHORT_CARRYOVER_MAX_SECS: f32 = 0.2;
const SHORT_CARRYOVER_MAX_WORDS: usize = 2;
const SHORT_CARRYOVER_MAX_CHARS: usize = 16;
const NGRAM_N: usize = 6;
const LOOKBACK_CHARS: usize = 500;
const SIMILARITY_THRESHOLD: f32 = 0.6;
fn split_words(text: &str) -> Vec<&str> {
text.split_whitespace()
.filter(|word| !word.is_empty())
.collect()
}
fn normalise_token(word: &str) -> String {
word.chars()
.filter(|ch| ch.is_alphanumeric() || *ch == '_')
.flat_map(|ch| ch.to_lowercase())
.collect()
}
fn normalised_words(text: &str) -> Vec<String> {
split_words(text)
.into_iter()
.map(normalise_token)
.filter(|word| !word.is_empty())
.collect()
}
fn collapse_repeated_phrase_once(text: &str) -> String {
let raw_words = split_words(text);
if raw_words.len() < 4 {
return text.trim().to_string();
}
let normalised: Vec<String> = raw_words.iter().map(|word| normalise_token(word)).collect();
for size in (2..=raw_words.len() / 2).rev() {
for start in 0..=raw_words.len().saturating_sub(size * 2) {
let phrase_chars = raw_words[start..start + size]
.iter()
.map(|word| word.len())
.sum::<usize>()
+ size.saturating_sub(1);
if phrase_chars < 10 {
continue;
}
if normalised[start..start + size] == normalised[start + size..start + size * 2] {
let mut collapsed = Vec::with_capacity(raw_words.len() - size);
collapsed.extend_from_slice(&raw_words[..start + size]);
collapsed.extend_from_slice(&raw_words[start + size * 2..]);
return collapsed.join(" ").trim().to_string();
}
}
}
text.trim().to_string()
}
fn collapse_repeats(text: &str) -> String {
let mut current = text.trim().to_string();
loop {
let next = collapse_repeated_phrase_once(&current);
if next == current {
return next;
}
current = next;
}
}
fn starts_with_words(full: &[String], prefix: &[String]) -> bool {
prefix.len() <= full.len() && full.iter().take(prefix.len()).eq(prefix.iter())
}
fn ends_with_words(full: &[String], suffix: &[String]) -> bool {
suffix.len() <= full.len()
&& full
.iter()
.skip(full.len() - suffix.len())
.eq(suffix.iter())
}
fn suffix_prefix_overlap(left: &[String], right: &[String]) -> usize {
let max = left.len().min(right.len());
for size in (1..=max).rev() {
if left[left.len() - size..] == right[..size] {
return size;
}
}
0
}
fn is_meaningful_phrase(words: &[String]) -> bool {
words.len() >= MIN_MEANINGFUL_WORDS
&& words.iter().map(|word| word.len()).sum::<usize>() >= MIN_MEANINGFUL_CHARS
}
fn is_short_carryover(seg: &Segment, words: &[String]) -> bool {
seg.end - seg.start <= SHORT_CARRYOVER_MAX_SECS
|| words.len() <= SHORT_CARRYOVER_MAX_WORDS
|| words.iter().map(|word| word.len()).sum::<usize>() + words.len().saturating_sub(1)
<= SHORT_CARRYOVER_MAX_CHARS
}
fn trim_leading_words(text: &str, count: usize) -> String {
split_words(text)
.into_iter()
.skip(count)
.collect::<Vec<_>>()
.join(" ")
.trim()
.to_string()
}
fn merge_identical_segments(segments: Vec<Segment>) -> Vec<Segment> {
let mut out: Vec<Segment> = Vec::with_capacity(segments.len());
for seg in segments {
if let Some(last) = out.last_mut() {
if normalised_words(&last.text) == normalised_words(&seg.text) {
last.end = last.end.max(seg.end);
if !seg.words.is_empty() {
last.words = seg.words;
}
continue;
}
}
out.push(seg);
}
out
}
fn collapse_incremental_segments(segments: Vec<Segment>) -> Vec<Segment> {
let mut out: Vec<Segment> = Vec::with_capacity(segments.len());
for mut seg in segments {
seg.text = seg.text.trim().to_string();
if seg.text.is_empty() {
continue;
}
let Some(last) = out.last_mut() else {
out.push(seg);
continue;
};
let gap = seg.start - last.end;
if gap > MAX_CHAIN_GAP_SECS {
out.push(seg);
continue;
}
let last_words = normalised_words(&last.text);
let seg_words = normalised_words(&seg.text);
if last_words.is_empty() || seg_words.is_empty() {
out.push(seg);
continue;
}
if seg_words.len() > last_words.len()
&& starts_with_words(&seg_words, &last_words)
&& (is_meaningful_phrase(&last_words) || is_short_carryover(last, &last_words))
{
last.text = seg.text;
last.end = seg.end;
last.words = seg.words;
continue;
}
if ends_with_words(&last_words, &seg_words)
&& (is_meaningful_phrase(&seg_words) || is_short_carryover(&seg, &seg_words))
{
last.end = last.end.max(seg.end);
continue;
}
let overlap = suffix_prefix_overlap(&last_words, &seg_words);
if overlap >= MIN_OVERLAP_WORDS {
let trimmed_text = trim_leading_words(&seg.text, overlap);
if trimmed_text.is_empty() {
last.end = last.end.max(seg.end);
continue;
}
seg.start = seg.start.max(last.end);
seg.text = trimmed_text;
seg.words.clear();
}
out.push(seg);
}
out
}
fn ngrams(text: &str, n: usize) -> HashSet<String> {
let words = text
.to_lowercase()
.split_whitespace()
.map(str::to_string)
.collect::<Vec<_>>();
if words.len() < n {
return HashSet::new();
}
let mut grams = HashSet::new();
for idx in 0..=words.len() - n {
grams.insert(words[idx..idx + n].join(" "));
}
grams
}
fn jaccard_similarity(left: &str, right: &str) -> f32 {
let left_grams = ngrams(left, NGRAM_N);
let right_grams = ngrams(right, NGRAM_N);
if left_grams.is_empty() && right_grams.is_empty() {
return 0.0;
}
let intersection = left_grams.intersection(&right_grams).count();
let union = left_grams.union(&right_grams).count();
if union == 0 {
0.0
} else {
intersection as f32 / union as f32
}
}
fn tail_chars(text: &str, limit: usize) -> String {
let chars = text.chars().collect::<Vec<_>>();
let start = chars.len().saturating_sub(limit);
chars[start..].iter().collect()
}
fn ngram_dedup(segments: Vec<Segment>) -> Vec<Segment> {
let mut out = Vec::with_capacity(segments.len());
for seg in segments {
let window_text = out
.iter()
.skip(out.len().saturating_sub(20))
.map(|segment: &Segment| segment.text.as_str())
.collect::<Vec<_>>()
.join(" ");
let recent_context = tail_chars(&window_text, LOOKBACK_CHARS);
if !recent_context.is_empty()
&& jaccard_similarity(&seg.text, &recent_context) >= SIMILARITY_THRESHOLD
{
continue;
}
out.push(seg);
}
out
}
fn normalise_segments(segments: Vec<Segment>) -> Vec<Segment> {
let mut result = segments
.into_iter()
.map(|mut seg| {
seg.text = collapse_repeats(seg.text.trim());
seg
})
.filter(|seg| !seg.text.is_empty())
.collect::<Vec<_>>();
result = collapse_incremental_segments(result);
result = merge_identical_segments(result);
result = ngram_dedup(result);
result = collapse_incremental_segments(result);
merge_identical_segments(result)
}
// ── Job processing ────────────────────────────────────────────────────────────
async fn process_job(
job: &Job,
audio_path: &std::path::Path,
progress_tx: &ProgressTx,
cmd_tx: &std::sync::mpsc::SyncSender<WorkerCmd>,
storage: &Arc<Storage>,
) -> crate::Result<(Vec<Segment>, String, f32)> {
let pcm = decode_audio(audio_path).await?;
let total_secs = pcm.len() as f32 / 16_000.0;
let silence_mids = detect_silence_midpoints(audio_path).await;
let cuts = snap_to_silence(
&silence_mids,
total_secs,
TARGET_CHUNK_SECS,
SNAP_WINDOW_SECS,
);
let chunks = to_chunk_ranges(&cuts, total_secs);
let n = chunks.len();
tracing::info!(
total_secs,
n_chunks = n,
silence_points = silence_mids.len(),
"audio chunked by silence"
);
let mut all_segments: Vec<Segment> = Vec::new();
let mut language = String::new();
for (ci, (chunk_start, chunk_end)) in chunks.iter().enumerate() {
let s0 = (*chunk_start * 16_000.0) as usize;
let s1 = ((*chunk_end * 16_000.0) as usize).min(pcm.len());
let mut chunk_pcm = pcm[s0..s1].to_vec();
trim_trailing_silence(&mut chunk_pcm);
let base = (ci * 100 / n) as u8;
let span = (100usize / n).max(1) as u8;
// Save progress to disk before emitting SSE — polling clients who respond
// immediately to the SSE event will then see consistent state.
let mut snapshot = job.clone();
snapshot.progress = base;
if let Err(e) = storage.save(&snapshot).await {
tracing::warn!(error = %e, "failed to persist mid-job progress");
}
let _ = progress_tx.send(ProgressEvent::Progress {
percent: base,
chunk: ci + 1,
total: n,
});
let tx = progress_tx.clone();
let chunk_num = ci + 1;
let on_progress = Box::new(move |p: u8| {
let overall = base.saturating_add(p.saturating_mul(span) / 100);
let _ = tx.send(ProgressEvent::Progress {
percent: overall,
chunk: chunk_num,
total: n,
});
});
let (reply_tx, reply_rx) = oneshot::channel();
cmd_tx
.send(WorkerCmd::Transcribe(TranscribeRequest {
pcm: chunk_pcm,
language: job.language.clone(),
task: job.task.clone(),
on_progress,
reply: reply_tx,
}))
.map_err(|_| AppError::Internal("worker command channel closed".into()))?;
let (mut segs, lang) = reply_rx
.await
.map_err(|_| AppError::Internal("transcriber thread dropped reply".into()))??;
let offset = *chunk_start;
for seg in &mut segs {
seg.start += offset;
seg.end += offset;
for word in &mut seg.words {
word.start += offset;
word.end += offset;
}
}
tracing::debug!(
chunk = ci + 1,
of = n,
start = chunk_start,
end = chunk_end,
segs = segs.len(),
"chunk done"
);
all_segments.extend(segs);
if language.is_empty() {
language = lang;
}
}
all_segments = normalise_segments(all_segments);
for (i, seg) in all_segments.iter_mut().enumerate() {
seg.index = i as i32;
}
let _ = progress_tx.send(ProgressEvent::Progress {
percent: 100,
chunk: n,
total: n,
});
Ok((all_segments, language, total_secs))
}
fn trim_trailing_silence(pcm: &mut Vec<f32>) {
const THRESHOLD: f32 = 0.017_8;
const PADDING: usize = 8_000;
if let Some(last_loud) = pcm.iter().rposition(|&s| s.abs() > THRESHOLD) {
let new_len = (last_loud + 1 + PADDING).min(pcm.len());
if new_len < pcm.len() {
tracing::trace!(
original_samples = pcm.len(),
trimmed_samples = pcm.len() - new_len,
"trimmed trailing silence"
);
pcm.truncate(new_len);
}
}
}
async fn decode_audio(path: &std::path::Path) -> crate::Result<Vec<f32>> {
use tokio::process::Command;
let output = Command::new("ffmpeg")
.args([
"-nostdin",
"-threads",
"0",
"-i",
path.to_str().unwrap_or(""),
"-f",
"f32le",
"-ac",
"1",
"-ar",
"16000",
"-",
])
.output()
.await
.map_err(|e| AppError::Internal(format!("ffmpeg spawn failed: {e}")))?;
if !output.status.success() {
let stderr = String::from_utf8_lossy(&output.stderr);
return Err(AppError::Internal(format!(
"ffmpeg exited with {}: {}",
output.status, stderr
)));
}
let bytes = output.stdout;
if bytes.len() % 4 != 0 {
return Err(AppError::Internal(
"ffmpeg output length not a multiple of 4".into(),
));
}
Ok(bytes
.chunks_exact(4)
.map(|b| f32::from_le_bytes([b[0], b[1], b[2], b[3]]))
.collect())
}
pub fn audio_path_for(id: &JobId) -> PathBuf {
let data_dir = std::env::var("DATA_DIR").unwrap_or_else(|_| "/data".into());
PathBuf::from(data_dir).join(format!("{id}.audio"))
}
// ── Unit tests ────────────────────────────────────────────────────────────────
#[cfg(test)]
mod tests {
use super::*;
use crate::models::Word;
fn segment(index: i32, start: f32, end: f32, text: &str) -> Segment {
Segment {
index,
start,
end,
text: text.into(),
words: Vec::<Word>::new(),
}
}
#[test]
fn test_snap_to_silence_uses_nearest_midpoint() {
let mids = vec![55.0, 58.0, 62.0];
let cuts = snap_to_silence(&mids, 120.0, 60.0, 30.0);
assert!(!cuts.is_empty());
assert!(
(cuts[0] - 58.0).abs() < 0.01,
"expected ~58.0, got {}",
cuts[0]
);
}
#[test]
fn test_snap_to_silence_hard_cut_when_no_silence() {
let cuts = snap_to_silence(&[], 120.0, 60.0, 30.0);
assert_eq!(cuts, vec![60.0]);
}
#[test]
fn test_to_chunk_ranges_single_chunk() {
let ranges = to_chunk_ranges(&[], 30.0);
assert_eq!(ranges, vec![(0.0, 30.0)]);
}
#[test]
fn test_to_chunk_ranges_two_chunks() {
let ranges = to_chunk_ranges(&[60.0], 120.0);
assert_eq!(ranges, vec![(0.0, 60.0), (60.0, 120.0)]);
}
#[test]
fn test_trim_trailing_silence_all_silent() {
let mut pcm = vec![0.0f32; 1000];
trim_trailing_silence(&mut pcm);
assert_eq!(pcm.len(), 1000);
}
#[test]
fn test_trim_trailing_silence_trims_to_padding() {
let mut pcm = vec![0.0f32; 32_000];
pcm[10_000] = 1.0;
trim_trailing_silence(&mut pcm);
assert_eq!(pcm.len(), (10_001 + 8_000).min(32_000));
}
#[test]
fn test_normalise_segments_collapses_prefix_growth_chain() {
let input = vec![
segment(0, 15.24, 16.6, "Hello everyone."),
segment(1, 16.6, 19.47, "Hello everyone. Um, welcome to this talk."),
segment(2, 19.47, 19.48, "Um, welcome to this talk."),
segment(
3,
19.48,
21.67,
"Um, welcome to this talk. I'll be speaking about small model",
),
segment(4, 21.67, 21.68, "I'll be speaking about small model"),
segment(
5,
21.68,
24.59,
"I'll be speaking about small model inference and a gap that we've",
),
];
let result = normalise_segments(input);
assert_eq!(result.len(), 2);
assert_eq!(result[0].text, "Hello everyone. Um, welcome to this talk.");
assert!((result[0].start - 15.24).abs() < 0.01);
assert!((result[0].end - 19.48).abs() < 0.01);
assert_eq!(
result[1].text,
"I'll be speaking about small model inference and a gap that we've"
);
assert!((result[1].start - 19.48).abs() < 0.01);
assert!((result[1].end - 24.59).abs() < 0.01);
}
#[test]
fn test_normalise_segments_collapses_repeated_phrase_inside_segment() {
let input = vec![segment(
0,
0.0,
5.0,
"the quick brown fox the quick brown fox jumps over the fence",
)];
let result = normalise_segments(input);
assert_eq!(result.len(), 1);
assert_eq!(result[0].text, "the quick brown fox jumps over the fence");
}
#[test]
fn test_normalise_segments_keeps_real_gap() {
let input = vec![
segment(0, 0.0, 1.0, "Hello everyone."),
segment(1, 2.0, 4.0, "Hello everyone. Welcome back."),
];
let result = normalise_segments(input);
assert_eq!(result.len(), 2);
assert_eq!(result[0].text, "Hello everyone.");
assert_eq!(result[1].text, "Hello everyone. Welcome back.");
}
#[test]
fn test_normalise_segments_collapses_tiny_carry_over_segments() {
let input = vec![
segment(0, 94.8, 96.4, "world."),
segment(
1,
96.4,
98.96,
"world. And that aspect that I overlooked was",
),
segment(2, 98.96, 100.72, "inference."),
segment(
3,
100.72,
103.92,
"inference. So, as someone who kind of wants to",
),
segment(4, 107.19, 107.2, "and"),
segment(
5,
107.2,
109.56,
"and work to understand the problems and the",
),
];
let result = normalise_segments(input);
assert_eq!(result.len(), 3);
assert_eq!(
result[0].text,
"world. And that aspect that I overlooked was"
);
assert_eq!(
result[1].text,
"inference. So, as someone who kind of wants to"
);
assert_eq!(
result[2].text,
"and work to understand the problems and the"
);
}
#[test]
fn test_normalise_segments_trims_single_word_adjacent_overlap() {
let input = vec![
segment(0, 94.8, 96.4, "world."),
segment(
1,
96.4,
98.96,
"world. And that aspect that I overlooked was",
),
segment(2, 120.12, 123.71, "to find more about inference."),
segment(
3,
123.72,
126.92,
"inference. So, I've done a lot of work with VLAM,",
),
];
let result = normalise_segments(input);
assert_eq!(result.len(), 3);
assert_eq!(
result[0].text,
"world. And that aspect that I overlooked was"
);
assert_eq!(result[2].text, "So, I've done a lot of work with VLAM,");
}
}