From cb0b07b2ffefc52d3dd25bec174fc7f1e0f50cc3 Mon Sep 17 00:00:00 2001 From: Giancarmine Salucci Date: Mon, 11 May 2026 22:46:38 +0200 Subject: [PATCH] fix(worker): collapse incremental segments Normalize rolling partial-hypothesis chains before final job persistence so downstream clients receive stable transcript segments instead of echoed continuations. Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- src/error.rs | 54 +++-- src/main.rs | 18 +- src/models.rs | 136 ++++++++---- src/routes/health.rs | 12 +- src/routes/jobs.rs | 116 ++++++---- src/routes/mod.rs | 20 +- src/routes/model.rs | 54 +++-- src/storage.rs | 42 ++-- src/transcriber.rs | 90 ++++---- src/worker.rs | 501 ++++++++++++++++++++++++++++++++----------- 10 files changed, 712 insertions(+), 331 deletions(-) diff --git a/src/error.rs b/src/error.rs index 8a0d63b..d2e0821 100644 --- a/src/error.rs +++ b/src/error.rs @@ -1,10 +1,10 @@ -use thiserror::Error; use axum::{ - http::{StatusCode, HeaderValue, header}, + http::{header, HeaderValue, StatusCode}, response::{IntoResponse, Response}, Json, }; use serde_json::json; +use thiserror::Error; pub type Result = std::result::Result; @@ -31,7 +31,10 @@ pub enum AppError { /// Returned when a job is submitted but the model is not yet loaded. /// Carries the current state tag and recommended Retry-After seconds. #[error("model not ready: {state}")] - ModelNotReady { state: String, retry_after_secs: u64 }, + ModelNotReady { + state: String, + retry_after_secs: u64, + }, } impl AppError { @@ -59,13 +62,20 @@ impl IntoResponse for AppError { } AppError::Internal(m) => { tracing::error!(error = %m, "internal error"); - (StatusCode::INTERNAL_SERVER_ERROR, Json(json!({ "error": m }))).into_response() + ( + StatusCode::INTERNAL_SERVER_ERROR, + Json(json!({ "error": m })), + ) + .into_response() } AppError::OutOfMemory(m) => { tracing::warn!(error = %m, "GPU out of memory during model load"); (StatusCode::SERVICE_UNAVAILABLE, Json(json!({ "error": m }))).into_response() } - AppError::ModelNotReady { state, retry_after_secs } => { + AppError::ModelNotReady { + state, + retry_after_secs, + } => { let body = Json(json!({ "error": "model_not_ready", "state": state, @@ -117,17 +127,25 @@ mod tests { #[tokio::test] async fn test_model_not_ready_response_has_retry_after_header() { - let err = AppError::ModelNotReady { state: "loading".into(), retry_after_secs: 10 }; + let err = AppError::ModelNotReady { + state: "loading".into(), + retry_after_secs: 10, + }; let resp = err.into_response(); assert_eq!(resp.status(), StatusCode::SERVICE_UNAVAILABLE); - let retry_after = resp.headers().get(header::RETRY_AFTER) + let retry_after = resp + .headers() + .get(header::RETRY_AFTER) .expect("Retry-After header missing"); assert_eq!(retry_after, "10"); } #[tokio::test] async fn test_model_not_ready_response_body() { - let err = AppError::ModelNotReady { state: "unloaded".into(), retry_after_secs: 30 }; + let err = AppError::ModelNotReady { + state: "unloaded".into(), + retry_after_secs: 30, + }; let resp = err.into_response(); let bytes = to_bytes(resp.into_body(), usize::MAX).await.unwrap(); let v: serde_json::Value = serde_json::from_slice(&bytes).unwrap(); @@ -138,21 +156,21 @@ mod tests { #[tokio::test] async fn test_model_not_ready_loading_retry_after_10() { - let err = AppError::ModelNotReady { state: "loading".into(), retry_after_secs: 10 }; + let err = AppError::ModelNotReady { + state: "loading".into(), + retry_after_secs: 10, + }; let resp = err.into_response(); - assert_eq!( - resp.headers().get(header::RETRY_AFTER).unwrap(), - "10" - ); + assert_eq!(resp.headers().get(header::RETRY_AFTER).unwrap(), "10"); } #[tokio::test] async fn test_model_not_ready_unloaded_retry_after_30() { - let err = AppError::ModelNotReady { state: "unloaded".into(), retry_after_secs: 30 }; + let err = AppError::ModelNotReady { + state: "unloaded".into(), + retry_after_secs: 30, + }; let resp = err.into_response(); - assert_eq!( - resp.headers().get(header::RETRY_AFTER).unwrap(), - "30" - ); + assert_eq!(resp.headers().get(header::RETRY_AFTER).unwrap(), "30"); } } diff --git a/src/main.rs b/src/main.rs index 16f72c5..67d49ae 100644 --- a/src/main.rs +++ b/src/main.rs @@ -97,10 +97,10 @@ async fn main() -> anyhow::Result<()> { .with(tracing_subscriber::fmt::layer().json()) .init(); - let data_dir = std::env::var("DATA_DIR").unwrap_or_else(|_| "/data".into()); - let model_path = std::env::var("WHISPER_MODEL_PATH") - .unwrap_or_else(|_| "/models/ggml-large-v3.bin".into()); - let port = std::env::var("PORT").unwrap_or_else(|_| "8080".into()); + let data_dir = std::env::var("DATA_DIR").unwrap_or_else(|_| "/data".into()); + let model_path = + std::env::var("WHISPER_MODEL_PATH").unwrap_or_else(|_| "/models/ggml-large-v3.bin".into()); + let port = std::env::var("PORT").unwrap_or_else(|_| "8080".into()); let model_name = std::env::var("WHISPER_MODEL").unwrap_or_else(|_| "large-v3".into()); let gpu_device: u32 = std::env::var("CUDA_DEVICE") .ok() @@ -132,7 +132,9 @@ async fn main() -> anyhow::Result<()> { // Model starts unloaded — lazy load on first job or POST /model/load. let model_state = Arc::new(RwLock::new(models::ModelState::Unloaded)); let (model_event_tx, _) = broadcast::channel::(32); - let webhook_registry = Arc::new(std::sync::Mutex::new(std::collections::HashSet::::new())); + let webhook_registry = Arc::new(std::sync::Mutex::new( + std::collections::HashSet::::new(), + )); // Spawn single GPU worker; get back the SSE broadcast registry and cmd channel. let (progress, cmd_tx) = worker::start( @@ -153,13 +155,13 @@ async fn main() -> anyhow::Result<()> { cmd_tx, storage: Arc::clone(&storage), progress, - model_name: model_name.as_str().into(), - queue_depth: Arc::clone(&queue_depth), + model_name: model_name.as_str().into(), + queue_depth: Arc::clone(&queue_depth), gpu_device, model_state, model_event_tx, webhook_registry, - idle_timeout: std::time::Duration::from_secs(idle_timeout_secs), + idle_timeout: std::time::Duration::from_secs(idle_timeout_secs), gpu_poll_interval: std::time::Duration::from_secs(gpu_poll_interval_secs), }; diff --git a/src/models.rs b/src/models.rs index 0cbc759..771f2aa 100644 --- a/src/models.rs +++ b/src/models.rs @@ -48,20 +48,20 @@ impl ModelState { /// Suggested `Retry-After` value (seconds) to include in 503 responses. pub fn retry_after_secs(&self) -> u64 { match self { - ModelState::Unloaded => 30, // conservative load estimate - ModelState::Loading => 10, + ModelState::Unloaded => 30, // conservative load estimate + ModelState::Loading => 10, ModelState::WaitingForGpu { retry_in_secs, .. } => *retry_in_secs, - ModelState::Ready { .. } => 0, // shouldn't 503 if ready + ModelState::Ready { .. } => 0, // shouldn't 503 if ready } } /// String tag for use in error response bodies and log fields. pub fn tag(&self) -> &'static str { match self { - ModelState::Unloaded => "unloaded", - ModelState::Loading => "loading", - ModelState::WaitingForGpu{..} => "waiting_for_gpu", - ModelState::Ready{..} => "ready", + ModelState::Unloaded => "unloaded", + ModelState::Loading => "loading", + ModelState::WaitingForGpu { .. } => "waiting_for_gpu", + ModelState::Ready { .. } => "ready", } } } @@ -77,9 +77,7 @@ impl ModelState { #[serde(tag = "type", rename_all = "snake_case")] pub enum ModelEvent { /// Model finished loading and the GPU warmup completed — ready to accept jobs. - ModelReady { - loaded_at: DateTime, - }, + ModelReady { loaded_at: DateTime }, /// Model was unloaded from GPU memory (idle timeout or manual unload). ModelUnloaded, /// Model load initiated. @@ -87,15 +85,18 @@ pub enum ModelEvent { /// Load failed due to insufficient VRAM; retrying after `retry_in_secs`. ModelWaitingForGpu { vram_needed_mb: u64, - vram_free_mb: u64, - retry_in_secs: u64, + vram_free_mb: u64, + retry_in_secs: u64, }, } impl ModelEvent { /// Returns true if this event should be delivered via webhook. pub fn is_webhook_event(&self) -> bool { - matches!(self, ModelEvent::ModelReady { .. } | ModelEvent::ModelUnloaded) + matches!( + self, + ModelEvent::ModelReady { .. } | ModelEvent::ModelUnloaded + ) } } @@ -132,11 +133,11 @@ pub enum JobStatus { #[derive(Debug, Clone, Serialize, Deserialize, ToSchema)] pub struct Word { /// Word text - pub text: String, + pub text: String, /// Start time in seconds - pub start: f32, + pub start: f32, /// End time in seconds - pub end: f32, + pub end: f32, /// Model confidence (0–1) pub probability: f32, } @@ -148,9 +149,9 @@ pub struct Segment { /// Start time in seconds pub start: f32, /// End time in seconds - pub end: f32, + pub end: f32, /// Transcribed text - pub text: String, + pub text: String, /// Token-level word timestamps (empty when flash_attn is enabled) #[serde(default)] pub words: Vec, @@ -205,18 +206,23 @@ pub struct Job { } impl Job { - pub fn new(id: JobId, task: String, webhook_url: Option, filename: Option) -> Self { + pub fn new( + id: JobId, + task: String, + webhook_url: Option, + filename: Option, + ) -> Self { Self { id, - status: JobStatus::Queued, - language: None, + status: JobStatus::Queued, + language: None, task, duration_secs: None, - segments: vec![], - error: None, + segments: vec![], + error: None, webhook_url, - progress: 0, - created_at: Utc::now(), + progress: 0, + created_at: Utc::now(), completed_at: None, filename, } @@ -235,13 +241,13 @@ pub struct SubmitResponse { /// Response from GET /health. #[derive(Debug, Serialize, ToSchema)] pub struct HealthResponse { - pub status: String, - pub gpu_name: Option, + pub status: String, + pub gpu_name: Option, pub vram_total_mb: Option, - pub model: String, - pub queue_depth: usize, + pub model: String, + pub queue_depth: usize, /// Current state of the whisper model. - pub model_state: String, + pub model_state: String, } // ── SSE event payload ──────────────────────────────────────────────────────── @@ -257,8 +263,12 @@ pub enum SsePayload { /// Total number of silence-split chunks in this job. chunks_total: usize, }, - Done { job: Box }, - Error { message: String }, + Done { + job: Box, + }, + Error { + message: String, + }, } // ── Unit tests ─────────────────────────────────────────────────────────────── @@ -284,7 +294,11 @@ mod tests { #[test] fn test_model_state_waiting_serializes() { - let s = ModelState::WaitingForGpu { vram_needed_mb: 3000, vram_free_mb: 500, retry_in_secs: 30 }; + let s = ModelState::WaitingForGpu { + vram_needed_mb: 3000, + vram_free_mb: 500, + retry_in_secs: 30, + }; let v: Value = serde_json::to_value(&s).unwrap(); assert_eq!(v["state"], "waiting_for_gpu"); assert_eq!(v["vram_needed_mb"], 3000); @@ -305,8 +319,16 @@ mod tests { fn test_model_state_is_ready() { assert!(!ModelState::Unloaded.is_ready()); assert!(!ModelState::Loading.is_ready()); - assert!(!ModelState::WaitingForGpu { vram_needed_mb: 0, vram_free_mb: 0, retry_in_secs: 30 }.is_ready()); - assert!(ModelState::Ready { loaded_at: Utc::now() }.is_ready()); + assert!(!ModelState::WaitingForGpu { + vram_needed_mb: 0, + vram_free_mb: 0, + retry_in_secs: 30 + } + .is_ready()); + assert!(ModelState::Ready { + loaded_at: Utc::now() + } + .is_ready()); } #[test] @@ -321,13 +343,23 @@ mod tests { #[test] fn test_retry_after_waiting_for_gpu() { - let s = ModelState::WaitingForGpu { vram_needed_mb: 0, vram_free_mb: 0, retry_in_secs: 45 }; + let s = ModelState::WaitingForGpu { + vram_needed_mb: 0, + vram_free_mb: 0, + retry_in_secs: 45, + }; assert_eq!(s.retry_after_secs(), 45); } #[test] fn test_retry_after_ready_is_zero() { - assert_eq!(ModelState::Ready { loaded_at: Utc::now() }.retry_after_secs(), 0); + assert_eq!( + ModelState::Ready { + loaded_at: Utc::now() + } + .retry_after_secs(), + 0 + ); } // ── ModelEvent serialization ───────────────────────────────────────────── @@ -355,7 +387,11 @@ mod tests { #[test] fn test_model_event_waiting_serializes() { - let e = ModelEvent::ModelWaitingForGpu { vram_needed_mb: 3000, vram_free_mb: 200, retry_in_secs: 30 }; + let e = ModelEvent::ModelWaitingForGpu { + vram_needed_mb: 3000, + vram_free_mb: 200, + retry_in_secs: 30, + }; let v: Value = serde_json::to_value(&e).unwrap(); assert_eq!(v["type"], "model_waiting_for_gpu"); assert_eq!(v["vram_needed_mb"], 3000); @@ -363,10 +399,18 @@ mod tests { #[test] fn test_model_event_webhook_filter() { - assert!(ModelEvent::ModelReady { loaded_at: Utc::now() }.is_webhook_event()); + assert!(ModelEvent::ModelReady { + loaded_at: Utc::now() + } + .is_webhook_event()); assert!(ModelEvent::ModelUnloaded.is_webhook_event()); assert!(!ModelEvent::ModelLoading.is_webhook_event()); - assert!(!ModelEvent::ModelWaitingForGpu { vram_needed_mb: 0, vram_free_mb: 0, retry_in_secs: 30 }.is_webhook_event()); + assert!(!ModelEvent::ModelWaitingForGpu { + vram_needed_mb: 0, + vram_free_mb: 0, + retry_in_secs: 30 + } + .is_webhook_event()); } // ── ModelStatusResponse ────────────────────────────────────────────────── @@ -374,8 +418,10 @@ mod tests { #[test] fn test_model_status_response_roundtrip() { let r = ModelStatusResponse { - state: ModelState::Ready { loaded_at: Utc::now() }, - vram_used_mb: Some(4096), + state: ModelState::Ready { + loaded_at: Utc::now(), + }, + vram_used_mb: Some(4096), vram_total_mb: Some(8192), }; let json_str = serde_json::to_string(&r).unwrap(); @@ -387,7 +433,11 @@ mod tests { #[test] fn test_model_status_response_omits_nulls() { - let r = ModelStatusResponse { state: ModelState::Loading, vram_used_mb: None, vram_total_mb: None }; + let r = ModelStatusResponse { + state: ModelState::Loading, + vram_used_mb: None, + vram_total_mb: None, + }; let v: Value = serde_json::to_value(&r).unwrap(); assert_eq!(v["state"], "loading"); assert!(v.get("vram_used_mb").is_none()); diff --git a/src/routes/health.rs b/src/routes/health.rs index 7dcc0ab..5476b02 100644 --- a/src/routes/health.rs +++ b/src/routes/health.rs @@ -19,12 +19,12 @@ pub async fn health(State(state): State) -> Result (Option, Option) { let mut parts = line.splitn(2, ','); let name = parts.next().map(|s| s.trim().to_owned()); - let vram = parts - .next() - .and_then(|s| s.trim().parse::().ok()); + let vram = parts.next().and_then(|s| s.trim().parse::().ok()); (name, vram) } diff --git a/src/routes/jobs.rs b/src/routes/jobs.rs index b5db1e3..1fb114d 100644 --- a/src/routes/jobs.rs +++ b/src/routes/jobs.rs @@ -23,7 +23,8 @@ use crate::{ AppError, AppState, Result, }; -type SseStream = Pin> + Send>>; +type SseStream = + Pin> + Send>>; // ── POST /jobs ─────────────────────────────────────────────────────────────── @@ -53,18 +54,20 @@ pub async fn submit_job( State(state): State, mut multipart: Multipart, ) -> Result { - let mut language: Option = None; - let mut task: String = "transcribe".into(); + let mut language: Option = None; + let mut task: String = "transcribe".into(); let mut webhook_url: Option = None; - let mut filename: Option = None; + let mut filename: Option = None; let mut audio_saved = false; // Assign ID early so we know where to stream the audio bytes. let id = Uuid::new_v4(); let audio_path = audio_path_for(&id); - while let Some(field) = multipart.next_field().await.map_err(|e| { - AppError::BadRequest(format!("multipart error: {e}")) - })? { + while let Some(field) = multipart + .next_field() + .await + .map_err(|e| AppError::BadRequest(format!("multipart error: {e}")))? + { let field_name = field.name().unwrap_or("").to_owned(); match field_name.as_str() { @@ -77,9 +80,11 @@ pub async fn submit_job( })?; let mut bytes_written: u64 = 0; let mut stream = field; - while let Some(chunk) = stream.chunk().await.map_err(|e| { - AppError::BadRequest(format!("failed to read audio field: {e}")) - })? { + while let Some(chunk) = stream + .chunk() + .await + .map_err(|e| AppError::BadRequest(format!("failed to read audio field: {e}")))? + { file.write_all(&chunk).await.map_err(|e| { AppError::Internal(format!("failed to write audio chunk: {e}")) })?; @@ -90,10 +95,29 @@ pub async fn submit_job( } audio_saved = true; } - "language" => language = Some(field.text().await.map_err(|e| AppError::BadRequest(e.to_string()))?), - "task" => task = field.text().await.map_err(|e| AppError::BadRequest(e.to_string()))?, - "webhook_url" => webhook_url = Some(field.text().await.map_err(|e| AppError::BadRequest(e.to_string()))?), - _ => {} // ignore unknown fields + "language" => { + language = Some( + field + .text() + .await + .map_err(|e| AppError::BadRequest(e.to_string()))?, + ) + } + "task" => { + task = field + .text() + .await + .map_err(|e| AppError::BadRequest(e.to_string()))? + } + "webhook_url" => { + webhook_url = Some( + field + .text() + .await + .map_err(|e| AppError::BadRequest(e.to_string()))?, + ) + } + _ => {} // ignore unknown fields } } @@ -112,14 +136,16 @@ pub async fn submit_job( let ms = state.model_state.read().await; let ready = ms.is_ready(); let retry = ms.retry_after_secs(); - let tag = ms.tag().to_string(); + let tag = ms.tag().to_string(); (ready, retry, tag) }; // Register the webhook URL regardless of model state — so model lifecycle // events are delivered even if the job itself is rejected. if let Some(url) = &webhook_url { - state.webhook_registry.lock() + state + .webhook_registry + .lock() .unwrap_or_else(|e| e.into_inner()) .insert(url.clone()); } @@ -143,12 +169,16 @@ pub async fn submit_job( state.storage.create(&job).await?; // Pre-create the broadcast channel so SSE subscribers don't miss events. - state.progress.entry(id).or_insert_with(|| broadcast::channel(64).0); + state + .progress + .entry(id) + .or_insert_with(|| broadcast::channel(64).0); state.queue_depth.fetch_add(1, Ordering::Relaxed); - state.job_tx.send(id).map_err(|_| { - AppError::Internal("worker channel closed".into()) - })?; + state + .job_tx + .send(id) + .map_err(|_| AppError::Internal("worker channel closed".into()))?; tracing::info!(job_id = %id, "job queued"); @@ -168,10 +198,7 @@ pub async fn submit_job( (status = 404, description = "Not found"), ) )] -pub async fn get_job( - State(state): State, - Path(id): Path, -) -> Result> { +pub async fn get_job(State(state): State, Path(id): Path) -> Result> { let job = state.storage.get(&id).await?; Ok(Json(job)) } @@ -196,15 +223,15 @@ pub async fn get_job( )] pub async fn stream_job( State(state): State, - Path(id): Path, + Path(id): Path, ) -> Result> { // If the job is already finished, return a single done event immediately. let job = state.storage.get(&id).await?; match job.status { JobStatus::Done | JobStatus::Failed | JobStatus::Cancelled => { - let payload = serde_json::to_string( - &crate::models::SsePayload::Done { job: Box::new(job) } - ).unwrap_or_default(); + let payload = + serde_json::to_string(&crate::models::SsePayload::Done { job: Box::new(job) }) + .unwrap_or_default(); let s: SseStream = Box::pin(stream::once(async move { Ok(Event::default().event("done").data(payload)) })); @@ -222,22 +249,28 @@ pub async fn stream_job( let sse_stream: SseStream = Box::pin(BroadcastStream::new(rx).filter_map(|msg| async move { let event = match msg { - Ok(ProgressEvent::Progress { percent, chunk, total }) => { - let payload = serde_json::to_string( - &crate::models::SsePayload::Progress { percent, chunk, chunks_total: total } - ).ok()?; + Ok(ProgressEvent::Progress { + percent, + chunk, + total, + }) => { + let payload = serde_json::to_string(&crate::models::SsePayload::Progress { + percent, + chunk, + chunks_total: total, + }) + .ok()?; Event::default().event("progress").data(payload) } Ok(ProgressEvent::Done(job)) => { - let payload = serde_json::to_string( - &crate::models::SsePayload::Done { job } - ).ok()?; + let payload = + serde_json::to_string(&crate::models::SsePayload::Done { job }).ok()?; Event::default().event("done").data(payload) } Ok(ProgressEvent::Error(msg)) => { - let payload = serde_json::to_string( - &crate::models::SsePayload::Error { message: msg } - ).ok()?; + let payload = + serde_json::to_string(&crate::models::SsePayload::Error { message: msg }) + .ok()?; Event::default().event("error").data(payload) } Err(_) => return None, // lagged / channel closed @@ -264,10 +297,7 @@ pub async fn stream_job( (status = 409, description = "Job already finished"), ) )] -pub async fn delete_job( - State(state): State, - Path(id): Path, -) -> Result> { +pub async fn delete_job(State(state): State, Path(id): Path) -> Result> { let mut job = state.storage.get(&id).await?; match job.status { @@ -280,7 +310,7 @@ pub async fn delete_job( _ => {} } - job.status = JobStatus::Cancelled; + job.status = JobStatus::Cancelled; job.completed_at = Some(Utc::now()); state.storage.save(&job).await?; diff --git a/src/routes/mod.rs b/src/routes/mod.rs index 06745d1..9b20f2d 100644 --- a/src/routes/mod.rs +++ b/src/routes/mod.rs @@ -2,27 +2,33 @@ pub mod health; pub mod jobs; pub mod model; -use axum::{extract::DefaultBodyLimit, routing::{delete, get, post}, Router}; use crate::AppState; +use axum::{ + extract::DefaultBodyLimit, + routing::{delete, get, post}, + Router, +}; pub fn jobs_router() -> Router { Router::new() // No body limit on the upload route — files can be multiple GB. - .route("/jobs", post(jobs::submit_job).layer(DefaultBodyLimit::disable())) - .route("/jobs/:id", get(jobs::get_job)) + .route( + "/jobs", + post(jobs::submit_job).layer(DefaultBodyLimit::disable()), + ) + .route("/jobs/:id", get(jobs::get_job)) .route("/jobs/:id/stream", get(jobs::stream_job)) - .route("/jobs/:id", delete(jobs::delete_job)) + .route("/jobs/:id", delete(jobs::delete_job)) } pub fn health_router() -> Router { - Router::new() - .route("/health", get(health::health)) + Router::new().route("/health", get(health::health)) } pub fn model_router() -> Router { Router::new() .route("/model/status", get(model::model_status)) - .route("/model/load", post(model::model_load)) + .route("/model/load", post(model::model_load)) .route("/model/unload", post(model::model_unload)) .route("/model/events", get(model::model_events)) } diff --git a/src/routes/model.rs b/src/routes/model.rs index 65b3826..426171c 100644 --- a/src/routes/model.rs +++ b/src/routes/model.rs @@ -10,8 +10,8 @@ use axum::{ Json, }; use futures::Stream; -use tokio_stream::wrappers::BroadcastStream; use futures::StreamExt; +use tokio_stream::wrappers::BroadcastStream; use crate::{ models::{ModelEvent, ModelStatusResponse}, @@ -19,7 +19,8 @@ use crate::{ AppState, Result, }; -type SseStream = Pin> + Send>>; +type SseStream = + Pin> + Send>>; // ── GET /model/status ──────────────────────────────────────────────────────── @@ -61,11 +62,17 @@ pub async fn model_status(State(state): State) -> Result) -> impl IntoResponse { let is_ready = state.model_state.read().await.is_ready(); if is_ready { - return (StatusCode::OK, Json(serde_json::json!({"status": "already_ready"}))); + return ( + StatusCode::OK, + Json(serde_json::json!({"status": "already_ready"})), + ); } // Ignore send errors (channel full = load already in progress). let _ = state.cmd_tx.try_send(WorkerCmd::Load); - (StatusCode::ACCEPTED, Json(serde_json::json!({"status": "load_initiated"}))) + ( + StatusCode::ACCEPTED, + Json(serde_json::json!({"status": "load_initiated"})), + ) } // ── POST /model/unload ─────────────────────────────────────────────────────── @@ -82,7 +89,10 @@ pub async fn model_load(State(state): State) -> impl IntoResponse { )] pub async fn model_unload(State(state): State) -> impl IntoResponse { let _ = state.cmd_tx.try_send(WorkerCmd::Unload); - (StatusCode::OK, Json(serde_json::json!({"status": "unload_requested"}))) + ( + StatusCode::OK, + Json(serde_json::json!({"status": "unload_requested"})), + ) } // ── GET /model/events ──────────────────────────────────────────────────────── @@ -105,23 +115,21 @@ pub async fn model_unload(State(state): State) -> impl IntoResponse { pub async fn model_events(State(state): State) -> Sse { let rx = state.model_event_tx.subscribe(); - let stream: SseStream = Box::pin( - BroadcastStream::new(rx).filter_map(|msg| async move { - match msg { - Ok(event) => { - let event_type = match &event { - ModelEvent::ModelReady { .. } => "model_ready", - ModelEvent::ModelUnloaded => "model_unloaded", - ModelEvent::ModelLoading => "model_loading", - ModelEvent::ModelWaitingForGpu {..} => "model_waiting_for_gpu", - }; - let data = serde_json::to_string(&event).ok()?; - Some(Ok(Event::default().event(event_type).data(data))) - } - Err(_) => None, + let stream: SseStream = Box::pin(BroadcastStream::new(rx).filter_map(|msg| async move { + match msg { + Ok(event) => { + let event_type = match &event { + ModelEvent::ModelReady { .. } => "model_ready", + ModelEvent::ModelUnloaded => "model_unloaded", + ModelEvent::ModelLoading => "model_loading", + ModelEvent::ModelWaitingForGpu { .. } => "model_waiting_for_gpu", + }; + let data = serde_json::to_string(&event).ok()?; + Some(Ok(Event::default().event(event_type).data(data))) } - }) - ); + Err(_) => None, + } + })); Sse::new(stream).keep_alive(KeepAlive::default()) } @@ -146,13 +154,13 @@ fn vram_stats(gpu_device: u32) -> (Option, Option) { let line = String::from_utf8_lossy(&out.stdout); let line = line.trim(); let mut parts = line.splitn(2, ','); - let used = parts.next().and_then(|s| s.trim().parse::().ok())?; + let used = parts.next().and_then(|s| s.trim().parse::().ok())?; let total = parts.next().and_then(|s| s.trim().parse::().ok())?; Some((used, total)) } match inner(gpu_device) { Some((u, t)) => (Some(u), Some(t)), - None => (None, None), + None => (None, None), } } diff --git a/src/storage.rs b/src/storage.rs index 4ae5ae1..edb9fe2 100644 --- a/src/storage.rs +++ b/src/storage.rs @@ -30,20 +30,20 @@ impl Storage { // ── CRUD ───────────────────────────────────────────────────────────────── pub async fn create(&self, job: &Job) -> Result<()> { - let path = self.job_path(&job.id); - let payload = serde_json::to_vec_pretty(job) - .map_err(|e| AppError::Internal(e.to_string()))?; - fs::write(&path, payload).await.map_err(|e| { - AppError::Internal(format!("failed to write job {}: {e}", job.id)) - })?; + let path = self.job_path(&job.id); + let payload = + serde_json::to_vec_pretty(job).map_err(|e| AppError::Internal(e.to_string()))?; + fs::write(&path, payload) + .await + .map_err(|e| AppError::Internal(format!("failed to write job {}: {e}", job.id)))?; Ok(()) } pub async fn get(&self, id: &JobId) -> Result { let path = self.job_path(id); - let raw = fs::read(&path).await.map_err(|_| { - AppError::NotFound(format!("job {id} not found")) - })?; + let raw = fs::read(&path) + .await + .map_err(|_| AppError::NotFound(format!("job {id} not found")))?; serde_json::from_slice(&raw).map_err(|e| AppError::Internal(e.to_string())) } @@ -54,22 +54,24 @@ impl Storage { pub async fn delete(&self, id: &JobId) -> Result<()> { let path = self.job_path(id); - fs::remove_file(&path).await.map_err(|_| { - AppError::NotFound(format!("job {id} not found")) - })?; + fs::remove_file(&path) + .await + .map_err(|_| AppError::NotFound(format!("job {id} not found")))?; Ok(()) } /// List all job IDs present on disk. pub async fn list_ids(&self) -> Result> { - let mut entries = fs::read_dir(&self.dir).await.map_err(|e| { - AppError::Internal(format!("read_dir failed: {e}")) - })?; + let mut entries = fs::read_dir(&self.dir) + .await + .map_err(|e| AppError::Internal(format!("read_dir failed: {e}")))?; let mut ids = Vec::new(); - while let Some(entry) = entries.next_entry().await.map_err(|e| { - AppError::Internal(e.to_string()) - })? { + while let Some(entry) = entries + .next_entry() + .await + .map_err(|e| AppError::Internal(e.to_string()))? + { let name = entry.file_name(); let name = name.to_string_lossy(); if let Some(stem) = name.strip_suffix(".json") { @@ -88,8 +90,8 @@ impl Storage { if let Ok(mut job) = self.get(&id).await { if job.status == JobStatus::Running { tracing::warn!(job_id = %id, "recovering interrupted job → failed"); - job.status = JobStatus::Failed; - job.error = Some("server restarted while job was running".into()); + job.status = JobStatus::Failed; + job.error = Some("server restarted while job was running".into()); job.completed_at = Some(chrono::Utc::now()); let _ = self.save(&job).await; } diff --git a/src/transcriber.rs b/src/transcriber.rs index 93d67c1..348c476 100644 --- a/src/transcriber.rs +++ b/src/transcriber.rs @@ -37,9 +37,10 @@ impl Transcriber { /// 0 segments. The warmup forces kernel compilation at startup so all subsequent /// jobs run correctly from the very first request. pub fn load(model_path: impl AsRef, gpu_device: u32) -> Result { - let path = model_path.as_ref().to_str().ok_or_else(|| { - AppError::Internal("model path is not valid UTF-8".into()) - })?; + let path = model_path + .as_ref() + .to_str() + .ok_or_else(|| AppError::Internal("model path is not valid UTF-8".into()))?; let mut params = WhisperContextParameters::new(); params.use_gpu(true); @@ -48,25 +49,23 @@ impl Transcriber { // real-world audio (conference recordings, noisy MP3s). // params.flash_attn(true); - let ctx = WhisperContext::new_with_params(path, params) - .map_err(|e| { - let msg = format!("failed to load model: {e}"); - if AppError::is_oom(&msg) { - AppError::OutOfMemory(msg) - } else { - AppError::Internal(msg) - } - })?; + let ctx = WhisperContext::new_with_params(path, params).map_err(|e| { + let msg = format!("failed to load model: {e}"); + if AppError::is_oom(&msg) { + AppError::OutOfMemory(msg) + } else { + AppError::Internal(msg) + } + })?; - let mut state = ctx.create_state() - .map_err(|e| { - let msg = format!("failed to create whisper state: {e}"); - if AppError::is_oom(&msg) { - AppError::OutOfMemory(msg) - } else { - AppError::Internal(msg) - } - })?; + let mut state = ctx.create_state().map_err(|e| { + let msg = format!("failed to create whisper state: {e}"); + if AppError::is_oom(&msg) { + AppError::OutOfMemory(msg) + } else { + AppError::Internal(msg) + } + })?; // ctx drops here; state holds Arc so model stays loaded. // ── GPU warmup ──────────────────────────────────────────────────────── @@ -95,16 +94,16 @@ impl Transcriber { /// `no_context=true` in the params prevents KV-cache contamination between chunks. pub fn transcribe( &mut self, - pcm: &[f32], - language: Option<&str>, - task: &str, + pcm: &[f32], + language: Option<&str>, + task: &str, on_progress: impl Fn(u8) + Send + 'static, ) -> Result<(Vec, String)> { let state = &mut self.state; let mut fp = FullParams::new(SamplingStrategy::BeamSearch { beam_size: 5, - patience: 1.0, + patience: 1.0, }); fp.set_n_threads(num_cpus::get() as i32); @@ -158,40 +157,55 @@ impl Transcriber { .full(fp, pcm) .map_err(|e| AppError::Internal(format!("transcription failed: {e}")))?; - let n_segments = state.full_n_segments() + let n_segments = state + .full_n_segments() .map_err(|e| AppError::Internal(e.to_string()))?; let mut segments = Vec::with_capacity(n_segments as usize); for i in 0..n_segments { - let text = state.full_get_segment_text(i) + let text = state + .full_get_segment_text(i) .map_err(|e| AppError::Internal(e.to_string()))?; - let start = state.full_get_segment_t0(i) - .map_err(|e| AppError::Internal(e.to_string()))? as f32 / 100.0; - let end = state.full_get_segment_t1(i) - .map_err(|e| AppError::Internal(e.to_string()))? as f32 / 100.0; + let start = state + .full_get_segment_t0(i) + .map_err(|e| AppError::Internal(e.to_string()))? as f32 + / 100.0; + let end = state + .full_get_segment_t1(i) + .map_err(|e| AppError::Internal(e.to_string()))? as f32 + / 100.0; - let n_tokens = state.full_n_tokens(i) + let n_tokens = state + .full_n_tokens(i) .map_err(|e| AppError::Internal(e.to_string()))?; let mut words = Vec::new(); for t in 0..n_tokens { - let token_text = state.full_get_token_text(i, t) + let token_text = state + .full_get_token_text(i, t) .map_err(|e| AppError::Internal(e.to_string()))?; if token_text.starts_with('[') { continue; // skip special tokens ([MUSIC], [APPLAUSE], etc.) } - let data = state.full_get_token_data(i, t) + let data = state + .full_get_token_data(i, t) .map_err(|e| AppError::Internal(e.to_string()))?; words.push(Word { - text: token_text, - start: data.t0 as f32 / 100.0, - end: data.t1 as f32 / 100.0, + text: token_text, + start: data.t0 as f32 / 100.0, + end: data.t1 as f32 / 100.0, probability: data.p, }); } - segments.push(Segment { index: i, start, end, text, words }); + segments.push(Segment { + index: i, + start, + end, + text, + words, + }); } let lang = state diff --git a/src/worker.rs b/src/worker.rs index 6b2a5ce..65f75f9 100644 --- a/src/worker.rs +++ b/src/worker.rs @@ -16,8 +16,7 @@ use crate::{ models::{Job, JobId, JobStatus, ModelEvent, ModelState, Segment}, storage::Storage, transcriber::Transcriber, - webhook, - AppError, + webhook, AppError, }; /// Per-job broadcast channel for SSE subscribers. @@ -26,7 +25,11 @@ pub type ProgressTx = broadcast::Sender; #[derive(Debug, Clone)] pub enum ProgressEvent { /// `percent` — overall 0–100; `chunk` — 1-based; `total` — total chunks. - Progress { percent: u8, chunk: usize, total: usize }, + Progress { + percent: u8, + chunk: usize, + total: usize, + }, Done(Box), Error(String), } @@ -50,11 +53,11 @@ pub enum WorkerCmd { // ── Transcription request/response types ───────────────────────────────────── pub struct TranscribeRequest { - pub pcm: Vec, - pub language: Option, - pub task: String, + pub pcm: Vec, + pub language: Option, + pub task: String, pub on_progress: Box, - pub reply: oneshot::Sender, String)>>, + pub reply: oneshot::Sender, String)>>, } impl std::fmt::Debug for TranscribeRequest { @@ -75,15 +78,15 @@ impl std::fmt::Debug for TranscribeRequest { /// trigger loading. #[allow(clippy::too_many_arguments)] pub fn start( - job_rx: mpsc::UnboundedReceiver, - storage: Arc, - model_path: PathBuf, - queue_depth: Arc, - gpu_device: u32, - model_state: Arc>, - model_event_tx: broadcast::Sender, + job_rx: mpsc::UnboundedReceiver, + storage: Arc, + model_path: PathBuf, + queue_depth: Arc, + gpu_device: u32, + model_state: Arc>, + model_event_tx: broadcast::Sender, webhook_registry: Arc>>, - idle_timeout: Duration, + idle_timeout: Duration, gpu_poll_interval: Duration, ) -> (ProgressRegistry, std::sync::mpsc::SyncSender) { let registry: ProgressRegistry = Arc::new(dashmap::DashMap::new()); @@ -126,15 +129,15 @@ pub fn start( /// separate thread. #[allow(clippy::too_many_arguments)] fn transcriber_thread( - rx: std::sync::mpsc::Receiver, - model_path: PathBuf, - gpu_device: u32, - model_state: Arc>, - model_event_tx: broadcast::Sender, + rx: std::sync::mpsc::Receiver, + model_path: PathBuf, + gpu_device: u32, + model_state: Arc>, + model_event_tx: broadcast::Sender, webhook_registry: Arc>>, - idle_timeout: Duration, + idle_timeout: Duration, gpu_poll_interval: Duration, - rt: tokio::runtime::Handle, + rt: tokio::runtime::Handle, ) { let mut transcriber: Option = None; let mut last_job = Instant::now(); @@ -162,14 +165,22 @@ fn transcriber_thread( } Ok(WorkerCmd::Unload) => { - do_unload(&mut transcriber, &model_state, &model_event_tx, &webhook_registry, &rt); + do_unload( + &mut transcriber, + &model_state, + &model_event_tx, + &webhook_registry, + &rt, + ); } Ok(WorkerCmd::Transcribe(req)) => { let t = match &mut transcriber { Some(t) => t, None => { - tracing::warn!("Transcribe cmd received but model is unloaded — failing job"); + tracing::warn!( + "Transcribe cmd received but model is unloaded — failing job" + ); let _ = req.reply.send(Err(AppError::Internal( "model unloaded before job could run".into(), ))); @@ -177,12 +188,9 @@ fn transcriber_thread( } }; - let result = t.transcribe( - &req.pcm, - req.language.as_deref(), - &req.task, - move |p| (req.on_progress)(p), - ); + let result = t.transcribe(&req.pcm, req.language.as_deref(), &req.task, move |p| { + (req.on_progress)(p) + }); last_job = Instant::now(); let _ = req.reply.send(result); } @@ -218,14 +226,14 @@ fn transcriber_thread( /// rejection. Returns `Some(Transcriber)` on success, `None` if cancelled. #[allow(clippy::too_many_arguments)] fn try_load_with_polling( - rx: &std::sync::mpsc::Receiver, - model_path: &PathBuf, - gpu_device: u32, - model_state: &Arc>, - model_event_tx: &broadcast::Sender, + rx: &std::sync::mpsc::Receiver, + model_path: &PathBuf, + gpu_device: u32, + model_state: &Arc>, + model_event_tx: &broadcast::Sender, webhook_registry: &Arc>>, gpu_poll_interval: Duration, - rt: &tokio::runtime::Handle, + rt: &tokio::runtime::Handle, ) -> Option { loop { set_state(model_state, ModelState::Loading); @@ -253,25 +261,35 @@ fn try_load_with_polling( "insufficient VRAM — will retry" ); - set_state(model_state, ModelState::WaitingForGpu { - vram_needed_mb, - vram_free_mb, - retry_in_secs, - }); - broadcast_event(model_event_tx, ModelEvent::ModelWaitingForGpu { - vram_needed_mb, - vram_free_mb, - retry_in_secs, - }); + set_state( + model_state, + ModelState::WaitingForGpu { + vram_needed_mb, + vram_free_mb, + retry_in_secs, + }, + ); + broadcast_event( + model_event_tx, + ModelEvent::ModelWaitingForGpu { + vram_needed_mb, + vram_free_mb, + retry_in_secs, + }, + ); // Interruptible sleep: drain rx while waiting for gpu_poll_interval. let deadline = Instant::now() + gpu_poll_interval; loop { let remaining = deadline.saturating_duration_since(Instant::now()); - if remaining.is_zero() { break; } + if remaining.is_zero() { + break; + } match rx.recv_timeout(remaining.min(Duration::from_secs(1))) { Ok(WorkerCmd::Unload) => { - tracing::info!("Unload received while waiting for GPU — cancelling load"); + tracing::info!( + "Unload received while waiting for GPU — cancelling load" + ); set_state(model_state, ModelState::Unloaded); broadcast_event(model_event_tx, ModelEvent::ModelUnloaded); fire_webhooks(webhook_registry, ModelEvent::ModelUnloaded, rt); @@ -303,11 +321,11 @@ fn try_load_with_polling( } fn do_unload( - transcriber: &mut Option, - model_state: &Arc>, - model_event_tx: &broadcast::Sender, + transcriber: &mut Option, + model_state: &Arc>, + model_event_tx: &broadcast::Sender, webhook_registry: &Arc>>, - rt: &tokio::runtime::Handle, + rt: &tokio::runtime::Handle, ) { *transcriber = None; set_state(model_state, ModelState::Unloaded); @@ -328,8 +346,8 @@ fn broadcast_event(tx: &broadcast::Sender, event: ModelEvent) { fn fire_webhooks( registry: &Arc>>, - event: ModelEvent, - rt: &tokio::runtime::Handle, + event: ModelEvent, + rt: &tokio::runtime::Handle, ) { if !event.is_webhook_event() { return; @@ -341,11 +359,16 @@ fn fire_webhooks( .cloned() .collect(); - if urls.is_empty() { return; } + if urls.is_empty() { + return; + } let payload = match serde_json::to_string(&event) { - Ok(p) => p, - Err(e) => { tracing::error!(error = %e, "failed to serialize model event"); return; } + Ok(p) => p, + Err(e) => { + tracing::error!(error = %e, "failed to serialize model event"); + return; + } }; for url in urls { @@ -356,7 +379,8 @@ fn fire_webhooks( .build() .expect("http client"); for attempt in 0..3_u32 { - match http.post(&url) + match http + .post(&url) .header("content-type", "application/json") .body(body.clone()) .send() @@ -405,11 +429,11 @@ fn parse_oom_vram(msg: &str, gpu_device: u32) -> (u64, u64) { // ── Async job runner ────────────────────────────────────────────────────────── async fn run( - mut job_rx: mpsc::UnboundedReceiver, - storage: Arc, + mut job_rx: mpsc::UnboundedReceiver, + storage: Arc, queue_depth: Arc, - registry: ProgressRegistry, - cmd_tx: std::sync::mpsc::SyncSender, + registry: ProgressRegistry, + cmd_tx: std::sync::mpsc::SyncSender, ) { let http = Client::builder() .timeout(Duration::from_secs(30)) @@ -420,7 +444,7 @@ async fn run( queue_depth.fetch_sub(1, Ordering::Relaxed); let mut job = match storage.get(&job_id).await { - Ok(j) => j, + Ok(j) => j, Err(e) => { tracing::warn!(job_id = %job_id, error = %e, "job vanished before processing"); registry.remove(&job_id); @@ -461,19 +485,19 @@ async fn run( match result { Ok((segments, language, duration_secs)) => { - job.status = JobStatus::Done; - job.segments = segments; - job.language = Some(language); + job.status = JobStatus::Done; + job.segments = segments; + job.language = Some(language); job.duration_secs = Some(duration_secs); - job.progress = 100; - job.completed_at = Some(Utc::now()); + job.progress = 100; + job.completed_at = Some(Utc::now()); let _ = progress_tx.send(ProgressEvent::Done(Box::new(job.clone()))); } Err(e) => { let msg = e.to_string(); tracing::error!(job_id = %job_id, error = %msg, "transcription failed"); - job.status = JobStatus::Failed; - job.error = Some(msg.clone()); + job.status = JobStatus::Failed; + job.error = Some(msg.clone()); job.completed_at = Some(Utc::now()); let _ = progress_tx.send(ProgressEvent::Error(msg)); } @@ -485,9 +509,11 @@ async fn run( if let Some(url) = &job.webhook_url.clone() { let http = http.clone(); - let url = url.clone(); - let job = job.clone(); - tokio::spawn(async move { webhook::fire(&http, &url, &job).await; }); + let url = url.clone(); + let job = job.clone(); + tokio::spawn(async move { + webhook::fire(&http, &url, &job).await; + }); } tokio::time::sleep(Duration::from_secs(30)).await; @@ -498,9 +524,9 @@ async fn run( // ── Silence-based chunking ──────────────────────────────────────────────────── const TARGET_CHUNK_SECS: f32 = 60.0; -const SNAP_WINDOW_SECS: f32 = 30.0; -const SILENCE_DB: &str = "-35dB"; -const SILENCE_DUR: &str = "0.4"; +const SNAP_WINDOW_SECS: f32 = 30.0; +const SILENCE_DB: &str = "-35dB"; +const SILENCE_DUR: &str = "0.4"; async fn detect_silence_midpoints(path: &std::path::Path) -> Vec { use tokio::process::Command; @@ -509,15 +535,19 @@ async fn detect_silence_midpoints(path: &std::path::Path) -> Vec { let output = Command::new("ffmpeg") .args([ "-nostdin", - "-i", path.to_str().unwrap_or(""), - "-af", &filter, - "-f", "null", "-", + "-i", + path.to_str().unwrap_or(""), + "-af", + &filter, + "-f", + "null", + "-", ]) .output() .await; let output = match output { - Ok(o) => o, + Ok(o) => o, Err(e) => { tracing::warn!(error = %e, "silencedetect unavailable; using hard cuts"); return Vec::new(); @@ -526,7 +556,7 @@ async fn detect_silence_midpoints(path: &std::path::Path) -> Vec { let stderr = String::from_utf8_lossy(&output.stderr); let mut starts: Vec = Vec::new(); - let mut ends: Vec = Vec::new(); + let mut ends: Vec = Vec::new(); for line in stderr.lines() { if let Some(i) = line.find("silence_start: ") { @@ -545,7 +575,9 @@ async fn detect_silence_midpoints(path: &std::path::Path) -> Vec { } } - let mids: Vec = starts.iter().zip(ends.iter()) + let mids: Vec = starts + .iter() + .zip(ends.iter()) .map(|(s, e)| (s + e) / 2.0) .collect(); @@ -553,18 +585,15 @@ async fn detect_silence_midpoints(path: &std::path::Path) -> Vec { mids } -fn snap_to_silence( - mids: &[f32], - total_secs: f32, - target_secs: f32, - snap_window: f32, -) -> Vec { +fn snap_to_silence(mids: &[f32], total_secs: f32, target_secs: f32, snap_window: f32) -> Vec { let mut cuts: Vec = Vec::new(); let mut pos = target_secs; while pos < total_secs - target_secs * 0.25 { let prev_cut = cuts.last().copied().unwrap_or(0.0); - let best = mids.iter().copied() + let best = mids + .iter() + .copied() .filter(|&t| t > prev_cut + 10.0 && (t - pos).abs() <= snap_window) .min_by(|a, b| (a - pos).abs().partial_cmp(&(b - pos).abs()).unwrap()); let cut = best.unwrap_or(pos); @@ -591,20 +620,165 @@ fn to_chunk_ranges(cuts: &[f32], total_secs: f32) -> Vec<(f32, f32)> { ranges } +const MAX_CHAIN_GAP_SECS: f32 = 0.15; +const MIN_MEANINGFUL_WORDS: usize = 2; +const MIN_MEANINGFUL_CHARS: usize = 8; +const MIN_OVERLAP_WORDS: usize = 3; + +fn normalised_words(text: &str) -> Vec { + text.split_whitespace() + .map(|word| { + word.chars() + .filter(|ch| ch.is_alphanumeric() || *ch == '_') + .flat_map(|ch| ch.to_lowercase()) + .collect::() + }) + .filter(|word| !word.is_empty()) + .collect() +} + +fn starts_with_words(full: &[String], prefix: &[String]) -> bool { + prefix.len() <= full.len() && full.iter().take(prefix.len()).eq(prefix.iter()) +} + +fn ends_with_words(full: &[String], suffix: &[String]) -> bool { + suffix.len() <= full.len() + && full + .iter() + .skip(full.len() - suffix.len()) + .eq(suffix.iter()) +} + +fn suffix_prefix_overlap(left: &[String], right: &[String]) -> usize { + let max = left.len().min(right.len()); + for size in (1..=max).rev() { + if left[left.len() - size..] == right[..size] { + return size; + } + } + 0 +} + +fn is_meaningful_phrase(words: &[String]) -> bool { + words.len() >= MIN_MEANINGFUL_WORDS + && words.iter().map(|word| word.len()).sum::() >= MIN_MEANINGFUL_CHARS +} + +fn trim_leading_words(text: &str, count: usize) -> String { + text.split_whitespace() + .skip(count) + .collect::>() + .join(" ") + .trim() + .to_string() +} + +fn merge_identical_segments(segments: Vec) -> Vec { + let mut out: Vec = Vec::with_capacity(segments.len()); + + for seg in segments { + if let Some(last) = out.last_mut() { + if normalised_words(&last.text) == normalised_words(&seg.text) { + last.end = last.end.max(seg.end); + if !seg.words.is_empty() { + last.words = seg.words; + } + continue; + } + } + out.push(seg); + } + + out +} + +fn collapse_incremental_segments(segments: Vec) -> Vec { + let mut out: Vec = Vec::with_capacity(segments.len()); + + for mut seg in segments { + seg.text = seg.text.trim().to_string(); + if seg.text.is_empty() { + continue; + } + + let Some(last) = out.last_mut() else { + out.push(seg); + continue; + }; + + let gap = seg.start - last.end; + if gap > MAX_CHAIN_GAP_SECS { + out.push(seg); + continue; + } + + let last_words = normalised_words(&last.text); + let seg_words = normalised_words(&seg.text); + if last_words.is_empty() || seg_words.is_empty() { + out.push(seg); + continue; + } + + if seg_words.len() > last_words.len() + && starts_with_words(&seg_words, &last_words) + && is_meaningful_phrase(&last_words) + { + last.text = seg.text; + last.end = seg.end; + last.words = seg.words; + continue; + } + + if ends_with_words(&last_words, &seg_words) && is_meaningful_phrase(&seg_words) { + last.end = last.end.max(seg.end); + continue; + } + + let overlap = suffix_prefix_overlap(&last_words, &seg_words); + if overlap >= MIN_OVERLAP_WORDS { + let trimmed_text = trim_leading_words(&seg.text, overlap); + if trimmed_text.is_empty() { + last.end = last.end.max(seg.end); + continue; + } + + seg.start = seg.start.max(last.end); + seg.text = trimmed_text; + seg.words.clear(); + } + + out.push(seg); + } + + out +} + +fn normalise_segments(segments: Vec) -> Vec { + let mut result = collapse_incremental_segments(segments); + result = merge_identical_segments(result); + result = collapse_incremental_segments(result); + merge_identical_segments(result) +} + // ── Job processing ──────────────────────────────────────────────────────────── async fn process_job( - job: &Job, - audio_path: &std::path::Path, + job: &Job, + audio_path: &std::path::Path, progress_tx: &ProgressTx, - cmd_tx: &std::sync::mpsc::SyncSender, - storage: &Arc, + cmd_tx: &std::sync::mpsc::SyncSender, + storage: &Arc, ) -> crate::Result<(Vec, String, f32)> { let pcm = decode_audio(audio_path).await?; let total_secs = pcm.len() as f32 / 16_000.0; let silence_mids = detect_silence_midpoints(audio_path).await; - let cuts = snap_to_silence(&silence_mids, total_secs, TARGET_CHUNK_SECS, SNAP_WINDOW_SECS); + let cuts = snap_to_silence( + &silence_mids, + total_secs, + TARGET_CHUNK_SECS, + SNAP_WINDOW_SECS, + ); let chunks = to_chunk_ranges(&cuts, total_secs); let n = chunks.len(); @@ -620,12 +794,12 @@ async fn process_job( for (ci, (chunk_start, chunk_end)) in chunks.iter().enumerate() { let s0 = (*chunk_start * 16_000.0) as usize; - let s1 = ((*chunk_end * 16_000.0) as usize).min(pcm.len()); + let s1 = ((*chunk_end * 16_000.0) as usize).min(pcm.len()); let mut chunk_pcm = pcm[s0..s1].to_vec(); trim_trailing_silence(&mut chunk_pcm); - let base = (ci * 100 / n) as u8; - let span = (100usize / n).max(1) as u8; + let base = (ci * 100 / n) as u8; + let span = (100usize / n).max(1) as u8; // Save progress to disk before emitting SSE — polling clients who respond // immediately to the SSE event will then see consistent state. @@ -637,49 +811,52 @@ async fn process_job( let _ = progress_tx.send(ProgressEvent::Progress { percent: base, - chunk: ci + 1, - total: n, + chunk: ci + 1, + total: n, }); - let tx = progress_tx.clone(); + let tx = progress_tx.clone(); let chunk_num = ci + 1; let on_progress = Box::new(move |p: u8| { let overall = base.saturating_add(p.saturating_mul(span) / 100); let _ = tx.send(ProgressEvent::Progress { percent: overall, - chunk: chunk_num, - total: n, + chunk: chunk_num, + total: n, }); }); let (reply_tx, reply_rx) = oneshot::channel(); - cmd_tx.send(WorkerCmd::Transcribe(TranscribeRequest { - pcm: chunk_pcm, - language: job.language.clone(), - task: job.task.clone(), - on_progress, - reply: reply_tx, - })).map_err(|_| AppError::Internal("worker command channel closed".into()))?; + cmd_tx + .send(WorkerCmd::Transcribe(TranscribeRequest { + pcm: chunk_pcm, + language: job.language.clone(), + task: job.task.clone(), + on_progress, + reply: reply_tx, + })) + .map_err(|_| AppError::Internal("worker command channel closed".into()))?; - let (mut segs, lang) = reply_rx.await + let (mut segs, lang) = reply_rx + .await .map_err(|_| AppError::Internal("transcriber thread dropped reply".into()))??; let offset = *chunk_start; for seg in &mut segs { seg.start += offset; - seg.end += offset; + seg.end += offset; for word in &mut seg.words { word.start += offset; - word.end += offset; + word.end += offset; } } tracing::debug!( chunk = ci + 1, - of = n, + of = n, start = chunk_start, - end = chunk_end, - segs = segs.len(), + end = chunk_end, + segs = segs.len(), "chunk done" ); @@ -689,24 +866,30 @@ async fn process_job( } } + all_segments = normalise_segments(all_segments); + for (i, seg) in all_segments.iter_mut().enumerate() { seg.index = i as i32; } - let _ = progress_tx.send(ProgressEvent::Progress { percent: 100, chunk: n, total: n }); + let _ = progress_tx.send(ProgressEvent::Progress { + percent: 100, + chunk: n, + total: n, + }); Ok((all_segments, language, total_secs)) } fn trim_trailing_silence(pcm: &mut Vec) { const THRESHOLD: f32 = 0.017_8; - const PADDING: usize = 8_000; + const PADDING: usize = 8_000; if let Some(last_loud) = pcm.iter().rposition(|&s| s.abs() > THRESHOLD) { let new_len = (last_loud + 1 + PADDING).min(pcm.len()); if new_len < pcm.len() { tracing::trace!( original_samples = pcm.len(), - trimmed_samples = pcm.len() - new_len, + trimmed_samples = pcm.len() - new_len, "trimmed trailing silence" ); pcm.truncate(new_len); @@ -719,11 +902,17 @@ async fn decode_audio(path: &std::path::Path) -> crate::Result> { let output = Command::new("ffmpeg") .args([ - "-nostdin", "-threads", "0", - "-i", path.to_str().unwrap_or(""), - "-f", "f32le", - "-ac", "1", - "-ar", "16000", + "-nostdin", + "-threads", + "0", + "-i", + path.to_str().unwrap_or(""), + "-f", + "f32le", + "-ac", + "1", + "-ar", + "16000", "-", ]) .output() @@ -760,13 +949,28 @@ pub fn audio_path_for(id: &JobId) -> PathBuf { #[cfg(test)] mod tests { use super::*; + use crate::models::Word; + + fn segment(index: i32, start: f32, end: f32, text: &str) -> Segment { + Segment { + index, + start, + end, + text: text.into(), + words: Vec::::new(), + } + } #[test] fn test_snap_to_silence_uses_nearest_midpoint() { let mids = vec![55.0, 58.0, 62.0]; let cuts = snap_to_silence(&mids, 120.0, 60.0, 30.0); assert!(!cuts.is_empty()); - assert!((cuts[0] - 58.0).abs() < 0.01, "expected ~58.0, got {}", cuts[0]); + assert!( + (cuts[0] - 58.0).abs() < 0.01, + "expected ~58.0, got {}", + cuts[0] + ); } #[test] @@ -801,4 +1005,53 @@ mod tests { trim_trailing_silence(&mut pcm); assert_eq!(pcm.len(), (10_001 + 8_000).min(32_000)); } + + #[test] + fn test_normalise_segments_collapses_prefix_growth_chain() { + let input = vec![ + segment(0, 15.24, 16.6, "Hello everyone."), + segment(1, 16.6, 19.47, "Hello everyone. Um, welcome to this talk."), + segment(2, 19.47, 19.48, "Um, welcome to this talk."), + segment( + 3, + 19.48, + 21.67, + "Um, welcome to this talk. I'll be speaking about small model", + ), + segment(4, 21.67, 21.68, "I'll be speaking about small model"), + segment( + 5, + 21.68, + 24.59, + "I'll be speaking about small model inference and a gap that we've", + ), + ]; + + let result = normalise_segments(input); + + assert_eq!(result.len(), 2); + assert_eq!(result[0].text, "Hello everyone. Um, welcome to this talk."); + assert!((result[0].start - 15.24).abs() < 0.01); + assert!((result[0].end - 19.48).abs() < 0.01); + assert_eq!( + result[1].text, + "I'll be speaking about small model inference and a gap that we've" + ); + assert!((result[1].start - 19.48).abs() < 0.01); + assert!((result[1].end - 24.59).abs() < 0.01); + } + + #[test] + fn test_normalise_segments_keeps_real_gap() { + let input = vec![ + segment(0, 0.0, 1.0, "Hello everyone."), + segment(1, 2.0, 4.0, "Hello everyone. Welcome back."), + ]; + + let result = normalise_segments(input); + + assert_eq!(result.len(), 2); + assert_eq!(result[0].text, "Hello everyone."); + assert_eq!(result[1].text, "Hello everyone. Welcome back."); + } }