feat: dynamic model loading/unloading with GPU polling
All checks were successful
Build & Push Docker Image / build-and-push (push) Successful in 8m41s

- Model starts unloaded (lazy); loads on first job or POST /model/load
- Auto-unloads after IDLE_TIMEOUT_SECS (default 300) of inactivity
- POST /model/unload for immediate manual release
- GPU-busy detection: on VRAM OOM, enters WaitingForGpu and retries
  every GPU_POLL_INTERVAL_SECS (default 30) indefinitely
- POST /jobs when unloaded → 503 + Retry-After header, triggers load
- AppError::OutOfMemory and AppError::ModelNotReady variants
- WorkerCmd channel (SyncSender<WorkerCmd>) replaces bare tx_req channel
- Idle timer via recv_timeout(1s) tick inside OS thread (no extra thread)
- Model lifecycle events broadcast via tokio broadcast channel (SSE + webhooks)
- webhook_registry: all clients that ever submitted a webhook_url receive
  model_ready and model_unloaded webhooks
- GPU warmup retained on every (re)load

New routes:
  GET  /model/status  — current state + VRAM stats
  POST /model/load    — trigger load (idempotent)
  POST /model/unload  — immediate unload
  GET  /model/events  — SSE stream of model lifecycle events

New env vars:
  IDLE_TIMEOUT_SECS       (default 300)
  GPU_POLL_INTERVAL_SECS  (default 30)

Tests:
  tests/test_model_lifecycle.sh — 18 integration tests (full state machine,
    SSE events, webhooks, concurrency, unload-during-load)
  tests/test_idle_timeout.sh    — 5 tests with short IDLE_TIMEOUT_SECS=5
  test_all.sh updated: loads model before job submission, asserts
    model_state in /health, adds POST /model/unload at end

Docs:
  docs/USAGE.md: model lifecycle section, new env vars, 503 retry pattern,
    updated /health response shape

Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
This commit is contained in:
mozempk
2026-05-08 17:57:20 +02:00
parent 78c6fab81b
commit b191fbe200
13 changed files with 2053 additions and 148 deletions

View File

