Files
whisper-rtx2080/src/worker.rs
mozempk b191fbe200
All checks were successful
Build & Push Docker Image / build-and-push (push) Successful in 8m41s
feat: dynamic model loading/unloading with GPU polling
- Model starts unloaded (lazy); loads on first job or POST /model/load
- Auto-unloads after IDLE_TIMEOUT_SECS (default 300) of inactivity
- POST /model/unload for immediate manual release
- GPU-busy detection: on VRAM OOM, enters WaitingForGpu and retries
  every GPU_POLL_INTERVAL_SECS (default 30) indefinitely
- POST /jobs when unloaded → 503 + Retry-After header, triggers load
- AppError::OutOfMemory and AppError::ModelNotReady variants
- WorkerCmd channel (SyncSender<WorkerCmd>) replaces bare tx_req channel
- Idle timer via recv_timeout(1s) tick inside OS thread (no extra thread)
- Model lifecycle events broadcast via tokio broadcast channel (SSE + webhooks)
- webhook_registry: all clients that ever submitted a webhook_url receive
  model_ready and model_unloaded webhooks
- GPU warmup retained on every (re)load

New routes:
  GET  /model/status  — current state + VRAM stats
  POST /model/load    — trigger load (idempotent)
  POST /model/unload  — immediate unload
  GET  /model/events  — SSE stream of model lifecycle events

New env vars:
  IDLE_TIMEOUT_SECS       (default 300)
  GPU_POLL_INTERVAL_SECS  (default 30)

Tests:
  tests/test_model_lifecycle.sh — 18 integration tests (full state machine,
    SSE events, webhooks, concurrency, unload-during-load)
  tests/test_idle_timeout.sh    — 5 tests with short IDLE_TIMEOUT_SECS=5
  test_all.sh updated: loads model before job submission, asserts
    model_state in /health, adds POST /model/unload at end

Docs:
  docs/USAGE.md: model lifecycle section, new env vars, 503 retry pattern,
    updated /health response shape

Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
2026-05-08 17:57:20 +02:00

