fix(worker): collapse incremental segments
All checks were successful
Build & Push Docker Image / test (push) Successful in 6m20s
Build & Push Docker Image / build-and-push (push) Successful in 6m29s

Normalize rolling partial-hypothesis chains before final job persistence so downstream clients receive stable transcript segments instead of echoed continuations.

Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
This commit is contained in:
2026-05-11 22:46:38 +02:00
parent d3a67f11b3
commit cb0b07b2ff
10 changed files with 712 additions and 331 deletions

View File

@@ -1,10 +1,10 @@
use thiserror::Error;
use axum::{
http::{StatusCode, HeaderValue, header},
http::{header, HeaderValue, StatusCode},
response::{IntoResponse, Response},
Json,
};
use serde_json::json;
use thiserror::Error;
pub type Result<T> = std::result::Result<T, AppError>;
@@ -31,7 +31,10 @@ pub enum AppError {
/// Returned when a job is submitted but the model is not yet loaded.
/// Carries the current state tag and recommended Retry-After seconds.
#[error("model not ready: {state}")]
ModelNotReady { state: String, retry_after_secs: u64 },
ModelNotReady {
state: String,
retry_after_secs: u64,
},
}
impl AppError {
@@ -59,13 +62,20 @@ impl IntoResponse for AppError {
}
AppError::Internal(m) => {
tracing::error!(error = %m, "internal error");
(StatusCode::INTERNAL_SERVER_ERROR, Json(json!({ "error": m }))).into_response()
(
StatusCode::INTERNAL_SERVER_ERROR,
Json(json!({ "error": m })),
)
.into_response()
}
AppError::OutOfMemory(m) => {
tracing::warn!(error = %m, "GPU out of memory during model load");
(StatusCode::SERVICE_UNAVAILABLE, Json(json!({ "error": m }))).into_response()
}
AppError::ModelNotReady { state, retry_after_secs } => {
AppError::ModelNotReady {
state,
retry_after_secs,
} => {
let body = Json(json!({
"error": "model_not_ready",
"state": state,
@@ -117,17 +127,25 @@ mod tests {
#[tokio::test]
async fn test_model_not_ready_response_has_retry_after_header() {
let err = AppError::ModelNotReady { state: "loading".into(), retry_after_secs: 10 };
let err = AppError::ModelNotReady {
state: "loading".into(),
retry_after_secs: 10,
};
let resp = err.into_response();
assert_eq!(resp.status(), StatusCode::SERVICE_UNAVAILABLE);
let retry_after = resp.headers().get(header::RETRY_AFTER)
let retry_after = resp
.headers()
.get(header::RETRY_AFTER)
.expect("Retry-After header missing");
assert_eq!(retry_after, "10");
}
#[tokio::test]
async fn test_model_not_ready_response_body() {
let err = AppError::ModelNotReady { state: "unloaded".into(), retry_after_secs: 30 };
let err = AppError::ModelNotReady {
state: "unloaded".into(),
retry_after_secs: 30,
};
let resp = err.into_response();
let bytes = to_bytes(resp.into_body(), usize::MAX).await.unwrap();
let v: serde_json::Value = serde_json::from_slice(&bytes).unwrap();
@@ -138,21 +156,21 @@ mod tests {
#[tokio::test]
async fn test_model_not_ready_loading_retry_after_10() {
let err = AppError::ModelNotReady { state: "loading".into(), retry_after_secs: 10 };
let err = AppError::ModelNotReady {
state: "loading".into(),
retry_after_secs: 10,
};
let resp = err.into_response();
assert_eq!(
resp.headers().get(header::RETRY_AFTER).unwrap(),
"10"
);
assert_eq!(resp.headers().get(header::RETRY_AFTER).unwrap(), "10");
}
#[tokio::test]
async fn test_model_not_ready_unloaded_retry_after_30() {
let err = AppError::ModelNotReady { state: "unloaded".into(), retry_after_secs: 30 };
let err = AppError::ModelNotReady {
state: "unloaded".into(),
retry_after_secs: 30,
};
let resp = err.into_response();
assert_eq!(
resp.headers().get(header::RETRY_AFTER).unwrap(),
"30"
);
assert_eq!(resp.headers().get(header::RETRY_AFTER).unwrap(), "30");
}
}

View File

@@ -98,8 +98,8 @@ async fn main() -> anyhow::Result<()> {
.init();
let data_dir = std::env::var("DATA_DIR").unwrap_or_else(|_| "/data".into());
let model_path = std::env::var("WHISPER_MODEL_PATH")
.unwrap_or_else(|_| "/models/ggml-large-v3.bin".into());
let model_path =
std::env::var("WHISPER_MODEL_PATH").unwrap_or_else(|_| "/models/ggml-large-v3.bin".into());
let port = std::env::var("PORT").unwrap_or_else(|_| "8080".into());
let model_name = std::env::var("WHISPER_MODEL").unwrap_or_else(|_| "large-v3".into());
let gpu_device: u32 = std::env::var("CUDA_DEVICE")
@@ -132,7 +132,9 @@ async fn main() -> anyhow::Result<()> {
// Model starts unloaded — lazy load on first job or POST /model/load.
let model_state = Arc::new(RwLock::new(models::ModelState::Unloaded));
let (model_event_tx, _) = broadcast::channel::<models::ModelEvent>(32);
let webhook_registry = Arc::new(std::sync::Mutex::new(std::collections::HashSet::<String>::new()));
let webhook_registry = Arc::new(std::sync::Mutex::new(
std::collections::HashSet::<String>::new(),
));
// Spawn single GPU worker; get back the SSE broadcast registry and cmd channel.
let (progress, cmd_tx) = worker::start(

View File

@@ -77,9 +77,7 @@ impl ModelState {
#[serde(tag = "type", rename_all = "snake_case")]
pub enum ModelEvent {
/// Model finished loading and the GPU warmup completed — ready to accept jobs.
ModelReady {
loaded_at: DateTime<Utc>,
},
ModelReady { loaded_at: DateTime<Utc> },
/// Model was unloaded from GPU memory (idle timeout or manual unload).
ModelUnloaded,
/// Model load initiated.
@@ -95,7 +93,10 @@ pub enum ModelEvent {
impl ModelEvent {
/// Returns true if this event should be delivered via webhook.
pub fn is_webhook_event(&self) -> bool {
matches!(self, ModelEvent::ModelReady { .. } | ModelEvent::ModelUnloaded)
matches!(
self,
ModelEvent::ModelReady { .. } | ModelEvent::ModelUnloaded
)
}
}
@@ -205,7 +206,12 @@ pub struct Job {
}
impl Job {
pub fn new(id: JobId, task: String, webhook_url: Option<String>, filename: Option<String>) -> Self {
pub fn new(
id: JobId,
task: String,
webhook_url: Option<String>,
filename: Option<String>,
) -> Self {
Self {
id,
status: JobStatus::Queued,
@@ -257,8 +263,12 @@ pub enum SsePayload {
/// Total number of silence-split chunks in this job.
chunks_total: usize,
},
Done { job: Box<Job> },
Error { message: String },
Done {
job: Box<Job>,
},
Error {
message: String,
},
}
// ── Unit tests ───────────────────────────────────────────────────────────────
@@ -284,7 +294,11 @@ mod tests {
#[test]
fn test_model_state_waiting_serializes() {
let s = ModelState::WaitingForGpu { vram_needed_mb: 3000, vram_free_mb: 500, retry_in_secs: 30 };
let s = ModelState::WaitingForGpu {
vram_needed_mb: 3000,
vram_free_mb: 500,
retry_in_secs: 30,
};
let v: Value = serde_json::to_value(&s).unwrap();
assert_eq!(v["state"], "waiting_for_gpu");
assert_eq!(v["vram_needed_mb"], 3000);
@@ -305,8 +319,16 @@ mod tests {
fn test_model_state_is_ready() {
assert!(!ModelState::Unloaded.is_ready());
assert!(!ModelState::Loading.is_ready());
assert!(!ModelState::WaitingForGpu { vram_needed_mb: 0, vram_free_mb: 0, retry_in_secs: 30 }.is_ready());
assert!(ModelState::Ready { loaded_at: Utc::now() }.is_ready());
assert!(!ModelState::WaitingForGpu {
vram_needed_mb: 0,
vram_free_mb: 0,
retry_in_secs: 30
}
.is_ready());
assert!(ModelState::Ready {
loaded_at: Utc::now()
}
.is_ready());
}
#[test]
@@ -321,13 +343,23 @@ mod tests {
#[test]
fn test_retry_after_waiting_for_gpu() {
let s = ModelState::WaitingForGpu { vram_needed_mb: 0, vram_free_mb: 0, retry_in_secs: 45 };
let s = ModelState::WaitingForGpu {
vram_needed_mb: 0,
vram_free_mb: 0,
retry_in_secs: 45,
};
assert_eq!(s.retry_after_secs(), 45);
}
#[test]
fn test_retry_after_ready_is_zero() {
assert_eq!(ModelState::Ready { loaded_at: Utc::now() }.retry_after_secs(), 0);
assert_eq!(
ModelState::Ready {
loaded_at: Utc::now()
}
.retry_after_secs(),
0
);
}
// ── ModelEvent serialization ─────────────────────────────────────────────
@@ -355,7 +387,11 @@ mod tests {
#[test]
fn test_model_event_waiting_serializes() {
let e = ModelEvent::ModelWaitingForGpu { vram_needed_mb: 3000, vram_free_mb: 200, retry_in_secs: 30 };
let e = ModelEvent::ModelWaitingForGpu {
vram_needed_mb: 3000,
vram_free_mb: 200,
retry_in_secs: 30,
};
let v: Value = serde_json::to_value(&e).unwrap();
assert_eq!(v["type"], "model_waiting_for_gpu");
assert_eq!(v["vram_needed_mb"], 3000);
@@ -363,10 +399,18 @@ mod tests {
#[test]
fn test_model_event_webhook_filter() {
assert!(ModelEvent::ModelReady { loaded_at: Utc::now() }.is_webhook_event());
assert!(ModelEvent::ModelReady {
loaded_at: Utc::now()
}
.is_webhook_event());
assert!(ModelEvent::ModelUnloaded.is_webhook_event());
assert!(!ModelEvent::ModelLoading.is_webhook_event());
assert!(!ModelEvent::ModelWaitingForGpu { vram_needed_mb: 0, vram_free_mb: 0, retry_in_secs: 30 }.is_webhook_event());
assert!(!ModelEvent::ModelWaitingForGpu {
vram_needed_mb: 0,
vram_free_mb: 0,
retry_in_secs: 30
}
.is_webhook_event());
}
// ── ModelStatusResponse ──────────────────────────────────────────────────
@@ -374,7 +418,9 @@ mod tests {
#[test]
fn test_model_status_response_roundtrip() {
let r = ModelStatusResponse {
state: ModelState::Ready { loaded_at: Utc::now() },
state: ModelState::Ready {
loaded_at: Utc::now(),
},
vram_used_mb: Some(4096),
vram_total_mb: Some(8192),
};
@@ -387,7 +433,11 @@ mod tests {
#[test]
fn test_model_status_response_omits_nulls() {
let r = ModelStatusResponse { state: ModelState::Loading, vram_used_mb: None, vram_total_mb: None };
let r = ModelStatusResponse {
state: ModelState::Loading,
vram_used_mb: None,
vram_total_mb: None,
};
let v: Value = serde_json::to_value(&r).unwrap();
assert_eq!(v["state"], "loading");
assert!(v.get("vram_used_mb").is_none());

View File

@@ -50,9 +50,7 @@ fn gpu_info(device: u32) -> (Option<String>, Option<u64>) {
let mut parts = line.splitn(2, ',');
let name = parts.next().map(|s| s.trim().to_owned());
let vram = parts
.next()
.and_then(|s| s.trim().parse::<u64>().ok());
let vram = parts.next().and_then(|s| s.trim().parse::<u64>().ok());
(name, vram)
}

View File

@@ -23,7 +23,8 @@ use crate::{
AppError, AppState, Result,
};
type SseStream = Pin<Box<dyn Stream<Item = std::result::Result<Event, std::convert::Infallible>> + Send>>;
type SseStream =
Pin<Box<dyn Stream<Item = std::result::Result<Event, std::convert::Infallible>> + Send>>;
// ── POST /jobs ───────────────────────────────────────────────────────────────
@@ -62,9 +63,11 @@ pub async fn submit_job(
let id = Uuid::new_v4();
let audio_path = audio_path_for(&id);
while let Some(field) = multipart.next_field().await.map_err(|e| {
AppError::BadRequest(format!("multipart error: {e}"))
})? {
while let Some(field) = multipart
.next_field()
.await
.map_err(|e| AppError::BadRequest(format!("multipart error: {e}")))?
{
let field_name = field.name().unwrap_or("").to_owned();
match field_name.as_str() {
@@ -77,9 +80,11 @@ pub async fn submit_job(
})?;
let mut bytes_written: u64 = 0;
let mut stream = field;
while let Some(chunk) = stream.chunk().await.map_err(|e| {
AppError::BadRequest(format!("failed to read audio field: {e}"))
})? {
while let Some(chunk) = stream
.chunk()
.await
.map_err(|e| AppError::BadRequest(format!("failed to read audio field: {e}")))?
{
file.write_all(&chunk).await.map_err(|e| {
AppError::Internal(format!("failed to write audio chunk: {e}"))
})?;
@@ -90,9 +95,28 @@ pub async fn submit_job(
}
audio_saved = true;
}
"language" => language = Some(field.text().await.map_err(|e| AppError::BadRequest(e.to_string()))?),
"task" => task = field.text().await.map_err(|e| AppError::BadRequest(e.to_string()))?,
"webhook_url" => webhook_url = Some(field.text().await.map_err(|e| AppError::BadRequest(e.to_string()))?),
"language" => {
language = Some(
field
.text()
.await
.map_err(|e| AppError::BadRequest(e.to_string()))?,
)
}
"task" => {
task = field
.text()
.await
.map_err(|e| AppError::BadRequest(e.to_string()))?
}
"webhook_url" => {
webhook_url = Some(
field
.text()
.await
.map_err(|e| AppError::BadRequest(e.to_string()))?,
)
}
_ => {} // ignore unknown fields
}
}
@@ -119,7 +143,9 @@ pub async fn submit_job(
// Register the webhook URL regardless of model state — so model lifecycle
// events are delivered even if the job itself is rejected.
if let Some(url) = &webhook_url {
state.webhook_registry.lock()
state
.webhook_registry
.lock()
.unwrap_or_else(|e| e.into_inner())
.insert(url.clone());
}
@@ -143,12 +169,16 @@ pub async fn submit_job(
state.storage.create(&job).await?;
// Pre-create the broadcast channel so SSE subscribers don't miss events.
state.progress.entry(id).or_insert_with(|| broadcast::channel(64).0);
state
.progress
.entry(id)
.or_insert_with(|| broadcast::channel(64).0);
state.queue_depth.fetch_add(1, Ordering::Relaxed);
state.job_tx.send(id).map_err(|_| {
AppError::Internal("worker channel closed".into())
})?;
state
.job_tx
.send(id)
.map_err(|_| AppError::Internal("worker channel closed".into()))?;
tracing::info!(job_id = %id, "job queued");
@@ -168,10 +198,7 @@ pub async fn submit_job(
(status = 404, description = "Not found"),
)
)]
pub async fn get_job(
State(state): State<AppState>,
Path(id): Path<JobId>,
) -> Result<Json<Job>> {
pub async fn get_job(State(state): State<AppState>, Path(id): Path<JobId>) -> Result<Json<Job>> {
let job = state.storage.get(&id).await?;
Ok(Json(job))
}
@@ -202,9 +229,9 @@ pub async fn stream_job(
let job = state.storage.get(&id).await?;
match job.status {
JobStatus::Done | JobStatus::Failed | JobStatus::Cancelled => {
let payload = serde_json::to_string(
&crate::models::SsePayload::Done { job: Box::new(job) }
).unwrap_or_default();
let payload =
serde_json::to_string(&crate::models::SsePayload::Done { job: Box::new(job) })
.unwrap_or_default();
let s: SseStream = Box::pin(stream::once(async move {
Ok(Event::default().event("done").data(payload))
}));
@@ -222,22 +249,28 @@ pub async fn stream_job(
let sse_stream: SseStream = Box::pin(BroadcastStream::new(rx).filter_map(|msg| async move {
let event = match msg {
Ok(ProgressEvent::Progress { percent, chunk, total }) => {
let payload = serde_json::to_string(
&crate::models::SsePayload::Progress { percent, chunk, chunks_total: total }
).ok()?;
Ok(ProgressEvent::Progress {
percent,
chunk,
total,
}) => {
let payload = serde_json::to_string(&crate::models::SsePayload::Progress {
percent,
chunk,
chunks_total: total,
})
.ok()?;
Event::default().event("progress").data(payload)
}
Ok(ProgressEvent::Done(job)) => {
let payload = serde_json::to_string(
&crate::models::SsePayload::Done { job }
).ok()?;
let payload =
serde_json::to_string(&crate::models::SsePayload::Done { job }).ok()?;
Event::default().event("done").data(payload)
}
Ok(ProgressEvent::Error(msg)) => {
let payload = serde_json::to_string(
&crate::models::SsePayload::Error { message: msg }
).ok()?;
let payload =
serde_json::to_string(&crate::models::SsePayload::Error { message: msg })
.ok()?;
Event::default().event("error").data(payload)
}
Err(_) => return None, // lagged / channel closed
@@ -264,10 +297,7 @@ pub async fn stream_job(
(status = 409, description = "Job already finished"),
)
)]
pub async fn delete_job(
State(state): State<AppState>,
Path(id): Path<JobId>,
) -> Result<Json<Job>> {
pub async fn delete_job(State(state): State<AppState>, Path(id): Path<JobId>) -> Result<Json<Job>> {
let mut job = state.storage.get(&id).await?;
match job.status {

View File

@@ -2,21 +2,27 @@ pub mod health;
pub mod jobs;
pub mod model;
use axum::{extract::DefaultBodyLimit, routing::{delete, get, post}, Router};
use crate::AppState;
use axum::{
extract::DefaultBodyLimit,
routing::{delete, get, post},
Router,
};
pub fn jobs_router() -> Router<AppState> {
Router::new()
// No body limit on the upload route — files can be multiple GB.
.route("/jobs", post(jobs::submit_job).layer(DefaultBodyLimit::disable()))
.route(
"/jobs",
post(jobs::submit_job).layer(DefaultBodyLimit::disable()),
)
.route("/jobs/:id", get(jobs::get_job))
.route("/jobs/:id/stream", get(jobs::stream_job))
.route("/jobs/:id", delete(jobs::delete_job))
}
pub fn health_router() -> Router<AppState> {
Router::new()
.route("/health", get(health::health))
Router::new().route("/health", get(health::health))
}
pub fn model_router() -> Router<AppState> {

View File

@@ -10,8 +10,8 @@ use axum::{
Json,
};
use futures::Stream;
use tokio_stream::wrappers::BroadcastStream;
use futures::StreamExt;
use tokio_stream::wrappers::BroadcastStream;
use crate::{
models::{ModelEvent, ModelStatusResponse},
@@ -19,7 +19,8 @@ use crate::{
AppState, Result,
};
type SseStream = Pin<Box<dyn Stream<Item = std::result::Result<Event, std::convert::Infallible>> + Send>>;
type SseStream =
Pin<Box<dyn Stream<Item = std::result::Result<Event, std::convert::Infallible>> + Send>>;
// ── GET /model/status ────────────────────────────────────────────────────────
@@ -61,11 +62,17 @@ pub async fn model_status(State(state): State<AppState>) -> Result<Json<ModelSta
pub async fn model_load(State(state): State<AppState>) -> impl IntoResponse {
let is_ready = state.model_state.read().await.is_ready();
if is_ready {
return (StatusCode::OK, Json(serde_json::json!({"status": "already_ready"})));
return (
StatusCode::OK,
Json(serde_json::json!({"status": "already_ready"})),
);
}
// Ignore send errors (channel full = load already in progress).
let _ = state.cmd_tx.try_send(WorkerCmd::Load);
(StatusCode::ACCEPTED, Json(serde_json::json!({"status": "load_initiated"})))
(
StatusCode::ACCEPTED,
Json(serde_json::json!({"status": "load_initiated"})),
)
}
// ── POST /model/unload ───────────────────────────────────────────────────────
@@ -82,7 +89,10 @@ pub async fn model_load(State(state): State<AppState>) -> impl IntoResponse {
)]
pub async fn model_unload(State(state): State<AppState>) -> impl IntoResponse {
let _ = state.cmd_tx.try_send(WorkerCmd::Unload);
(StatusCode::OK, Json(serde_json::json!({"status": "unload_requested"})))
(
StatusCode::OK,
Json(serde_json::json!({"status": "unload_requested"})),
)
}
// ── GET /model/events ────────────────────────────────────────────────────────
@@ -105,8 +115,7 @@ pub async fn model_unload(State(state): State<AppState>) -> impl IntoResponse {
pub async fn model_events(State(state): State<AppState>) -> Sse<SseStream> {
let rx = state.model_event_tx.subscribe();
let stream: SseStream = Box::pin(
BroadcastStream::new(rx).filter_map(|msg| async move {
let stream: SseStream = Box::pin(BroadcastStream::new(rx).filter_map(|msg| async move {
match msg {
Ok(event) => {
let event_type = match &event {
@@ -120,8 +129,7 @@ pub async fn model_events(State(state): State<AppState>) -> Sse<SseStream> {
}
Err(_) => None,
}
})
);
}));
Sse::new(stream).keep_alive(KeepAlive::default())
}

View File

@@ -31,19 +31,19 @@ impl Storage {
pub async fn create(&self, job: &Job) -> Result<()> {
let path = self.job_path(&job.id);
let payload = serde_json::to_vec_pretty(job)
.map_err(|e| AppError::Internal(e.to_string()))?;
fs::write(&path, payload).await.map_err(|e| {
AppError::Internal(format!("failed to write job {}: {e}", job.id))
})?;
let payload =
serde_json::to_vec_pretty(job).map_err(|e| AppError::Internal(e.to_string()))?;
fs::write(&path, payload)
.await
.map_err(|e| AppError::Internal(format!("failed to write job {}: {e}", job.id)))?;
Ok(())
}
pub async fn get(&self, id: &JobId) -> Result<Job> {
let path = self.job_path(id);
let raw = fs::read(&path).await.map_err(|_| {
AppError::NotFound(format!("job {id} not found"))
})?;
let raw = fs::read(&path)
.await
.map_err(|_| AppError::NotFound(format!("job {id} not found")))?;
serde_json::from_slice(&raw).map_err(|e| AppError::Internal(e.to_string()))
}
@@ -54,22 +54,24 @@ impl Storage {
pub async fn delete(&self, id: &JobId) -> Result<()> {
let path = self.job_path(id);
fs::remove_file(&path).await.map_err(|_| {
AppError::NotFound(format!("job {id} not found"))
})?;
fs::remove_file(&path)
.await
.map_err(|_| AppError::NotFound(format!("job {id} not found")))?;
Ok(())
}
/// List all job IDs present on disk.
pub async fn list_ids(&self) -> Result<Vec<JobId>> {
let mut entries = fs::read_dir(&self.dir).await.map_err(|e| {
AppError::Internal(format!("read_dir failed: {e}"))
})?;
let mut entries = fs::read_dir(&self.dir)
.await
.map_err(|e| AppError::Internal(format!("read_dir failed: {e}")))?;
let mut ids = Vec::new();
while let Some(entry) = entries.next_entry().await.map_err(|e| {
AppError::Internal(e.to_string())
})? {
while let Some(entry) = entries
.next_entry()
.await
.map_err(|e| AppError::Internal(e.to_string()))?
{
let name = entry.file_name();
let name = name.to_string_lossy();
if let Some(stem) = name.strip_suffix(".json") {

View File

@@ -37,9 +37,10 @@ impl Transcriber {
/// 0 segments. The warmup forces kernel compilation at startup so all subsequent
/// jobs run correctly from the very first request.
pub fn load(model_path: impl AsRef<Path>, gpu_device: u32) -> Result<Self> {
let path = model_path.as_ref().to_str().ok_or_else(|| {
AppError::Internal("model path is not valid UTF-8".into())
})?;
let path = model_path
.as_ref()
.to_str()
.ok_or_else(|| AppError::Internal("model path is not valid UTF-8".into()))?;
let mut params = WhisperContextParameters::new();
params.use_gpu(true);
@@ -48,8 +49,7 @@ impl Transcriber {
// real-world audio (conference recordings, noisy MP3s).
// params.flash_attn(true);
let ctx = WhisperContext::new_with_params(path, params)
.map_err(|e| {
let ctx = WhisperContext::new_with_params(path, params).map_err(|e| {
let msg = format!("failed to load model: {e}");
if AppError::is_oom(&msg) {
AppError::OutOfMemory(msg)
@@ -58,8 +58,7 @@ impl Transcriber {
}
})?;
let mut state = ctx.create_state()
.map_err(|e| {
let mut state = ctx.create_state().map_err(|e| {
let msg = format!("failed to create whisper state: {e}");
if AppError::is_oom(&msg) {
AppError::OutOfMemory(msg)
@@ -158,30 +157,39 @@ impl Transcriber {
.full(fp, pcm)
.map_err(|e| AppError::Internal(format!("transcription failed: {e}")))?;
let n_segments = state.full_n_segments()
let n_segments = state
.full_n_segments()
.map_err(|e| AppError::Internal(e.to_string()))?;
let mut segments = Vec::with_capacity(n_segments as usize);
for i in 0..n_segments {
let text = state.full_get_segment_text(i)
let text = state
.full_get_segment_text(i)
.map_err(|e| AppError::Internal(e.to_string()))?;
let start = state.full_get_segment_t0(i)
.map_err(|e| AppError::Internal(e.to_string()))? as f32 / 100.0;
let end = state.full_get_segment_t1(i)
.map_err(|e| AppError::Internal(e.to_string()))? as f32 / 100.0;
let start = state
.full_get_segment_t0(i)
.map_err(|e| AppError::Internal(e.to_string()))? as f32
/ 100.0;
let end = state
.full_get_segment_t1(i)
.map_err(|e| AppError::Internal(e.to_string()))? as f32
/ 100.0;
let n_tokens = state.full_n_tokens(i)
let n_tokens = state
.full_n_tokens(i)
.map_err(|e| AppError::Internal(e.to_string()))?;
let mut words = Vec::new();
for t in 0..n_tokens {
let token_text = state.full_get_token_text(i, t)
let token_text = state
.full_get_token_text(i, t)
.map_err(|e| AppError::Internal(e.to_string()))?;
if token_text.starts_with('[') {
continue; // skip special tokens ([MUSIC], [APPLAUSE], etc.)
}
let data = state.full_get_token_data(i, t)
let data = state
.full_get_token_data(i, t)
.map_err(|e| AppError::Internal(e.to_string()))?;
words.push(Word {
text: token_text,
@@ -191,7 +199,13 @@ impl Transcriber {
});
}
segments.push(Segment { index: i, start, end, text, words });
segments.push(Segment {
index: i,
start,
end,
text,
words,
});
}
let lang = state

View File

@@ -16,8 +16,7 @@ use crate::{
models::{Job, JobId, JobStatus, ModelEvent, ModelState, Segment},
storage::Storage,
transcriber::Transcriber,
webhook,
AppError,
webhook, AppError,
};
/// Per-job broadcast channel for SSE subscribers.
@@ -26,7 +25,11 @@ pub type ProgressTx = broadcast::Sender<ProgressEvent>;
#[derive(Debug, Clone)]
pub enum ProgressEvent {
/// `percent` — overall 0100; `chunk` — 1-based; `total` — total chunks.
Progress { percent: u8, chunk: usize, total: usize },
Progress {
percent: u8,
chunk: usize,
total: usize,
},
Done(Box<Job>),
Error(String),
}
@@ -162,14 +165,22 @@ fn transcriber_thread(
}
Ok(WorkerCmd::Unload) => {
do_unload(&mut transcriber, &model_state, &model_event_tx, &webhook_registry, &rt);
do_unload(
&mut transcriber,
&model_state,
&model_event_tx,
&webhook_registry,
&rt,
);
}
Ok(WorkerCmd::Transcribe(req)) => {
let t = match &mut transcriber {
Some(t) => t,
None => {
tracing::warn!("Transcribe cmd received but model is unloaded — failing job");
tracing::warn!(
"Transcribe cmd received but model is unloaded — failing job"
);
let _ = req.reply.send(Err(AppError::Internal(
"model unloaded before job could run".into(),
)));
@@ -177,12 +188,9 @@ fn transcriber_thread(
}
};
let result = t.transcribe(
&req.pcm,
req.language.as_deref(),
&req.task,
move |p| (req.on_progress)(p),
);
let result = t.transcribe(&req.pcm, req.language.as_deref(), &req.task, move |p| {
(req.on_progress)(p)
});
last_job = Instant::now();
let _ = req.reply.send(result);
}
@@ -253,25 +261,35 @@ fn try_load_with_polling(
"insufficient VRAM — will retry"
);
set_state(model_state, ModelState::WaitingForGpu {
set_state(
model_state,
ModelState::WaitingForGpu {
vram_needed_mb,
vram_free_mb,
retry_in_secs,
});
broadcast_event(model_event_tx, ModelEvent::ModelWaitingForGpu {
},
);
broadcast_event(
model_event_tx,
ModelEvent::ModelWaitingForGpu {
vram_needed_mb,
vram_free_mb,
retry_in_secs,
});
},
);
// Interruptible sleep: drain rx while waiting for gpu_poll_interval.
let deadline = Instant::now() + gpu_poll_interval;
loop {
let remaining = deadline.saturating_duration_since(Instant::now());
if remaining.is_zero() { break; }
if remaining.is_zero() {
break;
}
match rx.recv_timeout(remaining.min(Duration::from_secs(1))) {
Ok(WorkerCmd::Unload) => {
tracing::info!("Unload received while waiting for GPU — cancelling load");
tracing::info!(
"Unload received while waiting for GPU — cancelling load"
);
set_state(model_state, ModelState::Unloaded);
broadcast_event(model_event_tx, ModelEvent::ModelUnloaded);
fire_webhooks(webhook_registry, ModelEvent::ModelUnloaded, rt);
@@ -341,11 +359,16 @@ fn fire_webhooks(
.cloned()
.collect();
if urls.is_empty() { return; }
if urls.is_empty() {
return;
}
let payload = match serde_json::to_string(&event) {
Ok(p) => p,
Err(e) => { tracing::error!(error = %e, "failed to serialize model event"); return; }
Err(e) => {
tracing::error!(error = %e, "failed to serialize model event");
return;
}
};
for url in urls {
@@ -356,7 +379,8 @@ fn fire_webhooks(
.build()
.expect("http client");
for attempt in 0..3_u32 {
match http.post(&url)
match http
.post(&url)
.header("content-type", "application/json")
.body(body.clone())
.send()
@@ -487,7 +511,9 @@ async fn run(
let http = http.clone();
let url = url.clone();
let job = job.clone();
tokio::spawn(async move { webhook::fire(&http, &url, &job).await; });
tokio::spawn(async move {
webhook::fire(&http, &url, &job).await;
});
}
tokio::time::sleep(Duration::from_secs(30)).await;
@@ -509,9 +535,13 @@ async fn detect_silence_midpoints(path: &std::path::Path) -> Vec<f32> {
let output = Command::new("ffmpeg")
.args([
"-nostdin",
"-i", path.to_str().unwrap_or(""),
"-af", &filter,
"-f", "null", "-",
"-i",
path.to_str().unwrap_or(""),
"-af",
&filter,
"-f",
"null",
"-",
])
.output()
.await;
@@ -545,7 +575,9 @@ async fn detect_silence_midpoints(path: &std::path::Path) -> Vec<f32> {
}
}
let mids: Vec<f32> = starts.iter().zip(ends.iter())
let mids: Vec<f32> = starts
.iter()
.zip(ends.iter())
.map(|(s, e)| (s + e) / 2.0)
.collect();
@@ -553,18 +585,15 @@ async fn detect_silence_midpoints(path: &std::path::Path) -> Vec<f32> {
mids
}
fn snap_to_silence(
mids: &[f32],
total_secs: f32,
target_secs: f32,
snap_window: f32,
) -> Vec<f32> {
fn snap_to_silence(mids: &[f32], total_secs: f32, target_secs: f32, snap_window: f32) -> Vec<f32> {
let mut cuts: Vec<f32> = Vec::new();
let mut pos = target_secs;
while pos < total_secs - target_secs * 0.25 {
let prev_cut = cuts.last().copied().unwrap_or(0.0);
let best = mids.iter().copied()
let best = mids
.iter()
.copied()
.filter(|&t| t > prev_cut + 10.0 && (t - pos).abs() <= snap_window)
.min_by(|a, b| (a - pos).abs().partial_cmp(&(b - pos).abs()).unwrap());
let cut = best.unwrap_or(pos);
@@ -591,6 +620,146 @@ fn to_chunk_ranges(cuts: &[f32], total_secs: f32) -> Vec<(f32, f32)> {
ranges
}
const MAX_CHAIN_GAP_SECS: f32 = 0.15;
const MIN_MEANINGFUL_WORDS: usize = 2;
const MIN_MEANINGFUL_CHARS: usize = 8;
const MIN_OVERLAP_WORDS: usize = 3;
fn normalised_words(text: &str) -> Vec<String> {
text.split_whitespace()
.map(|word| {
word.chars()
.filter(|ch| ch.is_alphanumeric() || *ch == '_')
.flat_map(|ch| ch.to_lowercase())
.collect::<String>()
})
.filter(|word| !word.is_empty())
.collect()
}
fn starts_with_words(full: &[String], prefix: &[String]) -> bool {
prefix.len() <= full.len() && full.iter().take(prefix.len()).eq(prefix.iter())
}
fn ends_with_words(full: &[String], suffix: &[String]) -> bool {
suffix.len() <= full.len()
&& full
.iter()
.skip(full.len() - suffix.len())
.eq(suffix.iter())
}
fn suffix_prefix_overlap(left: &[String], right: &[String]) -> usize {
let max = left.len().min(right.len());
for size in (1..=max).rev() {
if left[left.len() - size..] == right[..size] {
return size;
}
}
0
}
fn is_meaningful_phrase(words: &[String]) -> bool {
words.len() >= MIN_MEANINGFUL_WORDS
&& words.iter().map(|word| word.len()).sum::<usize>() >= MIN_MEANINGFUL_CHARS
}
fn trim_leading_words(text: &str, count: usize) -> String {
text.split_whitespace()
.skip(count)
.collect::<Vec<_>>()
.join(" ")
.trim()
.to_string()
}
fn merge_identical_segments(segments: Vec<Segment>) -> Vec<Segment> {
let mut out: Vec<Segment> = Vec::with_capacity(segments.len());
for seg in segments {
if let Some(last) = out.last_mut() {
if normalised_words(&last.text) == normalised_words(&seg.text) {
last.end = last.end.max(seg.end);
if !seg.words.is_empty() {
last.words = seg.words;
}
continue;
}
}
out.push(seg);
}
out
}
fn collapse_incremental_segments(segments: Vec<Segment>) -> Vec<Segment> {
let mut out: Vec<Segment> = Vec::with_capacity(segments.len());
for mut seg in segments {
seg.text = seg.text.trim().to_string();
if seg.text.is_empty() {
continue;
}
let Some(last) = out.last_mut() else {
out.push(seg);
continue;
};
let gap = seg.start - last.end;
if gap > MAX_CHAIN_GAP_SECS {
out.push(seg);
continue;
}
let last_words = normalised_words(&last.text);
let seg_words = normalised_words(&seg.text);
if last_words.is_empty() || seg_words.is_empty() {
out.push(seg);
continue;
}
if seg_words.len() > last_words.len()
&& starts_with_words(&seg_words, &last_words)
&& is_meaningful_phrase(&last_words)
{
last.text = seg.text;
last.end = seg.end;
last.words = seg.words;
continue;
}
if ends_with_words(&last_words, &seg_words) && is_meaningful_phrase(&seg_words) {
last.end = last.end.max(seg.end);
continue;
}
let overlap = suffix_prefix_overlap(&last_words, &seg_words);
if overlap >= MIN_OVERLAP_WORDS {
let trimmed_text = trim_leading_words(&seg.text, overlap);
if trimmed_text.is_empty() {
last.end = last.end.max(seg.end);
continue;
}
seg.start = seg.start.max(last.end);
seg.text = trimmed_text;
seg.words.clear();
}
out.push(seg);
}
out
}
fn normalise_segments(segments: Vec<Segment>) -> Vec<Segment> {
let mut result = collapse_incremental_segments(segments);
result = merge_identical_segments(result);
result = collapse_incremental_segments(result);
merge_identical_segments(result)
}
// ── Job processing ────────────────────────────────────────────────────────────
async fn process_job(
@@ -604,7 +773,12 @@ async fn process_job(
let total_secs = pcm.len() as f32 / 16_000.0;
let silence_mids = detect_silence_midpoints(audio_path).await;
let cuts = snap_to_silence(&silence_mids, total_secs, TARGET_CHUNK_SECS, SNAP_WINDOW_SECS);
let cuts = snap_to_silence(
&silence_mids,
total_secs,
TARGET_CHUNK_SECS,
SNAP_WINDOW_SECS,
);
let chunks = to_chunk_ranges(&cuts, total_secs);
let n = chunks.len();
@@ -653,15 +827,18 @@ async fn process_job(
});
let (reply_tx, reply_rx) = oneshot::channel();
cmd_tx.send(WorkerCmd::Transcribe(TranscribeRequest {
cmd_tx
.send(WorkerCmd::Transcribe(TranscribeRequest {
pcm: chunk_pcm,
language: job.language.clone(),
task: job.task.clone(),
on_progress,
reply: reply_tx,
})).map_err(|_| AppError::Internal("worker command channel closed".into()))?;
}))
.map_err(|_| AppError::Internal("worker command channel closed".into()))?;
let (mut segs, lang) = reply_rx.await
let (mut segs, lang) = reply_rx
.await
.map_err(|_| AppError::Internal("transcriber thread dropped reply".into()))??;
let offset = *chunk_start;
@@ -689,11 +866,17 @@ async fn process_job(
}
}
all_segments = normalise_segments(all_segments);
for (i, seg) in all_segments.iter_mut().enumerate() {
seg.index = i as i32;
}
let _ = progress_tx.send(ProgressEvent::Progress { percent: 100, chunk: n, total: n });
let _ = progress_tx.send(ProgressEvent::Progress {
percent: 100,
chunk: n,
total: n,
});
Ok((all_segments, language, total_secs))
}
@@ -719,11 +902,17 @@ async fn decode_audio(path: &std::path::Path) -> crate::Result<Vec<f32>> {
let output = Command::new("ffmpeg")
.args([
"-nostdin", "-threads", "0",
"-i", path.to_str().unwrap_or(""),
"-f", "f32le",
"-ac", "1",
"-ar", "16000",
"-nostdin",
"-threads",
"0",
"-i",
path.to_str().unwrap_or(""),
"-f",
"f32le",
"-ac",
"1",
"-ar",
"16000",
"-",
])
.output()
@@ -760,13 +949,28 @@ pub fn audio_path_for(id: &JobId) -> PathBuf {
#[cfg(test)]
mod tests {
use super::*;
use crate::models::Word;
fn segment(index: i32, start: f32, end: f32, text: &str) -> Segment {
Segment {
index,
start,
end,
text: text.into(),
words: Vec::<Word>::new(),
}
}
#[test]
fn test_snap_to_silence_uses_nearest_midpoint() {
let mids = vec![55.0, 58.0, 62.0];
let cuts = snap_to_silence(&mids, 120.0, 60.0, 30.0);
assert!(!cuts.is_empty());
assert!((cuts[0] - 58.0).abs() < 0.01, "expected ~58.0, got {}", cuts[0]);
assert!(
(cuts[0] - 58.0).abs() < 0.01,
"expected ~58.0, got {}",
cuts[0]
);
}
#[test]
@@ -801,4 +1005,53 @@ mod tests {
trim_trailing_silence(&mut pcm);
assert_eq!(pcm.len(), (10_001 + 8_000).min(32_000));
}
#[test]
fn test_normalise_segments_collapses_prefix_growth_chain() {
let input = vec![
segment(0, 15.24, 16.6, "Hello everyone."),
segment(1, 16.6, 19.47, "Hello everyone. Um, welcome to this talk."),
segment(2, 19.47, 19.48, "Um, welcome to this talk."),
segment(
3,
19.48,
21.67,
"Um, welcome to this talk. I'll be speaking about small model",
),
segment(4, 21.67, 21.68, "I'll be speaking about small model"),
segment(
5,
21.68,
24.59,
"I'll be speaking about small model inference and a gap that we've",
),
];
let result = normalise_segments(input);
assert_eq!(result.len(), 2);
assert_eq!(result[0].text, "Hello everyone. Um, welcome to this talk.");
assert!((result[0].start - 15.24).abs() < 0.01);
assert!((result[0].end - 19.48).abs() < 0.01);
assert_eq!(
result[1].text,
"I'll be speaking about small model inference and a gap that we've"
);
assert!((result[1].start - 19.48).abs() < 0.01);
assert!((result[1].end - 24.59).abs() < 0.01);
}
#[test]
fn test_normalise_segments_keeps_real_gap() {
let input = vec![
segment(0, 0.0, 1.0, "Hello everyone."),
segment(1, 2.0, 4.0, "Hello everyone. Welcome back."),
];
let result = normalise_segments(input);
assert_eq!(result.len(), 2);
assert_eq!(result[0].text, "Hello everyone.");
assert_eq!(result[1].text, "Hello everyone. Welcome back.");
}
}