@@ -1,20 +1,23 @@
use std::{
collections::HashSet,
path::PathBuf,
sync::{
atomic::{AtomicUsize, Ordering},
Arc,
Arc, Mutex,
},
time::{Duration, Instant},
};
use chrono::Utc;
use reqwest::Client;
use tokio::sync::{broadcast, mpsc, oneshot};
use tokio::sync::{broadcast, mpsc, oneshot, RwLock};
use crate::{
models::{Job, JobId, JobStatus, Segment},
models::{Job, JobId, JobStatus, ModelEvent, ModelState, Segment},
storage::Storage,
transcriber::Transcriber,
webhook,
AppError,
};
/// Per-job broadcast channel for SSE subscribers.
@@ -31,83 +34,383 @@ pub enum ProgressEvent {
/// Global registry: job_id → broadcast sender.
pub type ProgressRegistry = Arc<dashmap::DashMap<JobId, ProgressTx>>;
// ── Transcription request/response types for the blocking thread ─────────────
// ── Worker command channel ────────────────────────────────────────────────────
struct TranscribeRequest {
pcm: Vec<f32>,
language: Option<String>,
task: String,
/// Per-chunk progress callback — receives 0100 from whisper.cpp and can
/// scale/offset it before forwarding to the job's broadcast channel.
on_progress: Box<dyn Fn(u8) + Send + 'static>,
reply: oneshot::Sender<crate::Result<(Vec<Segment>, String)>>,
/// Commands sent to the GPU worker OS thread.
#[derive(Debug)]
pub enum WorkerCmd {
/// Request a model load. Idempotent: if already loading/ready, ignored.
Load,
/// Unload the model immediately and free GPU memory.
Unload,
/// Internal: run a transcription chunk.
Transcribe(TranscribeRequest),
}
// ── Transcription request/response types ─────────────────────────────────────
pub struct TranscribeRequest {
pub pcm: Vec<f32>,
pub language: Option<String>,
pub task: String,
pub on_progress: Box<dyn Fn(u8) + Send + 'static>,
pub reply: oneshot::Sender<crate::Result<(Vec<Segment>, String)>>,
}
impl std::fmt::Debug for TranscribeRequest {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("TranscribeRequest")
.field("language", &self.language)
.field("task", &self.task)
.finish_non_exhaustive()
}
}
// ── Public API ────────────────────────────────────────────────────────────────
/// Spawn the single GPU worker.
/// Returns the SSE progress registry.
///
/// Returns the SSE progress registry and a command sender for the worker thread.
/// The model starts **unloaded**; send `WorkerCmd::Load` or submit a job to
/// trigger loading.
#[allow(clippy::too_many_arguments)]
pub fn start(
job_rx: mpsc::UnboundedReceiver<JobId>,
storage: Arc<Storage>,
model_path: PathBuf,
queue_depth: Arc<AtomicUsize>,
gpu_device: u32,
) -> ProgressRegistry {
job_rx: mpsc::UnboundedReceiver<JobId>,
storage: Arc<Storage>,
model_path: PathBuf,
queue_depth: Arc<AtomicUsize>,
gpu_device: u32,
model_state: Arc<RwLock<ModelState>>,
model_event_tx: broadcast::Sender<ModelEvent>,
webhook_registry: Arc<Mutex<HashSet<String>>>,
idle_timeout: Duration,
gpu_poll_interval: Duration,
) -> (ProgressRegistry, std::sync::mpsc::SyncSender<WorkerCmd>) {
let registry: ProgressRegistry = Arc::new(dashmap::DashMap::new());
let reg_clone = Arc::clone(&registry);
let (tx_req, rx_req) = std::sync::mpsc::channel::<TranscribeRequest>();
// Bounded sync channel: capacity 8 is plenty (load/unload are rare).
let (cmd_tx, cmd_rx) = std::sync::mpsc::sync_channel::<WorkerCmd>(8);
let cmd_tx_clone = cmd_tx.clone();
// Capture Tokio runtime handle so the OS thread can spawn async tasks.
let rt_handle = tokio::runtime::Handle::current();
std::thread::Builder::new()
.name("whisper-gpu".into())
.spawn(move || transcriber_thread(rx_req, model_path, gpu_device))
.spawn(move || {
transcriber_thread(
cmd_rx,
model_path,
gpu_device,
model_state,
model_event_tx,
webhook_registry,
idle_timeout,
gpu_poll_interval,
rt_handle,
);
})
.expect("failed to spawn whisper-gpu thread");
tokio::spawn(run(job_rx, storage, queue_depth, reg_clone, tx_req));
tokio::spawn(run(job_rx, storage, queue_depth, reg_clone, cmd_tx_clone));
registry
(registry, cmd_tx)
}
/// Dedicated OS thread that owns the Transcriber (non-Send) and runs inference.
///
/// The Transcriber holds a single `WhisperState` that is reused for every chunk.
/// GPU compute buffers (~700 MB) are allocated once at startup rather than on
/// every call, eliminating per-chunk `whisper_init_state` overhead and the
/// VRAM churn that caused intermittent 0-segment results.
fn transcriber_thread(
rx: std::sync::mpsc::Receiver<TranscribeRequest>,
model_path: PathBuf,
gpu_device: u32,
) {
let mut transcriber = match Transcriber::load(&model_path, gpu_device) {
Ok(t) => t,
Err(e) => {
tracing::error!(error = %e, "failed to load whisper model — transcriber thread exiting");
return;
}
};
tracing::info!(model = %model_path.display(), "GPU worker ready");
// ── GPU OS thread ─────────────────────────────────────────────────────────────
for req in rx {
let on_progress = req.on_progress;
let result = transcriber.transcribe(
&req.pcm,
req.language.as_deref(),
&req.task,
move |p| on_progress(p),
);
let _ = req.reply.send(result);
/// The worker OS thread that owns the `Transcriber` (non-`Send`).
///
/// Uses `recv_timeout` with a 1-second tick to drive the idle timer without a
/// separate thread.
#[allow(clippy::too_many_arguments)]
fn transcriber_thread(
rx: std::sync::mpsc::Receiver<WorkerCmd>,
model_path: PathBuf,
gpu_device: u32,
model_state: Arc<RwLock<ModelState>>,
model_event_tx: broadcast::Sender<ModelEvent>,
webhook_registry: Arc<Mutex<HashSet<String>>>,
idle_timeout: Duration,
gpu_poll_interval: Duration,
rt: tokio::runtime::Handle,
) {
let mut transcriber: Option<Transcriber> = None;
let mut last_job = Instant::now();
loop {
match rx.recv_timeout(Duration::from_secs(1)) {
Ok(WorkerCmd::Load) => {
if transcriber.is_some() {
tracing::debug!("WorkerCmd::Load ignored — model already loaded");
continue;
}
transcriber = try_load_with_polling(
&rx,
&model_path,
gpu_device,
&model_state,
&model_event_tx,
&webhook_registry,
gpu_poll_interval,
&rt,
);
if transcriber.is_some() {
last_job = Instant::now();
}
}
Ok(WorkerCmd::Unload) => {
do_unload(&mut transcriber, &model_state, &model_event_tx, &webhook_registry, &rt);
}
Ok(WorkerCmd::Transcribe(req)) => {
let t = match &mut transcriber {
Some(t) => t,
None => {
tracing::warn!("Transcribe cmd received but model is unloaded — failing job");
let _ = req.reply.send(Err(AppError::Internal(
"model unloaded before job could run".into(),
)));
continue;
}
};
let result = t.transcribe(
&req.pcm,
req.language.as_deref(),
&req.task,
move |p| (req.on_progress)(p),
);
last_job = Instant::now();
let _ = req.reply.send(result);
}
Err(std::sync::mpsc::RecvTimeoutError::Timeout) => {
if transcriber.is_some() && last_job.elapsed() >= idle_timeout {
tracing::info!(
elapsed_secs = last_job.elapsed().as_secs(),
"idle timeout reached — unloading model"
);
do_unload(
&mut transcriber,
&model_state,
&model_event_tx,
&webhook_registry,
&rt,
);
}
}
Err(std::sync::mpsc::RecvTimeoutError::Disconnected) => {
tracing::info!("worker command channel closed — shutting down GPU thread");
break;
}
}
}
}
/// Attempt to load the model, polling on VRAM failures.
///
/// While waiting for GPU, drains `rx` so that `WorkerCmd::Unload` cancels the
/// load attempt and `WorkerCmd::Transcribe` commands get a "model not ready"
/// rejection. Returns `Some(Transcriber)` on success, `None` if cancelled.
#[allow(clippy::too_many_arguments)]
fn try_load_with_polling(
rx: &std::sync::mpsc::Receiver<WorkerCmd>,
model_path: &PathBuf,
gpu_device: u32,
model_state: &Arc<RwLock<ModelState>>,
model_event_tx: &broadcast::Sender<ModelEvent>,
webhook_registry: &Arc<Mutex<HashSet<String>>>,
gpu_poll_interval: Duration,
rt: &tokio::runtime::Handle,
) -> Option<Transcriber> {
loop {
set_state(model_state, ModelState::Loading);
broadcast_event(model_event_tx, ModelEvent::ModelLoading);
tracing::info!("loading whisper model...");
match Transcriber::load(model_path, gpu_device) {
Ok(t) => {
let loaded_at = Utc::now();
set_state(model_state, ModelState::Ready { loaded_at });
broadcast_event(model_event_tx, ModelEvent::ModelReady { loaded_at });
fire_webhooks(webhook_registry, ModelEvent::ModelReady { loaded_at }, rt);
tracing::info!("model loaded and ready");
return Some(t);
}
Err(AppError::OutOfMemory(msg)) => {
let (vram_needed_mb, vram_free_mb) = parse_oom_vram(&msg, gpu_device);
let retry_in_secs = gpu_poll_interval.as_secs();
tracing::warn!(
vram_needed_mb,
vram_free_mb,
retry_in_secs,
"insufficient VRAM — will retry"
);
set_state(model_state, ModelState::WaitingForGpu {
vram_needed_mb,
vram_free_mb,
retry_in_secs,
});
broadcast_event(model_event_tx, ModelEvent::ModelWaitingForGpu {
vram_needed_mb,
vram_free_mb,
retry_in_secs,
});
// Interruptible sleep: drain rx while waiting for gpu_poll_interval.
let deadline = Instant::now() + gpu_poll_interval;
loop {
let remaining = deadline.saturating_duration_since(Instant::now());
if remaining.is_zero() { break; }
match rx.recv_timeout(remaining.min(Duration::from_secs(1))) {
Ok(WorkerCmd::Unload) => {
tracing::info!("Unload received while waiting for GPU — cancelling load");
set_state(model_state, ModelState::Unloaded);
broadcast_event(model_event_tx, ModelEvent::ModelUnloaded);
fire_webhooks(webhook_registry, ModelEvent::ModelUnloaded, rt);
return None;
}
Ok(WorkerCmd::Load) => {} // idempotent
Ok(WorkerCmd::Transcribe(req)) => {
let _ = req.reply.send(Err(AppError::ModelNotReady {
state: "waiting_for_gpu".into(),
retry_after_secs: retry_in_secs,
}));
}
Err(std::sync::mpsc::RecvTimeoutError::Timeout) => {}
Err(std::sync::mpsc::RecvTimeoutError::Disconnected) => return None,
}
}
// Loop back to retry load
}
Err(e) => {
tracing::error!(error = %e, "model load failed with non-recoverable error");
set_state(model_state, ModelState::Unloaded);
return None;
}
}
}
}
fn do_unload(
transcriber: &mut Option<Transcriber>,
model_state: &Arc<RwLock<ModelState>>,
model_event_tx: &broadcast::Sender<ModelEvent>,
webhook_registry: &Arc<Mutex<HashSet<String>>>,
rt: &tokio::runtime::Handle,
) {
*transcriber = None;
set_state(model_state, ModelState::Unloaded);
broadcast_event(model_event_tx, ModelEvent::ModelUnloaded);
fire_webhooks(webhook_registry, ModelEvent::ModelUnloaded, rt);
tracing::info!("model unloaded — GPU memory freed");
}
// ── Helpers ───────────────────────────────────────────────────────────────────
fn set_state(arc: &Arc<RwLock<ModelState>>, state: ModelState) {
*arc.blocking_write() = state;
}
fn broadcast_event(tx: &broadcast::Sender<ModelEvent>, event: ModelEvent) {
let _ = tx.send(event);
}
fn fire_webhooks(
registry: &Arc<Mutex<HashSet<String>>>,
event: ModelEvent,
rt: &tokio::runtime::Handle,
) {
if !event.is_webhook_event() {
return;
}
let urls: Vec<String> = registry
.lock()
.unwrap_or_else(|e| e.into_inner())
.iter()
.cloned()
.collect();
if urls.is_empty() { return; }
let payload = match serde_json::to_string(&event) {
Ok(p) => p,
Err(e) => { tracing::error!(error = %e, "failed to serialize model event"); return; }
};
for url in urls {
let body = payload.clone();
rt.spawn(async move {
let http = Client::builder()
.timeout(Duration::from_secs(10))
.build()
.expect("http client");
for attempt in 0..3_u32 {
match http.post(&url)
.header("content-type", "application/json")
.body(body.clone())
.send()
.await
{
Ok(r) if r.status().is_success() => {
tracing::debug!(url, "model event webhook delivered");
return;
}
Ok(r) => tracing::warn!(url, status = r.status().as_u16(), "webhook non-2xx"),
Err(e) => tracing::warn!(url, error = %e, attempt, "webhook delivery failed"),
}
if attempt < 2 {
tokio::time::sleep(Duration::from_secs(2u64.pow(attempt))).await;
}
}
tracing::error!(url, "model event webhook failed after 3 attempts");
});
}
}
fn parse_oom_vram(msg: &str, gpu_device: u32) -> (u64, u64) {
let needed = msg
.split_whitespace()
.zip(msg.split_whitespace().skip(1))
.find(|(_, next)| *next == "MiB")
.and_then(|(n, _)| n.parse::<f64>().ok())
.map(|v| v as u64)
.unwrap_or(0);
let free = std::process::Command::new("nvidia-smi")
.args([
&format!("--id={gpu_device}"),
"--query-gpu=memory.free",
"--format=csv,noheader,nounits",
])
.output()
.ok()
.and_then(|o| String::from_utf8(o.stdout).ok())
.and_then(|s| s.trim().parse::<u64>().ok())
.unwrap_or(0);
(needed, free)
}
// ── Async job runner ──────────────────────────────────────────────────────────
async fn run(
mut job_rx: mpsc::UnboundedReceiver<JobId>,
storage: Arc<Storage>,
queue_depth: Arc<AtomicUsize>,
registry: ProgressRegistry,
tx_req: std::sync::mpsc::Sender<TranscribeRequest>,
cmd_tx: std::sync::mpsc::SyncSender<WorkerCmd>,
) {
let http = Client::builder()
.timeout(std::time::Duration::from_secs(30))
.timeout(Duration::from_secs(30))
.build()
.expect("failed to build reqwest client");
@@ -140,7 +443,7 @@ async fn run(
let audio_path = audio_path_for(&job_id);
let result = process_job(&job, &audio_path, &progress_tx, &tx_req, &storage).await;
let result = process_job(&job, &audio_path, &progress_tx, &cmd_tx, &storage).await;
let _ = tokio::fs::remove_file(&audio_path).await;
@@ -175,26 +478,18 @@ async fn run(
tokio::spawn(async move { webhook::fire(&http, &url, &job).await; });
}
tokio::time::sleep(std::time::Duration::from_secs(30)).await;
tokio::time::sleep(Duration::from_secs(30)).await;
registry.remove(&job_id);
}
}
// ── Silence-based chunking ────────────────────────────────────────────────────
/// Target chunk length. 60s ≈ 2× whisper's native 30s window — short enough
/// that a hallucinated phrase can't compound beyond a single window.
const TARGET_CHUNK_SECS: f32 = 60.0;
/// How far from the target we'll snap to a silence midpoint.
const SNAP_WINDOW_SECS: f32 = 30.0;
/// Silence below this level (dB) counts as a split candidate.
const SILENCE_DB: &str = "-35dB";
/// Minimum silence duration to register as a candidate split.
const SILENCE_DUR: &str = "0.4";
/// Detect silence periods and return the midpoint (seconds) of each.
/// On any error (ffmpeg missing, binary format, etc.) returns an empty vec
/// so the caller can fall back to hard cuts.
async fn detect_silence_midpoints(path: &std::path::Path) -> Vec<f32> {
use tokio::process::Command;
@@ -217,7 +512,6 @@ async fn detect_silence_midpoints(path: &std::path::Path) -> Vec<f32> {
}
};
// silencedetect logs to stderr
let stderr = String::from_utf8_lossy(&output.stderr);
let mut starts: Vec<f32> = Vec::new();
let mut ends: Vec<f32> = Vec::new();
@@ -228,7 +522,6 @@ async fn detect_silence_midpoints(path: &std::path::Path) -> Vec<f32> {
starts.push(t);
}
} else if let Some(i) = line.find("silence_end: ") {
// Format: "silence_end: 12.34 | silence_duration: 0.56"
let t_str = line[i + "silence_end: ".len()..]
.split(" |")
.next()
@@ -248,10 +541,6 @@ async fn detect_silence_midpoints(path: &std::path::Path) -> Vec<f32> {
mids
}
/// Build cut points every `target_secs`, snapping to the nearest silence
/// midpoint within `snap_window` when one exists; otherwise a hard cut.
/// Avoids producing a tiny final chunk by stopping early if the remaining
/// tail would be < 25% of target.
fn snap_to_silence(
mids: &[f32],
total_secs: f32,
@@ -263,13 +552,9 @@ fn snap_to_silence(
while pos < total_secs - target_secs * 0.25 {
let prev_cut = cuts.last().copied().unwrap_or(0.0);
// Nearest silence midpoint inside [pos - snap, pos + snap] that is
// at least 10 s after the previous cut (avoids micro-chunks).
let best = mids.iter().copied()
.filter(|&t| t > prev_cut + 10.0 && (t - pos).abs() <= snap_window)
.min_by(|a, b| (a - pos).abs().partial_cmp(&(b - pos).abs()).unwrap());
let cut = best.unwrap_or(pos);
cuts.push(cut);
pos = cut + target_secs;
@@ -278,7 +563,6 @@ fn snap_to_silence(
cuts
}
/// Convert cut points into (start_secs, end_secs) chunk pairs.
fn to_chunk_ranges(cuts: &[f32], total_secs: f32) -> Vec<(f32, f32)> {
let mut ranges = Vec::new();
let mut start = 0.0_f32;
@@ -289,7 +573,6 @@ fn to_chunk_ranges(cuts: &[f32], total_secs: f32) -> Vec<(f32, f32)> {
start = cut;
}
}
// Last chunk
if total_secs - start >= 1.0 {
ranges.push((start, total_secs));
}
@@ -302,17 +585,13 @@ async fn process_job(
job: &Job,
audio_path: &std::path::Path,
progress_tx: &ProgressTx,
tx_req: &std::sync::mpsc::Sender<TranscribeRequest>,
cmd_tx: &std::sync::mpsc::SyncSender<WorkerCmd>,
storage: &Arc<Storage>,
) -> crate::Result<(Vec<Segment>, String, f32)> {
// 1. Decode full audio to 16 kHz mono PCM.
let pcm = decode_audio(audio_path).await?;
let total_secs = pcm.len() as f32 / 16_000.0;
// 2. Detect silence midpoints from original file.
let silence_mids = detect_silence_midpoints(audio_path).await;
// 3. Build silence-snapped chunk boundaries.
let cuts = snap_to_silence(&silence_mids, total_secs, TARGET_CHUNK_SECS, SNAP_WINDOW_SECS);
let chunks = to_chunk_ranges(&cuts, total_secs);
let n = chunks.len();
@@ -324,7 +603,6 @@ async fn process_job(
"audio chunked by silence"
);
// 4. Transcribe each chunk, applying a time offset to all timestamps.
let mut all_segments: Vec<Segment> = Vec::new();
let mut language = String::new();
@@ -334,11 +612,9 @@ async fn process_job(
let mut chunk_pcm = pcm[s0..s1].to_vec();
trim_trailing_silence(&mut chunk_pcm);
// Base percent this chunk starts at.
let base = (ci * 100 / n) as u8;
let span = (100usize / n).max(1) as u8;
// Emit a progress event and persist it at the start of every chunk.
let _ = progress_tx.send(ProgressEvent::Progress {
percent: base,
chunk: ci + 1,
@@ -350,8 +626,7 @@ async fn process_job(
tracing::warn!(error = %e, "failed to persist mid-job progress");
}
// Scale whisper's per-chunk 0100 into the job's overall range.
let tx = progress_tx.clone();
let tx = progress_tx.clone();
let chunk_num = ci + 1;
let on_progress = Box::new(move |p: u8| {
let overall = base.saturating_add(p.saturating_mul(span) / 100);
@@ -363,18 +638,17 @@ async fn process_job(
});
let (reply_tx, reply_rx) = oneshot::channel();
tx_req.send(TranscribeRequest {
cmd_tx.send(WorkerCmd::Transcribe(TranscribeRequest {
pcm: chunk_pcm,
language: job.language.clone(),
task: job.task.clone(),
on_progress,
reply: reply_tx,
}).map_err(|_| crate::AppError::Internal("transcriber thread gone".into()))?;
})).map_err(|_| AppError::Internal("worker command channel closed".into()))?;
let (mut segs, lang) = reply_rx.await
.map_err(|_| crate::AppError::Internal("transcriber thread dropped reply".into()))??;
.map_err(|_| AppError::Internal("transcriber thread dropped reply".into()))??;
// Shift all timestamps by chunk offset.
let offset = *chunk_start;
for seg in &mut segs {
seg.start += offset;
@@ -400,7 +674,6 @@ async fn process_job(
}
}
// Renumber segment indices across the merged output.
for (i, seg) in all_segments.iter_mut().enumerate() {
seg.index = i as i32;
}
@@ -409,14 +682,9 @@ async fn process_job(
Ok((all_segments, language, total_secs))
}
/// Trim trailing silence from a 16 kHz mono PCM buffer.
///
/// Scans backwards to find the last sample above 35 dB, then keeps
/// 0.5 s of padding after it. This prevents whisper from hallucinating
/// filler tokens into end-of-chunk silence.
fn trim_trailing_silence(pcm: &mut Vec<f32>) {
const THRESHOLD: f32 = 0.017_8; // 35 dB (10^(35/20))
const PADDING: usize = 8_000; // 0.5 s at 16 kHz
const THRESHOLD: f32 = 0.017_8;
const PADDING: usize = 8_000;
if let Some(last_loud) = pcm.iter().rposition(|&s| s.abs() > THRESHOLD) {
let new_len = (last_loud + 1 + PADDING).min(pcm.len());
@@ -429,10 +697,8 @@ fn trim_trailing_silence(pcm: &mut Vec<f32>) {
pcm.truncate(new_len);
}
}
// All-silent chunk: keep as-is — whisper will produce zero segments, which is correct.
}
/// Decode any audio file to 16 kHz mono PCM f32 using ffmpeg.
async fn decode_audio(path: &std::path::Path) -> crate::Result<Vec<f32>> {
use tokio::process::Command;
@@ -447,11 +713,11 @@ async fn decode_audio(path: &std::path::Path) -> crate::Result<Vec<f32>> {
])
.output()
.await
.map_err(|e| crate::AppError::Internal(format!("ffmpeg spawn failed: {e}")))?;
.map_err(|e| AppError::Internal(format!("ffmpeg spawn failed: {e}")))?;
if !output.status.success() {
let stderr = String::from_utf8_lossy(&output.stderr);
return Err(crate::AppError::Internal(format!(
return Err(AppError::Internal(format!(
"ffmpeg exited with {}: {}",
output.status, stderr
)));
@@ -459,7 +725,7 @@ async fn decode_audio(path: &std::path::Path) -> crate::Result<Vec<f32>> {
let bytes = output.stdout;
if bytes.len() % 4 != 0 {
return Err(crate::AppError::Internal(
return Err(AppError::Internal(
"ffmpeg output length not a multiple of 4".into(),
));
}
@@ -473,3 +739,51 @@ pub fn audio_path_for(id: &JobId) -> PathBuf {
let data_dir = std::env::var("DATA_DIR").unwrap_or_else(|_| "/data".into());
PathBuf::from(data_dir).join(format!("{id}.audio"))
}
// ── Unit tests ────────────────────────────────────────────────────────────────
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_snap_to_silence_uses_nearest_midpoint() {
let mids = vec![55.0, 58.0, 62.0];
let cuts = snap_to_silence(&mids, 120.0, 60.0, 30.0);
assert!(!cuts.is_empty());
assert!((cuts[0] - 58.0).abs() < 0.01, "expected ~58.0, got {}", cuts[0]);
}
#[test]
fn test_snap_to_silence_hard_cut_when_no_silence() {
let cuts = snap_to_silence(&[], 120.0, 60.0, 30.0);
assert_eq!(cuts, vec![60.0]);
}
#[test]
fn test_to_chunk_ranges_single_chunk() {
let ranges = to_chunk_ranges(&[], 30.0);
assert_eq!(ranges, vec![(0.0, 30.0)]);
}
#[test]
fn test_to_chunk_ranges_two_chunks() {
let ranges = to_chunk_ranges(&[60.0], 120.0);
assert_eq!(ranges, vec![(0.0, 60.0), (60.0, 120.0)]);
}
#[test]
fn test_trim_trailing_silence_all_silent() {
let mut pcm = vec![0.0f32; 1000];
trim_trailing_silence(&mut pcm);
assert_eq!(pcm.len(), 1000);
}
#[test]
fn test_trim_trailing_silence_trims_to_padding() {
let mut pcm = vec![0.0f32; 32_000];
pcm[10_000] = 1.0;
trim_trailing_silence(&mut pcm);
assert_eq!(pcm.len(), (10_001 + 8_000).min(32_000));
}
}