790 lines
27 KiB
Rust
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
use std::{
collections::HashSet,
path::PathBuf,
sync::{
atomic::{AtomicUsize, Ordering},
Arc, Mutex,
},
time::{Duration, Instant},
};
use chrono::Utc;
use reqwest::Client;
use tokio::sync::{broadcast, mpsc, oneshot, RwLock};
use crate::{
models::{Job, JobId, JobStatus, ModelEvent, ModelState, Segment},
storage::Storage,
transcriber::Transcriber,
webhook,
AppError,
};
/// Per-job broadcast channel for SSE subscribers.
pub type ProgressTx = broadcast::Sender<ProgressEvent>;
#[derive(Debug, Clone)]
pub enum ProgressEvent {
/// `percent` — overall 0100; `chunk` — 1-based; `total` — total chunks.
Progress { percent: u8, chunk: usize, total: usize },
Done(Box<Job>),
Error(String),
}
/// Global registry: job_id → broadcast sender.
pub type ProgressRegistry = Arc<dashmap::DashMap<JobId, ProgressTx>>;
// ── Worker command channel ────────────────────────────────────────────────────
/// Commands sent to the GPU worker OS thread.
#[derive(Debug)]
pub enum WorkerCmd {
/// Request a model load. Idempotent: if already loading/ready, ignored.
Load,
/// Unload the model immediately and free GPU memory.
Unload,
/// Internal: run a transcription chunk.
Transcribe(TranscribeRequest),
}
// ── Transcription request/response types ─────────────────────────────────────
pub struct TranscribeRequest {
pub pcm: Vec<f32>,
pub language: Option<String>,
pub task: String,
pub on_progress: Box<dyn Fn(u8) + Send + 'static>,
pub reply: oneshot::Sender<crate::Result<(Vec<Segment>, String)>>,
}
impl std::fmt::Debug for TranscribeRequest {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("TranscribeRequest")
.field("language", &self.language)
.field("task", &self.task)
.finish_non_exhaustive()
}
}
// ── Public API ────────────────────────────────────────────────────────────────
/// Spawn the single GPU worker.
///
/// Returns the SSE progress registry and a command sender for the worker thread.
/// The model starts **unloaded**; send `WorkerCmd::Load` or submit a job to
/// trigger loading.
#[allow(clippy::too_many_arguments)]
pub fn start(
job_rx: mpsc::UnboundedReceiver<JobId>,
storage: Arc<Storage>,
model_path: PathBuf,
queue_depth: Arc<AtomicUsize>,
gpu_device: u32,
model_state: Arc<RwLock<ModelState>>,
model_event_tx: broadcast::Sender<ModelEvent>,
webhook_registry: Arc<Mutex<HashSet<String>>>,
idle_timeout: Duration,
gpu_poll_interval: Duration,
) -> (ProgressRegistry, std::sync::mpsc::SyncSender<WorkerCmd>) {
let registry: ProgressRegistry = Arc::new(dashmap::DashMap::new());
let reg_clone = Arc::clone(&registry);
// Bounded sync channel: capacity 8 is plenty (load/unload are rare).
let (cmd_tx, cmd_rx) = std::sync::mpsc::sync_channel::<WorkerCmd>(8);
let cmd_tx_clone = cmd_tx.clone();
// Capture Tokio runtime handle so the OS thread can spawn async tasks.
let rt_handle = tokio::runtime::Handle::current();
std::thread::Builder::new()
.name("whisper-gpu".into())
.spawn(move || {
transcriber_thread(
cmd_rx,
model_path,
gpu_device,
model_state,
model_event_tx,
webhook_registry,
idle_timeout,
gpu_poll_interval,
rt_handle,
);
})
.expect("failed to spawn whisper-gpu thread");
tokio::spawn(run(job_rx, storage, queue_depth, reg_clone, cmd_tx_clone));
(registry, cmd_tx)
}
// ── GPU OS thread ─────────────────────────────────────────────────────────────
/// The worker OS thread that owns the `Transcriber` (non-`Send`).
///
/// Uses `recv_timeout` with a 1-second tick to drive the idle timer without a
/// separate thread.
#[allow(clippy::too_many_arguments)]
fn transcriber_thread(
rx: std::sync::mpsc::Receiver<WorkerCmd>,
model_path: PathBuf,
gpu_device: u32,
model_state: Arc<RwLock<ModelState>>,
model_event_tx: broadcast::Sender<ModelEvent>,
webhook_registry: Arc<Mutex<HashSet<String>>>,
idle_timeout: Duration,
gpu_poll_interval: Duration,
rt: tokio::runtime::Handle,
) {
let mut transcriber: Option<Transcriber> = None;
let mut last_job = Instant::now();
loop {
match rx.recv_timeout(Duration::from_secs(1)) {
Ok(WorkerCmd::Load) => {
if transcriber.is_some() {
tracing::debug!("WorkerCmd::Load ignored — model already loaded");
continue;
}
transcriber = try_load_with_polling(
&rx,
&model_path,
gpu_device,
&model_state,
&model_event_tx,
&webhook_registry,
gpu_poll_interval,
&rt,
);
if transcriber.is_some() {
last_job = Instant::now();
}
}
Ok(WorkerCmd::Unload) => {
do_unload(&mut transcriber, &model_state, &model_event_tx, &webhook_registry, &rt);
}
Ok(WorkerCmd::Transcribe(req)) => {
let t = match &mut transcriber {
Some(t) => t,
None => {
tracing::warn!("Transcribe cmd received but model is unloaded — failing job");
let _ = req.reply.send(Err(AppError::Internal(
"model unloaded before job could run".into(),
)));
continue;
}
};
let result = t.transcribe(
&req.pcm,
req.language.as_deref(),
&req.task,
move |p| (req.on_progress)(p),
);
last_job = Instant::now();
let _ = req.reply.send(result);
}
Err(std::sync::mpsc::RecvTimeoutError::Timeout) => {
if transcriber.is_some() && last_job.elapsed() >= idle_timeout {
tracing::info!(
elapsed_secs = last_job.elapsed().as_secs(),
"idle timeout reached — unloading model"
);
do_unload(
&mut transcriber,
&model_state,
&model_event_tx,
&webhook_registry,
&rt,
);
}
}
Err(std::sync::mpsc::RecvTimeoutError::Disconnected) => {
tracing::info!("worker command channel closed — shutting down GPU thread");
break;
}
}
}
}
/// Attempt to load the model, polling on VRAM failures.
///
/// While waiting for GPU, drains `rx` so that `WorkerCmd::Unload` cancels the
/// load attempt and `WorkerCmd::Transcribe` commands get a "model not ready"
/// rejection. Returns `Some(Transcriber)` on success, `None` if cancelled.
#[allow(clippy::too_many_arguments)]
fn try_load_with_polling(
rx: &std::sync::mpsc::Receiver<WorkerCmd>,
model_path: &PathBuf,
gpu_device: u32,
model_state: &Arc<RwLock<ModelState>>,
model_event_tx: &broadcast::Sender<ModelEvent>,
webhook_registry: &Arc<Mutex<HashSet<String>>>,
gpu_poll_interval: Duration,
rt: &tokio::runtime::Handle,
) -> Option<Transcriber> {
loop {
set_state(model_state, ModelState::Loading);
broadcast_event(model_event_tx, ModelEvent::ModelLoading);
tracing::info!("loading whisper model...");
match Transcriber::load(model_path, gpu_device) {
Ok(t) => {
let loaded_at = Utc::now();
set_state(model_state, ModelState::Ready { loaded_at });
broadcast_event(model_event_tx, ModelEvent::ModelReady { loaded_at });
fire_webhooks(webhook_registry, ModelEvent::ModelReady { loaded_at }, rt);
tracing::info!("model loaded and ready");
return Some(t);
}
Err(AppError::OutOfMemory(msg)) => {
let (vram_needed_mb, vram_free_mb) = parse_oom_vram(&msg, gpu_device);
let retry_in_secs = gpu_poll_interval.as_secs();
tracing::warn!(
vram_needed_mb,
vram_free_mb,
retry_in_secs,
"insufficient VRAM — will retry"
);
set_state(model_state, ModelState::WaitingForGpu {
vram_needed_mb,
vram_free_mb,
retry_in_secs,
});
broadcast_event(model_event_tx, ModelEvent::ModelWaitingForGpu {
vram_needed_mb,
vram_free_mb,
retry_in_secs,
});
// Interruptible sleep: drain rx while waiting for gpu_poll_interval.
let deadline = Instant::now() + gpu_poll_interval;
loop {
let remaining = deadline.saturating_duration_since(Instant::now());
if remaining.is_zero() { break; }
match rx.recv_timeout(remaining.min(Duration::from_secs(1))) {
Ok(WorkerCmd::Unload) => {
tracing::info!("Unload received while waiting for GPU — cancelling load");
set_state(model_state, ModelState::Unloaded);
broadcast_event(model_event_tx, ModelEvent::ModelUnloaded);
fire_webhooks(webhook_registry, ModelEvent::ModelUnloaded, rt);
return None;
}
Ok(WorkerCmd::Load) => {} // idempotent
Ok(WorkerCmd::Transcribe(req)) => {
let _ = req.reply.send(Err(AppError::ModelNotReady {
state: "waiting_for_gpu".into(),
retry_after_secs: retry_in_secs,
}));
}
Err(std::sync::mpsc::RecvTimeoutError::Timeout) => {}
Err(std::sync::mpsc::RecvTimeoutError::Disconnected) => return None,
}
}
// Loop back to retry load
}
Err(e) => {
tracing::error!(error = %e, "model load failed with non-recoverable error");
set_state(model_state, ModelState::Unloaded);
return None;
}
}
}
}
fn do_unload(
transcriber: &mut Option<Transcriber>,
model_state: &Arc<RwLock<ModelState>>,
model_event_tx: &broadcast::Sender<ModelEvent>,
webhook_registry: &Arc<Mutex<HashSet<String>>>,
rt: &tokio::runtime::Handle,
) {
*transcriber = None;
set_state(model_state, ModelState::Unloaded);
broadcast_event(model_event_tx, ModelEvent::ModelUnloaded);
fire_webhooks(webhook_registry, ModelEvent::ModelUnloaded, rt);
tracing::info!("model unloaded — GPU memory freed");
}
// ── Helpers ───────────────────────────────────────────────────────────────────
fn set_state(arc: &Arc<RwLock<ModelState>>, state: ModelState) {
*arc.blocking_write() = state;
}
fn broadcast_event(tx: &broadcast::Sender<ModelEvent>, event: ModelEvent) {
let _ = tx.send(event);
}
fn fire_webhooks(
registry: &Arc<Mutex<HashSet<String>>>,
event: ModelEvent,
rt: &tokio::runtime::Handle,
) {
if !event.is_webhook_event() {
return;
}
let urls: Vec<String> = registry
.lock()
.unwrap_or_else(|e| e.into_inner())
.iter()
.cloned()
.collect();
if urls.is_empty() { return; }
let payload = match serde_json::to_string(&event) {
Ok(p) => p,
Err(e) => { tracing::error!(error = %e, "failed to serialize model event"); return; }
};
for url in urls {
let body = payload.clone();
rt.spawn(async move {
let http = Client::builder()
.timeout(Duration::from_secs(10))
.build()
.expect("http client");
for attempt in 0..3_u32 {
match http.post(&url)
.header("content-type", "application/json")
.body(body.clone())
.send()
.await
{
Ok(r) if r.status().is_success() => {
tracing::debug!(url, "model event webhook delivered");
return;
}
Ok(r) => tracing::warn!(url, status = r.status().as_u16(), "webhook non-2xx"),
Err(e) => tracing::warn!(url, error = %e, attempt, "webhook delivery failed"),
}
if attempt < 2 {
tokio::time::sleep(Duration::from_secs(2u64.pow(attempt))).await;
}
}
tracing::error!(url, "model event webhook failed after 3 attempts");
});
}
}
fn parse_oom_vram(msg: &str, gpu_device: u32) -> (u64, u64) {
let needed = msg
.split_whitespace()
.zip(msg.split_whitespace().skip(1))
.find(|(_, next)| *next == "MiB")
.and_then(|(n, _)| n.parse::<f64>().ok())
.map(|v| v as u64)
.unwrap_or(0);
let free = std::process::Command::new("nvidia-smi")
.args([
&format!("--id={gpu_device}"),
"--query-gpu=memory.free",
"--format=csv,noheader,nounits",
])
.output()
.ok()
.and_then(|o| String::from_utf8(o.stdout).ok())
.and_then(|s| s.trim().parse::<u64>().ok())
.unwrap_or(0);
(needed, free)
}
// ── Async job runner ──────────────────────────────────────────────────────────
async fn run(
mut job_rx: mpsc::UnboundedReceiver<JobId>,
storage: Arc<Storage>,
queue_depth: Arc<AtomicUsize>,
registry: ProgressRegistry,
cmd_tx: std::sync::mpsc::SyncSender<WorkerCmd>,
) {
let http = Client::builder()
.timeout(Duration::from_secs(30))
.build()
.expect("failed to build reqwest client");
while let Some(job_id) = job_rx.recv().await {
queue_depth.fetch_sub(1, Ordering::Relaxed);
let mut job = match storage.get(&job_id).await {
Ok(j) => j,
Err(e) => {
tracing::warn!(job_id = %job_id, error = %e, "job vanished before processing");
registry.remove(&job_id);
continue;
}
};
if job.status == JobStatus::Cancelled {
registry.remove(&job_id);
continue;
}
job.status = JobStatus::Running;
if let Err(e) = storage.save(&job).await {
tracing::error!(job_id = %job_id, error = %e, "failed to persist running status");
}
let progress_tx = registry
.entry(job_id)
.or_insert_with(|| broadcast::channel(64).0)
.clone();
let audio_path = audio_path_for(&job_id);
let result = process_job(&job, &audio_path, &progress_tx, &cmd_tx, &storage).await;
let _ = tokio::fs::remove_file(&audio_path).await;
match result {
Ok((segments, language, duration_secs)) => {
job.status = JobStatus::Done;
job.segments = segments;
job.language = Some(language);
job.duration_secs = Some(duration_secs);
job.progress = 100;
job.completed_at = Some(Utc::now());
let _ = progress_tx.send(ProgressEvent::Done(Box::new(job.clone())));
}
Err(e) => {
let msg = e.to_string();
tracing::error!(job_id = %job_id, error = %msg, "transcription failed");
job.status = JobStatus::Failed;
job.error = Some(msg.clone());
job.completed_at = Some(Utc::now());
let _ = progress_tx.send(ProgressEvent::Error(msg));
}
}
if let Err(e) = storage.save(&job).await {
tracing::error!(job_id = %job_id, error = %e, "failed to persist final job state");
}
if let Some(url) = &job.webhook_url.clone() {
let http = http.clone();
let url = url.clone();
let job = job.clone();
tokio::spawn(async move { webhook::fire(&http, &url, &job).await; });
}
tokio::time::sleep(Duration::from_secs(30)).await;
registry.remove(&job_id);
}
}
// ── Silence-based chunking ────────────────────────────────────────────────────
const TARGET_CHUNK_SECS: f32 = 60.0;
const SNAP_WINDOW_SECS: f32 = 30.0;
const SILENCE_DB: &str = "-35dB";
const SILENCE_DUR: &str = "0.4";
async fn detect_silence_midpoints(path: &std::path::Path) -> Vec<f32> {
use tokio::process::Command;
let filter = format!("silencedetect=n={}:d={}", SILENCE_DB, SILENCE_DUR);
let output = Command::new("ffmpeg")
.args([
"-nostdin",
"-i", path.to_str().unwrap_or(""),
"-af", &filter,
"-f", "null", "-",
])
.output()
.await;
let output = match output {
Ok(o) => o,
Err(e) => {
tracing::warn!(error = %e, "silencedetect unavailable; using hard cuts");
return Vec::new();
}
};
let stderr = String::from_utf8_lossy(&output.stderr);
let mut starts: Vec<f32> = Vec::new();
let mut ends: Vec<f32> = Vec::new();
for line in stderr.lines() {
if let Some(i) = line.find("silence_start: ") {
if let Ok(t) = line[i + "silence_start: ".len()..].trim().parse::<f32>() {
starts.push(t);
}
} else if let Some(i) = line.find("silence_end: ") {
let t_str = line[i + "silence_end: ".len()..]
.split(" |")
.next()
.unwrap_or("")
.trim();
if let Ok(t) = t_str.parse::<f32>() {
ends.push(t);
}
}
}
let mids: Vec<f32> = starts.iter().zip(ends.iter())
.map(|(s, e)| (s + e) / 2.0)
.collect();
tracing::debug!(n = mids.len(), "silence midpoints detected");
mids
}
fn snap_to_silence(
mids: &[f32],
total_secs: f32,
target_secs: f32,
snap_window: f32,
) -> Vec<f32> {
let mut cuts: Vec<f32> = Vec::new();
let mut pos = target_secs;
while pos < total_secs - target_secs * 0.25 {
let prev_cut = cuts.last().copied().unwrap_or(0.0);
let best = mids.iter().copied()
.filter(|&t| t > prev_cut + 10.0 && (t - pos).abs() <= snap_window)
.min_by(|a, b| (a - pos).abs().partial_cmp(&(b - pos).abs()).unwrap());
let cut = best.unwrap_or(pos);
cuts.push(cut);
pos = cut + target_secs;
}
cuts
}
fn to_chunk_ranges(cuts: &[f32], total_secs: f32) -> Vec<(f32, f32)> {
let mut ranges = Vec::new();
let mut start = 0.0_f32;
for &cut in cuts {
if cut - start >= 5.0 {
ranges.push((start, cut));
start = cut;
}
}
if total_secs - start >= 1.0 {
ranges.push((start, total_secs));
}
ranges
}
// ── Job processing ────────────────────────────────────────────────────────────
async fn process_job(
job: &Job,
audio_path: &std::path::Path,
progress_tx: &ProgressTx,
cmd_tx: &std::sync::mpsc::SyncSender<WorkerCmd>,
storage: &Arc<Storage>,
) -> crate::Result<(Vec<Segment>, String, f32)> {
let pcm = decode_audio(audio_path).await?;
let total_secs = pcm.len() as f32 / 16_000.0;
let silence_mids = detect_silence_midpoints(audio_path).await;
let cuts = snap_to_silence(&silence_mids, total_secs, TARGET_CHUNK_SECS, SNAP_WINDOW_SECS);
let chunks = to_chunk_ranges(&cuts, total_secs);
let n = chunks.len();
tracing::info!(
total_secs,
n_chunks = n,
silence_points = silence_mids.len(),
"audio chunked by silence"
);
let mut all_segments: Vec<Segment> = Vec::new();
let mut language = String::new();
for (ci, (chunk_start, chunk_end)) in chunks.iter().enumerate() {
let s0 = (*chunk_start * 16_000.0) as usize;
let s1 = ((*chunk_end * 16_000.0) as usize).min(pcm.len());
let mut chunk_pcm = pcm[s0..s1].to_vec();
trim_trailing_silence(&mut chunk_pcm);
let base = (ci * 100 / n) as u8;
let span = (100usize / n).max(1) as u8;
let _ = progress_tx.send(ProgressEvent::Progress {
percent: base,
chunk: ci + 1,
total: n,
});
let mut snapshot = job.clone();
snapshot.progress = base;
if let Err(e) = storage.save(&snapshot).await {
tracing::warn!(error = %e, "failed to persist mid-job progress");
}
let tx = progress_tx.clone();
let chunk_num = ci + 1;
let on_progress = Box::new(move |p: u8| {
let overall = base.saturating_add(p.saturating_mul(span) / 100);
let _ = tx.send(ProgressEvent::Progress {
percent: overall,
chunk: chunk_num,
total: n,
});
});
let (reply_tx, reply_rx) = oneshot::channel();
cmd_tx.send(WorkerCmd::Transcribe(TranscribeRequest {
pcm: chunk_pcm,
language: job.language.clone(),
task: job.task.clone(),
on_progress,
reply: reply_tx,
})).map_err(|_| AppError::Internal("worker command channel closed".into()))?;
let (mut segs, lang) = reply_rx.await
.map_err(|_| AppError::Internal("transcriber thread dropped reply".into()))??;
let offset = *chunk_start;
for seg in &mut segs {
seg.start += offset;
seg.end += offset;
for word in &mut seg.words {
word.start += offset;
word.end += offset;
}
}
tracing::debug!(
chunk = ci + 1,
of = n,
start = chunk_start,
end = chunk_end,
segs = segs.len(),
"chunk done"
);
all_segments.extend(segs);
if language.is_empty() {
language = lang;
}
}
for (i, seg) in all_segments.iter_mut().enumerate() {
seg.index = i as i32;
}
let _ = progress_tx.send(ProgressEvent::Progress { percent: 100, chunk: n, total: n });
Ok((all_segments, language, total_secs))
}
fn trim_trailing_silence(pcm: &mut Vec<f32>) {
const THRESHOLD: f32 = 0.017_8;
const PADDING: usize = 8_000;
if let Some(last_loud) = pcm.iter().rposition(|&s| s.abs() > THRESHOLD) {
let new_len = (last_loud + 1 + PADDING).min(pcm.len());
if new_len < pcm.len() {
tracing::trace!(
original_samples = pcm.len(),
trimmed_samples = pcm.len() - new_len,
"trimmed trailing silence"
);
pcm.truncate(new_len);
}
}
}
async fn decode_audio(path: &std::path::Path) -> crate::Result<Vec<f32>> {
use tokio::process::Command;
let output = Command::new("ffmpeg")
.args([
"-nostdin", "-threads", "0",
"-i", path.to_str().unwrap_or(""),
"-f", "f32le",
"-ac", "1",
"-ar", "16000",
"-",
])
.output()
.await
.map_err(|e| AppError::Internal(format!("ffmpeg spawn failed: {e}")))?;
if !output.status.success() {
let stderr = String::from_utf8_lossy(&output.stderr);
return Err(AppError::Internal(format!(
"ffmpeg exited with {}: {}",
output.status, stderr
)));
}
let bytes = output.stdout;
if bytes.len() % 4 != 0 {
return Err(AppError::Internal(
"ffmpeg output length not a multiple of 4".into(),
));
}
Ok(bytes
.chunks_exact(4)
.map(|b| f32::from_le_bytes([b[0], b[1], b[2], b[3]]))
.collect())
}
pub fn audio_path_for(id: &JobId) -> PathBuf {
let data_dir = std::env::var("DATA_DIR").unwrap_or_else(|_| "/data".into());
PathBuf::from(data_dir).join(format!("{id}.audio"))
}
// ── Unit tests ────────────────────────────────────────────────────────────────
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_snap_to_silence_uses_nearest_midpoint() {
let mids = vec![55.0, 58.0, 62.0];
let cuts = snap_to_silence(&mids, 120.0, 60.0, 30.0);
assert!(!cuts.is_empty());
assert!((cuts[0] - 58.0).abs() < 0.01, "expected ~58.0, got {}", cuts[0]);
}
#[test]
fn test_snap_to_silence_hard_cut_when_no_silence() {
let cuts = snap_to_silence(&[], 120.0, 60.0, 30.0);
assert_eq!(cuts, vec![60.0]);
}
#[test]
fn test_to_chunk_ranges_single_chunk() {
let ranges = to_chunk_ranges(&[], 30.0);
assert_eq!(ranges, vec![(0.0, 30.0)]);
}
#[test]
fn test_to_chunk_ranges_two_chunks() {
let ranges = to_chunk_ranges(&[60.0], 120.0);
assert_eq!(ranges, vec![(0.0, 60.0), (60.0, 120.0)]);
}
#[test]
fn test_trim_trailing_silence_all_silent() {
let mut pcm = vec![0.0f32; 1000];
trim_trailing_silence(&mut pcm);
assert_eq!(pcm.len(), 1000);
}
#[test]
fn test_trim_trailing_silence_trims_to_padding() {
let mut pcm = vec![0.0f32; 32_000];
pcm[10_000] = 1.0;
trim_trailing_silence(&mut pcm);
assert_eq!(pcm.len(), (10_001 + 8_000).min(32_000));
}
}