fix(worker): collapse incremental segments
All checks were successful
Build & Push Docker Image / test (push) Successful in 6m20s
Build & Push Docker Image / build-and-push (push) Successful in 6m29s

Normalize rolling partial-hypothesis chains before final job persistence so downstream clients receive stable transcript segments instead of echoed continuations.

Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
This commit is contained in:
2026-05-11 22:46:38 +02:00
parent d3a67f11b3
commit cb0b07b2ff
10 changed files with 712 additions and 331 deletions

View File

@@ -1,10 +1,10 @@
use thiserror::Error;
use axum::{ use axum::{
http::{StatusCode, HeaderValue, header}, http::{header, HeaderValue, StatusCode},
response::{IntoResponse, Response}, response::{IntoResponse, Response},
Json, Json,
}; };
use serde_json::json; use serde_json::json;
use thiserror::Error;
pub type Result<T> = std::result::Result<T, AppError>; pub type Result<T> = std::result::Result<T, AppError>;
@@ -31,7 +31,10 @@ pub enum AppError {
/// Returned when a job is submitted but the model is not yet loaded. /// Returned when a job is submitted but the model is not yet loaded.
/// Carries the current state tag and recommended Retry-After seconds. /// Carries the current state tag and recommended Retry-After seconds.
#[error("model not ready: {state}")] #[error("model not ready: {state}")]
ModelNotReady { state: String, retry_after_secs: u64 }, ModelNotReady {
state: String,
retry_after_secs: u64,
},
} }
impl AppError { impl AppError {
@@ -59,13 +62,20 @@ impl IntoResponse for AppError {
} }
AppError::Internal(m) => { AppError::Internal(m) => {
tracing::error!(error = %m, "internal error"); tracing::error!(error = %m, "internal error");
(StatusCode::INTERNAL_SERVER_ERROR, Json(json!({ "error": m }))).into_response() (
StatusCode::INTERNAL_SERVER_ERROR,
Json(json!({ "error": m })),
)
.into_response()
} }
AppError::OutOfMemory(m) => { AppError::OutOfMemory(m) => {
tracing::warn!(error = %m, "GPU out of memory during model load"); tracing::warn!(error = %m, "GPU out of memory during model load");
(StatusCode::SERVICE_UNAVAILABLE, Json(json!({ "error": m }))).into_response() (StatusCode::SERVICE_UNAVAILABLE, Json(json!({ "error": m }))).into_response()
} }
AppError::ModelNotReady { state, retry_after_secs } => { AppError::ModelNotReady {
state,
retry_after_secs,
} => {
let body = Json(json!({ let body = Json(json!({
"error": "model_not_ready", "error": "model_not_ready",
"state": state, "state": state,
@@ -117,17 +127,25 @@ mod tests {
#[tokio::test] #[tokio::test]
async fn test_model_not_ready_response_has_retry_after_header() { async fn test_model_not_ready_response_has_retry_after_header() {
let err = AppError::ModelNotReady { state: "loading".into(), retry_after_secs: 10 }; let err = AppError::ModelNotReady {
state: "loading".into(),
retry_after_secs: 10,
};
let resp = err.into_response(); let resp = err.into_response();
assert_eq!(resp.status(), StatusCode::SERVICE_UNAVAILABLE); assert_eq!(resp.status(), StatusCode::SERVICE_UNAVAILABLE);
let retry_after = resp.headers().get(header::RETRY_AFTER) let retry_after = resp
.headers()
.get(header::RETRY_AFTER)
.expect("Retry-After header missing"); .expect("Retry-After header missing");
assert_eq!(retry_after, "10"); assert_eq!(retry_after, "10");
} }
#[tokio::test] #[tokio::test]
async fn test_model_not_ready_response_body() { async fn test_model_not_ready_response_body() {
let err = AppError::ModelNotReady { state: "unloaded".into(), retry_after_secs: 30 }; let err = AppError::ModelNotReady {
state: "unloaded".into(),
retry_after_secs: 30,
};
let resp = err.into_response(); let resp = err.into_response();
let bytes = to_bytes(resp.into_body(), usize::MAX).await.unwrap(); let bytes = to_bytes(resp.into_body(), usize::MAX).await.unwrap();
let v: serde_json::Value = serde_json::from_slice(&bytes).unwrap(); let v: serde_json::Value = serde_json::from_slice(&bytes).unwrap();
@@ -138,21 +156,21 @@ mod tests {
#[tokio::test] #[tokio::test]
async fn test_model_not_ready_loading_retry_after_10() { async fn test_model_not_ready_loading_retry_after_10() {
let err = AppError::ModelNotReady { state: "loading".into(), retry_after_secs: 10 }; let err = AppError::ModelNotReady {
state: "loading".into(),
retry_after_secs: 10,
};
let resp = err.into_response(); let resp = err.into_response();
assert_eq!( assert_eq!(resp.headers().get(header::RETRY_AFTER).unwrap(), "10");
resp.headers().get(header::RETRY_AFTER).unwrap(),
"10"
);
} }
#[tokio::test] #[tokio::test]
async fn test_model_not_ready_unloaded_retry_after_30() { async fn test_model_not_ready_unloaded_retry_after_30() {
let err = AppError::ModelNotReady { state: "unloaded".into(), retry_after_secs: 30 }; let err = AppError::ModelNotReady {
state: "unloaded".into(),
retry_after_secs: 30,
};
let resp = err.into_response(); let resp = err.into_response();
assert_eq!( assert_eq!(resp.headers().get(header::RETRY_AFTER).unwrap(), "30");
resp.headers().get(header::RETRY_AFTER).unwrap(),
"30"
);
} }
} }

View File

@@ -97,10 +97,10 @@ async fn main() -> anyhow::Result<()> {
.with(tracing_subscriber::fmt::layer().json()) .with(tracing_subscriber::fmt::layer().json())
.init(); .init();
let data_dir = std::env::var("DATA_DIR").unwrap_or_else(|_| "/data".into()); let data_dir = std::env::var("DATA_DIR").unwrap_or_else(|_| "/data".into());
let model_path = std::env::var("WHISPER_MODEL_PATH") let model_path =
.unwrap_or_else(|_| "/models/ggml-large-v3.bin".into()); std::env::var("WHISPER_MODEL_PATH").unwrap_or_else(|_| "/models/ggml-large-v3.bin".into());
let port = std::env::var("PORT").unwrap_or_else(|_| "8080".into()); let port = std::env::var("PORT").unwrap_or_else(|_| "8080".into());
let model_name = std::env::var("WHISPER_MODEL").unwrap_or_else(|_| "large-v3".into()); let model_name = std::env::var("WHISPER_MODEL").unwrap_or_else(|_| "large-v3".into());
let gpu_device: u32 = std::env::var("CUDA_DEVICE") let gpu_device: u32 = std::env::var("CUDA_DEVICE")
.ok() .ok()
@@ -132,7 +132,9 @@ async fn main() -> anyhow::Result<()> {
// Model starts unloaded — lazy load on first job or POST /model/load. // Model starts unloaded — lazy load on first job or POST /model/load.
let model_state = Arc::new(RwLock::new(models::ModelState::Unloaded)); let model_state = Arc::new(RwLock::new(models::ModelState::Unloaded));
let (model_event_tx, _) = broadcast::channel::<models::ModelEvent>(32); let (model_event_tx, _) = broadcast::channel::<models::ModelEvent>(32);
let webhook_registry = Arc::new(std::sync::Mutex::new(std::collections::HashSet::<String>::new())); let webhook_registry = Arc::new(std::sync::Mutex::new(
std::collections::HashSet::<String>::new(),
));
// Spawn single GPU worker; get back the SSE broadcast registry and cmd channel. // Spawn single GPU worker; get back the SSE broadcast registry and cmd channel.
let (progress, cmd_tx) = worker::start( let (progress, cmd_tx) = worker::start(
@@ -153,13 +155,13 @@ async fn main() -> anyhow::Result<()> {
cmd_tx, cmd_tx,
storage: Arc::clone(&storage), storage: Arc::clone(&storage),
progress, progress,
model_name: model_name.as_str().into(), model_name: model_name.as_str().into(),
queue_depth: Arc::clone(&queue_depth), queue_depth: Arc::clone(&queue_depth),
gpu_device, gpu_device,
model_state, model_state,
model_event_tx, model_event_tx,
webhook_registry, webhook_registry,
idle_timeout: std::time::Duration::from_secs(idle_timeout_secs), idle_timeout: std::time::Duration::from_secs(idle_timeout_secs),
gpu_poll_interval: std::time::Duration::from_secs(gpu_poll_interval_secs), gpu_poll_interval: std::time::Duration::from_secs(gpu_poll_interval_secs),
}; };

View File

@@ -48,20 +48,20 @@ impl ModelState {
/// Suggested `Retry-After` value (seconds) to include in 503 responses. /// Suggested `Retry-After` value (seconds) to include in 503 responses.
pub fn retry_after_secs(&self) -> u64 { pub fn retry_after_secs(&self) -> u64 {
match self { match self {
ModelState::Unloaded => 30, // conservative load estimate ModelState::Unloaded => 30, // conservative load estimate
ModelState::Loading => 10, ModelState::Loading => 10,
ModelState::WaitingForGpu { retry_in_secs, .. } => *retry_in_secs, ModelState::WaitingForGpu { retry_in_secs, .. } => *retry_in_secs,
ModelState::Ready { .. } => 0, // shouldn't 503 if ready ModelState::Ready { .. } => 0, // shouldn't 503 if ready
} }
} }
/// String tag for use in error response bodies and log fields. /// String tag for use in error response bodies and log fields.
pub fn tag(&self) -> &'static str { pub fn tag(&self) -> &'static str {
match self { match self {
ModelState::Unloaded => "unloaded", ModelState::Unloaded => "unloaded",
ModelState::Loading => "loading", ModelState::Loading => "loading",
ModelState::WaitingForGpu{..} => "waiting_for_gpu", ModelState::WaitingForGpu { .. } => "waiting_for_gpu",
ModelState::Ready{..} => "ready", ModelState::Ready { .. } => "ready",
} }
} }
} }
@@ -77,9 +77,7 @@ impl ModelState {
#[serde(tag = "type", rename_all = "snake_case")] #[serde(tag = "type", rename_all = "snake_case")]
pub enum ModelEvent { pub enum ModelEvent {
/// Model finished loading and the GPU warmup completed — ready to accept jobs. /// Model finished loading and the GPU warmup completed — ready to accept jobs.
ModelReady { ModelReady { loaded_at: DateTime<Utc> },
loaded_at: DateTime<Utc>,
},
/// Model was unloaded from GPU memory (idle timeout or manual unload). /// Model was unloaded from GPU memory (idle timeout or manual unload).
ModelUnloaded, ModelUnloaded,
/// Model load initiated. /// Model load initiated.
@@ -87,15 +85,18 @@ pub enum ModelEvent {
/// Load failed due to insufficient VRAM; retrying after `retry_in_secs`. /// Load failed due to insufficient VRAM; retrying after `retry_in_secs`.
ModelWaitingForGpu { ModelWaitingForGpu {
vram_needed_mb: u64, vram_needed_mb: u64,
vram_free_mb: u64, vram_free_mb: u64,
retry_in_secs: u64, retry_in_secs: u64,
}, },
} }
impl ModelEvent { impl ModelEvent {
/// Returns true if this event should be delivered via webhook. /// Returns true if this event should be delivered via webhook.
pub fn is_webhook_event(&self) -> bool { pub fn is_webhook_event(&self) -> bool {
matches!(self, ModelEvent::ModelReady { .. } | ModelEvent::ModelUnloaded) matches!(
self,
ModelEvent::ModelReady { .. } | ModelEvent::ModelUnloaded
)
} }
} }
@@ -132,11 +133,11 @@ pub enum JobStatus {
#[derive(Debug, Clone, Serialize, Deserialize, ToSchema)] #[derive(Debug, Clone, Serialize, Deserialize, ToSchema)]
pub struct Word { pub struct Word {
/// Word text /// Word text
pub text: String, pub text: String,
/// Start time in seconds /// Start time in seconds
pub start: f32, pub start: f32,
/// End time in seconds /// End time in seconds
pub end: f32, pub end: f32,
/// Model confidence (01) /// Model confidence (01)
pub probability: f32, pub probability: f32,
} }
@@ -148,9 +149,9 @@ pub struct Segment {
/// Start time in seconds /// Start time in seconds
pub start: f32, pub start: f32,
/// End time in seconds /// End time in seconds
pub end: f32, pub end: f32,
/// Transcribed text /// Transcribed text
pub text: String, pub text: String,
/// Token-level word timestamps (empty when flash_attn is enabled) /// Token-level word timestamps (empty when flash_attn is enabled)
#[serde(default)] #[serde(default)]
pub words: Vec<Word>, pub words: Vec<Word>,
@@ -205,18 +206,23 @@ pub struct Job {
} }
impl Job { impl Job {
pub fn new(id: JobId, task: String, webhook_url: Option<String>, filename: Option<String>) -> Self { pub fn new(
id: JobId,
task: String,
webhook_url: Option<String>,
filename: Option<String>,
) -> Self {
Self { Self {
id, id,
status: JobStatus::Queued, status: JobStatus::Queued,
language: None, language: None,
task, task,
duration_secs: None, duration_secs: None,
segments: vec![], segments: vec![],
error: None, error: None,
webhook_url, webhook_url,
progress: 0, progress: 0,
created_at: Utc::now(), created_at: Utc::now(),
completed_at: None, completed_at: None,
filename, filename,
} }
@@ -235,13 +241,13 @@ pub struct SubmitResponse {
/// Response from GET /health. /// Response from GET /health.
#[derive(Debug, Serialize, ToSchema)] #[derive(Debug, Serialize, ToSchema)]
pub struct HealthResponse { pub struct HealthResponse {
pub status: String, pub status: String,
pub gpu_name: Option<String>, pub gpu_name: Option<String>,
pub vram_total_mb: Option<u64>, pub vram_total_mb: Option<u64>,
pub model: String, pub model: String,
pub queue_depth: usize, pub queue_depth: usize,
/// Current state of the whisper model. /// Current state of the whisper model.
pub model_state: String, pub model_state: String,
} }
// ── SSE event payload ──────────────────────────────────────────────────────── // ── SSE event payload ────────────────────────────────────────────────────────
@@ -257,8 +263,12 @@ pub enum SsePayload {
/// Total number of silence-split chunks in this job. /// Total number of silence-split chunks in this job.
chunks_total: usize, chunks_total: usize,
}, },
Done { job: Box<Job> }, Done {
Error { message: String }, job: Box<Job>,
},
Error {
message: String,
},
} }
// ── Unit tests ─────────────────────────────────────────────────────────────── // ── Unit tests ───────────────────────────────────────────────────────────────
@@ -284,7 +294,11 @@ mod tests {
#[test] #[test]
fn test_model_state_waiting_serializes() { fn test_model_state_waiting_serializes() {
let s = ModelState::WaitingForGpu { vram_needed_mb: 3000, vram_free_mb: 500, retry_in_secs: 30 }; let s = ModelState::WaitingForGpu {
vram_needed_mb: 3000,
vram_free_mb: 500,
retry_in_secs: 30,
};
let v: Value = serde_json::to_value(&s).unwrap(); let v: Value = serde_json::to_value(&s).unwrap();
assert_eq!(v["state"], "waiting_for_gpu"); assert_eq!(v["state"], "waiting_for_gpu");
assert_eq!(v["vram_needed_mb"], 3000); assert_eq!(v["vram_needed_mb"], 3000);
@@ -305,8 +319,16 @@ mod tests {
fn test_model_state_is_ready() { fn test_model_state_is_ready() {
assert!(!ModelState::Unloaded.is_ready()); assert!(!ModelState::Unloaded.is_ready());
assert!(!ModelState::Loading.is_ready()); assert!(!ModelState::Loading.is_ready());
assert!(!ModelState::WaitingForGpu { vram_needed_mb: 0, vram_free_mb: 0, retry_in_secs: 30 }.is_ready()); assert!(!ModelState::WaitingForGpu {
assert!(ModelState::Ready { loaded_at: Utc::now() }.is_ready()); vram_needed_mb: 0,
vram_free_mb: 0,
retry_in_secs: 30
}
.is_ready());
assert!(ModelState::Ready {
loaded_at: Utc::now()
}
.is_ready());
} }
#[test] #[test]
@@ -321,13 +343,23 @@ mod tests {
#[test] #[test]
fn test_retry_after_waiting_for_gpu() { fn test_retry_after_waiting_for_gpu() {
let s = ModelState::WaitingForGpu { vram_needed_mb: 0, vram_free_mb: 0, retry_in_secs: 45 }; let s = ModelState::WaitingForGpu {
vram_needed_mb: 0,
vram_free_mb: 0,
retry_in_secs: 45,
};
assert_eq!(s.retry_after_secs(), 45); assert_eq!(s.retry_after_secs(), 45);
} }
#[test] #[test]
fn test_retry_after_ready_is_zero() { fn test_retry_after_ready_is_zero() {
assert_eq!(ModelState::Ready { loaded_at: Utc::now() }.retry_after_secs(), 0); assert_eq!(
ModelState::Ready {
loaded_at: Utc::now()
}
.retry_after_secs(),
0
);
} }
// ── ModelEvent serialization ───────────────────────────────────────────── // ── ModelEvent serialization ─────────────────────────────────────────────
@@ -355,7 +387,11 @@ mod tests {
#[test] #[test]
fn test_model_event_waiting_serializes() { fn test_model_event_waiting_serializes() {
let e = ModelEvent::ModelWaitingForGpu { vram_needed_mb: 3000, vram_free_mb: 200, retry_in_secs: 30 }; let e = ModelEvent::ModelWaitingForGpu {
vram_needed_mb: 3000,
vram_free_mb: 200,
retry_in_secs: 30,
};
let v: Value = serde_json::to_value(&e).unwrap(); let v: Value = serde_json::to_value(&e).unwrap();
assert_eq!(v["type"], "model_waiting_for_gpu"); assert_eq!(v["type"], "model_waiting_for_gpu");
assert_eq!(v["vram_needed_mb"], 3000); assert_eq!(v["vram_needed_mb"], 3000);
@@ -363,10 +399,18 @@ mod tests {
#[test] #[test]
fn test_model_event_webhook_filter() { fn test_model_event_webhook_filter() {
assert!(ModelEvent::ModelReady { loaded_at: Utc::now() }.is_webhook_event()); assert!(ModelEvent::ModelReady {
loaded_at: Utc::now()
}
.is_webhook_event());
assert!(ModelEvent::ModelUnloaded.is_webhook_event()); assert!(ModelEvent::ModelUnloaded.is_webhook_event());
assert!(!ModelEvent::ModelLoading.is_webhook_event()); assert!(!ModelEvent::ModelLoading.is_webhook_event());
assert!(!ModelEvent::ModelWaitingForGpu { vram_needed_mb: 0, vram_free_mb: 0, retry_in_secs: 30 }.is_webhook_event()); assert!(!ModelEvent::ModelWaitingForGpu {
vram_needed_mb: 0,
vram_free_mb: 0,
retry_in_secs: 30
}
.is_webhook_event());
} }
// ── ModelStatusResponse ────────────────────────────────────────────────── // ── ModelStatusResponse ──────────────────────────────────────────────────
@@ -374,8 +418,10 @@ mod tests {
#[test] #[test]
fn test_model_status_response_roundtrip() { fn test_model_status_response_roundtrip() {
let r = ModelStatusResponse { let r = ModelStatusResponse {
state: ModelState::Ready { loaded_at: Utc::now() }, state: ModelState::Ready {
vram_used_mb: Some(4096), loaded_at: Utc::now(),
},
vram_used_mb: Some(4096),
vram_total_mb: Some(8192), vram_total_mb: Some(8192),
}; };
let json_str = serde_json::to_string(&r).unwrap(); let json_str = serde_json::to_string(&r).unwrap();
@@ -387,7 +433,11 @@ mod tests {
#[test] #[test]
fn test_model_status_response_omits_nulls() { fn test_model_status_response_omits_nulls() {
let r = ModelStatusResponse { state: ModelState::Loading, vram_used_mb: None, vram_total_mb: None }; let r = ModelStatusResponse {
state: ModelState::Loading,
vram_used_mb: None,
vram_total_mb: None,
};
let v: Value = serde_json::to_value(&r).unwrap(); let v: Value = serde_json::to_value(&r).unwrap();
assert_eq!(v["state"], "loading"); assert_eq!(v["state"], "loading");
assert!(v.get("vram_used_mb").is_none()); assert!(v.get("vram_used_mb").is_none());

View File

@@ -19,12 +19,12 @@ pub async fn health(State(state): State<AppState>) -> Result<Json<HealthResponse
let model_state_tag = state.model_state.read().await.tag().to_string(); let model_state_tag = state.model_state.read().await.tag().to_string();
Ok(Json(HealthResponse { Ok(Json(HealthResponse {
status: "ok".into(), status: "ok".into(),
gpu_name, gpu_name,
vram_total_mb, vram_total_mb,
model: state.model_name.to_string(), model: state.model_name.to_string(),
queue_depth: state.queue_depth.load(Ordering::Relaxed), queue_depth: state.queue_depth.load(Ordering::Relaxed),
model_state: model_state_tag, model_state: model_state_tag,
})) }))
} }
@@ -50,9 +50,7 @@ fn gpu_info(device: u32) -> (Option<String>, Option<u64>) {
let mut parts = line.splitn(2, ','); let mut parts = line.splitn(2, ',');
let name = parts.next().map(|s| s.trim().to_owned()); let name = parts.next().map(|s| s.trim().to_owned());
let vram = parts let vram = parts.next().and_then(|s| s.trim().parse::<u64>().ok());
.next()
.and_then(|s| s.trim().parse::<u64>().ok());
(name, vram) (name, vram)
} }

View File

@@ -23,7 +23,8 @@ use crate::{
AppError, AppState, Result, AppError, AppState, Result,
}; };
type SseStream = Pin<Box<dyn Stream<Item = std::result::Result<Event, std::convert::Infallible>> + Send>>; type SseStream =
Pin<Box<dyn Stream<Item = std::result::Result<Event, std::convert::Infallible>> + Send>>;
// ── POST /jobs ─────────────────────────────────────────────────────────────── // ── POST /jobs ───────────────────────────────────────────────────────────────
@@ -53,18 +54,20 @@ pub async fn submit_job(
State(state): State<AppState>, State(state): State<AppState>,
mut multipart: Multipart, mut multipart: Multipart,
) -> Result<impl IntoResponse> { ) -> Result<impl IntoResponse> {
let mut language: Option<String> = None; let mut language: Option<String> = None;
let mut task: String = "transcribe".into(); let mut task: String = "transcribe".into();
let mut webhook_url: Option<String> = None; let mut webhook_url: Option<String> = None;
let mut filename: Option<String> = None; let mut filename: Option<String> = None;
let mut audio_saved = false; let mut audio_saved = false;
// Assign ID early so we know where to stream the audio bytes. // Assign ID early so we know where to stream the audio bytes.
let id = Uuid::new_v4(); let id = Uuid::new_v4();
let audio_path = audio_path_for(&id); let audio_path = audio_path_for(&id);
while let Some(field) = multipart.next_field().await.map_err(|e| { while let Some(field) = multipart
AppError::BadRequest(format!("multipart error: {e}")) .next_field()
})? { .await
.map_err(|e| AppError::BadRequest(format!("multipart error: {e}")))?
{
let field_name = field.name().unwrap_or("").to_owned(); let field_name = field.name().unwrap_or("").to_owned();
match field_name.as_str() { match field_name.as_str() {
@@ -77,9 +80,11 @@ pub async fn submit_job(
})?; })?;
let mut bytes_written: u64 = 0; let mut bytes_written: u64 = 0;
let mut stream = field; let mut stream = field;
while let Some(chunk) = stream.chunk().await.map_err(|e| { while let Some(chunk) = stream
AppError::BadRequest(format!("failed to read audio field: {e}")) .chunk()
})? { .await
.map_err(|e| AppError::BadRequest(format!("failed to read audio field: {e}")))?
{
file.write_all(&chunk).await.map_err(|e| { file.write_all(&chunk).await.map_err(|e| {
AppError::Internal(format!("failed to write audio chunk: {e}")) AppError::Internal(format!("failed to write audio chunk: {e}"))
})?; })?;
@@ -90,10 +95,29 @@ pub async fn submit_job(
} }
audio_saved = true; audio_saved = true;
} }
"language" => language = Some(field.text().await.map_err(|e| AppError::BadRequest(e.to_string()))?), "language" => {
"task" => task = field.text().await.map_err(|e| AppError::BadRequest(e.to_string()))?, language = Some(
"webhook_url" => webhook_url = Some(field.text().await.map_err(|e| AppError::BadRequest(e.to_string()))?), field
_ => {} // ignore unknown fields .text()
.await
.map_err(|e| AppError::BadRequest(e.to_string()))?,
)
}
"task" => {
task = field
.text()
.await
.map_err(|e| AppError::BadRequest(e.to_string()))?
}
"webhook_url" => {
webhook_url = Some(
field
.text()
.await
.map_err(|e| AppError::BadRequest(e.to_string()))?,
)
}
_ => {} // ignore unknown fields
} }
} }
@@ -112,14 +136,16 @@ pub async fn submit_job(
let ms = state.model_state.read().await; let ms = state.model_state.read().await;
let ready = ms.is_ready(); let ready = ms.is_ready();
let retry = ms.retry_after_secs(); let retry = ms.retry_after_secs();
let tag = ms.tag().to_string(); let tag = ms.tag().to_string();
(ready, retry, tag) (ready, retry, tag)
}; };
// Register the webhook URL regardless of model state — so model lifecycle // Register the webhook URL regardless of model state — so model lifecycle
// events are delivered even if the job itself is rejected. // events are delivered even if the job itself is rejected.
if let Some(url) = &webhook_url { if let Some(url) = &webhook_url {
state.webhook_registry.lock() state
.webhook_registry
.lock()
.unwrap_or_else(|e| e.into_inner()) .unwrap_or_else(|e| e.into_inner())
.insert(url.clone()); .insert(url.clone());
} }
@@ -143,12 +169,16 @@ pub async fn submit_job(
state.storage.create(&job).await?; state.storage.create(&job).await?;
// Pre-create the broadcast channel so SSE subscribers don't miss events. // Pre-create the broadcast channel so SSE subscribers don't miss events.
state.progress.entry(id).or_insert_with(|| broadcast::channel(64).0); state
.progress
.entry(id)
.or_insert_with(|| broadcast::channel(64).0);
state.queue_depth.fetch_add(1, Ordering::Relaxed); state.queue_depth.fetch_add(1, Ordering::Relaxed);
state.job_tx.send(id).map_err(|_| { state
AppError::Internal("worker channel closed".into()) .job_tx
})?; .send(id)
.map_err(|_| AppError::Internal("worker channel closed".into()))?;
tracing::info!(job_id = %id, "job queued"); tracing::info!(job_id = %id, "job queued");
@@ -168,10 +198,7 @@ pub async fn submit_job(
(status = 404, description = "Not found"), (status = 404, description = "Not found"),
) )
)] )]
pub async fn get_job( pub async fn get_job(State(state): State<AppState>, Path(id): Path<JobId>) -> Result<Json<Job>> {
State(state): State<AppState>,
Path(id): Path<JobId>,
) -> Result<Json<Job>> {
let job = state.storage.get(&id).await?; let job = state.storage.get(&id).await?;
Ok(Json(job)) Ok(Json(job))
} }
@@ -196,15 +223,15 @@ pub async fn get_job(
)] )]
pub async fn stream_job( pub async fn stream_job(
State(state): State<AppState>, State(state): State<AppState>,
Path(id): Path<JobId>, Path(id): Path<JobId>,
) -> Result<Sse<SseStream>> { ) -> Result<Sse<SseStream>> {
// If the job is already finished, return a single done event immediately. // If the job is already finished, return a single done event immediately.
let job = state.storage.get(&id).await?; let job = state.storage.get(&id).await?;
match job.status { match job.status {
JobStatus::Done | JobStatus::Failed | JobStatus::Cancelled => { JobStatus::Done | JobStatus::Failed | JobStatus::Cancelled => {
let payload = serde_json::to_string( let payload =
&crate::models::SsePayload::Done { job: Box::new(job) } serde_json::to_string(&crate::models::SsePayload::Done { job: Box::new(job) })
).unwrap_or_default(); .unwrap_or_default();
let s: SseStream = Box::pin(stream::once(async move { let s: SseStream = Box::pin(stream::once(async move {
Ok(Event::default().event("done").data(payload)) Ok(Event::default().event("done").data(payload))
})); }));
@@ -222,22 +249,28 @@ pub async fn stream_job(
let sse_stream: SseStream = Box::pin(BroadcastStream::new(rx).filter_map(|msg| async move { let sse_stream: SseStream = Box::pin(BroadcastStream::new(rx).filter_map(|msg| async move {
let event = match msg { let event = match msg {
Ok(ProgressEvent::Progress { percent, chunk, total }) => { Ok(ProgressEvent::Progress {
let payload = serde_json::to_string( percent,
&crate::models::SsePayload::Progress { percent, chunk, chunks_total: total } chunk,
).ok()?; total,
}) => {
let payload = serde_json::to_string(&crate::models::SsePayload::Progress {
percent,
chunk,
chunks_total: total,
})
.ok()?;
Event::default().event("progress").data(payload) Event::default().event("progress").data(payload)
} }
Ok(ProgressEvent::Done(job)) => { Ok(ProgressEvent::Done(job)) => {
let payload = serde_json::to_string( let payload =
&crate::models::SsePayload::Done { job } serde_json::to_string(&crate::models::SsePayload::Done { job }).ok()?;
).ok()?;
Event::default().event("done").data(payload) Event::default().event("done").data(payload)
} }
Ok(ProgressEvent::Error(msg)) => { Ok(ProgressEvent::Error(msg)) => {
let payload = serde_json::to_string( let payload =
&crate::models::SsePayload::Error { message: msg } serde_json::to_string(&crate::models::SsePayload::Error { message: msg })
).ok()?; .ok()?;
Event::default().event("error").data(payload) Event::default().event("error").data(payload)
} }
Err(_) => return None, // lagged / channel closed Err(_) => return None, // lagged / channel closed
@@ -264,10 +297,7 @@ pub async fn stream_job(
(status = 409, description = "Job already finished"), (status = 409, description = "Job already finished"),
) )
)] )]
pub async fn delete_job( pub async fn delete_job(State(state): State<AppState>, Path(id): Path<JobId>) -> Result<Json<Job>> {
State(state): State<AppState>,
Path(id): Path<JobId>,
) -> Result<Json<Job>> {
let mut job = state.storage.get(&id).await?; let mut job = state.storage.get(&id).await?;
match job.status { match job.status {
@@ -280,7 +310,7 @@ pub async fn delete_job(
_ => {} _ => {}
} }
job.status = JobStatus::Cancelled; job.status = JobStatus::Cancelled;
job.completed_at = Some(Utc::now()); job.completed_at = Some(Utc::now());
state.storage.save(&job).await?; state.storage.save(&job).await?;

View File

@@ -2,27 +2,33 @@ pub mod health;
pub mod jobs; pub mod jobs;
pub mod model; pub mod model;
use axum::{extract::DefaultBodyLimit, routing::{delete, get, post}, Router};
use crate::AppState; use crate::AppState;
use axum::{
extract::DefaultBodyLimit,
routing::{delete, get, post},
Router,
};
pub fn jobs_router() -> Router<AppState> { pub fn jobs_router() -> Router<AppState> {
Router::new() Router::new()
// No body limit on the upload route — files can be multiple GB. // No body limit on the upload route — files can be multiple GB.
.route("/jobs", post(jobs::submit_job).layer(DefaultBodyLimit::disable())) .route(
.route("/jobs/:id", get(jobs::get_job)) "/jobs",
post(jobs::submit_job).layer(DefaultBodyLimit::disable()),
)
.route("/jobs/:id", get(jobs::get_job))
.route("/jobs/:id/stream", get(jobs::stream_job)) .route("/jobs/:id/stream", get(jobs::stream_job))
.route("/jobs/:id", delete(jobs::delete_job)) .route("/jobs/:id", delete(jobs::delete_job))
} }
pub fn health_router() -> Router<AppState> { pub fn health_router() -> Router<AppState> {
Router::new() Router::new().route("/health", get(health::health))
.route("/health", get(health::health))
} }
pub fn model_router() -> Router<AppState> { pub fn model_router() -> Router<AppState> {
Router::new() Router::new()
.route("/model/status", get(model::model_status)) .route("/model/status", get(model::model_status))
.route("/model/load", post(model::model_load)) .route("/model/load", post(model::model_load))
.route("/model/unload", post(model::model_unload)) .route("/model/unload", post(model::model_unload))
.route("/model/events", get(model::model_events)) .route("/model/events", get(model::model_events))
} }

View File

@@ -10,8 +10,8 @@ use axum::{
Json, Json,
}; };
use futures::Stream; use futures::Stream;
use tokio_stream::wrappers::BroadcastStream;
use futures::StreamExt; use futures::StreamExt;
use tokio_stream::wrappers::BroadcastStream;
use crate::{ use crate::{
models::{ModelEvent, ModelStatusResponse}, models::{ModelEvent, ModelStatusResponse},
@@ -19,7 +19,8 @@ use crate::{
AppState, Result, AppState, Result,
}; };
type SseStream = Pin<Box<dyn Stream<Item = std::result::Result<Event, std::convert::Infallible>> + Send>>; type SseStream =
Pin<Box<dyn Stream<Item = std::result::Result<Event, std::convert::Infallible>> + Send>>;
// ── GET /model/status ──────────────────────────────────────────────────────── // ── GET /model/status ────────────────────────────────────────────────────────
@@ -61,11 +62,17 @@ pub async fn model_status(State(state): State<AppState>) -> Result<Json<ModelSta
pub async fn model_load(State(state): State<AppState>) -> impl IntoResponse { pub async fn model_load(State(state): State<AppState>) -> impl IntoResponse {
let is_ready = state.model_state.read().await.is_ready(); let is_ready = state.model_state.read().await.is_ready();
if is_ready { if is_ready {
return (StatusCode::OK, Json(serde_json::json!({"status": "already_ready"}))); return (
StatusCode::OK,
Json(serde_json::json!({"status": "already_ready"})),
);
} }
// Ignore send errors (channel full = load already in progress). // Ignore send errors (channel full = load already in progress).
let _ = state.cmd_tx.try_send(WorkerCmd::Load); let _ = state.cmd_tx.try_send(WorkerCmd::Load);
(StatusCode::ACCEPTED, Json(serde_json::json!({"status": "load_initiated"}))) (
StatusCode::ACCEPTED,
Json(serde_json::json!({"status": "load_initiated"})),
)
} }
// ── POST /model/unload ─────────────────────────────────────────────────────── // ── POST /model/unload ───────────────────────────────────────────────────────
@@ -82,7 +89,10 @@ pub async fn model_load(State(state): State<AppState>) -> impl IntoResponse {
)] )]
pub async fn model_unload(State(state): State<AppState>) -> impl IntoResponse { pub async fn model_unload(State(state): State<AppState>) -> impl IntoResponse {
let _ = state.cmd_tx.try_send(WorkerCmd::Unload); let _ = state.cmd_tx.try_send(WorkerCmd::Unload);
(StatusCode::OK, Json(serde_json::json!({"status": "unload_requested"}))) (
StatusCode::OK,
Json(serde_json::json!({"status": "unload_requested"})),
)
} }
// ── GET /model/events ──────────────────────────────────────────────────────── // ── GET /model/events ────────────────────────────────────────────────────────
@@ -105,23 +115,21 @@ pub async fn model_unload(State(state): State<AppState>) -> impl IntoResponse {
pub async fn model_events(State(state): State<AppState>) -> Sse<SseStream> { pub async fn model_events(State(state): State<AppState>) -> Sse<SseStream> {
let rx = state.model_event_tx.subscribe(); let rx = state.model_event_tx.subscribe();
let stream: SseStream = Box::pin( let stream: SseStream = Box::pin(BroadcastStream::new(rx).filter_map(|msg| async move {
BroadcastStream::new(rx).filter_map(|msg| async move { match msg {
match msg { Ok(event) => {
Ok(event) => { let event_type = match &event {
let event_type = match &event { ModelEvent::ModelReady { .. } => "model_ready",
ModelEvent::ModelReady { .. } => "model_ready", ModelEvent::ModelUnloaded => "model_unloaded",
ModelEvent::ModelUnloaded => "model_unloaded", ModelEvent::ModelLoading => "model_loading",
ModelEvent::ModelLoading => "model_loading", ModelEvent::ModelWaitingForGpu { .. } => "model_waiting_for_gpu",
ModelEvent::ModelWaitingForGpu {..} => "model_waiting_for_gpu", };
}; let data = serde_json::to_string(&event).ok()?;
let data = serde_json::to_string(&event).ok()?; Some(Ok(Event::default().event(event_type).data(data)))
Some(Ok(Event::default().event(event_type).data(data)))
}
Err(_) => None,
} }
}) Err(_) => None,
); }
}));
Sse::new(stream).keep_alive(KeepAlive::default()) Sse::new(stream).keep_alive(KeepAlive::default())
} }
@@ -146,13 +154,13 @@ fn vram_stats(gpu_device: u32) -> (Option<u64>, Option<u64>) {
let line = String::from_utf8_lossy(&out.stdout); let line = String::from_utf8_lossy(&out.stdout);
let line = line.trim(); let line = line.trim();
let mut parts = line.splitn(2, ','); let mut parts = line.splitn(2, ',');
let used = parts.next().and_then(|s| s.trim().parse::<u64>().ok())?; let used = parts.next().and_then(|s| s.trim().parse::<u64>().ok())?;
let total = parts.next().and_then(|s| s.trim().parse::<u64>().ok())?; let total = parts.next().and_then(|s| s.trim().parse::<u64>().ok())?;
Some((used, total)) Some((used, total))
} }
match inner(gpu_device) { match inner(gpu_device) {
Some((u, t)) => (Some(u), Some(t)), Some((u, t)) => (Some(u), Some(t)),
None => (None, None), None => (None, None),
} }
} }

View File

@@ -30,20 +30,20 @@ impl Storage {
// ── CRUD ───────────────────────────────────────────────────────────────── // ── CRUD ─────────────────────────────────────────────────────────────────
pub async fn create(&self, job: &Job) -> Result<()> { pub async fn create(&self, job: &Job) -> Result<()> {
let path = self.job_path(&job.id); let path = self.job_path(&job.id);
let payload = serde_json::to_vec_pretty(job) let payload =
.map_err(|e| AppError::Internal(e.to_string()))?; serde_json::to_vec_pretty(job).map_err(|e| AppError::Internal(e.to_string()))?;
fs::write(&path, payload).await.map_err(|e| { fs::write(&path, payload)
AppError::Internal(format!("failed to write job {}: {e}", job.id)) .await
})?; .map_err(|e| AppError::Internal(format!("failed to write job {}: {e}", job.id)))?;
Ok(()) Ok(())
} }
pub async fn get(&self, id: &JobId) -> Result<Job> { pub async fn get(&self, id: &JobId) -> Result<Job> {
let path = self.job_path(id); let path = self.job_path(id);
let raw = fs::read(&path).await.map_err(|_| { let raw = fs::read(&path)
AppError::NotFound(format!("job {id} not found")) .await
})?; .map_err(|_| AppError::NotFound(format!("job {id} not found")))?;
serde_json::from_slice(&raw).map_err(|e| AppError::Internal(e.to_string())) serde_json::from_slice(&raw).map_err(|e| AppError::Internal(e.to_string()))
} }
@@ -54,22 +54,24 @@ impl Storage {
pub async fn delete(&self, id: &JobId) -> Result<()> { pub async fn delete(&self, id: &JobId) -> Result<()> {
let path = self.job_path(id); let path = self.job_path(id);
fs::remove_file(&path).await.map_err(|_| { fs::remove_file(&path)
AppError::NotFound(format!("job {id} not found")) .await
})?; .map_err(|_| AppError::NotFound(format!("job {id} not found")))?;
Ok(()) Ok(())
} }
/// List all job IDs present on disk. /// List all job IDs present on disk.
pub async fn list_ids(&self) -> Result<Vec<JobId>> { pub async fn list_ids(&self) -> Result<Vec<JobId>> {
let mut entries = fs::read_dir(&self.dir).await.map_err(|e| { let mut entries = fs::read_dir(&self.dir)
AppError::Internal(format!("read_dir failed: {e}")) .await
})?; .map_err(|e| AppError::Internal(format!("read_dir failed: {e}")))?;
let mut ids = Vec::new(); let mut ids = Vec::new();
while let Some(entry) = entries.next_entry().await.map_err(|e| { while let Some(entry) = entries
AppError::Internal(e.to_string()) .next_entry()
})? { .await
.map_err(|e| AppError::Internal(e.to_string()))?
{
let name = entry.file_name(); let name = entry.file_name();
let name = name.to_string_lossy(); let name = name.to_string_lossy();
if let Some(stem) = name.strip_suffix(".json") { if let Some(stem) = name.strip_suffix(".json") {
@@ -88,8 +90,8 @@ impl Storage {
if let Ok(mut job) = self.get(&id).await { if let Ok(mut job) = self.get(&id).await {
if job.status == JobStatus::Running { if job.status == JobStatus::Running {
tracing::warn!(job_id = %id, "recovering interrupted job → failed"); tracing::warn!(job_id = %id, "recovering interrupted job → failed");
job.status = JobStatus::Failed; job.status = JobStatus::Failed;
job.error = Some("server restarted while job was running".into()); job.error = Some("server restarted while job was running".into());
job.completed_at = Some(chrono::Utc::now()); job.completed_at = Some(chrono::Utc::now());
let _ = self.save(&job).await; let _ = self.save(&job).await;
} }

View File

@@ -37,9 +37,10 @@ impl Transcriber {
/// 0 segments. The warmup forces kernel compilation at startup so all subsequent /// 0 segments. The warmup forces kernel compilation at startup so all subsequent
/// jobs run correctly from the very first request. /// jobs run correctly from the very first request.
pub fn load(model_path: impl AsRef<Path>, gpu_device: u32) -> Result<Self> { pub fn load(model_path: impl AsRef<Path>, gpu_device: u32) -> Result<Self> {
let path = model_path.as_ref().to_str().ok_or_else(|| { let path = model_path
AppError::Internal("model path is not valid UTF-8".into()) .as_ref()
})?; .to_str()
.ok_or_else(|| AppError::Internal("model path is not valid UTF-8".into()))?;
let mut params = WhisperContextParameters::new(); let mut params = WhisperContextParameters::new();
params.use_gpu(true); params.use_gpu(true);
@@ -48,25 +49,23 @@ impl Transcriber {
// real-world audio (conference recordings, noisy MP3s). // real-world audio (conference recordings, noisy MP3s).
// params.flash_attn(true); // params.flash_attn(true);
let ctx = WhisperContext::new_with_params(path, params) let ctx = WhisperContext::new_with_params(path, params).map_err(|e| {
.map_err(|e| { let msg = format!("failed to load model: {e}");
let msg = format!("failed to load model: {e}"); if AppError::is_oom(&msg) {
if AppError::is_oom(&msg) { AppError::OutOfMemory(msg)
AppError::OutOfMemory(msg) } else {
} else { AppError::Internal(msg)
AppError::Internal(msg) }
} })?;
})?;
let mut state = ctx.create_state() let mut state = ctx.create_state().map_err(|e| {
.map_err(|e| { let msg = format!("failed to create whisper state: {e}");
let msg = format!("failed to create whisper state: {e}"); if AppError::is_oom(&msg) {
if AppError::is_oom(&msg) { AppError::OutOfMemory(msg)
AppError::OutOfMemory(msg) } else {
} else { AppError::Internal(msg)
AppError::Internal(msg) }
} })?;
})?;
// ctx drops here; state holds Arc<WhisperInnerContext> so model stays loaded. // ctx drops here; state holds Arc<WhisperInnerContext> so model stays loaded.
// ── GPU warmup ──────────────────────────────────────────────────────── // ── GPU warmup ────────────────────────────────────────────────────────
@@ -95,16 +94,16 @@ impl Transcriber {
/// `no_context=true` in the params prevents KV-cache contamination between chunks. /// `no_context=true` in the params prevents KV-cache contamination between chunks.
pub fn transcribe( pub fn transcribe(
&mut self, &mut self,
pcm: &[f32], pcm: &[f32],
language: Option<&str>, language: Option<&str>,
task: &str, task: &str,
on_progress: impl Fn(u8) + Send + 'static, on_progress: impl Fn(u8) + Send + 'static,
) -> Result<(Vec<Segment>, String)> { ) -> Result<(Vec<Segment>, String)> {
let state = &mut self.state; let state = &mut self.state;
let mut fp = FullParams::new(SamplingStrategy::BeamSearch { let mut fp = FullParams::new(SamplingStrategy::BeamSearch {
beam_size: 5, beam_size: 5,
patience: 1.0, patience: 1.0,
}); });
fp.set_n_threads(num_cpus::get() as i32); fp.set_n_threads(num_cpus::get() as i32);
@@ -158,40 +157,55 @@ impl Transcriber {
.full(fp, pcm) .full(fp, pcm)
.map_err(|e| AppError::Internal(format!("transcription failed: {e}")))?; .map_err(|e| AppError::Internal(format!("transcription failed: {e}")))?;
let n_segments = state.full_n_segments() let n_segments = state
.full_n_segments()
.map_err(|e| AppError::Internal(e.to_string()))?; .map_err(|e| AppError::Internal(e.to_string()))?;
let mut segments = Vec::with_capacity(n_segments as usize); let mut segments = Vec::with_capacity(n_segments as usize);
for i in 0..n_segments { for i in 0..n_segments {
let text = state.full_get_segment_text(i) let text = state
.full_get_segment_text(i)
.map_err(|e| AppError::Internal(e.to_string()))?; .map_err(|e| AppError::Internal(e.to_string()))?;
let start = state.full_get_segment_t0(i) let start = state
.map_err(|e| AppError::Internal(e.to_string()))? as f32 / 100.0; .full_get_segment_t0(i)
let end = state.full_get_segment_t1(i) .map_err(|e| AppError::Internal(e.to_string()))? as f32
.map_err(|e| AppError::Internal(e.to_string()))? as f32 / 100.0; / 100.0;
let end = state
.full_get_segment_t1(i)
.map_err(|e| AppError::Internal(e.to_string()))? as f32
/ 100.0;
let n_tokens = state.full_n_tokens(i) let n_tokens = state
.full_n_tokens(i)
.map_err(|e| AppError::Internal(e.to_string()))?; .map_err(|e| AppError::Internal(e.to_string()))?;
let mut words = Vec::new(); let mut words = Vec::new();
for t in 0..n_tokens { for t in 0..n_tokens {
let token_text = state.full_get_token_text(i, t) let token_text = state
.full_get_token_text(i, t)
.map_err(|e| AppError::Internal(e.to_string()))?; .map_err(|e| AppError::Internal(e.to_string()))?;
if token_text.starts_with('[') { if token_text.starts_with('[') {
continue; // skip special tokens ([MUSIC], [APPLAUSE], etc.) continue; // skip special tokens ([MUSIC], [APPLAUSE], etc.)
} }
let data = state.full_get_token_data(i, t) let data = state
.full_get_token_data(i, t)
.map_err(|e| AppError::Internal(e.to_string()))?; .map_err(|e| AppError::Internal(e.to_string()))?;
words.push(Word { words.push(Word {
text: token_text, text: token_text,
start: data.t0 as f32 / 100.0, start: data.t0 as f32 / 100.0,
end: data.t1 as f32 / 100.0, end: data.t1 as f32 / 100.0,
probability: data.p, probability: data.p,
}); });
} }
segments.push(Segment { index: i, start, end, text, words }); segments.push(Segment {
index: i,
start,
end,
text,
words,
});
} }
let lang = state let lang = state

View File

@@ -16,8 +16,7 @@ use crate::{
models::{Job, JobId, JobStatus, ModelEvent, ModelState, Segment}, models::{Job, JobId, JobStatus, ModelEvent, ModelState, Segment},
storage::Storage, storage::Storage,
transcriber::Transcriber, transcriber::Transcriber,
webhook, webhook, AppError,
AppError,
}; };
/// Per-job broadcast channel for SSE subscribers. /// Per-job broadcast channel for SSE subscribers.
@@ -26,7 +25,11 @@ pub type ProgressTx = broadcast::Sender<ProgressEvent>;
#[derive(Debug, Clone)] #[derive(Debug, Clone)]
pub enum ProgressEvent { pub enum ProgressEvent {
/// `percent` — overall 0100; `chunk` — 1-based; `total` — total chunks. /// `percent` — overall 0100; `chunk` — 1-based; `total` — total chunks.
Progress { percent: u8, chunk: usize, total: usize }, Progress {
percent: u8,
chunk: usize,
total: usize,
},
Done(Box<Job>), Done(Box<Job>),
Error(String), Error(String),
} }
@@ -50,11 +53,11 @@ pub enum WorkerCmd {
// ── Transcription request/response types ───────────────────────────────────── // ── Transcription request/response types ─────────────────────────────────────
pub struct TranscribeRequest { pub struct TranscribeRequest {
pub pcm: Vec<f32>, pub pcm: Vec<f32>,
pub language: Option<String>, pub language: Option<String>,
pub task: String, pub task: String,
pub on_progress: Box<dyn Fn(u8) + Send + 'static>, pub on_progress: Box<dyn Fn(u8) + Send + 'static>,
pub reply: oneshot::Sender<crate::Result<(Vec<Segment>, String)>>, pub reply: oneshot::Sender<crate::Result<(Vec<Segment>, String)>>,
} }
impl std::fmt::Debug for TranscribeRequest { impl std::fmt::Debug for TranscribeRequest {
@@ -75,15 +78,15 @@ impl std::fmt::Debug for TranscribeRequest {
/// trigger loading. /// trigger loading.
#[allow(clippy::too_many_arguments)] #[allow(clippy::too_many_arguments)]
pub fn start( pub fn start(
job_rx: mpsc::UnboundedReceiver<JobId>, job_rx: mpsc::UnboundedReceiver<JobId>,
storage: Arc<Storage>, storage: Arc<Storage>,
model_path: PathBuf, model_path: PathBuf,
queue_depth: Arc<AtomicUsize>, queue_depth: Arc<AtomicUsize>,
gpu_device: u32, gpu_device: u32,
model_state: Arc<RwLock<ModelState>>, model_state: Arc<RwLock<ModelState>>,
model_event_tx: broadcast::Sender<ModelEvent>, model_event_tx: broadcast::Sender<ModelEvent>,
webhook_registry: Arc<Mutex<HashSet<String>>>, webhook_registry: Arc<Mutex<HashSet<String>>>,
idle_timeout: Duration, idle_timeout: Duration,
gpu_poll_interval: Duration, gpu_poll_interval: Duration,
) -> (ProgressRegistry, std::sync::mpsc::SyncSender<WorkerCmd>) { ) -> (ProgressRegistry, std::sync::mpsc::SyncSender<WorkerCmd>) {
let registry: ProgressRegistry = Arc::new(dashmap::DashMap::new()); let registry: ProgressRegistry = Arc::new(dashmap::DashMap::new());
@@ -126,15 +129,15 @@ pub fn start(
/// separate thread. /// separate thread.
#[allow(clippy::too_many_arguments)] #[allow(clippy::too_many_arguments)]
fn transcriber_thread( fn transcriber_thread(
rx: std::sync::mpsc::Receiver<WorkerCmd>, rx: std::sync::mpsc::Receiver<WorkerCmd>,
model_path: PathBuf, model_path: PathBuf,
gpu_device: u32, gpu_device: u32,
model_state: Arc<RwLock<ModelState>>, model_state: Arc<RwLock<ModelState>>,
model_event_tx: broadcast::Sender<ModelEvent>, model_event_tx: broadcast::Sender<ModelEvent>,
webhook_registry: Arc<Mutex<HashSet<String>>>, webhook_registry: Arc<Mutex<HashSet<String>>>,
idle_timeout: Duration, idle_timeout: Duration,
gpu_poll_interval: Duration, gpu_poll_interval: Duration,
rt: tokio::runtime::Handle, rt: tokio::runtime::Handle,
) { ) {
let mut transcriber: Option<Transcriber> = None; let mut transcriber: Option<Transcriber> = None;
let mut last_job = Instant::now(); let mut last_job = Instant::now();
@@ -162,14 +165,22 @@ fn transcriber_thread(
} }
Ok(WorkerCmd::Unload) => { Ok(WorkerCmd::Unload) => {
do_unload(&mut transcriber, &model_state, &model_event_tx, &webhook_registry, &rt); do_unload(
&mut transcriber,
&model_state,
&model_event_tx,
&webhook_registry,
&rt,
);
} }
Ok(WorkerCmd::Transcribe(req)) => { Ok(WorkerCmd::Transcribe(req)) => {
let t = match &mut transcriber { let t = match &mut transcriber {
Some(t) => t, Some(t) => t,
None => { None => {
tracing::warn!("Transcribe cmd received but model is unloaded — failing job"); tracing::warn!(
"Transcribe cmd received but model is unloaded — failing job"
);
let _ = req.reply.send(Err(AppError::Internal( let _ = req.reply.send(Err(AppError::Internal(
"model unloaded before job could run".into(), "model unloaded before job could run".into(),
))); )));
@@ -177,12 +188,9 @@ fn transcriber_thread(
} }
}; };
let result = t.transcribe( let result = t.transcribe(&req.pcm, req.language.as_deref(), &req.task, move |p| {
&req.pcm, (req.on_progress)(p)
req.language.as_deref(), });
&req.task,
move |p| (req.on_progress)(p),
);
last_job = Instant::now(); last_job = Instant::now();
let _ = req.reply.send(result); let _ = req.reply.send(result);
} }
@@ -218,14 +226,14 @@ fn transcriber_thread(
/// rejection. Returns `Some(Transcriber)` on success, `None` if cancelled. /// rejection. Returns `Some(Transcriber)` on success, `None` if cancelled.
#[allow(clippy::too_many_arguments)] #[allow(clippy::too_many_arguments)]
fn try_load_with_polling( fn try_load_with_polling(
rx: &std::sync::mpsc::Receiver<WorkerCmd>, rx: &std::sync::mpsc::Receiver<WorkerCmd>,
model_path: &PathBuf, model_path: &PathBuf,
gpu_device: u32, gpu_device: u32,
model_state: &Arc<RwLock<ModelState>>, model_state: &Arc<RwLock<ModelState>>,
model_event_tx: &broadcast::Sender<ModelEvent>, model_event_tx: &broadcast::Sender<ModelEvent>,
webhook_registry: &Arc<Mutex<HashSet<String>>>, webhook_registry: &Arc<Mutex<HashSet<String>>>,
gpu_poll_interval: Duration, gpu_poll_interval: Duration,
rt: &tokio::runtime::Handle, rt: &tokio::runtime::Handle,
) -> Option<Transcriber> { ) -> Option<Transcriber> {
loop { loop {
set_state(model_state, ModelState::Loading); set_state(model_state, ModelState::Loading);
@@ -253,25 +261,35 @@ fn try_load_with_polling(
"insufficient VRAM — will retry" "insufficient VRAM — will retry"
); );
set_state(model_state, ModelState::WaitingForGpu { set_state(
vram_needed_mb, model_state,
vram_free_mb, ModelState::WaitingForGpu {
retry_in_secs, vram_needed_mb,
}); vram_free_mb,
broadcast_event(model_event_tx, ModelEvent::ModelWaitingForGpu { retry_in_secs,
vram_needed_mb, },
vram_free_mb, );
retry_in_secs, broadcast_event(
}); model_event_tx,
ModelEvent::ModelWaitingForGpu {
vram_needed_mb,
vram_free_mb,
retry_in_secs,
},
);
// Interruptible sleep: drain rx while waiting for gpu_poll_interval. // Interruptible sleep: drain rx while waiting for gpu_poll_interval.
let deadline = Instant::now() + gpu_poll_interval; let deadline = Instant::now() + gpu_poll_interval;
loop { loop {
let remaining = deadline.saturating_duration_since(Instant::now()); let remaining = deadline.saturating_duration_since(Instant::now());
if remaining.is_zero() { break; } if remaining.is_zero() {
break;
}
match rx.recv_timeout(remaining.min(Duration::from_secs(1))) { match rx.recv_timeout(remaining.min(Duration::from_secs(1))) {
Ok(WorkerCmd::Unload) => { Ok(WorkerCmd::Unload) => {
tracing::info!("Unload received while waiting for GPU — cancelling load"); tracing::info!(
"Unload received while waiting for GPU — cancelling load"
);
set_state(model_state, ModelState::Unloaded); set_state(model_state, ModelState::Unloaded);
broadcast_event(model_event_tx, ModelEvent::ModelUnloaded); broadcast_event(model_event_tx, ModelEvent::ModelUnloaded);
fire_webhooks(webhook_registry, ModelEvent::ModelUnloaded, rt); fire_webhooks(webhook_registry, ModelEvent::ModelUnloaded, rt);
@@ -303,11 +321,11 @@ fn try_load_with_polling(
} }
fn do_unload( fn do_unload(
transcriber: &mut Option<Transcriber>, transcriber: &mut Option<Transcriber>,
model_state: &Arc<RwLock<ModelState>>, model_state: &Arc<RwLock<ModelState>>,
model_event_tx: &broadcast::Sender<ModelEvent>, model_event_tx: &broadcast::Sender<ModelEvent>,
webhook_registry: &Arc<Mutex<HashSet<String>>>, webhook_registry: &Arc<Mutex<HashSet<String>>>,
rt: &tokio::runtime::Handle, rt: &tokio::runtime::Handle,
) { ) {
*transcriber = None; *transcriber = None;
set_state(model_state, ModelState::Unloaded); set_state(model_state, ModelState::Unloaded);
@@ -328,8 +346,8 @@ fn broadcast_event(tx: &broadcast::Sender<ModelEvent>, event: ModelEvent) {
fn fire_webhooks( fn fire_webhooks(
registry: &Arc<Mutex<HashSet<String>>>, registry: &Arc<Mutex<HashSet<String>>>,
event: ModelEvent, event: ModelEvent,
rt: &tokio::runtime::Handle, rt: &tokio::runtime::Handle,
) { ) {
if !event.is_webhook_event() { if !event.is_webhook_event() {
return; return;
@@ -341,11 +359,16 @@ fn fire_webhooks(
.cloned() .cloned()
.collect(); .collect();
if urls.is_empty() { return; } if urls.is_empty() {
return;
}
let payload = match serde_json::to_string(&event) { let payload = match serde_json::to_string(&event) {
Ok(p) => p, Ok(p) => p,
Err(e) => { tracing::error!(error = %e, "failed to serialize model event"); return; } Err(e) => {
tracing::error!(error = %e, "failed to serialize model event");
return;
}
}; };
for url in urls { for url in urls {
@@ -356,7 +379,8 @@ fn fire_webhooks(
.build() .build()
.expect("http client"); .expect("http client");
for attempt in 0..3_u32 { for attempt in 0..3_u32 {
match http.post(&url) match http
.post(&url)
.header("content-type", "application/json") .header("content-type", "application/json")
.body(body.clone()) .body(body.clone())
.send() .send()
@@ -405,11 +429,11 @@ fn parse_oom_vram(msg: &str, gpu_device: u32) -> (u64, u64) {
// ── Async job runner ────────────────────────────────────────────────────────── // ── Async job runner ──────────────────────────────────────────────────────────
async fn run( async fn run(
mut job_rx: mpsc::UnboundedReceiver<JobId>, mut job_rx: mpsc::UnboundedReceiver<JobId>,
storage: Arc<Storage>, storage: Arc<Storage>,
queue_depth: Arc<AtomicUsize>, queue_depth: Arc<AtomicUsize>,
registry: ProgressRegistry, registry: ProgressRegistry,
cmd_tx: std::sync::mpsc::SyncSender<WorkerCmd>, cmd_tx: std::sync::mpsc::SyncSender<WorkerCmd>,
) { ) {
let http = Client::builder() let http = Client::builder()
.timeout(Duration::from_secs(30)) .timeout(Duration::from_secs(30))
@@ -420,7 +444,7 @@ async fn run(
queue_depth.fetch_sub(1, Ordering::Relaxed); queue_depth.fetch_sub(1, Ordering::Relaxed);
let mut job = match storage.get(&job_id).await { let mut job = match storage.get(&job_id).await {
Ok(j) => j, Ok(j) => j,
Err(e) => { Err(e) => {
tracing::warn!(job_id = %job_id, error = %e, "job vanished before processing"); tracing::warn!(job_id = %job_id, error = %e, "job vanished before processing");
registry.remove(&job_id); registry.remove(&job_id);
@@ -461,19 +485,19 @@ async fn run(
match result { match result {
Ok((segments, language, duration_secs)) => { Ok((segments, language, duration_secs)) => {
job.status = JobStatus::Done; job.status = JobStatus::Done;
job.segments = segments; job.segments = segments;
job.language = Some(language); job.language = Some(language);
job.duration_secs = Some(duration_secs); job.duration_secs = Some(duration_secs);
job.progress = 100; job.progress = 100;
job.completed_at = Some(Utc::now()); job.completed_at = Some(Utc::now());
let _ = progress_tx.send(ProgressEvent::Done(Box::new(job.clone()))); let _ = progress_tx.send(ProgressEvent::Done(Box::new(job.clone())));
} }
Err(e) => { Err(e) => {
let msg = e.to_string(); let msg = e.to_string();
tracing::error!(job_id = %job_id, error = %msg, "transcription failed"); tracing::error!(job_id = %job_id, error = %msg, "transcription failed");
job.status = JobStatus::Failed; job.status = JobStatus::Failed;
job.error = Some(msg.clone()); job.error = Some(msg.clone());
job.completed_at = Some(Utc::now()); job.completed_at = Some(Utc::now());
let _ = progress_tx.send(ProgressEvent::Error(msg)); let _ = progress_tx.send(ProgressEvent::Error(msg));
} }
@@ -485,9 +509,11 @@ async fn run(
if let Some(url) = &job.webhook_url.clone() { if let Some(url) = &job.webhook_url.clone() {
let http = http.clone(); let http = http.clone();
let url = url.clone(); let url = url.clone();
let job = job.clone(); let job = job.clone();
tokio::spawn(async move { webhook::fire(&http, &url, &job).await; }); tokio::spawn(async move {
webhook::fire(&http, &url, &job).await;
});
} }
tokio::time::sleep(Duration::from_secs(30)).await; tokio::time::sleep(Duration::from_secs(30)).await;
@@ -498,9 +524,9 @@ async fn run(
// ── Silence-based chunking ──────────────────────────────────────────────────── // ── Silence-based chunking ────────────────────────────────────────────────────
const TARGET_CHUNK_SECS: f32 = 60.0; const TARGET_CHUNK_SECS: f32 = 60.0;
const SNAP_WINDOW_SECS: f32 = 30.0; const SNAP_WINDOW_SECS: f32 = 30.0;
const SILENCE_DB: &str = "-35dB"; const SILENCE_DB: &str = "-35dB";
const SILENCE_DUR: &str = "0.4"; const SILENCE_DUR: &str = "0.4";
async fn detect_silence_midpoints(path: &std::path::Path) -> Vec<f32> { async fn detect_silence_midpoints(path: &std::path::Path) -> Vec<f32> {
use tokio::process::Command; use tokio::process::Command;
@@ -509,15 +535,19 @@ async fn detect_silence_midpoints(path: &std::path::Path) -> Vec<f32> {
let output = Command::new("ffmpeg") let output = Command::new("ffmpeg")
.args([ .args([
"-nostdin", "-nostdin",
"-i", path.to_str().unwrap_or(""), "-i",
"-af", &filter, path.to_str().unwrap_or(""),
"-f", "null", "-", "-af",
&filter,
"-f",
"null",
"-",
]) ])
.output() .output()
.await; .await;
let output = match output { let output = match output {
Ok(o) => o, Ok(o) => o,
Err(e) => { Err(e) => {
tracing::warn!(error = %e, "silencedetect unavailable; using hard cuts"); tracing::warn!(error = %e, "silencedetect unavailable; using hard cuts");
return Vec::new(); return Vec::new();
@@ -526,7 +556,7 @@ async fn detect_silence_midpoints(path: &std::path::Path) -> Vec<f32> {
let stderr = String::from_utf8_lossy(&output.stderr); let stderr = String::from_utf8_lossy(&output.stderr);
let mut starts: Vec<f32> = Vec::new(); let mut starts: Vec<f32> = Vec::new();
let mut ends: Vec<f32> = Vec::new(); let mut ends: Vec<f32> = Vec::new();
for line in stderr.lines() { for line in stderr.lines() {
if let Some(i) = line.find("silence_start: ") { if let Some(i) = line.find("silence_start: ") {
@@ -545,7 +575,9 @@ async fn detect_silence_midpoints(path: &std::path::Path) -> Vec<f32> {
} }
} }
let mids: Vec<f32> = starts.iter().zip(ends.iter()) let mids: Vec<f32> = starts
.iter()
.zip(ends.iter())
.map(|(s, e)| (s + e) / 2.0) .map(|(s, e)| (s + e) / 2.0)
.collect(); .collect();
@@ -553,18 +585,15 @@ async fn detect_silence_midpoints(path: &std::path::Path) -> Vec<f32> {
mids mids
} }
fn snap_to_silence( fn snap_to_silence(mids: &[f32], total_secs: f32, target_secs: f32, snap_window: f32) -> Vec<f32> {
mids: &[f32],
total_secs: f32,
target_secs: f32,
snap_window: f32,
) -> Vec<f32> {
let mut cuts: Vec<f32> = Vec::new(); let mut cuts: Vec<f32> = Vec::new();
let mut pos = target_secs; let mut pos = target_secs;
while pos < total_secs - target_secs * 0.25 { while pos < total_secs - target_secs * 0.25 {
let prev_cut = cuts.last().copied().unwrap_or(0.0); let prev_cut = cuts.last().copied().unwrap_or(0.0);
let best = mids.iter().copied() let best = mids
.iter()
.copied()
.filter(|&t| t > prev_cut + 10.0 && (t - pos).abs() <= snap_window) .filter(|&t| t > prev_cut + 10.0 && (t - pos).abs() <= snap_window)
.min_by(|a, b| (a - pos).abs().partial_cmp(&(b - pos).abs()).unwrap()); .min_by(|a, b| (a - pos).abs().partial_cmp(&(b - pos).abs()).unwrap());
let cut = best.unwrap_or(pos); let cut = best.unwrap_or(pos);
@@ -591,20 +620,165 @@ fn to_chunk_ranges(cuts: &[f32], total_secs: f32) -> Vec<(f32, f32)> {
ranges ranges
} }
const MAX_CHAIN_GAP_SECS: f32 = 0.15;
const MIN_MEANINGFUL_WORDS: usize = 2;
const MIN_MEANINGFUL_CHARS: usize = 8;
const MIN_OVERLAP_WORDS: usize = 3;
fn normalised_words(text: &str) -> Vec<String> {
text.split_whitespace()
.map(|word| {
word.chars()
.filter(|ch| ch.is_alphanumeric() || *ch == '_')
.flat_map(|ch| ch.to_lowercase())
.collect::<String>()
})
.filter(|word| !word.is_empty())
.collect()
}
fn starts_with_words(full: &[String], prefix: &[String]) -> bool {
prefix.len() <= full.len() && full.iter().take(prefix.len()).eq(prefix.iter())
}
fn ends_with_words(full: &[String], suffix: &[String]) -> bool {
suffix.len() <= full.len()
&& full
.iter()
.skip(full.len() - suffix.len())
.eq(suffix.iter())
}
fn suffix_prefix_overlap(left: &[String], right: &[String]) -> usize {
let max = left.len().min(right.len());
for size in (1..=max).rev() {
if left[left.len() - size..] == right[..size] {
return size;
}
}
0
}
fn is_meaningful_phrase(words: &[String]) -> bool {
words.len() >= MIN_MEANINGFUL_WORDS
&& words.iter().map(|word| word.len()).sum::<usize>() >= MIN_MEANINGFUL_CHARS
}
fn trim_leading_words(text: &str, count: usize) -> String {
text.split_whitespace()
.skip(count)
.collect::<Vec<_>>()
.join(" ")
.trim()
.to_string()
}
fn merge_identical_segments(segments: Vec<Segment>) -> Vec<Segment> {
let mut out: Vec<Segment> = Vec::with_capacity(segments.len());
for seg in segments {
if let Some(last) = out.last_mut() {
if normalised_words(&last.text) == normalised_words(&seg.text) {
last.end = last.end.max(seg.end);
if !seg.words.is_empty() {
last.words = seg.words;
}
continue;
}
}
out.push(seg);
}
out
}
fn collapse_incremental_segments(segments: Vec<Segment>) -> Vec<Segment> {
let mut out: Vec<Segment> = Vec::with_capacity(segments.len());
for mut seg in segments {
seg.text = seg.text.trim().to_string();
if seg.text.is_empty() {
continue;
}
let Some(last) = out.last_mut() else {
out.push(seg);
continue;
};
let gap = seg.start - last.end;
if gap > MAX_CHAIN_GAP_SECS {
out.push(seg);
continue;
}
let last_words = normalised_words(&last.text);
let seg_words = normalised_words(&seg.text);
if last_words.is_empty() || seg_words.is_empty() {
out.push(seg);
continue;
}
if seg_words.len() > last_words.len()
&& starts_with_words(&seg_words, &last_words)
&& is_meaningful_phrase(&last_words)
{
last.text = seg.text;
last.end = seg.end;
last.words = seg.words;
continue;
}
if ends_with_words(&last_words, &seg_words) && is_meaningful_phrase(&seg_words) {
last.end = last.end.max(seg.end);
continue;
}
let overlap = suffix_prefix_overlap(&last_words, &seg_words);
if overlap >= MIN_OVERLAP_WORDS {
let trimmed_text = trim_leading_words(&seg.text, overlap);
if trimmed_text.is_empty() {
last.end = last.end.max(seg.end);
continue;
}
seg.start = seg.start.max(last.end);
seg.text = trimmed_text;
seg.words.clear();
}
out.push(seg);
}
out
}
fn normalise_segments(segments: Vec<Segment>) -> Vec<Segment> {
let mut result = collapse_incremental_segments(segments);
result = merge_identical_segments(result);
result = collapse_incremental_segments(result);
merge_identical_segments(result)
}
// ── Job processing ──────────────────────────────────────────────────────────── // ── Job processing ────────────────────────────────────────────────────────────
async fn process_job( async fn process_job(
job: &Job, job: &Job,
audio_path: &std::path::Path, audio_path: &std::path::Path,
progress_tx: &ProgressTx, progress_tx: &ProgressTx,
cmd_tx: &std::sync::mpsc::SyncSender<WorkerCmd>, cmd_tx: &std::sync::mpsc::SyncSender<WorkerCmd>,
storage: &Arc<Storage>, storage: &Arc<Storage>,
) -> crate::Result<(Vec<Segment>, String, f32)> { ) -> crate::Result<(Vec<Segment>, String, f32)> {
let pcm = decode_audio(audio_path).await?; let pcm = decode_audio(audio_path).await?;
let total_secs = pcm.len() as f32 / 16_000.0; let total_secs = pcm.len() as f32 / 16_000.0;
let silence_mids = detect_silence_midpoints(audio_path).await; let silence_mids = detect_silence_midpoints(audio_path).await;
let cuts = snap_to_silence(&silence_mids, total_secs, TARGET_CHUNK_SECS, SNAP_WINDOW_SECS); let cuts = snap_to_silence(
&silence_mids,
total_secs,
TARGET_CHUNK_SECS,
SNAP_WINDOW_SECS,
);
let chunks = to_chunk_ranges(&cuts, total_secs); let chunks = to_chunk_ranges(&cuts, total_secs);
let n = chunks.len(); let n = chunks.len();
@@ -620,12 +794,12 @@ async fn process_job(
for (ci, (chunk_start, chunk_end)) in chunks.iter().enumerate() { for (ci, (chunk_start, chunk_end)) in chunks.iter().enumerate() {
let s0 = (*chunk_start * 16_000.0) as usize; let s0 = (*chunk_start * 16_000.0) as usize;
let s1 = ((*chunk_end * 16_000.0) as usize).min(pcm.len()); let s1 = ((*chunk_end * 16_000.0) as usize).min(pcm.len());
let mut chunk_pcm = pcm[s0..s1].to_vec(); let mut chunk_pcm = pcm[s0..s1].to_vec();
trim_trailing_silence(&mut chunk_pcm); trim_trailing_silence(&mut chunk_pcm);
let base = (ci * 100 / n) as u8; let base = (ci * 100 / n) as u8;
let span = (100usize / n).max(1) as u8; let span = (100usize / n).max(1) as u8;
// Save progress to disk before emitting SSE — polling clients who respond // Save progress to disk before emitting SSE — polling clients who respond
// immediately to the SSE event will then see consistent state. // immediately to the SSE event will then see consistent state.
@@ -637,49 +811,52 @@ async fn process_job(
let _ = progress_tx.send(ProgressEvent::Progress { let _ = progress_tx.send(ProgressEvent::Progress {
percent: base, percent: base,
chunk: ci + 1, chunk: ci + 1,
total: n, total: n,
}); });
let tx = progress_tx.clone(); let tx = progress_tx.clone();
let chunk_num = ci + 1; let chunk_num = ci + 1;
let on_progress = Box::new(move |p: u8| { let on_progress = Box::new(move |p: u8| {
let overall = base.saturating_add(p.saturating_mul(span) / 100); let overall = base.saturating_add(p.saturating_mul(span) / 100);
let _ = tx.send(ProgressEvent::Progress { let _ = tx.send(ProgressEvent::Progress {
percent: overall, percent: overall,
chunk: chunk_num, chunk: chunk_num,
total: n, total: n,
}); });
}); });
let (reply_tx, reply_rx) = oneshot::channel(); let (reply_tx, reply_rx) = oneshot::channel();
cmd_tx.send(WorkerCmd::Transcribe(TranscribeRequest { cmd_tx
pcm: chunk_pcm, .send(WorkerCmd::Transcribe(TranscribeRequest {
language: job.language.clone(), pcm: chunk_pcm,
task: job.task.clone(), language: job.language.clone(),
on_progress, task: job.task.clone(),
reply: reply_tx, on_progress,
})).map_err(|_| AppError::Internal("worker command channel closed".into()))?; reply: reply_tx,
}))
.map_err(|_| AppError::Internal("worker command channel closed".into()))?;
let (mut segs, lang) = reply_rx.await let (mut segs, lang) = reply_rx
.await
.map_err(|_| AppError::Internal("transcriber thread dropped reply".into()))??; .map_err(|_| AppError::Internal("transcriber thread dropped reply".into()))??;
let offset = *chunk_start; let offset = *chunk_start;
for seg in &mut segs { for seg in &mut segs {
seg.start += offset; seg.start += offset;
seg.end += offset; seg.end += offset;
for word in &mut seg.words { for word in &mut seg.words {
word.start += offset; word.start += offset;
word.end += offset; word.end += offset;
} }
} }
tracing::debug!( tracing::debug!(
chunk = ci + 1, chunk = ci + 1,
of = n, of = n,
start = chunk_start, start = chunk_start,
end = chunk_end, end = chunk_end,
segs = segs.len(), segs = segs.len(),
"chunk done" "chunk done"
); );
@@ -689,24 +866,30 @@ async fn process_job(
} }
} }
all_segments = normalise_segments(all_segments);
for (i, seg) in all_segments.iter_mut().enumerate() { for (i, seg) in all_segments.iter_mut().enumerate() {
seg.index = i as i32; seg.index = i as i32;
} }
let _ = progress_tx.send(ProgressEvent::Progress { percent: 100, chunk: n, total: n }); let _ = progress_tx.send(ProgressEvent::Progress {
percent: 100,
chunk: n,
total: n,
});
Ok((all_segments, language, total_secs)) Ok((all_segments, language, total_secs))
} }
fn trim_trailing_silence(pcm: &mut Vec<f32>) { fn trim_trailing_silence(pcm: &mut Vec<f32>) {
const THRESHOLD: f32 = 0.017_8; const THRESHOLD: f32 = 0.017_8;
const PADDING: usize = 8_000; const PADDING: usize = 8_000;
if let Some(last_loud) = pcm.iter().rposition(|&s| s.abs() > THRESHOLD) { if let Some(last_loud) = pcm.iter().rposition(|&s| s.abs() > THRESHOLD) {
let new_len = (last_loud + 1 + PADDING).min(pcm.len()); let new_len = (last_loud + 1 + PADDING).min(pcm.len());
if new_len < pcm.len() { if new_len < pcm.len() {
tracing::trace!( tracing::trace!(
original_samples = pcm.len(), original_samples = pcm.len(),
trimmed_samples = pcm.len() - new_len, trimmed_samples = pcm.len() - new_len,
"trimmed trailing silence" "trimmed trailing silence"
); );
pcm.truncate(new_len); pcm.truncate(new_len);
@@ -719,11 +902,17 @@ async fn decode_audio(path: &std::path::Path) -> crate::Result<Vec<f32>> {
let output = Command::new("ffmpeg") let output = Command::new("ffmpeg")
.args([ .args([
"-nostdin", "-threads", "0", "-nostdin",
"-i", path.to_str().unwrap_or(""), "-threads",
"-f", "f32le", "0",
"-ac", "1", "-i",
"-ar", "16000", path.to_str().unwrap_or(""),
"-f",
"f32le",
"-ac",
"1",
"-ar",
"16000",
"-", "-",
]) ])
.output() .output()
@@ -760,13 +949,28 @@ pub fn audio_path_for(id: &JobId) -> PathBuf {
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use super::*; use super::*;
use crate::models::Word;
fn segment(index: i32, start: f32, end: f32, text: &str) -> Segment {
Segment {
index,
start,
end,
text: text.into(),
words: Vec::<Word>::new(),
}
}
#[test] #[test]
fn test_snap_to_silence_uses_nearest_midpoint() { fn test_snap_to_silence_uses_nearest_midpoint() {
let mids = vec![55.0, 58.0, 62.0]; let mids = vec![55.0, 58.0, 62.0];
let cuts = snap_to_silence(&mids, 120.0, 60.0, 30.0); let cuts = snap_to_silence(&mids, 120.0, 60.0, 30.0);
assert!(!cuts.is_empty()); assert!(!cuts.is_empty());
assert!((cuts[0] - 58.0).abs() < 0.01, "expected ~58.0, got {}", cuts[0]); assert!(
(cuts[0] - 58.0).abs() < 0.01,
"expected ~58.0, got {}",
cuts[0]
);
} }
#[test] #[test]
@@ -801,4 +1005,53 @@ mod tests {
trim_trailing_silence(&mut pcm); trim_trailing_silence(&mut pcm);
assert_eq!(pcm.len(), (10_001 + 8_000).min(32_000)); assert_eq!(pcm.len(), (10_001 + 8_000).min(32_000));
} }
#[test]
fn test_normalise_segments_collapses_prefix_growth_chain() {
let input = vec![
segment(0, 15.24, 16.6, "Hello everyone."),
segment(1, 16.6, 19.47, "Hello everyone. Um, welcome to this talk."),
segment(2, 19.47, 19.48, "Um, welcome to this talk."),
segment(
3,
19.48,
21.67,
"Um, welcome to this talk. I'll be speaking about small model",
),
segment(4, 21.67, 21.68, "I'll be speaking about small model"),
segment(
5,
21.68,
24.59,
"I'll be speaking about small model inference and a gap that we've",
),
];
let result = normalise_segments(input);
assert_eq!(result.len(), 2);
assert_eq!(result[0].text, "Hello everyone. Um, welcome to this talk.");
assert!((result[0].start - 15.24).abs() < 0.01);
assert!((result[0].end - 19.48).abs() < 0.01);
assert_eq!(
result[1].text,
"I'll be speaking about small model inference and a gap that we've"
);
assert!((result[1].start - 19.48).abs() < 0.01);
assert!((result[1].end - 24.59).abs() < 0.01);
}
#[test]
fn test_normalise_segments_keeps_real_gap() {
let input = vec![
segment(0, 0.0, 1.0, "Hello everyone."),
segment(1, 2.0, 4.0, "Hello everyone. Welcome back."),
];
let result = normalise_segments(input);
assert_eq!(result.len(), 2);
assert_eq!(result[0].text, "Hello everyone.");
assert_eq!(result[1].text, "Hello everyone. Welcome back.");
}
} }