1288 lines
41 KiB
Rust
1288 lines
41 KiB
Rust
use std::{
|
||
collections::HashSet,
|
||
path::PathBuf,
|
||
sync::{
|
||
atomic::{AtomicUsize, Ordering},
|
||
Arc, Mutex,
|
||
},
|
||
time::{Duration, Instant},
|
||
};
|
||
|
||
use chrono::Utc;
|
||
use reqwest::Client;
|
||
use tokio::sync::{broadcast, mpsc, oneshot, RwLock};
|
||
|
||
use crate::{
|
||
models::{Job, JobId, JobStatus, ModelEvent, ModelState, Segment},
|
||
storage::Storage,
|
||
transcriber::Transcriber,
|
||
webhook, AppError,
|
||
};
|
||
|
||
/// Per-job broadcast channel for SSE subscribers.
|
||
pub type ProgressTx = broadcast::Sender<ProgressEvent>;
|
||
|
||
#[derive(Debug, Clone)]
|
||
pub enum ProgressEvent {
|
||
/// `percent` — overall 0–100; `chunk` — 1-based; `total` — total chunks.
|
||
Progress {
|
||
percent: u8,
|
||
chunk: usize,
|
||
total: usize,
|
||
},
|
||
Done(Box<Job>),
|
||
Error(String),
|
||
}
|
||
|
||
/// Global registry: job_id → broadcast sender.
|
||
pub type ProgressRegistry = Arc<dashmap::DashMap<JobId, ProgressTx>>;
|
||
|
||
// ── Worker command channel ────────────────────────────────────────────────────
|
||
|
||
/// Commands sent to the GPU worker OS thread.
|
||
#[derive(Debug)]
|
||
pub enum WorkerCmd {
|
||
/// Request a model load. Idempotent: if already loading/ready, ignored.
|
||
Load,
|
||
/// Unload the model immediately and free GPU memory.
|
||
Unload,
|
||
/// Internal: run a transcription chunk.
|
||
Transcribe(TranscribeRequest),
|
||
}
|
||
|
||
// ── Transcription request/response types ─────────────────────────────────────
|
||
|
||
pub struct TranscribeRequest {
|
||
pub pcm: Vec<f32>,
|
||
pub language: Option<String>,
|
||
pub task: String,
|
||
pub on_progress: Box<dyn Fn(u8) + Send + 'static>,
|
||
pub reply: oneshot::Sender<crate::Result<(Vec<Segment>, String)>>,
|
||
}
|
||
|
||
impl std::fmt::Debug for TranscribeRequest {
|
||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||
f.debug_struct("TranscribeRequest")
|
||
.field("language", &self.language)
|
||
.field("task", &self.task)
|
||
.finish_non_exhaustive()
|
||
}
|
||
}
|
||
|
||
// ── Public API ────────────────────────────────────────────────────────────────
|
||
|
||
/// Spawn the single GPU worker.
|
||
///
|
||
/// Returns the SSE progress registry and a command sender for the worker thread.
|
||
/// The model starts **unloaded**; send `WorkerCmd::Load` or submit a job to
|
||
/// trigger loading.
|
||
#[allow(clippy::too_many_arguments)]
|
||
pub fn start(
|
||
job_rx: mpsc::UnboundedReceiver<JobId>,
|
||
storage: Arc<Storage>,
|
||
model_path: PathBuf,
|
||
queue_depth: Arc<AtomicUsize>,
|
||
gpu_device: u32,
|
||
model_state: Arc<RwLock<ModelState>>,
|
||
model_event_tx: broadcast::Sender<ModelEvent>,
|
||
webhook_registry: Arc<Mutex<HashSet<String>>>,
|
||
idle_timeout: Duration,
|
||
gpu_poll_interval: Duration,
|
||
) -> (ProgressRegistry, std::sync::mpsc::SyncSender<WorkerCmd>) {
|
||
let registry: ProgressRegistry = Arc::new(dashmap::DashMap::new());
|
||
let reg_clone = Arc::clone(®istry);
|
||
|
||
// Bounded sync channel: capacity 8 is plenty (load/unload are rare).
|
||
let (cmd_tx, cmd_rx) = std::sync::mpsc::sync_channel::<WorkerCmd>(8);
|
||
let cmd_tx_clone = cmd_tx.clone();
|
||
|
||
// Capture Tokio runtime handle so the OS thread can spawn async tasks.
|
||
let rt_handle = tokio::runtime::Handle::current();
|
||
|
||
std::thread::Builder::new()
|
||
.name("whisper-gpu".into())
|
||
.spawn(move || {
|
||
transcriber_thread(
|
||
cmd_rx,
|
||
model_path,
|
||
gpu_device,
|
||
model_state,
|
||
model_event_tx,
|
||
webhook_registry,
|
||
idle_timeout,
|
||
gpu_poll_interval,
|
||
rt_handle,
|
||
);
|
||
})
|
||
.expect("failed to spawn whisper-gpu thread");
|
||
|
||
tokio::spawn(run(job_rx, storage, queue_depth, reg_clone, cmd_tx_clone));
|
||
|
||
(registry, cmd_tx)
|
||
}
|
||
|
||
// ── GPU OS thread ─────────────────────────────────────────────────────────────
|
||
|
||
/// The worker OS thread that owns the `Transcriber` (non-`Send`).
|
||
///
|
||
/// Uses `recv_timeout` with a 1-second tick to drive the idle timer without a
|
||
/// separate thread.
|
||
#[allow(clippy::too_many_arguments)]
|
||
fn transcriber_thread(
|
||
rx: std::sync::mpsc::Receiver<WorkerCmd>,
|
||
model_path: PathBuf,
|
||
gpu_device: u32,
|
||
model_state: Arc<RwLock<ModelState>>,
|
||
model_event_tx: broadcast::Sender<ModelEvent>,
|
||
webhook_registry: Arc<Mutex<HashSet<String>>>,
|
||
idle_timeout: Duration,
|
||
gpu_poll_interval: Duration,
|
||
rt: tokio::runtime::Handle,
|
||
) {
|
||
let mut transcriber: Option<Transcriber> = None;
|
||
let mut last_job = Instant::now();
|
||
|
||
loop {
|
||
match rx.recv_timeout(Duration::from_secs(1)) {
|
||
Ok(WorkerCmd::Load) => {
|
||
if transcriber.is_some() {
|
||
tracing::debug!("WorkerCmd::Load ignored — model already loaded");
|
||
continue;
|
||
}
|
||
transcriber = try_load_with_polling(
|
||
&rx,
|
||
&model_path,
|
||
gpu_device,
|
||
&model_state,
|
||
&model_event_tx,
|
||
&webhook_registry,
|
||
gpu_poll_interval,
|
||
&rt,
|
||
);
|
||
if transcriber.is_some() {
|
||
last_job = Instant::now();
|
||
}
|
||
}
|
||
|
||
Ok(WorkerCmd::Unload) => {
|
||
do_unload(
|
||
&mut transcriber,
|
||
&model_state,
|
||
&model_event_tx,
|
||
&webhook_registry,
|
||
&rt,
|
||
);
|
||
}
|
||
|
||
Ok(WorkerCmd::Transcribe(req)) => {
|
||
let t = match &mut transcriber {
|
||
Some(t) => t,
|
||
None => {
|
||
tracing::warn!(
|
||
"Transcribe cmd received but model is unloaded — failing job"
|
||
);
|
||
let _ = req.reply.send(Err(AppError::Internal(
|
||
"model unloaded before job could run".into(),
|
||
)));
|
||
continue;
|
||
}
|
||
};
|
||
|
||
let result = t.transcribe(&req.pcm, req.language.as_deref(), &req.task, move |p| {
|
||
(req.on_progress)(p)
|
||
});
|
||
last_job = Instant::now();
|
||
let _ = req.reply.send(result);
|
||
}
|
||
|
||
Err(std::sync::mpsc::RecvTimeoutError::Timeout) => {
|
||
if transcriber.is_some() && last_job.elapsed() >= idle_timeout {
|
||
tracing::info!(
|
||
elapsed_secs = last_job.elapsed().as_secs(),
|
||
"idle timeout reached — unloading model"
|
||
);
|
||
do_unload(
|
||
&mut transcriber,
|
||
&model_state,
|
||
&model_event_tx,
|
||
&webhook_registry,
|
||
&rt,
|
||
);
|
||
}
|
||
}
|
||
|
||
Err(std::sync::mpsc::RecvTimeoutError::Disconnected) => {
|
||
tracing::info!("worker command channel closed — shutting down GPU thread");
|
||
break;
|
||
}
|
||
}
|
||
}
|
||
}
|
||
|
||
/// Attempt to load the model, polling on VRAM failures.
|
||
///
|
||
/// While waiting for GPU, drains `rx` so that `WorkerCmd::Unload` cancels the
|
||
/// load attempt and `WorkerCmd::Transcribe` commands get a "model not ready"
|
||
/// rejection. Returns `Some(Transcriber)` on success, `None` if cancelled.
|
||
#[allow(clippy::too_many_arguments)]
|
||
fn try_load_with_polling(
|
||
rx: &std::sync::mpsc::Receiver<WorkerCmd>,
|
||
model_path: &PathBuf,
|
||
gpu_device: u32,
|
||
model_state: &Arc<RwLock<ModelState>>,
|
||
model_event_tx: &broadcast::Sender<ModelEvent>,
|
||
webhook_registry: &Arc<Mutex<HashSet<String>>>,
|
||
gpu_poll_interval: Duration,
|
||
rt: &tokio::runtime::Handle,
|
||
) -> Option<Transcriber> {
|
||
loop {
|
||
set_state(model_state, ModelState::Loading);
|
||
broadcast_event(model_event_tx, ModelEvent::ModelLoading);
|
||
tracing::info!("loading whisper model...");
|
||
|
||
match Transcriber::load(model_path, gpu_device) {
|
||
Ok(t) => {
|
||
let loaded_at = Utc::now();
|
||
set_state(model_state, ModelState::Ready { loaded_at });
|
||
broadcast_event(model_event_tx, ModelEvent::ModelReady { loaded_at });
|
||
fire_webhooks(webhook_registry, ModelEvent::ModelReady { loaded_at }, rt);
|
||
tracing::info!("model loaded and ready");
|
||
return Some(t);
|
||
}
|
||
|
||
Err(AppError::OutOfMemory(msg)) => {
|
||
let (vram_needed_mb, vram_free_mb) = parse_oom_vram(&msg, gpu_device);
|
||
let retry_in_secs = gpu_poll_interval.as_secs();
|
||
|
||
tracing::warn!(
|
||
vram_needed_mb,
|
||
vram_free_mb,
|
||
retry_in_secs,
|
||
"insufficient VRAM — will retry"
|
||
);
|
||
|
||
set_state(
|
||
model_state,
|
||
ModelState::WaitingForGpu {
|
||
vram_needed_mb,
|
||
vram_free_mb,
|
||
retry_in_secs,
|
||
},
|
||
);
|
||
broadcast_event(
|
||
model_event_tx,
|
||
ModelEvent::ModelWaitingForGpu {
|
||
vram_needed_mb,
|
||
vram_free_mb,
|
||
retry_in_secs,
|
||
},
|
||
);
|
||
|
||
// Interruptible sleep: drain rx while waiting for gpu_poll_interval.
|
||
let deadline = Instant::now() + gpu_poll_interval;
|
||
loop {
|
||
let remaining = deadline.saturating_duration_since(Instant::now());
|
||
if remaining.is_zero() {
|
||
break;
|
||
}
|
||
match rx.recv_timeout(remaining.min(Duration::from_secs(1))) {
|
||
Ok(WorkerCmd::Unload) => {
|
||
tracing::info!(
|
||
"Unload received while waiting for GPU — cancelling load"
|
||
);
|
||
set_state(model_state, ModelState::Unloaded);
|
||
broadcast_event(model_event_tx, ModelEvent::ModelUnloaded);
|
||
fire_webhooks(webhook_registry, ModelEvent::ModelUnloaded, rt);
|
||
return None;
|
||
}
|
||
Ok(WorkerCmd::Load) => {} // idempotent
|
||
Ok(WorkerCmd::Transcribe(req)) => {
|
||
let _ = req.reply.send(Err(AppError::ModelNotReady {
|
||
state: "waiting_for_gpu".into(),
|
||
retry_after_secs: retry_in_secs,
|
||
}));
|
||
}
|
||
Err(std::sync::mpsc::RecvTimeoutError::Timeout) => {}
|
||
Err(std::sync::mpsc::RecvTimeoutError::Disconnected) => return None,
|
||
}
|
||
}
|
||
// Loop back to retry load
|
||
}
|
||
|
||
Err(e) => {
|
||
tracing::error!(error = %e, "model load failed with non-recoverable error");
|
||
set_state(model_state, ModelState::Unloaded);
|
||
broadcast_event(model_event_tx, ModelEvent::ModelUnloaded);
|
||
fire_webhooks(webhook_registry, ModelEvent::ModelUnloaded, rt);
|
||
return None;
|
||
}
|
||
}
|
||
}
|
||
}
|
||
|
||
fn do_unload(
|
||
transcriber: &mut Option<Transcriber>,
|
||
model_state: &Arc<RwLock<ModelState>>,
|
||
model_event_tx: &broadcast::Sender<ModelEvent>,
|
||
webhook_registry: &Arc<Mutex<HashSet<String>>>,
|
||
rt: &tokio::runtime::Handle,
|
||
) {
|
||
*transcriber = None;
|
||
set_state(model_state, ModelState::Unloaded);
|
||
broadcast_event(model_event_tx, ModelEvent::ModelUnloaded);
|
||
fire_webhooks(webhook_registry, ModelEvent::ModelUnloaded, rt);
|
||
tracing::info!("model unloaded — GPU memory freed");
|
||
}
|
||
|
||
// ── Helpers ───────────────────────────────────────────────────────────────────
|
||
|
||
fn set_state(arc: &Arc<RwLock<ModelState>>, state: ModelState) {
|
||
*arc.blocking_write() = state;
|
||
}
|
||
|
||
fn broadcast_event(tx: &broadcast::Sender<ModelEvent>, event: ModelEvent) {
|
||
let _ = tx.send(event);
|
||
}
|
||
|
||
fn fire_webhooks(
|
||
registry: &Arc<Mutex<HashSet<String>>>,
|
||
event: ModelEvent,
|
||
rt: &tokio::runtime::Handle,
|
||
) {
|
||
if !event.is_webhook_event() {
|
||
return;
|
||
}
|
||
let urls: Vec<String> = registry
|
||
.lock()
|
||
.unwrap_or_else(|e| e.into_inner())
|
||
.iter()
|
||
.cloned()
|
||
.collect();
|
||
|
||
if urls.is_empty() {
|
||
return;
|
||
}
|
||
|
||
let payload = match serde_json::to_string(&event) {
|
||
Ok(p) => p,
|
||
Err(e) => {
|
||
tracing::error!(error = %e, "failed to serialize model event");
|
||
return;
|
||
}
|
||
};
|
||
|
||
for url in urls {
|
||
let body = payload.clone();
|
||
rt.spawn(async move {
|
||
let http = Client::builder()
|
||
.timeout(Duration::from_secs(10))
|
||
.build()
|
||
.expect("http client");
|
||
for attempt in 0..3_u32 {
|
||
match http
|
||
.post(&url)
|
||
.header("content-type", "application/json")
|
||
.body(body.clone())
|
||
.send()
|
||
.await
|
||
{
|
||
Ok(r) if r.status().is_success() => {
|
||
tracing::debug!(url, "model event webhook delivered");
|
||
return;
|
||
}
|
||
Ok(r) => tracing::warn!(url, status = r.status().as_u16(), "webhook non-2xx"),
|
||
Err(e) => tracing::warn!(url, error = %e, attempt, "webhook delivery failed"),
|
||
}
|
||
if attempt < 2 {
|
||
tokio::time::sleep(Duration::from_secs(2u64.pow(attempt))).await;
|
||
}
|
||
}
|
||
tracing::error!(url, "model event webhook failed after 3 attempts");
|
||
});
|
||
}
|
||
}
|
||
|
||
fn parse_oom_vram(msg: &str, gpu_device: u32) -> (u64, u64) {
|
||
let needed = msg
|
||
.split_whitespace()
|
||
.zip(msg.split_whitespace().skip(1))
|
||
.find(|(_, next)| *next == "MiB")
|
||
.and_then(|(n, _)| n.parse::<f64>().ok())
|
||
.map(|v| v as u64)
|
||
.unwrap_or(0);
|
||
|
||
let free = std::process::Command::new("nvidia-smi")
|
||
.args([
|
||
&format!("--id={gpu_device}"),
|
||
"--query-gpu=memory.free",
|
||
"--format=csv,noheader,nounits",
|
||
])
|
||
.output()
|
||
.ok()
|
||
.and_then(|o| String::from_utf8(o.stdout).ok())
|
||
.and_then(|s| s.trim().parse::<u64>().ok())
|
||
.unwrap_or(0);
|
||
|
||
(needed, free)
|
||
}
|
||
|
||
// ── Async job runner ──────────────────────────────────────────────────────────
|
||
|
||
async fn run(
|
||
mut job_rx: mpsc::UnboundedReceiver<JobId>,
|
||
storage: Arc<Storage>,
|
||
queue_depth: Arc<AtomicUsize>,
|
||
registry: ProgressRegistry,
|
||
cmd_tx: std::sync::mpsc::SyncSender<WorkerCmd>,
|
||
) {
|
||
let http = Client::builder()
|
||
.timeout(Duration::from_secs(30))
|
||
.build()
|
||
.expect("failed to build reqwest client");
|
||
|
||
while let Some(job_id) = job_rx.recv().await {
|
||
queue_depth.fetch_sub(1, Ordering::Relaxed);
|
||
|
||
let mut job = match storage.get(&job_id).await {
|
||
Ok(j) => j,
|
||
Err(e) => {
|
||
tracing::warn!(job_id = %job_id, error = %e, "job vanished before processing");
|
||
registry.remove(&job_id);
|
||
continue;
|
||
}
|
||
};
|
||
|
||
if job.status == JobStatus::Cancelled {
|
||
let _ = tokio::fs::remove_file(&audio_path_for(&job_id)).await;
|
||
registry.remove(&job_id);
|
||
continue;
|
||
}
|
||
|
||
job.status = JobStatus::Running;
|
||
if let Err(e) = storage.save(&job).await {
|
||
tracing::error!(job_id = %job_id, error = %e, "failed to persist running status");
|
||
}
|
||
|
||
let progress_tx = registry
|
||
.entry(job_id)
|
||
.or_insert_with(|| broadcast::channel(64).0)
|
||
.clone();
|
||
|
||
let audio_path = audio_path_for(&job_id);
|
||
|
||
let result = process_job(&job, &audio_path, &progress_tx, &cmd_tx, &storage).await;
|
||
|
||
let _ = tokio::fs::remove_file(&audio_path).await;
|
||
|
||
// Re-read from storage: the job may have been cancelled via DELETE /jobs/:id
|
||
// while process_job() was running. If so, discard the result entirely.
|
||
let current_status = storage.get(&job_id).await.map(|j| j.status).ok();
|
||
if current_status == Some(JobStatus::Cancelled) {
|
||
tracing::info!(job_id = %job_id, "job cancelled during inference — discarding result");
|
||
registry.remove(&job_id);
|
||
continue;
|
||
}
|
||
|
||
match result {
|
||
Ok((segments, language, duration_secs)) => {
|
||
job.status = JobStatus::Done;
|
||
job.segments = segments;
|
||
job.language = Some(language);
|
||
job.duration_secs = Some(duration_secs);
|
||
job.progress = 100;
|
||
job.completed_at = Some(Utc::now());
|
||
let _ = progress_tx.send(ProgressEvent::Done(Box::new(job.clone())));
|
||
}
|
||
Err(e) => {
|
||
let msg = e.to_string();
|
||
tracing::error!(job_id = %job_id, error = %msg, "transcription failed");
|
||
job.status = JobStatus::Failed;
|
||
job.error = Some(msg.clone());
|
||
job.completed_at = Some(Utc::now());
|
||
let _ = progress_tx.send(ProgressEvent::Error(msg));
|
||
}
|
||
}
|
||
|
||
if let Err(e) = storage.save(&job).await {
|
||
tracing::error!(job_id = %job_id, error = %e, "failed to persist final job state");
|
||
}
|
||
|
||
if let Some(url) = &job.webhook_url.clone() {
|
||
let http = http.clone();
|
||
let url = url.clone();
|
||
let job = job.clone();
|
||
tokio::spawn(async move {
|
||
webhook::fire(&http, &url, &job).await;
|
||
});
|
||
}
|
||
|
||
tokio::time::sleep(Duration::from_secs(30)).await;
|
||
registry.remove(&job_id);
|
||
}
|
||
}
|
||
|
||
// ── Silence-based chunking ────────────────────────────────────────────────────
|
||
|
||
const TARGET_CHUNK_SECS: f32 = 60.0;
|
||
const SNAP_WINDOW_SECS: f32 = 30.0;
|
||
const SILENCE_DB: &str = "-35dB";
|
||
const SILENCE_DUR: &str = "0.4";
|
||
|
||
async fn detect_silence_midpoints(path: &std::path::Path) -> Vec<f32> {
|
||
use tokio::process::Command;
|
||
|
||
let filter = format!("silencedetect=n={}:d={}", SILENCE_DB, SILENCE_DUR);
|
||
let output = Command::new("ffmpeg")
|
||
.args([
|
||
"-nostdin",
|
||
"-i",
|
||
path.to_str().unwrap_or(""),
|
||
"-af",
|
||
&filter,
|
||
"-f",
|
||
"null",
|
||
"-",
|
||
])
|
||
.output()
|
||
.await;
|
||
|
||
let output = match output {
|
||
Ok(o) => o,
|
||
Err(e) => {
|
||
tracing::warn!(error = %e, "silencedetect unavailable; using hard cuts");
|
||
return Vec::new();
|
||
}
|
||
};
|
||
|
||
let stderr = String::from_utf8_lossy(&output.stderr);
|
||
let mut starts: Vec<f32> = Vec::new();
|
||
let mut ends: Vec<f32> = Vec::new();
|
||
|
||
for line in stderr.lines() {
|
||
if let Some(i) = line.find("silence_start: ") {
|
||
if let Ok(t) = line[i + "silence_start: ".len()..].trim().parse::<f32>() {
|
||
starts.push(t);
|
||
}
|
||
} else if let Some(i) = line.find("silence_end: ") {
|
||
let t_str = line[i + "silence_end: ".len()..]
|
||
.split(" |")
|
||
.next()
|
||
.unwrap_or("")
|
||
.trim();
|
||
if let Ok(t) = t_str.parse::<f32>() {
|
||
ends.push(t);
|
||
}
|
||
}
|
||
}
|
||
|
||
let mids: Vec<f32> = starts
|
||
.iter()
|
||
.zip(ends.iter())
|
||
.map(|(s, e)| (s + e) / 2.0)
|
||
.collect();
|
||
|
||
tracing::debug!(n = mids.len(), "silence midpoints detected");
|
||
mids
|
||
}
|
||
|
||
fn snap_to_silence(mids: &[f32], total_secs: f32, target_secs: f32, snap_window: f32) -> Vec<f32> {
|
||
let mut cuts: Vec<f32> = Vec::new();
|
||
let mut pos = target_secs;
|
||
|
||
while pos < total_secs - target_secs * 0.25 {
|
||
let prev_cut = cuts.last().copied().unwrap_or(0.0);
|
||
let best = mids
|
||
.iter()
|
||
.copied()
|
||
.filter(|&t| t > prev_cut + 10.0 && (t - pos).abs() <= snap_window)
|
||
.min_by(|a, b| (a - pos).abs().partial_cmp(&(b - pos).abs()).unwrap());
|
||
let cut = best.unwrap_or(pos);
|
||
cuts.push(cut);
|
||
pos = cut + target_secs;
|
||
}
|
||
|
||
cuts
|
||
}
|
||
|
||
fn to_chunk_ranges(cuts: &[f32], total_secs: f32) -> Vec<(f32, f32)> {
|
||
let mut ranges = Vec::new();
|
||
let mut start = 0.0_f32;
|
||
|
||
for &cut in cuts {
|
||
if cut - start >= 5.0 {
|
||
ranges.push((start, cut));
|
||
start = cut;
|
||
}
|
||
}
|
||
if total_secs - start >= 1.0 {
|
||
ranges.push((start, total_secs));
|
||
}
|
||
ranges
|
||
}
|
||
|
||
const MAX_CHAIN_GAP_SECS: f32 = 0.15;
|
||
const MIN_MEANINGFUL_WORDS: usize = 2;
|
||
const MIN_MEANINGFUL_CHARS: usize = 8;
|
||
const MIN_OVERLAP_WORDS: usize = 1;
|
||
const SHORT_CARRYOVER_MAX_SECS: f32 = 0.2;
|
||
const SHORT_CARRYOVER_MAX_WORDS: usize = 2;
|
||
const SHORT_CARRYOVER_MAX_CHARS: usize = 16;
|
||
const NGRAM_N: usize = 6;
|
||
const LOOKBACK_CHARS: usize = 500;
|
||
const SIMILARITY_THRESHOLD: f32 = 0.6;
|
||
|
||
fn split_words(text: &str) -> Vec<&str> {
|
||
text.split_whitespace()
|
||
.filter(|word| !word.is_empty())
|
||
.collect()
|
||
}
|
||
|
||
fn normalise_token(word: &str) -> String {
|
||
word.chars()
|
||
.filter(|ch| ch.is_alphanumeric() || *ch == '_')
|
||
.flat_map(|ch| ch.to_lowercase())
|
||
.collect()
|
||
}
|
||
|
||
fn normalised_words(text: &str) -> Vec<String> {
|
||
split_words(text)
|
||
.into_iter()
|
||
.map(normalise_token)
|
||
.filter(|word| !word.is_empty())
|
||
.collect()
|
||
}
|
||
|
||
fn collapse_repeated_phrase_once(text: &str) -> String {
|
||
let raw_words = split_words(text);
|
||
if raw_words.len() < 4 {
|
||
return text.trim().to_string();
|
||
}
|
||
|
||
let normalised: Vec<String> = raw_words.iter().map(|word| normalise_token(word)).collect();
|
||
|
||
for size in (2..=raw_words.len() / 2).rev() {
|
||
for start in 0..=raw_words.len().saturating_sub(size * 2) {
|
||
let phrase_chars = raw_words[start..start + size]
|
||
.iter()
|
||
.map(|word| word.len())
|
||
.sum::<usize>()
|
||
+ size.saturating_sub(1);
|
||
if phrase_chars < 10 {
|
||
continue;
|
||
}
|
||
|
||
if normalised[start..start + size] == normalised[start + size..start + size * 2] {
|
||
let mut collapsed = Vec::with_capacity(raw_words.len() - size);
|
||
collapsed.extend_from_slice(&raw_words[..start + size]);
|
||
collapsed.extend_from_slice(&raw_words[start + size * 2..]);
|
||
return collapsed.join(" ").trim().to_string();
|
||
}
|
||
}
|
||
}
|
||
|
||
text.trim().to_string()
|
||
}
|
||
|
||
fn collapse_repeats(text: &str) -> String {
|
||
let mut current = text.trim().to_string();
|
||
loop {
|
||
let next = collapse_repeated_phrase_once(¤t);
|
||
if next == current {
|
||
return next;
|
||
}
|
||
current = next;
|
||
}
|
||
}
|
||
|
||
fn starts_with_words(full: &[String], prefix: &[String]) -> bool {
|
||
prefix.len() <= full.len() && full.iter().take(prefix.len()).eq(prefix.iter())
|
||
}
|
||
|
||
fn ends_with_words(full: &[String], suffix: &[String]) -> bool {
|
||
suffix.len() <= full.len()
|
||
&& full
|
||
.iter()
|
||
.skip(full.len() - suffix.len())
|
||
.eq(suffix.iter())
|
||
}
|
||
|
||
fn suffix_prefix_overlap(left: &[String], right: &[String]) -> usize {
|
||
let max = left.len().min(right.len());
|
||
for size in (1..=max).rev() {
|
||
if left[left.len() - size..] == right[..size] {
|
||
return size;
|
||
}
|
||
}
|
||
0
|
||
}
|
||
|
||
fn is_meaningful_phrase(words: &[String]) -> bool {
|
||
words.len() >= MIN_MEANINGFUL_WORDS
|
||
&& words.iter().map(|word| word.len()).sum::<usize>() >= MIN_MEANINGFUL_CHARS
|
||
}
|
||
|
||
fn is_short_carryover(seg: &Segment, words: &[String]) -> bool {
|
||
seg.end - seg.start <= SHORT_CARRYOVER_MAX_SECS
|
||
|| words.len() <= SHORT_CARRYOVER_MAX_WORDS
|
||
|| words.iter().map(|word| word.len()).sum::<usize>() + words.len().saturating_sub(1)
|
||
<= SHORT_CARRYOVER_MAX_CHARS
|
||
}
|
||
|
||
fn trim_leading_words(text: &str, count: usize) -> String {
|
||
split_words(text)
|
||
.into_iter()
|
||
.skip(count)
|
||
.collect::<Vec<_>>()
|
||
.join(" ")
|
||
.trim()
|
||
.to_string()
|
||
}
|
||
|
||
fn merge_identical_segments(segments: Vec<Segment>) -> Vec<Segment> {
|
||
let mut out: Vec<Segment> = Vec::with_capacity(segments.len());
|
||
|
||
for seg in segments {
|
||
if let Some(last) = out.last_mut() {
|
||
if normalised_words(&last.text) == normalised_words(&seg.text) {
|
||
last.end = last.end.max(seg.end);
|
||
if !seg.words.is_empty() {
|
||
last.words = seg.words;
|
||
}
|
||
continue;
|
||
}
|
||
}
|
||
out.push(seg);
|
||
}
|
||
|
||
out
|
||
}
|
||
|
||
fn collapse_incremental_segments(segments: Vec<Segment>) -> Vec<Segment> {
|
||
let mut out: Vec<Segment> = Vec::with_capacity(segments.len());
|
||
|
||
for mut seg in segments {
|
||
seg.text = seg.text.trim().to_string();
|
||
if seg.text.is_empty() {
|
||
continue;
|
||
}
|
||
|
||
let Some(last) = out.last_mut() else {
|
||
out.push(seg);
|
||
continue;
|
||
};
|
||
|
||
let gap = seg.start - last.end;
|
||
if gap > MAX_CHAIN_GAP_SECS {
|
||
out.push(seg);
|
||
continue;
|
||
}
|
||
|
||
let last_words = normalised_words(&last.text);
|
||
let seg_words = normalised_words(&seg.text);
|
||
if last_words.is_empty() || seg_words.is_empty() {
|
||
out.push(seg);
|
||
continue;
|
||
}
|
||
|
||
if seg_words.len() > last_words.len()
|
||
&& starts_with_words(&seg_words, &last_words)
|
||
&& (is_meaningful_phrase(&last_words) || is_short_carryover(last, &last_words))
|
||
{
|
||
last.text = seg.text;
|
||
last.end = seg.end;
|
||
last.words = seg.words;
|
||
continue;
|
||
}
|
||
|
||
if ends_with_words(&last_words, &seg_words)
|
||
&& (is_meaningful_phrase(&seg_words) || is_short_carryover(&seg, &seg_words))
|
||
{
|
||
last.end = last.end.max(seg.end);
|
||
continue;
|
||
}
|
||
|
||
let overlap = suffix_prefix_overlap(&last_words, &seg_words);
|
||
if overlap >= MIN_OVERLAP_WORDS {
|
||
let trimmed_text = trim_leading_words(&seg.text, overlap);
|
||
if trimmed_text.is_empty() {
|
||
last.end = last.end.max(seg.end);
|
||
continue;
|
||
}
|
||
|
||
seg.start = seg.start.max(last.end);
|
||
seg.text = trimmed_text;
|
||
seg.words.clear();
|
||
}
|
||
|
||
out.push(seg);
|
||
}
|
||
|
||
out
|
||
}
|
||
|
||
fn ngrams(text: &str, n: usize) -> HashSet<String> {
|
||
let words = text
|
||
.to_lowercase()
|
||
.split_whitespace()
|
||
.map(str::to_string)
|
||
.collect::<Vec<_>>();
|
||
|
||
if words.len() < n {
|
||
return HashSet::new();
|
||
}
|
||
|
||
let mut grams = HashSet::new();
|
||
for idx in 0..=words.len() - n {
|
||
grams.insert(words[idx..idx + n].join(" "));
|
||
}
|
||
grams
|
||
}
|
||
|
||
fn jaccard_similarity(left: &str, right: &str) -> f32 {
|
||
let left_grams = ngrams(left, NGRAM_N);
|
||
let right_grams = ngrams(right, NGRAM_N);
|
||
|
||
if left_grams.is_empty() && right_grams.is_empty() {
|
||
return 0.0;
|
||
}
|
||
|
||
let intersection = left_grams.intersection(&right_grams).count();
|
||
let union = left_grams.union(&right_grams).count();
|
||
|
||
if union == 0 {
|
||
0.0
|
||
} else {
|
||
intersection as f32 / union as f32
|
||
}
|
||
}
|
||
|
||
fn tail_chars(text: &str, limit: usize) -> String {
|
||
let chars = text.chars().collect::<Vec<_>>();
|
||
let start = chars.len().saturating_sub(limit);
|
||
chars[start..].iter().collect()
|
||
}
|
||
|
||
fn ngram_dedup(segments: Vec<Segment>) -> Vec<Segment> {
|
||
let mut out = Vec::with_capacity(segments.len());
|
||
|
||
for seg in segments {
|
||
let window_text = out
|
||
.iter()
|
||
.skip(out.len().saturating_sub(20))
|
||
.map(|segment: &Segment| segment.text.as_str())
|
||
.collect::<Vec<_>>()
|
||
.join(" ");
|
||
let recent_context = tail_chars(&window_text, LOOKBACK_CHARS);
|
||
|
||
if !recent_context.is_empty()
|
||
&& jaccard_similarity(&seg.text, &recent_context) >= SIMILARITY_THRESHOLD
|
||
{
|
||
continue;
|
||
}
|
||
|
||
out.push(seg);
|
||
}
|
||
|
||
out
|
||
}
|
||
|
||
fn normalise_segments(segments: Vec<Segment>) -> Vec<Segment> {
|
||
let mut result = segments
|
||
.into_iter()
|
||
.map(|mut seg| {
|
||
seg.text = collapse_repeats(seg.text.trim());
|
||
seg
|
||
})
|
||
.filter(|seg| !seg.text.is_empty())
|
||
.collect::<Vec<_>>();
|
||
|
||
result = collapse_incremental_segments(result);
|
||
result = merge_identical_segments(result);
|
||
result = ngram_dedup(result);
|
||
result = collapse_incremental_segments(result);
|
||
merge_identical_segments(result)
|
||
}
|
||
|
||
// ── Job processing ────────────────────────────────────────────────────────────
|
||
|
||
async fn process_job(
|
||
job: &Job,
|
||
audio_path: &std::path::Path,
|
||
progress_tx: &ProgressTx,
|
||
cmd_tx: &std::sync::mpsc::SyncSender<WorkerCmd>,
|
||
storage: &Arc<Storage>,
|
||
) -> crate::Result<(Vec<Segment>, String, f32)> {
|
||
let pcm = decode_audio(audio_path).await?;
|
||
let total_secs = pcm.len() as f32 / 16_000.0;
|
||
|
||
let silence_mids = detect_silence_midpoints(audio_path).await;
|
||
let cuts = snap_to_silence(
|
||
&silence_mids,
|
||
total_secs,
|
||
TARGET_CHUNK_SECS,
|
||
SNAP_WINDOW_SECS,
|
||
);
|
||
let chunks = to_chunk_ranges(&cuts, total_secs);
|
||
let n = chunks.len();
|
||
|
||
tracing::info!(
|
||
total_secs,
|
||
n_chunks = n,
|
||
silence_points = silence_mids.len(),
|
||
"audio chunked by silence"
|
||
);
|
||
|
||
let mut all_segments: Vec<Segment> = Vec::new();
|
||
let mut language = String::new();
|
||
|
||
for (ci, (chunk_start, chunk_end)) in chunks.iter().enumerate() {
|
||
let s0 = (*chunk_start * 16_000.0) as usize;
|
||
let s1 = ((*chunk_end * 16_000.0) as usize).min(pcm.len());
|
||
let mut chunk_pcm = pcm[s0..s1].to_vec();
|
||
trim_trailing_silence(&mut chunk_pcm);
|
||
|
||
let base = (ci * 100 / n) as u8;
|
||
let span = (100usize / n).max(1) as u8;
|
||
|
||
// Save progress to disk before emitting SSE — polling clients who respond
|
||
// immediately to the SSE event will then see consistent state.
|
||
let mut snapshot = job.clone();
|
||
snapshot.progress = base;
|
||
if let Err(e) = storage.save(&snapshot).await {
|
||
tracing::warn!(error = %e, "failed to persist mid-job progress");
|
||
}
|
||
|
||
let _ = progress_tx.send(ProgressEvent::Progress {
|
||
percent: base,
|
||
chunk: ci + 1,
|
||
total: n,
|
||
});
|
||
|
||
let tx = progress_tx.clone();
|
||
let chunk_num = ci + 1;
|
||
let on_progress = Box::new(move |p: u8| {
|
||
let overall = base.saturating_add(p.saturating_mul(span) / 100);
|
||
let _ = tx.send(ProgressEvent::Progress {
|
||
percent: overall,
|
||
chunk: chunk_num,
|
||
total: n,
|
||
});
|
||
});
|
||
|
||
let (reply_tx, reply_rx) = oneshot::channel();
|
||
cmd_tx
|
||
.send(WorkerCmd::Transcribe(TranscribeRequest {
|
||
pcm: chunk_pcm,
|
||
language: job.language.clone(),
|
||
task: job.task.clone(),
|
||
on_progress,
|
||
reply: reply_tx,
|
||
}))
|
||
.map_err(|_| AppError::Internal("worker command channel closed".into()))?;
|
||
|
||
let (mut segs, lang) = reply_rx
|
||
.await
|
||
.map_err(|_| AppError::Internal("transcriber thread dropped reply".into()))??;
|
||
|
||
let offset = *chunk_start;
|
||
for seg in &mut segs {
|
||
seg.start += offset;
|
||
seg.end += offset;
|
||
for word in &mut seg.words {
|
||
word.start += offset;
|
||
word.end += offset;
|
||
}
|
||
}
|
||
|
||
tracing::debug!(
|
||
chunk = ci + 1,
|
||
of = n,
|
||
start = chunk_start,
|
||
end = chunk_end,
|
||
segs = segs.len(),
|
||
"chunk done"
|
||
);
|
||
|
||
all_segments.extend(segs);
|
||
if language.is_empty() {
|
||
language = lang;
|
||
}
|
||
}
|
||
|
||
all_segments = normalise_segments(all_segments);
|
||
|
||
for (i, seg) in all_segments.iter_mut().enumerate() {
|
||
seg.index = i as i32;
|
||
}
|
||
|
||
let _ = progress_tx.send(ProgressEvent::Progress {
|
||
percent: 100,
|
||
chunk: n,
|
||
total: n,
|
||
});
|
||
Ok((all_segments, language, total_secs))
|
||
}
|
||
|
||
fn trim_trailing_silence(pcm: &mut Vec<f32>) {
|
||
const THRESHOLD: f32 = 0.017_8;
|
||
const PADDING: usize = 8_000;
|
||
|
||
if let Some(last_loud) = pcm.iter().rposition(|&s| s.abs() > THRESHOLD) {
|
||
let new_len = (last_loud + 1 + PADDING).min(pcm.len());
|
||
if new_len < pcm.len() {
|
||
tracing::trace!(
|
||
original_samples = pcm.len(),
|
||
trimmed_samples = pcm.len() - new_len,
|
||
"trimmed trailing silence"
|
||
);
|
||
pcm.truncate(new_len);
|
||
}
|
||
}
|
||
}
|
||
|
||
async fn decode_audio(path: &std::path::Path) -> crate::Result<Vec<f32>> {
|
||
use tokio::process::Command;
|
||
|
||
let output = Command::new("ffmpeg")
|
||
.args([
|
||
"-nostdin",
|
||
"-threads",
|
||
"0",
|
||
"-i",
|
||
path.to_str().unwrap_or(""),
|
||
"-f",
|
||
"f32le",
|
||
"-ac",
|
||
"1",
|
||
"-ar",
|
||
"16000",
|
||
"-",
|
||
])
|
||
.output()
|
||
.await
|
||
.map_err(|e| AppError::Internal(format!("ffmpeg spawn failed: {e}")))?;
|
||
|
||
if !output.status.success() {
|
||
let stderr = String::from_utf8_lossy(&output.stderr);
|
||
return Err(AppError::Internal(format!(
|
||
"ffmpeg exited with {}: {}",
|
||
output.status, stderr
|
||
)));
|
||
}
|
||
|
||
let bytes = output.stdout;
|
||
if bytes.len() % 4 != 0 {
|
||
return Err(AppError::Internal(
|
||
"ffmpeg output length not a multiple of 4".into(),
|
||
));
|
||
}
|
||
Ok(bytes
|
||
.chunks_exact(4)
|
||
.map(|b| f32::from_le_bytes([b[0], b[1], b[2], b[3]]))
|
||
.collect())
|
||
}
|
||
|
||
pub fn audio_path_for(id: &JobId) -> PathBuf {
|
||
let data_dir = std::env::var("DATA_DIR").unwrap_or_else(|_| "/data".into());
|
||
PathBuf::from(data_dir).join(format!("{id}.audio"))
|
||
}
|
||
|
||
// ── Unit tests ────────────────────────────────────────────────────────────────
|
||
|
||
#[cfg(test)]
|
||
mod tests {
|
||
use super::*;
|
||
use crate::models::Word;
|
||
|
||
fn segment(index: i32, start: f32, end: f32, text: &str) -> Segment {
|
||
Segment {
|
||
index,
|
||
start,
|
||
end,
|
||
text: text.into(),
|
||
words: Vec::<Word>::new(),
|
||
}
|
||
}
|
||
|
||
#[test]
|
||
fn test_snap_to_silence_uses_nearest_midpoint() {
|
||
let mids = vec![55.0, 58.0, 62.0];
|
||
let cuts = snap_to_silence(&mids, 120.0, 60.0, 30.0);
|
||
assert!(!cuts.is_empty());
|
||
assert!(
|
||
(cuts[0] - 58.0).abs() < 0.01,
|
||
"expected ~58.0, got {}",
|
||
cuts[0]
|
||
);
|
||
}
|
||
|
||
#[test]
|
||
fn test_snap_to_silence_hard_cut_when_no_silence() {
|
||
let cuts = snap_to_silence(&[], 120.0, 60.0, 30.0);
|
||
assert_eq!(cuts, vec![60.0]);
|
||
}
|
||
|
||
#[test]
|
||
fn test_to_chunk_ranges_single_chunk() {
|
||
let ranges = to_chunk_ranges(&[], 30.0);
|
||
assert_eq!(ranges, vec![(0.0, 30.0)]);
|
||
}
|
||
|
||
#[test]
|
||
fn test_to_chunk_ranges_two_chunks() {
|
||
let ranges = to_chunk_ranges(&[60.0], 120.0);
|
||
assert_eq!(ranges, vec![(0.0, 60.0), (60.0, 120.0)]);
|
||
}
|
||
|
||
#[test]
|
||
fn test_trim_trailing_silence_all_silent() {
|
||
let mut pcm = vec![0.0f32; 1000];
|
||
trim_trailing_silence(&mut pcm);
|
||
assert_eq!(pcm.len(), 1000);
|
||
}
|
||
|
||
#[test]
|
||
fn test_trim_trailing_silence_trims_to_padding() {
|
||
let mut pcm = vec![0.0f32; 32_000];
|
||
pcm[10_000] = 1.0;
|
||
trim_trailing_silence(&mut pcm);
|
||
assert_eq!(pcm.len(), (10_001 + 8_000).min(32_000));
|
||
}
|
||
|
||
#[test]
|
||
fn test_normalise_segments_collapses_prefix_growth_chain() {
|
||
let input = vec![
|
||
segment(0, 15.24, 16.6, "Hello everyone."),
|
||
segment(1, 16.6, 19.47, "Hello everyone. Um, welcome to this talk."),
|
||
segment(2, 19.47, 19.48, "Um, welcome to this talk."),
|
||
segment(
|
||
3,
|
||
19.48,
|
||
21.67,
|
||
"Um, welcome to this talk. I'll be speaking about small model",
|
||
),
|
||
segment(4, 21.67, 21.68, "I'll be speaking about small model"),
|
||
segment(
|
||
5,
|
||
21.68,
|
||
24.59,
|
||
"I'll be speaking about small model inference and a gap that we've",
|
||
),
|
||
];
|
||
|
||
let result = normalise_segments(input);
|
||
|
||
assert_eq!(result.len(), 2);
|
||
assert_eq!(result[0].text, "Hello everyone. Um, welcome to this talk.");
|
||
assert!((result[0].start - 15.24).abs() < 0.01);
|
||
assert!((result[0].end - 19.48).abs() < 0.01);
|
||
assert_eq!(
|
||
result[1].text,
|
||
"I'll be speaking about small model inference and a gap that we've"
|
||
);
|
||
assert!((result[1].start - 19.48).abs() < 0.01);
|
||
assert!((result[1].end - 24.59).abs() < 0.01);
|
||
}
|
||
|
||
#[test]
|
||
fn test_normalise_segments_collapses_repeated_phrase_inside_segment() {
|
||
let input = vec![segment(
|
||
0,
|
||
0.0,
|
||
5.0,
|
||
"the quick brown fox the quick brown fox jumps over the fence",
|
||
)];
|
||
|
||
let result = normalise_segments(input);
|
||
|
||
assert_eq!(result.len(), 1);
|
||
assert_eq!(result[0].text, "the quick brown fox jumps over the fence");
|
||
}
|
||
|
||
#[test]
|
||
fn test_normalise_segments_keeps_real_gap() {
|
||
let input = vec![
|
||
segment(0, 0.0, 1.0, "Hello everyone."),
|
||
segment(1, 2.0, 4.0, "Hello everyone. Welcome back."),
|
||
];
|
||
|
||
let result = normalise_segments(input);
|
||
|
||
assert_eq!(result.len(), 2);
|
||
assert_eq!(result[0].text, "Hello everyone.");
|
||
assert_eq!(result[1].text, "Hello everyone. Welcome back.");
|
||
}
|
||
|
||
#[test]
|
||
fn test_normalise_segments_collapses_tiny_carry_over_segments() {
|
||
let input = vec![
|
||
segment(0, 94.8, 96.4, "world."),
|
||
segment(
|
||
1,
|
||
96.4,
|
||
98.96,
|
||
"world. And that aspect that I overlooked was",
|
||
),
|
||
segment(2, 98.96, 100.72, "inference."),
|
||
segment(
|
||
3,
|
||
100.72,
|
||
103.92,
|
||
"inference. So, as someone who kind of wants to",
|
||
),
|
||
segment(4, 107.19, 107.2, "and"),
|
||
segment(
|
||
5,
|
||
107.2,
|
||
109.56,
|
||
"and work to understand the problems and the",
|
||
),
|
||
];
|
||
|
||
let result = normalise_segments(input);
|
||
|
||
assert_eq!(result.len(), 3);
|
||
assert_eq!(
|
||
result[0].text,
|
||
"world. And that aspect that I overlooked was"
|
||
);
|
||
assert_eq!(
|
||
result[1].text,
|
||
"inference. So, as someone who kind of wants to"
|
||
);
|
||
assert_eq!(
|
||
result[2].text,
|
||
"and work to understand the problems and the"
|
||
);
|
||
}
|
||
|
||
#[test]
|
||
fn test_normalise_segments_trims_single_word_adjacent_overlap() {
|
||
let input = vec![
|
||
segment(0, 94.8, 96.4, "world."),
|
||
segment(
|
||
1,
|
||
96.4,
|
||
98.96,
|
||
"world. And that aspect that I overlooked was",
|
||
),
|
||
segment(2, 120.12, 123.71, "to find more about inference."),
|
||
segment(
|
||
3,
|
||
123.72,
|
||
126.92,
|
||
"inference. So, I've done a lot of work with VLAM,",
|
||
),
|
||
];
|
||
|
||
let result = normalise_segments(input);
|
||
|
||
assert_eq!(result.len(), 3);
|
||
assert_eq!(
|
||
result[0].text,
|
||
"world. And that aspect that I overlooked was"
|
||
);
|
||
assert_eq!(result[2].text, "So, I've done a lot of work with VLAM,");
|
||
}
|
||
}
|