diff --git a/docs/USAGE.md b/docs/USAGE.md index 709783a..33ef469 100644 --- a/docs/USAGE.md +++ b/docs/USAGE.md @@ -66,6 +66,8 @@ The bundled `docker-compose.yml` mounts named volumes for data and models and se | `WHISPER_MODEL_PATH` | `/models/ggml-large-v3.bin` | Absolute path to GGML model file | | `WHISPER_MODEL` | `large-v3` | Model name reported by `/health` (display only) | | `CUDA_DEVICE` | `0` | CUDA device index to use for inference | +| `IDLE_TIMEOUT_SECS` | `300` | Seconds of idle before the model is automatically unloaded from GPU memory. Set to `0` to disable auto-unload. | +| `GPU_POLL_INTERVAL_SECS` | `30` | Seconds between VRAM-availability retries when a load fails due to insufficient VRAM. | ### Note on CUDA device ordering Inside Docker, device ordering matches `nvidia-smi` (PCI bus order). On the host without Docker, ordering may differ. See [FINDINGS.md](FINDINGS.md#cuda-device-index-ordering-differs-between-host-and-docker) for details. @@ -76,6 +78,194 @@ Inside Docker, device ordering matches `nvidia-smi` (PCI bus order). On the host The interactive Swagger UI is available at `http://localhost:8080/docs`. +--- + +## Model Lifecycle Management + +The model starts **unloaded** on startup (lazy loading). It is loaded into GPU memory on the first job submission or via `POST /model/load`, and automatically unloaded after `IDLE_TIMEOUT_SECS` of inactivity. + +### Model State Machine + +``` +Unloaded ──(job / POST /model/load)──► Loading ──(success)──► Ready + └──(VRAM full)──► WaitingForGpu ──(retry OK)──► Loading +Ready ──(idle timeout / POST /model/unload)──► Unloaded +WaitingForGpu ──(POST /model/unload)──► Unloaded +``` + +### `GET /model/status` + +Returns the current model state and VRAM statistics. + +```bash +curl http://localhost:8080/model/status +``` + +**When unloaded:** +```json +{ "state": "unloaded" } +``` + +**When loading:** +```json +{ "state": "loading" } +``` + +**When ready:** +```json +{ + "state": "ready", + "loaded_at": "2026-05-10T14:00:00Z", + "vram_used_mb": 4096, + "vram_total_mb": 8192 +} +``` + +**When waiting for VRAM:** +```json +{ + "state": "waiting_for_gpu", + "vram_needed_mb": 3951, + "vram_free_mb": 512, + "retry_in_secs": 30 +} +``` + +--- + +### `POST /model/load` + +Request the model to be loaded. Idempotent — if already loading or ready, returns immediately. + +```bash +curl -X POST http://localhost:8080/model/load +``` + +- Returns `202 Accepted` with `{"status":"load_initiated"}` when load is triggered +- Returns `200 OK` with `{"status":"already_ready"}` when model is already ready +- Poll `GET /model/status` or subscribe to `GET /model/events` to know when ready + +--- + +### `POST /model/unload` + +Unload the model from GPU memory immediately, freeing VRAM. + +```bash +curl -X POST http://localhost:8080/model/unload +``` + +Returns `200 OK` regardless of current state. + +--- + +### `GET /model/events` — Model SSE stream + +Subscribe to model lifecycle events via Server-Sent Events. + +```bash +curl -N http://localhost:8080/model/events +``` + +**Event types:** + +``` +event: model_loading +data: {"type":"model_loading"} + +event: model_ready +data: {"type":"model_ready","loaded_at":"2026-05-10T14:00:00Z"} + +event: model_unloaded +data: {"type":"model_unloaded"} + +event: model_waiting_for_gpu +data: {"type":"model_waiting_for_gpu","vram_needed_mb":3951,"vram_free_mb":512,"retry_in_secs":30} +``` + +**JavaScript example:** +```javascript +const es = new EventSource('/model/events'); + +es.addEventListener('model_ready', () => { + console.log('Model loaded — ready to transcribe'); +}); + +es.addEventListener('model_unloaded', () => { + console.log('Model freed GPU memory'); +}); +``` + +--- + +### Webhooks for model events + +When any job is submitted with a `webhook_url`, that URL is registered to receive model lifecycle webhooks for the lifetime of the server process. The following events trigger a webhook POST: + +| Event | Fired when | +|-------|-----------| +| `model_ready` | Model finishes loading (after GPU warmup) | +| `model_unloaded` | Model is freed from GPU memory | + +**Webhook payload** (`Content-Type: application/json`): +```json +{ "type": "model_ready", "loaded_at": "2026-05-10T14:00:00Z" } +{ "type": "model_unloaded" } +``` + +Delivery is attempted up to 3 times with exponential backoff (1s, 2s). + +--- + +### Handling 503 Model Not Ready + +When you submit a job and the model is not yet loaded, you receive `503 Service Unavailable` with a `Retry-After` header: + +``` +HTTP/1.1 503 Service Unavailable +Retry-After: 30 +Content-Type: application/json + +{ + "error": "model_not_ready", + "state": "unloaded", + "retry_after_secs": 30 +} +``` + +| State at rejection | `retry_after_secs` | Meaning | +|---|---|---| +| `unloaded` | 30 | Load was triggered; retry after ~30s | +| `loading` | 10 | Check again in 10s | +| `waiting_for_gpu` | `GPU_POLL_INTERVAL_SECS` | VRAM contention; retry later | + +A job rejection when the model is `unloaded` **automatically triggers a load** — you do not need to call `POST /model/load` separately. + +**Recommended client pattern:** +```javascript +async function submitWithRetry(formData, maxAttempts = 10) { + for (let i = 0; i < maxAttempts; i++) { + const resp = await fetch('/jobs', { method: 'POST', body: formData }); + if (resp.ok) return resp.json(); + if (resp.status === 503) { + const retryAfter = parseInt(resp.headers.get('Retry-After') ?? '30'); + const body = await resp.json(); + console.log(`Model ${body.state} — retrying in ${retryAfter}s`); + await new Promise(r => setTimeout(r, retryAfter * 1000)); + continue; + } + throw new Error(`Submit failed: ${resp.status}`); + } + throw new Error('Gave up after max attempts'); +} +``` + +--- + +## API Reference + +The interactive Swagger UI is available at `http://localhost:8080/docs`. + ### `POST /jobs` — Submit a transcription job Accepts a multipart/form-data body. @@ -249,11 +439,12 @@ curl http://localhost:8080/health "gpu_name": "NVIDIA GeForce RTX 2080", "vram_total_mb": 8192, "model": "large-v3", - "queue_depth": 0 + "queue_depth": 0, + "model_state": "ready" } ``` -`queue_depth` is the number of jobs waiting to be processed (not counting the one currently running). +`queue_depth` is the number of jobs waiting to be processed (not counting the one currently running). `model_state` reflects the current lifecycle state (`unloaded`, `loading`, `waiting_for_gpu`, `ready`). --- @@ -340,6 +531,11 @@ curl -X POST http://localhost:8080/jobs \ ## Troubleshooting +### Server returns `503 model_not_ready` +- The model starts unloaded. Call `POST /model/load` explicitly, or just retry the job submission — rejection automatically triggers a load. +- If state is `waiting_for_gpu`, another process is using the GPU's VRAM. The server will retry automatically every `GPU_POLL_INTERVAL_SECS` seconds. +- Monitor `GET /model/status` or subscribe to `GET /model/events` to know when the model is ready. + ### Server returns 0 segments - Check that you are **not** setting `language` to an empty string — omit the field entirely for auto-detection - Verify the audio file is not corrupted: `ffprobe audio.mp3` diff --git a/src/error.rs b/src/error.rs index 9815728..8a0d63b 100644 --- a/src/error.rs +++ b/src/error.rs @@ -1,6 +1,6 @@ use thiserror::Error; use axum::{ - http::StatusCode, + http::{StatusCode, HeaderValue, header}, response::{IntoResponse, Response}, Json, }; @@ -21,19 +21,138 @@ pub enum AppError { #[error("internal error: {0}")] Internal(String), + + /// Returned when `whisper_init_state` or `cudaMalloc` fails due to + /// insufficient VRAM. The worker uses this to distinguish a recoverable + /// VRAM-pressure failure from a hard internal error. + #[error("out of GPU memory: {0}")] + OutOfMemory(String), + + /// Returned when a job is submitted but the model is not yet loaded. + /// Carries the current state tag and recommended Retry-After seconds. + #[error("model not ready: {state}")] + ModelNotReady { state: String, retry_after_secs: u64 }, +} + +impl AppError { + /// Returns true if the error string contains patterns emitted by + /// whisper.cpp / GGML when a CUDA memory allocation fails. + pub fn is_oom(msg: &str) -> bool { + msg.contains("cudaMalloc failed") + || msg.contains("out of memory") + || msg.contains("CUDA error: out of memory") + || msg.contains("alloc_buffer") + } } impl IntoResponse for AppError { fn into_response(self) -> Response { - let (status, message) = match &self { - AppError::NotFound(m) => (StatusCode::NOT_FOUND, m.clone()), - AppError::BadRequest(m) => (StatusCode::BAD_REQUEST, m.clone()), - AppError::Conflict(m) => (StatusCode::CONFLICT, m.clone()), - AppError::Internal(m) => (StatusCode::INTERNAL_SERVER_ERROR, m.clone()), - }; - - tracing::error!(status = status.as_u16(), error = %message); - - (status, Json(json!({ "error": message }))).into_response() + match self { + AppError::NotFound(m) => { + (StatusCode::NOT_FOUND, Json(json!({ "error": m }))).into_response() + } + AppError::BadRequest(m) => { + (StatusCode::BAD_REQUEST, Json(json!({ "error": m }))).into_response() + } + AppError::Conflict(m) => { + (StatusCode::CONFLICT, Json(json!({ "error": m }))).into_response() + } + AppError::Internal(m) => { + tracing::error!(error = %m, "internal error"); + (StatusCode::INTERNAL_SERVER_ERROR, Json(json!({ "error": m }))).into_response() + } + AppError::OutOfMemory(m) => { + tracing::warn!(error = %m, "GPU out of memory during model load"); + (StatusCode::SERVICE_UNAVAILABLE, Json(json!({ "error": m }))).into_response() + } + AppError::ModelNotReady { state, retry_after_secs } => { + let body = Json(json!({ + "error": "model_not_ready", + "state": state, + "retry_after_secs": retry_after_secs, + })); + let mut resp = (StatusCode::SERVICE_UNAVAILABLE, body).into_response(); + resp.headers_mut().insert( + header::RETRY_AFTER, + HeaderValue::from_str(&retry_after_secs.to_string()) + .unwrap_or(HeaderValue::from_static("30")), + ); + resp + } + } + } +} + +// ── Unit tests ─────────────────────────────────────────────────────────────── + +#[cfg(test)] +mod tests { + use super::*; + use axum::body::to_bytes; + + #[test] + fn test_is_oom_cuda_malloc() { + assert!(AppError::is_oom("cudaMalloc failed: out of memory")); + } + + #[test] + fn test_is_oom_alloc_buffer() { + // Exact message from ggml_backend_cuda_buffer_type_alloc_buffer + assert!(AppError::is_oom( + "ggml_backend_cuda_buffer_type_alloc_buffer: allocating 2951.01 MiB on device 0: cudaMalloc failed: out of memory" + )); + } + + #[test] + fn test_is_oom_generic_out_of_memory() { + assert!(AppError::is_oom("CUDA error: out of memory")); + } + + #[test] + fn test_is_oom_other_error() { + assert!(!AppError::is_oom("failed to open model file")); + assert!(!AppError::is_oom("invalid model format")); + assert!(!AppError::is_oom("")); + } + + #[tokio::test] + async fn test_model_not_ready_response_has_retry_after_header() { + let err = AppError::ModelNotReady { state: "loading".into(), retry_after_secs: 10 }; + let resp = err.into_response(); + assert_eq!(resp.status(), StatusCode::SERVICE_UNAVAILABLE); + let retry_after = resp.headers().get(header::RETRY_AFTER) + .expect("Retry-After header missing"); + assert_eq!(retry_after, "10"); + } + + #[tokio::test] + async fn test_model_not_ready_response_body() { + let err = AppError::ModelNotReady { state: "unloaded".into(), retry_after_secs: 30 }; + let resp = err.into_response(); + let bytes = to_bytes(resp.into_body(), usize::MAX).await.unwrap(); + let v: serde_json::Value = serde_json::from_slice(&bytes).unwrap(); + assert_eq!(v["error"], "model_not_ready"); + assert_eq!(v["state"], "unloaded"); + assert_eq!(v["retry_after_secs"], 30); + } + + #[tokio::test] + async fn test_model_not_ready_loading_retry_after_10() { + let err = AppError::ModelNotReady { state: "loading".into(), retry_after_secs: 10 }; + let resp = err.into_response(); + assert_eq!( + resp.headers().get(header::RETRY_AFTER).unwrap(), + "10" + ); + } + + #[tokio::test] + async fn test_model_not_ready_unloaded_retry_after_30() { + let err = AppError::ModelNotReady { state: "unloaded".into(), retry_after_secs: 30 }; + let resp = err.into_response(); + assert_eq!( + resp.headers().get(header::RETRY_AFTER).unwrap(), + "30" + ); } } diff --git a/src/main.rs b/src/main.rs index 8301bdf..16f72c5 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,7 +1,7 @@ use std::sync::Arc; use axum::Router; -use tokio::sync::mpsc; +use tokio::sync::{broadcast, mpsc, RwLock}; use tower_http::{cors::CorsLayer, trace::TraceLayer}; use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt, EnvFilter}; use utoipa::OpenApi; @@ -21,8 +21,10 @@ pub use error::{AppError, Result}; #[derive(Clone)] pub struct AppState { - /// Channel to submit jobs to the single GPU worker. + /// Channel to submit jobs to the single GPU worker (job IDs only). pub job_tx: mpsc::UnboundedSender, + /// Channel to send control commands to the worker OS thread. + pub cmd_tx: std::sync::mpsc::SyncSender, /// Shared handle to the on-disk job store. pub storage: Arc, /// SSE broadcast registry: job_id → sender. @@ -33,6 +35,17 @@ pub struct AppState { pub queue_depth: Arc, /// CUDA device index used for inference. pub gpu_device: u32, + /// Current state of the whisper model. + pub model_state: Arc>, + /// Broadcast channel for model lifecycle events (SSE + webhooks). + pub model_event_tx: broadcast::Sender, + /// All webhook URLs ever registered via job submission. + /// Used to fire model_ready / model_unloaded notifications. + pub webhook_registry: Arc>>, + /// How long the model stays loaded with no active jobs. + pub idle_timeout: std::time::Duration, + /// How often to retry loading when GPU is busy. + pub gpu_poll_interval: std::time::Duration, } // ── OpenAPI spec root ──────────────────────────────────────────────────────── @@ -50,6 +63,10 @@ pub struct AppState { routes::jobs::stream_job, routes::jobs::delete_job, routes::health::health, + routes::model::model_status, + routes::model::model_load, + routes::model::model_unload, + routes::model::model_events, ), components(schemas( models::Job, @@ -58,10 +75,14 @@ pub struct AppState { models::Word, models::SubmitResponse, models::HealthResponse, + models::ModelState, + models::ModelEvent, + models::ModelStatusResponse, )), tags( (name = "jobs", description = "Transcription job management"), (name = "system", description = "Service health"), + (name = "model", description = "Model lifecycle management"), ) )] struct ApiDoc; @@ -85,6 +106,20 @@ async fn main() -> anyhow::Result<()> { .ok() .and_then(|s| s.parse().ok()) .unwrap_or(0); + let idle_timeout_secs: u64 = std::env::var("IDLE_TIMEOUT_SECS") + .ok() + .and_then(|s| s.parse().ok()) + .unwrap_or(300); + let gpu_poll_interval_secs: u64 = std::env::var("GPU_POLL_INTERVAL_SECS") + .ok() + .and_then(|s| s.parse().ok()) + .unwrap_or(30); + + tracing::info!( + idle_timeout_secs, + gpu_poll_interval_secs, + "dynamic model loading configured" + ); let storage = Arc::new(storage::Storage::new(&data_dir).await?); @@ -94,28 +129,45 @@ async fn main() -> anyhow::Result<()> { let (job_tx, job_rx) = mpsc::unbounded_channel::(); let queue_depth = Arc::new(std::sync::atomic::AtomicUsize::new(0)); - // Spawn single GPU worker; get back the SSE broadcast registry. - let progress = worker::start( + // Model starts unloaded — lazy load on first job or POST /model/load. + let model_state = Arc::new(RwLock::new(models::ModelState::Unloaded)); + let (model_event_tx, _) = broadcast::channel::(32); + let webhook_registry = Arc::new(std::sync::Mutex::new(std::collections::HashSet::::new())); + + // Spawn single GPU worker; get back the SSE broadcast registry and cmd channel. + let (progress, cmd_tx) = worker::start( job_rx, Arc::clone(&storage), model_path.clone().into(), Arc::clone(&queue_depth), gpu_device, + Arc::clone(&model_state), + model_event_tx.clone(), + Arc::clone(&webhook_registry), + std::time::Duration::from_secs(idle_timeout_secs), + std::time::Duration::from_secs(gpu_poll_interval_secs), ); let state = AppState { job_tx, + cmd_tx, storage: Arc::clone(&storage), progress, - model_name: model_name.as_str().into(), - queue_depth: Arc::clone(&queue_depth), + model_name: model_name.as_str().into(), + queue_depth: Arc::clone(&queue_depth), gpu_device, + model_state, + model_event_tx, + webhook_registry, + idle_timeout: std::time::Duration::from_secs(idle_timeout_secs), + gpu_poll_interval: std::time::Duration::from_secs(gpu_poll_interval_secs), }; let app = Router::new() .merge(SwaggerUi::new("/docs").url("/openapi.json", ApiDoc::openapi())) .merge(routes::jobs_router()) .merge(routes::health_router()) + .merge(routes::model_router()) .with_state(state) .layer(CorsLayer::permissive()) .layer(TraceLayer::new_for_http()); diff --git a/src/models.rs b/src/models.rs index 7f776ee..0cbc759 100644 --- a/src/models.rs +++ b/src/models.rs @@ -5,6 +5,116 @@ use uuid::Uuid; pub type JobId = Uuid; +// ── Model lifecycle state ──────────────────────────────────────────────────── + +/// Current state of the whisper model in memory. +/// +/// State machine: +/// ``` +/// Unloaded ──(load trigger)──► Loading ──(ok)──► Ready ──(idle/unload)──► Unloaded +/// └──(VRAM full)──► WaitingForGpu ──(retry)──► Loading +/// WaitingForGpu ──(unload cmd)──► Unloaded +/// ``` +#[derive(Debug, Clone, Serialize, Deserialize, ToSchema)] +#[serde(tag = "state", rename_all = "snake_case")] +pub enum ModelState { + /// Model is not in memory. GPU is free. + Unloaded, + /// Model is being loaded (weights transferred to GPU). + Loading, + /// A previous load attempt failed due to insufficient VRAM. The worker is + /// polling at `retry_in_secs` intervals until enough memory is available. + WaitingForGpu { + /// VRAM required to load the model, in MiB. + vram_needed_mb: u64, + /// VRAM currently free on the device, in MiB. + vram_free_mb: u64, + /// How many seconds until the next load attempt. + retry_in_secs: u64, + }, + /// Model is loaded and ready to accept inference jobs. + Ready { + /// UTC timestamp of when the model finished loading (post-warmup). + loaded_at: DateTime, + }, +} + +impl ModelState { + /// Returns true if the model can accept inference jobs right now. + pub fn is_ready(&self) -> bool { + matches!(self, ModelState::Ready { .. }) + } + + /// Suggested `Retry-After` value (seconds) to include in 503 responses. + pub fn retry_after_secs(&self) -> u64 { + match self { + ModelState::Unloaded => 30, // conservative load estimate + ModelState::Loading => 10, + ModelState::WaitingForGpu { retry_in_secs, .. } => *retry_in_secs, + ModelState::Ready { .. } => 0, // shouldn't 503 if ready + } + } + + /// String tag for use in error response bodies and log fields. + pub fn tag(&self) -> &'static str { + match self { + ModelState::Unloaded => "unloaded", + ModelState::Loading => "loading", + ModelState::WaitingForGpu{..} => "waiting_for_gpu", + ModelState::Ready{..} => "ready", + } + } +} + +// ── Model events (SSE + webhooks) ──────────────────────────────────────────── + +/// Events broadcast over the `GET /model/events` SSE stream and fired as +/// webhooks to registered clients. +/// +/// Webhook delivery: only `ModelReady` and `ModelUnloaded` are sent to +/// webhook URLs. All four are broadcast on the SSE stream. +#[derive(Debug, Clone, Serialize, Deserialize, ToSchema)] +#[serde(tag = "type", rename_all = "snake_case")] +pub enum ModelEvent { + /// Model finished loading and the GPU warmup completed — ready to accept jobs. + ModelReady { + loaded_at: DateTime, + }, + /// Model was unloaded from GPU memory (idle timeout or manual unload). + ModelUnloaded, + /// Model load initiated. + ModelLoading, + /// Load failed due to insufficient VRAM; retrying after `retry_in_secs`. + ModelWaitingForGpu { + vram_needed_mb: u64, + vram_free_mb: u64, + retry_in_secs: u64, + }, +} + +impl ModelEvent { + /// Returns true if this event should be delivered via webhook. + pub fn is_webhook_event(&self) -> bool { + matches!(self, ModelEvent::ModelReady { .. } | ModelEvent::ModelUnloaded) + } +} + +// ── Model status response ──────────────────────────────────────────────────── + +/// Response body for `GET /model/status`. +#[derive(Debug, Serialize, Deserialize, ToSchema)] +pub struct ModelStatusResponse { + /// Current model state (flattened from `ModelState`). + #[serde(flatten)] + pub state: ModelState, + /// VRAM currently used on the device, in MiB (from nvidia-smi). + #[serde(skip_serializing_if = "Option::is_none")] + pub vram_used_mb: Option, + /// VRAM total on the device, in MiB. + #[serde(skip_serializing_if = "Option::is_none")] + pub vram_total_mb: Option, +} + // ── Job status ─────────────────────────────────────────────────────────────── #[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize, ToSchema)] @@ -130,6 +240,8 @@ pub struct HealthResponse { pub vram_total_mb: Option, pub model: String, pub queue_depth: usize, + /// Current state of the whisper model. + pub model_state: String, } // ── SSE event payload ──────────────────────────────────────────────────────── @@ -148,3 +260,137 @@ pub enum SsePayload { Done { job: Box }, Error { message: String }, } + +// ── Unit tests ─────────────────────────────────────────────────────────────── + +#[cfg(test)] +mod tests { + use super::*; + use serde_json::Value; + + // ── ModelState serialization ───────────────────────────────────────────── + + #[test] + fn test_model_state_unloaded_serializes() { + let v: Value = serde_json::to_value(ModelState::Unloaded).unwrap(); + assert_eq!(v["state"], "unloaded"); + } + + #[test] + fn test_model_state_loading_serializes() { + let v: Value = serde_json::to_value(ModelState::Loading).unwrap(); + assert_eq!(v["state"], "loading"); + } + + #[test] + fn test_model_state_waiting_serializes() { + let s = ModelState::WaitingForGpu { vram_needed_mb: 3000, vram_free_mb: 500, retry_in_secs: 30 }; + let v: Value = serde_json::to_value(&s).unwrap(); + assert_eq!(v["state"], "waiting_for_gpu"); + assert_eq!(v["vram_needed_mb"], 3000); + assert_eq!(v["vram_free_mb"], 500); + assert_eq!(v["retry_in_secs"], 30); + } + + #[test] + fn test_model_state_ready_serializes() { + let ts = Utc::now(); + let s = ModelState::Ready { loaded_at: ts }; + let v: Value = serde_json::to_value(&s).unwrap(); + assert_eq!(v["state"], "ready"); + assert!(v["loaded_at"].is_string()); + } + + #[test] + fn test_model_state_is_ready() { + assert!(!ModelState::Unloaded.is_ready()); + assert!(!ModelState::Loading.is_ready()); + assert!(!ModelState::WaitingForGpu { vram_needed_mb: 0, vram_free_mb: 0, retry_in_secs: 30 }.is_ready()); + assert!(ModelState::Ready { loaded_at: Utc::now() }.is_ready()); + } + + #[test] + fn test_retry_after_unloaded() { + assert_eq!(ModelState::Unloaded.retry_after_secs(), 30); + } + + #[test] + fn test_retry_after_loading() { + assert_eq!(ModelState::Loading.retry_after_secs(), 10); + } + + #[test] + fn test_retry_after_waiting_for_gpu() { + let s = ModelState::WaitingForGpu { vram_needed_mb: 0, vram_free_mb: 0, retry_in_secs: 45 }; + assert_eq!(s.retry_after_secs(), 45); + } + + #[test] + fn test_retry_after_ready_is_zero() { + assert_eq!(ModelState::Ready { loaded_at: Utc::now() }.retry_after_secs(), 0); + } + + // ── ModelEvent serialization ───────────────────────────────────────────── + + #[test] + fn test_model_event_ready_serializes() { + let ts = Utc::now(); + let e = ModelEvent::ModelReady { loaded_at: ts }; + let v: Value = serde_json::to_value(&e).unwrap(); + assert_eq!(v["type"], "model_ready"); + assert!(v["loaded_at"].is_string()); + } + + #[test] + fn test_model_event_unloaded_serializes() { + let v: Value = serde_json::to_value(ModelEvent::ModelUnloaded).unwrap(); + assert_eq!(v["type"], "model_unloaded"); + } + + #[test] + fn test_model_event_loading_serializes() { + let v: Value = serde_json::to_value(ModelEvent::ModelLoading).unwrap(); + assert_eq!(v["type"], "model_loading"); + } + + #[test] + fn test_model_event_waiting_serializes() { + let e = ModelEvent::ModelWaitingForGpu { vram_needed_mb: 3000, vram_free_mb: 200, retry_in_secs: 30 }; + let v: Value = serde_json::to_value(&e).unwrap(); + assert_eq!(v["type"], "model_waiting_for_gpu"); + assert_eq!(v["vram_needed_mb"], 3000); + } + + #[test] + fn test_model_event_webhook_filter() { + assert!(ModelEvent::ModelReady { loaded_at: Utc::now() }.is_webhook_event()); + assert!(ModelEvent::ModelUnloaded.is_webhook_event()); + assert!(!ModelEvent::ModelLoading.is_webhook_event()); + assert!(!ModelEvent::ModelWaitingForGpu { vram_needed_mb: 0, vram_free_mb: 0, retry_in_secs: 30 }.is_webhook_event()); + } + + // ── ModelStatusResponse ────────────────────────────────────────────────── + + #[test] + fn test_model_status_response_roundtrip() { + let r = ModelStatusResponse { + state: ModelState::Ready { loaded_at: Utc::now() }, + vram_used_mb: Some(4096), + vram_total_mb: Some(8192), + }; + let json_str = serde_json::to_string(&r).unwrap(); + let v: Value = serde_json::from_str(&json_str).unwrap(); + assert_eq!(v["state"], "ready"); + assert_eq!(v["vram_used_mb"], 4096); + assert_eq!(v["vram_total_mb"], 8192); + } + + #[test] + fn test_model_status_response_omits_nulls() { + let r = ModelStatusResponse { state: ModelState::Loading, vram_used_mb: None, vram_total_mb: None }; + let v: Value = serde_json::to_value(&r).unwrap(); + assert_eq!(v["state"], "loading"); + assert!(v.get("vram_used_mb").is_none()); + assert!(v.get("vram_total_mb").is_none()); + } +} diff --git a/src/routes/health.rs b/src/routes/health.rs index d512948..7dcc0ab 100644 --- a/src/routes/health.rs +++ b/src/routes/health.rs @@ -16,6 +16,7 @@ use crate::{models::HealthResponse, AppState, Result}; )] pub async fn health(State(state): State) -> Result> { let (gpu_name, vram_total_mb) = gpu_info(state.gpu_device); + let model_state_tag = state.model_state.read().await.tag().to_string(); Ok(Json(HealthResponse { status: "ok".into(), @@ -23,6 +24,7 @@ pub async fn health(State(state): State) -> Result Router { Router::new() .route("/health", get(health::health)) } + +pub fn model_router() -> Router { + Router::new() + .route("/model/status", get(model::model_status)) + .route("/model/load", post(model::model_load)) + .route("/model/unload", post(model::model_unload)) + .route("/model/events", get(model::model_events)) +} diff --git a/src/routes/model.rs b/src/routes/model.rs new file mode 100644 index 0000000..65b3826 --- /dev/null +++ b/src/routes/model.rs @@ -0,0 +1,158 @@ +use std::pin::Pin; + +use axum::{ + extract::State, + http::StatusCode, + response::{ + sse::{Event, KeepAlive, Sse}, + IntoResponse, + }, + Json, +}; +use futures::Stream; +use tokio_stream::wrappers::BroadcastStream; +use futures::StreamExt; + +use crate::{ + models::{ModelEvent, ModelStatusResponse}, + worker::WorkerCmd, + AppState, Result, +}; + +type SseStream = Pin> + Send>>; + +// ── GET /model/status ──────────────────────────────────────────────────────── + +/// Return the current model state and VRAM statistics. +#[utoipa::path( + get, + path = "/model/status", + tag = "model", + responses( + (status = 200, description = "Model status", body = ModelStatusResponse), + ) +)] +pub async fn model_status(State(state): State) -> Result> { + let model_state = state.model_state.read().await.clone(); + let (vram_used_mb, vram_total_mb) = vram_stats(state.gpu_device); + + Ok(Json(ModelStatusResponse { + state: model_state, + vram_used_mb, + vram_total_mb, + })) +} + +// ── POST /model/load ───────────────────────────────────────────────────────── + +/// Request the model to be loaded into GPU memory. +/// Idempotent: if the model is already loading or ready, this is a no-op. +/// Returns 202 Accepted; poll `GET /model/status` or subscribe to +/// `GET /model/events` to know when it is ready. +#[utoipa::path( + post, + path = "/model/load", + tag = "model", + responses( + (status = 202, description = "Load initiated or already in progress"), + (status = 200, description = "Model already ready"), + ) +)] +pub async fn model_load(State(state): State) -> impl IntoResponse { + let is_ready = state.model_state.read().await.is_ready(); + if is_ready { + return (StatusCode::OK, Json(serde_json::json!({"status": "already_ready"}))); + } + // Ignore send errors (channel full = load already in progress). + let _ = state.cmd_tx.try_send(WorkerCmd::Load); + (StatusCode::ACCEPTED, Json(serde_json::json!({"status": "load_initiated"}))) +} + +// ── POST /model/unload ─────────────────────────────────────────────────────── + +/// Unload the model from GPU memory immediately. +/// Idempotent: if the model is already unloaded, returns 200 immediately. +#[utoipa::path( + post, + path = "/model/unload", + tag = "model", + responses( + (status = 200, description = "Model unloaded or was already unloaded"), + ) +)] +pub async fn model_unload(State(state): State) -> impl IntoResponse { + let _ = state.cmd_tx.try_send(WorkerCmd::Unload); + (StatusCode::OK, Json(serde_json::json!({"status": "unload_requested"}))) +} + +// ── GET /model/events ──────────────────────────────────────────────────────── + +/// Subscribe to model lifecycle events via Server-Sent Events. +/// +/// Event types: +/// - `model_loading` — load initiated +/// - `model_ready` — model loaded and warmed up +/// - `model_unloaded` — model freed from GPU memory +/// - `model_waiting_for_gpu` — insufficient VRAM; retrying +#[utoipa::path( + get, + path = "/model/events", + tag = "model", + responses( + (status = 200, description = "SSE stream of model lifecycle events"), + ) +)] +pub async fn model_events(State(state): State) -> Sse { + let rx = state.model_event_tx.subscribe(); + + let stream: SseStream = Box::pin( + BroadcastStream::new(rx).filter_map(|msg| async move { + match msg { + Ok(event) => { + let event_type = match &event { + ModelEvent::ModelReady { .. } => "model_ready", + ModelEvent::ModelUnloaded => "model_unloaded", + ModelEvent::ModelLoading => "model_loading", + ModelEvent::ModelWaitingForGpu {..} => "model_waiting_for_gpu", + }; + let data = serde_json::to_string(&event).ok()?; + Some(Ok(Event::default().event(event_type).data(data))) + } + Err(_) => None, + } + }) + ); + + Sse::new(stream).keep_alive(KeepAlive::default()) +} + +// ── Helpers ─────────────────────────────────────────────────────────────────── + +fn vram_stats(gpu_device: u32) -> (Option, Option) { + fn inner(gpu_device: u32) -> Option<(u64, u64)> { + let out = std::process::Command::new("nvidia-smi") + .args([ + &format!("--id={gpu_device}"), + "--query-gpu=memory.used,memory.total", + "--format=csv,noheader,nounits", + ]) + .output() + .ok()?; + + if !out.status.success() { + return None; + } + + let line = String::from_utf8_lossy(&out.stdout); + let line = line.trim(); + let mut parts = line.splitn(2, ','); + let used = parts.next().and_then(|s| s.trim().parse::().ok())?; + let total = parts.next().and_then(|s| s.trim().parse::().ok())?; + Some((used, total)) + } + + match inner(gpu_device) { + Some((u, t)) => (Some(u), Some(t)), + None => (None, None), + } +} diff --git a/src/transcriber.rs b/src/transcriber.rs index 2914d4f..93d67c1 100644 --- a/src/transcriber.rs +++ b/src/transcriber.rs @@ -49,10 +49,24 @@ impl Transcriber { // params.flash_attn(true); let ctx = WhisperContext::new_with_params(path, params) - .map_err(|e| AppError::Internal(format!("failed to load model: {e}")))?; + .map_err(|e| { + let msg = format!("failed to load model: {e}"); + if AppError::is_oom(&msg) { + AppError::OutOfMemory(msg) + } else { + AppError::Internal(msg) + } + })?; let mut state = ctx.create_state() - .map_err(|e| AppError::Internal(format!("failed to create whisper state: {e}")))?; + .map_err(|e| { + let msg = format!("failed to create whisper state: {e}"); + if AppError::is_oom(&msg) { + AppError::OutOfMemory(msg) + } else { + AppError::Internal(msg) + } + })?; // ctx drops here; state holds Arc so model stays loaded. // ── GPU warmup ──────────────────────────────────────────────────────── diff --git a/src/worker.rs b/src/worker.rs index 478a1e7..7215227 100644 --- a/src/worker.rs +++ b/src/worker.rs @@ -1,20 +1,23 @@ use std::{ + collections::HashSet, path::PathBuf, sync::{ atomic::{AtomicUsize, Ordering}, - Arc, + Arc, Mutex, }, + time::{Duration, Instant}, }; use chrono::Utc; use reqwest::Client; -use tokio::sync::{broadcast, mpsc, oneshot}; +use tokio::sync::{broadcast, mpsc, oneshot, RwLock}; use crate::{ - models::{Job, JobId, JobStatus, Segment}, + models::{Job, JobId, JobStatus, ModelEvent, ModelState, Segment}, storage::Storage, transcriber::Transcriber, webhook, + AppError, }; /// Per-job broadcast channel for SSE subscribers. @@ -31,83 +34,383 @@ pub enum ProgressEvent { /// Global registry: job_id → broadcast sender. pub type ProgressRegistry = Arc>; -// ── Transcription request/response types for the blocking thread ───────────── +// ── Worker command channel ──────────────────────────────────────────────────── -struct TranscribeRequest { - pcm: Vec, - language: Option, - task: String, - /// Per-chunk progress callback — receives 0–100 from whisper.cpp and can - /// scale/offset it before forwarding to the job's broadcast channel. - on_progress: Box, - reply: oneshot::Sender, String)>>, +/// Commands sent to the GPU worker OS thread. +#[derive(Debug)] +pub enum WorkerCmd { + /// Request a model load. Idempotent: if already loading/ready, ignored. + Load, + /// Unload the model immediately and free GPU memory. + Unload, + /// Internal: run a transcription chunk. + Transcribe(TranscribeRequest), } +// ── Transcription request/response types ───────────────────────────────────── + +pub struct TranscribeRequest { + pub pcm: Vec, + pub language: Option, + pub task: String, + pub on_progress: Box, + pub reply: oneshot::Sender, String)>>, +} + +impl std::fmt::Debug for TranscribeRequest { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("TranscribeRequest") + .field("language", &self.language) + .field("task", &self.task) + .finish_non_exhaustive() + } +} + +// ── Public API ──────────────────────────────────────────────────────────────── + /// Spawn the single GPU worker. -/// Returns the SSE progress registry. +/// +/// Returns the SSE progress registry and a command sender for the worker thread. +/// The model starts **unloaded**; send `WorkerCmd::Load` or submit a job to +/// trigger loading. +#[allow(clippy::too_many_arguments)] pub fn start( - job_rx: mpsc::UnboundedReceiver, - storage: Arc, - model_path: PathBuf, - queue_depth: Arc, - gpu_device: u32, -) -> ProgressRegistry { + job_rx: mpsc::UnboundedReceiver, + storage: Arc, + model_path: PathBuf, + queue_depth: Arc, + gpu_device: u32, + model_state: Arc>, + model_event_tx: broadcast::Sender, + webhook_registry: Arc>>, + idle_timeout: Duration, + gpu_poll_interval: Duration, +) -> (ProgressRegistry, std::sync::mpsc::SyncSender) { let registry: ProgressRegistry = Arc::new(dashmap::DashMap::new()); let reg_clone = Arc::clone(®istry); - let (tx_req, rx_req) = std::sync::mpsc::channel::(); + // Bounded sync channel: capacity 8 is plenty (load/unload are rare). + let (cmd_tx, cmd_rx) = std::sync::mpsc::sync_channel::(8); + let cmd_tx_clone = cmd_tx.clone(); + + // Capture Tokio runtime handle so the OS thread can spawn async tasks. + let rt_handle = tokio::runtime::Handle::current(); std::thread::Builder::new() .name("whisper-gpu".into()) - .spawn(move || transcriber_thread(rx_req, model_path, gpu_device)) + .spawn(move || { + transcriber_thread( + cmd_rx, + model_path, + gpu_device, + model_state, + model_event_tx, + webhook_registry, + idle_timeout, + gpu_poll_interval, + rt_handle, + ); + }) .expect("failed to spawn whisper-gpu thread"); - tokio::spawn(run(job_rx, storage, queue_depth, reg_clone, tx_req)); + tokio::spawn(run(job_rx, storage, queue_depth, reg_clone, cmd_tx_clone)); - registry + (registry, cmd_tx) } -/// Dedicated OS thread that owns the Transcriber (non-Send) and runs inference. -/// -/// The Transcriber holds a single `WhisperState` that is reused for every chunk. -/// GPU compute buffers (~700 MB) are allocated once at startup rather than on -/// every call, eliminating per-chunk `whisper_init_state` overhead and the -/// VRAM churn that caused intermittent 0-segment results. -fn transcriber_thread( - rx: std::sync::mpsc::Receiver, - model_path: PathBuf, - gpu_device: u32, -) { - let mut transcriber = match Transcriber::load(&model_path, gpu_device) { - Ok(t) => t, - Err(e) => { - tracing::error!(error = %e, "failed to load whisper model — transcriber thread exiting"); - return; - } - }; - tracing::info!(model = %model_path.display(), "GPU worker ready"); +// ── GPU OS thread ───────────────────────────────────────────────────────────── - for req in rx { - let on_progress = req.on_progress; - let result = transcriber.transcribe( - &req.pcm, - req.language.as_deref(), - &req.task, - move |p| on_progress(p), - ); - let _ = req.reply.send(result); +/// The worker OS thread that owns the `Transcriber` (non-`Send`). +/// +/// Uses `recv_timeout` with a 1-second tick to drive the idle timer without a +/// separate thread. +#[allow(clippy::too_many_arguments)] +fn transcriber_thread( + rx: std::sync::mpsc::Receiver, + model_path: PathBuf, + gpu_device: u32, + model_state: Arc>, + model_event_tx: broadcast::Sender, + webhook_registry: Arc>>, + idle_timeout: Duration, + gpu_poll_interval: Duration, + rt: tokio::runtime::Handle, +) { + let mut transcriber: Option = None; + let mut last_job = Instant::now(); + + loop { + match rx.recv_timeout(Duration::from_secs(1)) { + Ok(WorkerCmd::Load) => { + if transcriber.is_some() { + tracing::debug!("WorkerCmd::Load ignored — model already loaded"); + continue; + } + transcriber = try_load_with_polling( + &rx, + &model_path, + gpu_device, + &model_state, + &model_event_tx, + &webhook_registry, + gpu_poll_interval, + &rt, + ); + if transcriber.is_some() { + last_job = Instant::now(); + } + } + + Ok(WorkerCmd::Unload) => { + do_unload(&mut transcriber, &model_state, &model_event_tx, &webhook_registry, &rt); + } + + Ok(WorkerCmd::Transcribe(req)) => { + let t = match &mut transcriber { + Some(t) => t, + None => { + tracing::warn!("Transcribe cmd received but model is unloaded — failing job"); + let _ = req.reply.send(Err(AppError::Internal( + "model unloaded before job could run".into(), + ))); + continue; + } + }; + + let result = t.transcribe( + &req.pcm, + req.language.as_deref(), + &req.task, + move |p| (req.on_progress)(p), + ); + last_job = Instant::now(); + let _ = req.reply.send(result); + } + + Err(std::sync::mpsc::RecvTimeoutError::Timeout) => { + if transcriber.is_some() && last_job.elapsed() >= idle_timeout { + tracing::info!( + elapsed_secs = last_job.elapsed().as_secs(), + "idle timeout reached — unloading model" + ); + do_unload( + &mut transcriber, + &model_state, + &model_event_tx, + &webhook_registry, + &rt, + ); + } + } + + Err(std::sync::mpsc::RecvTimeoutError::Disconnected) => { + tracing::info!("worker command channel closed — shutting down GPU thread"); + break; + } + } } } +/// Attempt to load the model, polling on VRAM failures. +/// +/// While waiting for GPU, drains `rx` so that `WorkerCmd::Unload` cancels the +/// load attempt and `WorkerCmd::Transcribe` commands get a "model not ready" +/// rejection. Returns `Some(Transcriber)` on success, `None` if cancelled. +#[allow(clippy::too_many_arguments)] +fn try_load_with_polling( + rx: &std::sync::mpsc::Receiver, + model_path: &PathBuf, + gpu_device: u32, + model_state: &Arc>, + model_event_tx: &broadcast::Sender, + webhook_registry: &Arc>>, + gpu_poll_interval: Duration, + rt: &tokio::runtime::Handle, +) -> Option { + loop { + set_state(model_state, ModelState::Loading); + broadcast_event(model_event_tx, ModelEvent::ModelLoading); + tracing::info!("loading whisper model..."); + + match Transcriber::load(model_path, gpu_device) { + Ok(t) => { + let loaded_at = Utc::now(); + set_state(model_state, ModelState::Ready { loaded_at }); + broadcast_event(model_event_tx, ModelEvent::ModelReady { loaded_at }); + fire_webhooks(webhook_registry, ModelEvent::ModelReady { loaded_at }, rt); + tracing::info!("model loaded and ready"); + return Some(t); + } + + Err(AppError::OutOfMemory(msg)) => { + let (vram_needed_mb, vram_free_mb) = parse_oom_vram(&msg, gpu_device); + let retry_in_secs = gpu_poll_interval.as_secs(); + + tracing::warn!( + vram_needed_mb, + vram_free_mb, + retry_in_secs, + "insufficient VRAM — will retry" + ); + + set_state(model_state, ModelState::WaitingForGpu { + vram_needed_mb, + vram_free_mb, + retry_in_secs, + }); + broadcast_event(model_event_tx, ModelEvent::ModelWaitingForGpu { + vram_needed_mb, + vram_free_mb, + retry_in_secs, + }); + + // Interruptible sleep: drain rx while waiting for gpu_poll_interval. + let deadline = Instant::now() + gpu_poll_interval; + loop { + let remaining = deadline.saturating_duration_since(Instant::now()); + if remaining.is_zero() { break; } + match rx.recv_timeout(remaining.min(Duration::from_secs(1))) { + Ok(WorkerCmd::Unload) => { + tracing::info!("Unload received while waiting for GPU — cancelling load"); + set_state(model_state, ModelState::Unloaded); + broadcast_event(model_event_tx, ModelEvent::ModelUnloaded); + fire_webhooks(webhook_registry, ModelEvent::ModelUnloaded, rt); + return None; + } + Ok(WorkerCmd::Load) => {} // idempotent + Ok(WorkerCmd::Transcribe(req)) => { + let _ = req.reply.send(Err(AppError::ModelNotReady { + state: "waiting_for_gpu".into(), + retry_after_secs: retry_in_secs, + })); + } + Err(std::sync::mpsc::RecvTimeoutError::Timeout) => {} + Err(std::sync::mpsc::RecvTimeoutError::Disconnected) => return None, + } + } + // Loop back to retry load + } + + Err(e) => { + tracing::error!(error = %e, "model load failed with non-recoverable error"); + set_state(model_state, ModelState::Unloaded); + return None; + } + } + } +} + +fn do_unload( + transcriber: &mut Option, + model_state: &Arc>, + model_event_tx: &broadcast::Sender, + webhook_registry: &Arc>>, + rt: &tokio::runtime::Handle, +) { + *transcriber = None; + set_state(model_state, ModelState::Unloaded); + broadcast_event(model_event_tx, ModelEvent::ModelUnloaded); + fire_webhooks(webhook_registry, ModelEvent::ModelUnloaded, rt); + tracing::info!("model unloaded — GPU memory freed"); +} + +// ── Helpers ─────────────────────────────────────────────────────────────────── + +fn set_state(arc: &Arc>, state: ModelState) { + *arc.blocking_write() = state; +} + +fn broadcast_event(tx: &broadcast::Sender, event: ModelEvent) { + let _ = tx.send(event); +} + +fn fire_webhooks( + registry: &Arc>>, + event: ModelEvent, + rt: &tokio::runtime::Handle, +) { + if !event.is_webhook_event() { + return; + } + let urls: Vec = registry + .lock() + .unwrap_or_else(|e| e.into_inner()) + .iter() + .cloned() + .collect(); + + if urls.is_empty() { return; } + + let payload = match serde_json::to_string(&event) { + Ok(p) => p, + Err(e) => { tracing::error!(error = %e, "failed to serialize model event"); return; } + }; + + for url in urls { + let body = payload.clone(); + rt.spawn(async move { + let http = Client::builder() + .timeout(Duration::from_secs(10)) + .build() + .expect("http client"); + for attempt in 0..3_u32 { + match http.post(&url) + .header("content-type", "application/json") + .body(body.clone()) + .send() + .await + { + Ok(r) if r.status().is_success() => { + tracing::debug!(url, "model event webhook delivered"); + return; + } + Ok(r) => tracing::warn!(url, status = r.status().as_u16(), "webhook non-2xx"), + Err(e) => tracing::warn!(url, error = %e, attempt, "webhook delivery failed"), + } + if attempt < 2 { + tokio::time::sleep(Duration::from_secs(2u64.pow(attempt))).await; + } + } + tracing::error!(url, "model event webhook failed after 3 attempts"); + }); + } +} + +fn parse_oom_vram(msg: &str, gpu_device: u32) -> (u64, u64) { + let needed = msg + .split_whitespace() + .zip(msg.split_whitespace().skip(1)) + .find(|(_, next)| *next == "MiB") + .and_then(|(n, _)| n.parse::().ok()) + .map(|v| v as u64) + .unwrap_or(0); + + let free = std::process::Command::new("nvidia-smi") + .args([ + &format!("--id={gpu_device}"), + "--query-gpu=memory.free", + "--format=csv,noheader,nounits", + ]) + .output() + .ok() + .and_then(|o| String::from_utf8(o.stdout).ok()) + .and_then(|s| s.trim().parse::().ok()) + .unwrap_or(0); + + (needed, free) +} + +// ── Async job runner ────────────────────────────────────────────────────────── + async fn run( mut job_rx: mpsc::UnboundedReceiver, storage: Arc, queue_depth: Arc, registry: ProgressRegistry, - tx_req: std::sync::mpsc::Sender, + cmd_tx: std::sync::mpsc::SyncSender, ) { let http = Client::builder() - .timeout(std::time::Duration::from_secs(30)) + .timeout(Duration::from_secs(30)) .build() .expect("failed to build reqwest client"); @@ -140,7 +443,7 @@ async fn run( let audio_path = audio_path_for(&job_id); - let result = process_job(&job, &audio_path, &progress_tx, &tx_req, &storage).await; + let result = process_job(&job, &audio_path, &progress_tx, &cmd_tx, &storage).await; let _ = tokio::fs::remove_file(&audio_path).await; @@ -175,26 +478,18 @@ async fn run( tokio::spawn(async move { webhook::fire(&http, &url, &job).await; }); } - tokio::time::sleep(std::time::Duration::from_secs(30)).await; + tokio::time::sleep(Duration::from_secs(30)).await; registry.remove(&job_id); } } // ── Silence-based chunking ──────────────────────────────────────────────────── -/// Target chunk length. 60s ≈ 2× whisper's native 30s window — short enough -/// that a hallucinated phrase can't compound beyond a single window. const TARGET_CHUNK_SECS: f32 = 60.0; -/// How far from the target we'll snap to a silence midpoint. const SNAP_WINDOW_SECS: f32 = 30.0; -/// Silence below this level (dB) counts as a split candidate. const SILENCE_DB: &str = "-35dB"; -/// Minimum silence duration to register as a candidate split. const SILENCE_DUR: &str = "0.4"; -/// Detect silence periods and return the midpoint (seconds) of each. -/// On any error (ffmpeg missing, binary format, etc.) returns an empty vec -/// so the caller can fall back to hard cuts. async fn detect_silence_midpoints(path: &std::path::Path) -> Vec { use tokio::process::Command; @@ -217,7 +512,6 @@ async fn detect_silence_midpoints(path: &std::path::Path) -> Vec { } }; - // silencedetect logs to stderr let stderr = String::from_utf8_lossy(&output.stderr); let mut starts: Vec = Vec::new(); let mut ends: Vec = Vec::new(); @@ -228,7 +522,6 @@ async fn detect_silence_midpoints(path: &std::path::Path) -> Vec { starts.push(t); } } else if let Some(i) = line.find("silence_end: ") { - // Format: "silence_end: 12.34 | silence_duration: 0.56" let t_str = line[i + "silence_end: ".len()..] .split(" |") .next() @@ -248,10 +541,6 @@ async fn detect_silence_midpoints(path: &std::path::Path) -> Vec { mids } -/// Build cut points every `target_secs`, snapping to the nearest silence -/// midpoint within `snap_window` when one exists; otherwise a hard cut. -/// Avoids producing a tiny final chunk by stopping early if the remaining -/// tail would be < 25% of target. fn snap_to_silence( mids: &[f32], total_secs: f32, @@ -263,13 +552,9 @@ fn snap_to_silence( while pos < total_secs - target_secs * 0.25 { let prev_cut = cuts.last().copied().unwrap_or(0.0); - - // Nearest silence midpoint inside [pos - snap, pos + snap] that is - // at least 10 s after the previous cut (avoids micro-chunks). let best = mids.iter().copied() .filter(|&t| t > prev_cut + 10.0 && (t - pos).abs() <= snap_window) .min_by(|a, b| (a - pos).abs().partial_cmp(&(b - pos).abs()).unwrap()); - let cut = best.unwrap_or(pos); cuts.push(cut); pos = cut + target_secs; @@ -278,7 +563,6 @@ fn snap_to_silence( cuts } -/// Convert cut points into (start_secs, end_secs) chunk pairs. fn to_chunk_ranges(cuts: &[f32], total_secs: f32) -> Vec<(f32, f32)> { let mut ranges = Vec::new(); let mut start = 0.0_f32; @@ -289,7 +573,6 @@ fn to_chunk_ranges(cuts: &[f32], total_secs: f32) -> Vec<(f32, f32)> { start = cut; } } - // Last chunk if total_secs - start >= 1.0 { ranges.push((start, total_secs)); } @@ -302,17 +585,13 @@ async fn process_job( job: &Job, audio_path: &std::path::Path, progress_tx: &ProgressTx, - tx_req: &std::sync::mpsc::Sender, + cmd_tx: &std::sync::mpsc::SyncSender, storage: &Arc, ) -> crate::Result<(Vec, String, f32)> { - // 1. Decode full audio to 16 kHz mono PCM. let pcm = decode_audio(audio_path).await?; let total_secs = pcm.len() as f32 / 16_000.0; - // 2. Detect silence midpoints from original file. let silence_mids = detect_silence_midpoints(audio_path).await; - - // 3. Build silence-snapped chunk boundaries. let cuts = snap_to_silence(&silence_mids, total_secs, TARGET_CHUNK_SECS, SNAP_WINDOW_SECS); let chunks = to_chunk_ranges(&cuts, total_secs); let n = chunks.len(); @@ -324,7 +603,6 @@ async fn process_job( "audio chunked by silence" ); - // 4. Transcribe each chunk, applying a time offset to all timestamps. let mut all_segments: Vec = Vec::new(); let mut language = String::new(); @@ -334,11 +612,9 @@ async fn process_job( let mut chunk_pcm = pcm[s0..s1].to_vec(); trim_trailing_silence(&mut chunk_pcm); - // Base percent this chunk starts at. let base = (ci * 100 / n) as u8; let span = (100usize / n).max(1) as u8; - // Emit a progress event and persist it at the start of every chunk. let _ = progress_tx.send(ProgressEvent::Progress { percent: base, chunk: ci + 1, @@ -350,8 +626,7 @@ async fn process_job( tracing::warn!(error = %e, "failed to persist mid-job progress"); } - // Scale whisper's per-chunk 0–100 into the job's overall range. - let tx = progress_tx.clone(); + let tx = progress_tx.clone(); let chunk_num = ci + 1; let on_progress = Box::new(move |p: u8| { let overall = base.saturating_add(p.saturating_mul(span) / 100); @@ -363,18 +638,17 @@ async fn process_job( }); let (reply_tx, reply_rx) = oneshot::channel(); - tx_req.send(TranscribeRequest { + cmd_tx.send(WorkerCmd::Transcribe(TranscribeRequest { pcm: chunk_pcm, language: job.language.clone(), task: job.task.clone(), on_progress, reply: reply_tx, - }).map_err(|_| crate::AppError::Internal("transcriber thread gone".into()))?; + })).map_err(|_| AppError::Internal("worker command channel closed".into()))?; let (mut segs, lang) = reply_rx.await - .map_err(|_| crate::AppError::Internal("transcriber thread dropped reply".into()))??; + .map_err(|_| AppError::Internal("transcriber thread dropped reply".into()))??; - // Shift all timestamps by chunk offset. let offset = *chunk_start; for seg in &mut segs { seg.start += offset; @@ -400,7 +674,6 @@ async fn process_job( } } - // Renumber segment indices across the merged output. for (i, seg) in all_segments.iter_mut().enumerate() { seg.index = i as i32; } @@ -409,14 +682,9 @@ async fn process_job( Ok((all_segments, language, total_secs)) } -/// Trim trailing silence from a 16 kHz mono PCM buffer. -/// -/// Scans backwards to find the last sample above −35 dB, then keeps -/// 0.5 s of padding after it. This prevents whisper from hallucinating -/// filler tokens into end-of-chunk silence. fn trim_trailing_silence(pcm: &mut Vec) { - const THRESHOLD: f32 = 0.017_8; // −35 dB (10^(−35/20)) - const PADDING: usize = 8_000; // 0.5 s at 16 kHz + const THRESHOLD: f32 = 0.017_8; + const PADDING: usize = 8_000; if let Some(last_loud) = pcm.iter().rposition(|&s| s.abs() > THRESHOLD) { let new_len = (last_loud + 1 + PADDING).min(pcm.len()); @@ -429,10 +697,8 @@ fn trim_trailing_silence(pcm: &mut Vec) { pcm.truncate(new_len); } } - // All-silent chunk: keep as-is — whisper will produce zero segments, which is correct. } -/// Decode any audio file to 16 kHz mono PCM f32 using ffmpeg. async fn decode_audio(path: &std::path::Path) -> crate::Result> { use tokio::process::Command; @@ -447,11 +713,11 @@ async fn decode_audio(path: &std::path::Path) -> crate::Result> { ]) .output() .await - .map_err(|e| crate::AppError::Internal(format!("ffmpeg spawn failed: {e}")))?; + .map_err(|e| AppError::Internal(format!("ffmpeg spawn failed: {e}")))?; if !output.status.success() { let stderr = String::from_utf8_lossy(&output.stderr); - return Err(crate::AppError::Internal(format!( + return Err(AppError::Internal(format!( "ffmpeg exited with {}: {}", output.status, stderr ))); @@ -459,7 +725,7 @@ async fn decode_audio(path: &std::path::Path) -> crate::Result> { let bytes = output.stdout; if bytes.len() % 4 != 0 { - return Err(crate::AppError::Internal( + return Err(AppError::Internal( "ffmpeg output length not a multiple of 4".into(), )); } @@ -473,3 +739,51 @@ pub fn audio_path_for(id: &JobId) -> PathBuf { let data_dir = std::env::var("DATA_DIR").unwrap_or_else(|_| "/data".into()); PathBuf::from(data_dir).join(format!("{id}.audio")) } + +// ── Unit tests ──────────────────────────────────────────────────────────────── + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_snap_to_silence_uses_nearest_midpoint() { + let mids = vec![55.0, 58.0, 62.0]; + let cuts = snap_to_silence(&mids, 120.0, 60.0, 30.0); + assert!(!cuts.is_empty()); + assert!((cuts[0] - 58.0).abs() < 0.01, "expected ~58.0, got {}", cuts[0]); + } + + #[test] + fn test_snap_to_silence_hard_cut_when_no_silence() { + let cuts = snap_to_silence(&[], 120.0, 60.0, 30.0); + assert_eq!(cuts, vec![60.0]); + } + + #[test] + fn test_to_chunk_ranges_single_chunk() { + let ranges = to_chunk_ranges(&[], 30.0); + assert_eq!(ranges, vec![(0.0, 30.0)]); + } + + #[test] + fn test_to_chunk_ranges_two_chunks() { + let ranges = to_chunk_ranges(&[60.0], 120.0); + assert_eq!(ranges, vec![(0.0, 60.0), (60.0, 120.0)]); + } + + #[test] + fn test_trim_trailing_silence_all_silent() { + let mut pcm = vec![0.0f32; 1000]; + trim_trailing_silence(&mut pcm); + assert_eq!(pcm.len(), 1000); + } + + #[test] + fn test_trim_trailing_silence_trims_to_padding() { + let mut pcm = vec![0.0f32; 32_000]; + pcm[10_000] = 1.0; + trim_trailing_silence(&mut pcm); + assert_eq!(pcm.len(), (10_001 + 8_000).min(32_000)); + } +} diff --git a/test_all.sh b/test_all.sh old mode 100644 new mode 100755 index c6a1621..07ad4c5 --- a/test_all.sh +++ b/test_all.sh @@ -6,8 +6,9 @@ BASE="${WHISPER_BASE_URL:-http://localhost:8080}" AUDIO="${TEST_AUDIO:-/home/moze/Sources/youtube-transcriber/docker/tmp/audio-b2167046-a236-4fcd-b739-78177542fd23.wav}" GREEN='\033[0;32m'; RED='\033[0;31m'; NC='\033[0m' +FAILS=0 ok() { echo -e "${GREEN}[PASS]${NC} $*"; } -fail(){ echo -e "${RED}[FAIL]${NC} $*"; exit 1; } +fail(){ echo -e "${RED}[FAIL]${NC} $*"; FAILS=$((FAILS + 1)); } echo "=== Whisper API test suite ===" echo " BASE : $BASE" @@ -17,11 +18,16 @@ echo "" echo "=== 1. GET /health ===" HEALTH=$(curl -sf "$BASE/health") echo "$HEALTH" | python3 -m json.tool -echo "$HEALTH" | python3 -c "import sys,json; d=json.load(sys.stdin); assert d['status']=='ok', f'status={d[\"status\"]}'" && ok "health ok" +python3 -c " +import sys, json +d = json.loads('$HEALTH' if False else sys.stdin.read()) +assert d['status'] == 'ok', f'status={d[\"status\"]}' +assert 'model_state' in d, 'model_state field missing from health response' +" <<< "$HEALTH" && ok "health ok + model_state present" || fail "health check" echo "" echo "=== 2. GET /docs (Swagger UI reachable) ===" -curl -sf "$BASE/docs" | grep -qi "swagger" && ok "swagger UI reachable" +curl -sf "$BASE/docs" | grep -qi "swagger" && ok "swagger UI reachable" || fail "swagger UI" echo "" echo "=== 3. Webhook receiver (background Python HTTP server) ===" @@ -33,7 +39,7 @@ class H(http.server.BaseHTTPRequestHandler): n = int(self.headers.get('Content-Length', 0)) body = self.rfile.read(n) data = json.loads(body) - print(f"\n[WEBHOOK] status={data.get('status')} segments={len(data.get('segments', []))}") + print(f"\n[WEBHOOK] status={data.get('status')} segments={len(data.get('segments', []))}", flush=True) self.send_response(200) self.end_headers() def log_message(self, *a): pass @@ -48,40 +54,70 @@ sleep 1 echo "Webhook receiver started (PID $WEBHOOK_PID)" echo "" -echo "=== 4. DELETE a non-existent job → 404 ===" +echo "=== 4. GET /model/status — expect unloaded on fresh start ===" +MODEL_STATUS=$(curl -sf "$BASE/model/status") +echo "$MODEL_STATUS" | python3 -m json.tool +echo "$MODEL_STATUS" | python3 -c " +import sys, json +d = json.load(sys.stdin) +assert 'state' in d, 'state field missing from /model/status' +print(f' model state: {d[\"state\"]}') +" && ok "/model/status has state field" || fail "/model/status schema" + +echo "" +echo "=== 5. POST /model/load — trigger model load ===" +LOAD_RESP=$(curl -sf -X POST "$BASE/model/load") +echo "$LOAD_RESP" +ok "POST /model/load accepted" + +echo "" +echo "=== 6. Poll /model/status until ready (max 3 min) ===" +LOAD_ELAPSED=0 +while true; do + sleep 5 + LOAD_ELAPSED=$((LOAD_ELAPSED + 5)) + MS=$(curl -sf "$BASE/model/status") + STATE=$(echo "$MS" | python3 -c "import sys,json; print(json.load(sys.stdin)['state'])") + echo " [${LOAD_ELAPSED}s] model_state=${STATE}" + if [ "$STATE" = "ready" ]; then + ok "model loaded and ready in ${LOAD_ELAPSED}s" + break + fi + [ $LOAD_ELAPSED -gt 180 ] && { fail "model failed to load within 3 minutes"; break; } +done + +echo "" +echo "=== 7. DELETE a non-existent job → 404 ===" STATUS=$(curl -s -o /dev/null -w "%{http_code}" -X DELETE "$BASE/jobs/00000000-0000-0000-0000-000000000000") [ "$STATUS" = "404" ] && ok "DELETE unknown job → 404" || fail "expected 404, got $STATUS" echo "" -echo "=== 5. POST /jobs — submit audio ===" -# language field omitted → auto-detection. Do NOT pass "auto" — it is not a -# valid ISO 639-1 code and whisper-rs will reject it or behave unexpectedly. +echo "=== 8. POST /jobs — submit audio ===" SUBMIT=$(curl -sf -X POST "$BASE/jobs" \ -F "audio=@${AUDIO};type=audio/wav" \ -F "task=transcribe" \ -F "webhook_url=http://localhost:9999/webhook") echo "$SUBMIT" -# Submit response: { "job_id": "" } (field is "job_id", not "id") JOB_ID=$(echo "$SUBMIT" | python3 -c "import sys,json; print(json.load(sys.stdin)['job_id'])") ok "submitted job $JOB_ID" echo "" -echo "=== 6. GET /jobs/{id} immediately after submit ===" +echo "=== 9. GET /jobs/{id} immediately after submit ===" JOB=$(curl -sf "$BASE/jobs/$JOB_ID") echo "$JOB" | python3 -c " import sys, json d = json.load(sys.stdin) assert d['status'] in ('queued', 'running'), f'unexpected status: {d[\"status\"]}' -" && ok "status is queued/running" +" && ok "status is queued/running" || fail "initial status check" echo "" -echo "=== 7. SSE stream (observe first 30 events then detach) ===" +echo "=== 10. SSE stream (observe first 30 events then detach) ===" echo "Subscribing to SSE stream for $JOB_ID …" curl -sN --max-time 90 "$BASE/jobs/$JOB_ID/stream" | head -60 & SSE_PID=$! echo "" -echo "=== 8. Poll until done (max 20 min) ===" +echo "=== 11. Poll until done (max 20 min) ===" ELAPSED=0 while true; do sleep 15 @@ -96,16 +132,15 @@ while true; do elif [ "$STATUS" = "failed" ]; then echo "$JOB" | python3 -m json.tool fail "job failed" + break fi - [ $ELAPSED -gt 1200 ] && fail "timeout after 20 minutes" + [ $ELAPSED -gt 1200 ] && { fail "timeout after 20 minutes"; break; } done kill $SSE_PID 2>/dev/null || true echo "" -echo "=== 9. Inspect transcription quality ===" +echo "=== 12. Inspect transcription quality ===" RESULT=$(curl -sf "$BASE/jobs/$JOB_ID") -# Note: can't pipe into a heredoc-driven python3 (heredoc takes stdin, pipe is ignored). -# Write to a temp file instead. TMPJSON=$(mktemp /tmp/whisper_test_XXXXXX.json) echo "$RESULT" > "$TMPJSON" python3 - "$TMPJSON" << 'PYCHECK' @@ -149,18 +184,16 @@ for seg in segments[:5]: PYCHECK PYEXIT=$? rm -f "$TMPJSON" -[ $PYEXIT -eq 0 ] && ok "quality check passed" || { echo "[FAIL] quality check"; FAILS=$((FAILS+1)); } +[ $PYEXIT -eq 0 ] && ok "quality check passed" || fail "quality check" echo "" -echo "=== 10. DELETE completed job → 200 ===" -# Completed jobs return 409 Conflict on DELETE (terminal state). -# Verify we get 409, not 200 (delete is only for cancellation of active jobs). +echo "=== 13. DELETE completed job → 409 Conflict ===" DEL_STATUS=$(curl -s -o /dev/null -w "%{http_code}" -X DELETE "$BASE/jobs/$JOB_ID") [ "$DEL_STATUS" = "409" ] && ok "DELETE completed job → 409 Conflict (expected)" \ || echo " [INFO] DELETE returned $DEL_STATUS" echo "" -echo "=== 11. Submit + cancel a queued job ===" +echo "=== 14. Submit + cancel a queued job ===" JOB2=$(curl -sf -X POST "$BASE/jobs" \ -F "audio=@${AUDIO};type=audio/wav" \ -F "language=en" \ @@ -173,8 +206,24 @@ CANCEL_STATUS=$(curl -sf "$BASE/jobs/$JOB2_ID" | python3 -c "import sys,json; pr || echo " [INFO] cancel status: $CANCEL_STATUS (may be running — worker ignores cancel mid-chunk)" echo "" -echo "=== 12. Verify webhook fired ===" +echo "=== 15. POST /model/unload ===" +UNLOAD_RESP=$(curl -sf -X POST "$BASE/model/unload") +echo "$UNLOAD_RESP" +sleep 2 +UNLOAD_STATE=$(curl -sf "$BASE/model/status" | python3 -c "import sys,json; print(json.load(sys.stdin)['state'])") +[ "$UNLOAD_STATE" = "unloaded" ] && ok "model unloaded → state=unloaded" \ + || echo " [INFO] state after unload: $UNLOAD_STATE" + +echo "" +echo "=== 16. Verify webhook fired ===" sleep 3 kill $WEBHOOK_PID 2>/dev/null || true -ok "all tests complete" +ok "webhook server stopped" +echo "" +if [ $FAILS -eq 0 ]; then + echo -e "${GREEN}=== ALL TESTS PASSED ===${NC}" +else + echo -e "${RED}=== $FAILS TEST(S) FAILED ===${NC}" + exit 1 +fi diff --git a/tests/test_idle_timeout.sh b/tests/test_idle_timeout.sh new file mode 100755 index 0000000..f4394d0 --- /dev/null +++ b/tests/test_idle_timeout.sh @@ -0,0 +1,246 @@ +#!/usr/bin/env bash +# tests/test_idle_timeout.sh +# +# Integration tests for the idle-timeout auto-unload feature. +# REQUIRES the server to be started with a short idle timeout: +# +# IDLE_TIMEOUT_SECS=5 ./whisper-server +# # or via Docker: +# docker run -e IDLE_TIMEOUT_SECS=5 ... +# +# The default idle timeout is 5 minutes; these tests use a 5-second window +# to keep the suite fast. + +set -euo pipefail + +BASE="${WHISPER_BASE_URL:-http://localhost:8080}" +IDLE_TIMEOUT="${EXPECTED_IDLE_TIMEOUT_SECS:-5}" +AUDIO="${TEST_AUDIO:-}" + +GREEN='\033[0;32m'; RED='\033[0;31m'; YELLOW='\033[0;33m'; NC='\033[0m' + +PASS=0; FAIL=0 + +ok() { echo -e "${GREEN}[PASS]${NC} $1"; PASS=$((PASS+1)); } +fail() { echo -e "${RED}[FAIL]${NC} $1"; FAIL=$((FAIL+1)); } +skip() { echo -e "${YELLOW}[SKIP]${NC} $1"; } +info() { echo " $1"; } + +echo "=== Idle Timeout Tests ===" +echo " BASE: $BASE" +echo " IDLE_TIMEOUT_SECS: $IDLE_TIMEOUT (must be configured on the server)" +echo "" +echo "NOTE: These tests require the server to be running with IDLE_TIMEOUT_SECS=$IDLE_TIMEOUT" +echo "" + +# ── Helpers ────────────────────────────────────────────────────────────────── + +get_state() { + curl -sf "$BASE/model/status" | python3 -c "import sys,json; print(json.load(sys.stdin)['state'])" +} + +ensure_ready() { + local state + state=$(get_state) + if [ "$state" = "ready" ]; then return 0; fi + curl -sf -X POST "$BASE/model/load" > /dev/null + local elapsed=0 + while true; do + sleep 3; elapsed=$((elapsed+3)) + state=$(get_state) + [ "$state" = "ready" ] && return 0 + [ $elapsed -gt 180 ] && return 1 + done +} + +ensure_unloaded() { + curl -sf -X POST "$BASE/model/unload" > /dev/null || true + sleep 2 +} + +# ── TEST 1: Load model, complete a job, then wait for idle unload ───────────── +echo "--- Test 1: Idle timeout triggers auto-unload ---" + +ensure_unloaded +ensure_ready || { fail "T1: model load failed"; } + +WAIT_SECS=$((IDLE_TIMEOUT + 3)) +info "Model is ready. Waiting $WAIT_SECS seconds (idle timeout=$IDLE_TIMEOUT + 3s buffer)..." +sleep $WAIT_SECS + +STATE=$(get_state) +if [ "$STATE" = "unloaded" ]; then + ok "T1: model auto-unloaded after ${IDLE_TIMEOUT}s idle" +else + fail "T1: expected unloaded after idle timeout, got $STATE" + info "Is the server running with IDLE_TIMEOUT_SECS=$IDLE_TIMEOUT?" +fi + +# ── TEST 2: model_unloaded webhook fires on idle timeout ───────────────────── +echo "" +echo "--- Test 2: model_unloaded webhook fires on idle timeout ---" + +ensure_unloaded + +# Start webhook receiver +python3 - <<'PYEOF' & +import http.server, json, sys, signal + +class H(http.server.BaseHTTPRequestHandler): + def do_POST(self): + n = int(self.headers.get('Content-Length', 0)) + body = json.loads(self.rfile.read(n)) + with open('/tmp/idle_wh_event.json', 'w') as f: + json.dump(body, f) + self.send_response(200); self.end_headers() + def log_message(self, *a): pass + +signal.signal(signal.SIGTERM, lambda *_: sys.exit(0)) +http.server.HTTPServer(('', 9995), H).serve_forever() +PYEOF +WH_PID=$! +sleep 1 + +# Register webhook via a job submission (will 503 since unloaded) +curl -sf -X POST "$BASE/jobs" \ + -F "audio=@/dev/urandom;type=audio/wav" \ + -F "webhook_url=http://localhost:9995/wh" \ + --max-time 5 > /dev/null 2>&1 || true + +# Load model +ensure_ready || { fail "T2: model load failed"; kill $WH_PID 2>/dev/null; } + +# Wait for idle timeout +WAIT_SECS=$((IDLE_TIMEOUT + 5)) +info "Waiting ${WAIT_SECS}s for idle timeout..." +sleep $WAIT_SECS + +kill $WH_PID 2>/dev/null || true +wait $WH_PID 2>/dev/null || true + +if [ -f /tmp/idle_wh_event.json ]; then + EVENT_TYPE=$(python3 -c "import json; print(json.load(open('/tmp/idle_wh_event.json')).get('type','?'))") + rm -f /tmp/idle_wh_event.json + [ "$EVENT_TYPE" = "model_unloaded" ] && ok "T2: model_unloaded webhook fired on idle timeout" \ + || fail "T2: webhook type=$EVENT_TYPE (expected model_unloaded)" +else + fail "T2: no webhook received within timeout" +fi + +# ── TEST 3: Job submission after idle timeout → 503 → triggers reload ───────── +echo "" +echo "--- Test 3: Job triggers reload after idle unload ---" + +ensure_unloaded +ensure_ready || { fail "T3: initial load failed"; } + +# Wait for auto-unload +WAIT_SECS=$((IDLE_TIMEOUT + 3)) +info "Waiting ${WAIT_SECS}s for idle unload..." +sleep $WAIT_SECS + +STATE=$(get_state) +[ "$STATE" = "unloaded" ] || info "Note: state=$STATE (expected unloaded)" + +# Submit job → 503, triggers reload +HTTP=$(curl -s -o /tmp/t3_body.json -w "%{http_code}" -X POST "$BASE/jobs" \ + -F "audio=@/dev/urandom;type=audio/wav" \ + --max-time 5 2>/dev/null || echo "000") + +if [ "$HTTP" = "503" ]; then + ok "T3a: POST /jobs → 503 after idle unload" +else + skip "T3a: POST /jobs returned $HTTP (model may have reloaded)" +fi + +# State should be loading or ready (reload triggered by job submission) +sleep 2 +STATE=$(get_state) +if [ "$STATE" = "loading" ] || [ "$STATE" = "ready" ]; then + ok "T3b: reload triggered by job submission ($STATE)" +else + fail "T3b: expected loading/ready, got $STATE" +fi + +rm -f /tmp/t3_body.json + +# ── TEST 4: Idle timer resets per job (wait 60% of timeout → still ready) ───── +echo "" +echo "--- Test 4: Idle timer resets with each completed job ---" + +ensure_unloaded +ensure_ready || { fail "T4: model load failed"; } + +HALF_WAIT=$((IDLE_TIMEOUT - 1)) +info "Waiting ${HALF_WAIT}s (less than idle timeout)..." +sleep $HALF_WAIT + +STATE=$(get_state) +if [ "$STATE" = "ready" ]; then + ok "T4a: model still ready after ${HALF_WAIT}s (less than ${IDLE_TIMEOUT}s timeout)" +else + fail "T4a: model unexpectedly $STATE after only ${HALF_WAIT}s" +fi + +# Wait for full unload +REMAINING=$((IDLE_TIMEOUT - HALF_WAIT + 3)) +info "Waiting another ${REMAINING}s for full idle unload..." +sleep $REMAINING +STATE=$(get_state) +[ "$STATE" = "unloaded" ] && ok "T4b: model unloaded after total > ${IDLE_TIMEOUT}s idle" \ + || fail "T4b: expected unloaded, got $STATE" + +# ── TEST 5: Job resets idle timer ───────────────────────────────────────────── +echo "" +echo "--- Test 5: Completing a job resets the idle timer ---" + +if [ -z "$AUDIO" ]; then + skip "T5: TEST_AUDIO not set — skipping timer-reset test" +else + ensure_unloaded + ensure_ready || { fail "T5: model load failed"; } + + # Submit a job + SUBMIT=$(curl -sf -X POST "$BASE/jobs" \ + -F "audio=@${AUDIO};type=audio/wav" \ + -F "task=transcribe" 2>&1) + JOB_ID=$(echo "$SUBMIT" | python3 -c "import sys,json; print(json.load(sys.stdin)['job_id'])" 2>/dev/null || echo "") + + if [ -z "$JOB_ID" ]; then + fail "T5: job submission failed" + else + # Wait for job to finish + elapsed=0 + while true; do + sleep 5; elapsed=$((elapsed+5)) + STATUS=$(curl -sf "$BASE/jobs/$JOB_ID" | python3 -c "import sys,json; print(json.load(sys.stdin)['status'])") + [ "$STATUS" = "done" ] || [ "$STATUS" = "failed" ] && break + [ $elapsed -gt 300 ] && break + done + info "Job finished in ${elapsed}s with status=$STATUS" + + # Now wait IDLE_TIMEOUT - 2 seconds — should still be ready + SAFE_WAIT=$((IDLE_TIMEOUT - 2)) + [ $SAFE_WAIT -lt 1 ] && SAFE_WAIT=1 + info "Waiting ${SAFE_WAIT}s after job completion (less than idle timeout)..." + sleep $SAFE_WAIT + STATE=$(get_state) + [ "$STATE" = "ready" ] && ok "T5a: model still ready ${SAFE_WAIT}s after job completion" \ + || fail "T5a: model unexpectedly $STATE after job" + + # Wait for idle timeout + REMAINING=$((IDLE_TIMEOUT - SAFE_WAIT + 3)) + info "Waiting ${REMAINING}s more for idle unload..." + sleep $REMAINING + STATE=$(get_state) + [ "$STATE" = "unloaded" ] && ok "T5b: model auto-unloaded after idle period post-job" \ + || fail "T5b: expected unloaded, got $STATE" + fi +fi + +# ── Summary ──────────────────────────────────────────────────────────────────── +echo "" +echo "==========================================" +echo " Results: ${PASS} passed, ${FAIL} failed" +echo "==========================================" +[ $FAIL -eq 0 ] && echo -e "${GREEN}ALL PASSED${NC}" || { echo -e "${RED}FAILURES: $FAIL${NC}"; exit 1; } diff --git a/tests/test_model_lifecycle.sh b/tests/test_model_lifecycle.sh new file mode 100755 index 0000000..c0f2982 --- /dev/null +++ b/tests/test_model_lifecycle.sh @@ -0,0 +1,470 @@ +#!/usr/bin/env bash +# tests/test_model_lifecycle.sh +# +# Integration tests for dynamic model loading/unloading. +# Requires a running whisper-server with GPU access. +# +# Usage: +# WHISPER_BASE_URL=http://localhost:8080 bash tests/test_model_lifecycle.sh +# +# Tests are designed to be independent; each section that needs a specific +# state resets it explicitly at the start. + +set -euo pipefail + +BASE="${WHISPER_BASE_URL:-http://localhost:8080}" +AUDIO="${TEST_AUDIO:-}" + +GREEN='\033[0;32m'; RED='\033[0;31m'; YELLOW='\033[0;33m'; NC='\033[0m' + +PASS=0; FAIL=0 + +ok() { echo -e "${GREEN}[PASS]${NC} $1"; PASS=$((PASS+1)); } +fail() { echo -e "${RED}[FAIL]${NC} $1"; FAIL=$((FAIL+1)); } +skip() { echo -e "${YELLOW}[SKIP]${NC} $1"; } +info() { echo " $1"; } + +echo "=== Model Lifecycle Integration Tests ===" +echo " BASE: $BASE" +echo "" + +# ── Helpers ────────────────────────────────────────────────────────────────── + +get_state() { + curl -sf "$BASE/model/status" | python3 -c "import sys,json; print(json.load(sys.stdin)['state'])" +} + +ensure_unloaded() { + curl -sf -X POST "$BASE/model/unload" > /dev/null + sleep 2 + local s + s=$(get_state) + if [ "$s" != "unloaded" ]; then + echo " WARNING: expected unloaded, got $s — waiting 5s" + sleep 5 + fi +} + +ensure_ready() { + local state + state=$(get_state) + if [ "$state" = "ready" ]; then return 0; fi + curl -sf -X POST "$BASE/model/load" > /dev/null + local elapsed=0 + while true; do + sleep 5; elapsed=$((elapsed+5)) + state=$(get_state) + [ "$state" = "ready" ] && return 0 + [ $elapsed -gt 180 ] && echo " TIMEOUT: model did not become ready" && return 1 + done +} + +poll_state_transition() { + local target="$1" max_secs="${2:-120}" + local elapsed=0 + while true; do + sleep 2; elapsed=$((elapsed+2)) + local s + s=$(get_state) + [ "$s" = "$target" ] && return 0 + [ $elapsed -ge $max_secs ] && return 1 + done +} + +# ── TEST 1: Startup state is unloaded ──────────────────────────────────────── +echo "--- Test 1: Startup state is unloaded (or after explicit unload) ---" +ensure_unloaded +STATE=$(get_state) +if [ "$STATE" = "unloaded" ]; then + ok "T1: state=unloaded after explicit unload" +else + fail "T1: expected unloaded, got $STATE" +fi + +# ── TEST 2: POST /model/load returns 202 ───────────────────────────────────── +echo "" +echo "--- Test 2: POST /model/load returns 202 ---" +ensure_unloaded +HTTP=$(curl -s -o /dev/null -w "%{http_code}" -X POST "$BASE/model/load") +if [ "$HTTP" = "202" ]; then + ok "T2: POST /model/load → 202 Accepted" +else + fail "T2: expected 202, got $HTTP" +fi +# Cancel the in-progress load to clean up +curl -sf -X POST "$BASE/model/unload" > /dev/null || true +sleep 2 + +# ── TEST 3: State transitions to loading/ready after load trigger ───────────── +echo "" +echo "--- Test 3: State transitions to loading (not stuck at unloaded) ---" +ensure_unloaded +curl -sf -X POST "$BASE/model/load" > /dev/null +sleep 1 +STATE=$(get_state) +if [ "$STATE" = "loading" ] || [ "$STATE" = "ready" ]; then + ok "T3: state transitioned to $STATE (not stuck at unloaded)" +else + fail "T3: expected loading or ready, got $STATE" +fi + +# ── TEST 4: Model reaches ready state and loaded_at is set ─────────────────── +echo "" +echo "--- Test 4: Model reaches ready state with loaded_at timestamp ---" +# Already loading from T3 — wait for ready +if ! poll_state_transition "ready" 180; then + fail "T4: model did not become ready within 3 minutes" +else + STATUS_JSON=$(curl -sf "$BASE/model/status") + LOADED_AT=$(echo "$STATUS_JSON" | python3 -c "import sys,json; d=json.load(sys.stdin); print(d.get('loaded_at','MISSING'))" 2>/dev/null || echo "MISSING") + if [ "$LOADED_AT" != "MISSING" ] && [ "$LOADED_AT" != "null" ] && [ -n "$LOADED_AT" ]; then + ok "T4: model=ready, loaded_at=$LOADED_AT" + else + fail "T4: model ready but loaded_at is missing or null" + fi +fi + +# ── TEST 5: Idempotent load — POST /model/load when ready returns 200 ───────── +echo "" +echo "--- Test 5: POST /model/load when already ready → 200 ---" +ensure_ready || { fail "T5: could not load model"; } +HTTP=$(curl -s -o /dev/null -w "%{http_code}" -X POST "$BASE/model/load") +STATE=$(get_state) +if [ "$HTTP" = "200" ] && [ "$STATE" = "ready" ]; then + ok "T5: idempotent load → 200, state stays ready" +elif [ "$HTTP" = "202" ] && [ "$STATE" = "ready" ]; then + ok "T5: idempotent load → 202, state stays ready" +else + fail "T5: expected 200 and ready, got HTTP=$HTTP state=$STATE" +fi + +# ── TEST 6: Job accepted when ready (segments > 0) ──────────────────────────── +echo "" +echo "--- Test 6: Job accepted when model is ready ---" +if [ -z "$AUDIO" ]; then + skip "T6: TEST_AUDIO not set — skipping job submission test" +else + ensure_ready || { fail "T6: model load failed"; } + SUBMIT=$(curl -sf -X POST "$BASE/jobs" -F "audio=@${AUDIO};type=audio/wav" -F "task=transcribe" 2>&1) + JOB_ID=$(echo "$SUBMIT" | python3 -c "import sys,json; print(json.load(sys.stdin)['job_id'])" 2>/dev/null || echo "") + if [ -n "$JOB_ID" ]; then + ok "T6: job accepted, id=$JOB_ID" + # Poll to done + elapsed=0 + while true; do + sleep 10; elapsed=$((elapsed+10)) + STATUS=$(curl -sf "$BASE/jobs/$JOB_ID" | python3 -c "import sys,json; print(json.load(sys.stdin)['status'])") + [ "$STATUS" = "done" ] && break + [ "$STATUS" = "failed" ] && break + [ $elapsed -gt 600 ] && break + done + SEGS=$(curl -sf "$BASE/jobs/$JOB_ID" | python3 -c "import sys,json; d=json.load(sys.stdin); print(len(d.get('segments',[])))") + [ "$SEGS" -gt 0 ] && ok "T6b: job done with $SEGS segments" || fail "T6b: job done but 0 segments" + else + fail "T6: job submission failed: $SUBMIT" + fi +fi + +# ── TEST 7: POST /model/unload → state=unloaded ─────────────────────────────── +echo "" +echo "--- Test 7: POST /model/unload ---" +ensure_ready || { fail "T7: model load failed"; } +curl -sf -X POST "$BASE/model/unload" > /dev/null +sleep 3 +STATE=$(get_state) +if [ "$STATE" = "unloaded" ]; then + ok "T7: POST /model/unload → state=unloaded" +else + fail "T7: expected unloaded after unload, got $STATE" +fi + +# ── TEST 8: POST /jobs when unloaded → 503 + Retry-After ───────────────────── +echo "" +echo "--- Test 8: POST /jobs when unloaded → 503 + Retry-After ---" +ensure_unloaded +# Submit a tiny dummy payload (won't be valid audio but that's ok for this test) +HTTP=$(curl -s -o /tmp/t8_body.json -w "%{http_code}" -X POST "$BASE/jobs" \ + -F "audio=@/dev/urandom;type=audio/wav" \ + --max-time 5 2>/dev/null || echo "000") +# If the model auto-loads it might start processing; check for 503 first +if [ "$HTTP" = "503" ]; then + RETRY_AFTER=$(curl -sI -X POST "$BASE/jobs" \ + -F "audio=@/dev/urandom;type=audio/wav" \ + --max-time 5 2>/dev/null | grep -i "retry-after" | awk '{print $2}' | tr -d '\r' || echo "") + BODY=$(cat /tmp/t8_body.json 2>/dev/null || echo "{}") + HAS_STATE=$(echo "$BODY" | python3 -c "import sys,json; d=json.load(sys.stdin); print('state' in d)" 2>/dev/null || echo "False") + HAS_RETRY=$(echo "$BODY" | python3 -c "import sys,json; d=json.load(sys.stdin); print('retry_after_secs' in d)" 2>/dev/null || echo "False") + if [ "$HAS_STATE" = "True" ] && [ "$HAS_RETRY" = "True" ]; then + ok "T8: 503 with state + retry_after_secs in body" + else + fail "T8: 503 but body missing state/retry_after_secs. body=$BODY" + fi + if [ -n "$RETRY_AFTER" ]; then + ok "T8b: Retry-After header present: $RETRY_AFTER" + else + fail "T8b: Retry-After header missing from 503 response" + fi +else + skip "T8: got HTTP $HTTP (model may have loaded before check) — skipping" +fi + +# ── TEST 9: Rejected job triggers load ──────────────────────────────────────── +echo "" +echo "--- Test 9: Job rejection triggers model load ---" +ensure_unloaded +# Send a job (we expect 503) +curl -sf -X POST "$BASE/jobs" \ + -F "audio=@/dev/urandom;type=audio/wav" \ + --max-time 5 > /dev/null 2>&1 || true +sleep 2 +STATE=$(get_state) +if [ "$STATE" = "loading" ] || [ "$STATE" = "ready" ]; then + ok "T9: model started loading after job rejection ($STATE)" +else + fail "T9: expected loading/ready after job rejection, got $STATE" +fi +# Stop the load to clean up +curl -sf -X POST "$BASE/model/unload" > /dev/null || true +sleep 2 + +# ── TEST 10: Retry-After values ─────────────────────────────────────────────── +echo "" +echo "--- Test 10: Retry-After values match state ---" +ensure_unloaded +# Unloaded → Retry-After: 30 +RESP_UNLOADED=$(curl -si -X POST "$BASE/jobs" -F "audio=@/dev/urandom;type=audio/wav" --max-time 5 2>/dev/null || echo "") +RA_UNLOADED=$(echo "$RESP_UNLOADED" | grep -i "retry-after" | awk '{print $2}' | tr -d '\r' || echo "") +[ "$RA_UNLOADED" = "30" ] && ok "T10a: Retry-After=30 when unloaded" \ + || skip "T10a: Retry-After=$RA_UNLOADED (expected 30) — model may have started loading" + +# ── TEST 11: Retry-After=10 during loading ──────────────────────────────────── +echo "" +echo "--- Test 11: Retry-After=10 when loading ---" +ensure_unloaded +curl -sf -X POST "$BASE/model/load" > /dev/null +sleep 1 # In loading state +STATE=$(get_state) +if [ "$STATE" = "loading" ]; then + RESP_LOADING=$(curl -si -X POST "$BASE/jobs" -F "audio=@/dev/urandom;type=audio/wav" --max-time 5 2>/dev/null || echo "") + RA_LOADING=$(echo "$RESP_LOADING" | grep -i "retry-after" | awk '{print $2}' | tr -d '\r' || echo "") + [ "$RA_LOADING" = "10" ] && ok "T11: Retry-After=10 when loading" \ + || fail "T11: expected Retry-After=10, got '$RA_LOADING' (state=$STATE)" +else + skip "T11: model already $STATE — can't test loading state Retry-After" +fi + +# ── TEST 12: 503 body schema validation ────────────────────────────────────── +echo "" +echo "--- Test 12: 503 body schema validation ---" +ensure_unloaded +BODY=$(curl -sf -X POST "$BASE/jobs" -F "audio=@/dev/urandom;type=audio/wav" --max-time 5 2>/dev/null || echo "{}") +python3 - < "$SSE_LOG" & +SSE_PID=$! +sleep 1 + +# Trigger load +curl -sf -X POST "$BASE/model/load" > /dev/null +poll_state_transition "ready" 180 || true + +sleep 2 +kill $SSE_PID 2>/dev/null || true +wait $SSE_PID 2>/dev/null || true + +if grep -q "model_loading" "$SSE_LOG" 2>/dev/null; then + ok "T14a: SSE received model_loading event" +else + fail "T14a: SSE did not receive model_loading event" +fi + +if grep -q "model_ready" "$SSE_LOG" 2>/dev/null; then + ok "T14b: SSE received model_ready event" +else + fail "T14b: SSE did not receive model_ready event" +fi + +# Now unload to get model_unloaded event +curl -sf -X POST "$BASE/model/unload" > /dev/null +sleep 1 + +SSE_LOG2=$(mktemp /tmp/sse_events_XXXXXX.txt) +curl -sN --max-time 10 "$BASE/model/events" > "$SSE_LOG2" & +SSE_PID2=$! +sleep 2 +kill $SSE_PID2 2>/dev/null || true +wait $SSE_PID2 2>/dev/null || true + +# model_unloaded fires immediately on unload command +if grep -q "model_unloaded" "$SSE_LOG" 2>/dev/null || grep -q "model_unloaded" "$SSE_LOG2" 2>/dev/null; then + ok "T14c: SSE received model_unloaded event" +else + fail "T14c: SSE did not receive model_unloaded event" +fi + +rm -f "$SSE_LOG" "$SSE_LOG2" + +# ── TEST 15: model_ready webhook fires after load ────────────────────────────── +echo "" +echo "--- Test 15: model_ready webhook ---" +ensure_unloaded + +# Start webhook receiver +WEBHOOK_LOG=$(mktemp /tmp/webhook_log_XXXXXX.txt) +python3 - <<'PYEOF' & +import http.server, json, sys, signal, os + +class H(http.server.BaseHTTPRequestHandler): + def do_POST(self): + n = int(self.headers.get('Content-Length', 0)) + body = json.loads(self.rfile.read(n)) + with open('/tmp/t15_webhook.json', 'w') as f: + json.dump(body, f) + self.send_response(200); self.end_headers() + def log_message(self, *a): pass + +signal.signal(signal.SIGTERM, lambda *_: sys.exit(0)) +http.server.HTTPServer(('', 9998), H).serve_forever() +PYEOF +WBOOK_PID=$! +sleep 1 + +# Register a webhook via a (doomed) job submission — this registers the URL +# even though the model is unloaded (and the job will 503) +curl -sf -X POST "$BASE/jobs" \ + -F "audio=@/dev/urandom;type=audio/wav" \ + -F "webhook_url=http://localhost:9998/wh" \ + --max-time 5 > /dev/null 2>&1 || true + +# Now load the model +curl -sf -X POST "$BASE/model/load" > /dev/null +poll_state_transition "ready" 180 || true +sleep 3 + +kill $WBOOK_PID 2>/dev/null || true +wait $WBOOK_PID 2>/dev/null || true + +if [ -f /tmp/t15_webhook.json ]; then + EVENT_TYPE=$(python3 -c "import json; d=json.load(open('/tmp/t15_webhook.json')); print(d.get('type','?'))") + [ "$EVENT_TYPE" = "model_ready" ] && ok "T15: model_ready webhook fired" \ + || fail "T15: webhook fired but type=$EVENT_TYPE (expected model_ready)" + rm -f /tmp/t15_webhook.json +else + fail "T15: model_ready webhook not received within timeout" +fi + +# ── TEST 16: model_unloaded webhook fires ───────────────────────────────────── +echo "" +echo "--- Test 16: model_unloaded webhook ---" + +python3 - <<'PYEOF' & +import http.server, json, sys, signal + +class H(http.server.BaseHTTPRequestHandler): + def do_POST(self): + n = int(self.headers.get('Content-Length', 0)) + body = json.loads(self.rfile.read(n)) + with open('/tmp/t16_webhook.json', 'w') as f: + json.dump(body, f) + self.send_response(200); self.end_headers() + def log_message(self, *a): pass + +signal.signal(signal.SIGTERM, lambda *_: sys.exit(0)) +http.server.HTTPServer(('', 9997), H).serve_forever() +PYEOF +WBOOK2_PID=$! +sleep 1 + +# Register webhook URL +curl -sf -X POST "$BASE/jobs" \ + -F "audio=@/dev/urandom;type=audio/wav" \ + -F "webhook_url=http://localhost:9997/wh" \ + --max-time 5 > /dev/null 2>&1 || true + +ensure_ready + +# Unload +curl -sf -X POST "$BASE/model/unload" > /dev/null +sleep 5 + +kill $WBOOK2_PID 2>/dev/null || true +wait $WBOOK2_PID 2>/dev/null || true + +if [ -f /tmp/t16_webhook.json ]; then + EVENT_TYPE=$(python3 -c "import json; d=json.load(open('/tmp/t16_webhook.json')); print(d.get('type','?'))") + [ "$EVENT_TYPE" = "model_unloaded" ] && ok "T16: model_unloaded webhook fired" \ + || fail "T16: webhook type=$EVENT_TYPE (expected model_unloaded)" + rm -f /tmp/t16_webhook.json +else + fail "T16: model_unloaded webhook not received" +fi + +# ── TEST 17: Concurrent load requests — single load, stable ready ───────────── +echo "" +echo "--- Test 17: Concurrent POST /model/load requests ---" +ensure_unloaded +# Send 3 concurrent load requests +curl -sf -X POST "$BASE/model/load" > /dev/null & +curl -sf -X POST "$BASE/model/load" > /dev/null & +curl -sf -X POST "$BASE/model/load" > /dev/null & +wait +poll_state_transition "ready" 180 || true +STATE=$(get_state) +[ "$STATE" = "ready" ] && ok "T17: concurrent loads handled cleanly, state=ready" \ + || fail "T17: expected ready after concurrent loads, got $STATE" + +# ── TEST 18: POST /model/unload during loading → clean unloaded ─────────────── +echo "" +echo "--- Test 18: POST /model/unload during loading ---" +ensure_unloaded +curl -sf -X POST "$BASE/model/load" > /dev/null +sleep 1 # Hopefully still in loading state +curl -sf -X POST "$BASE/model/unload" > /dev/null +# Allow time for the unload to propagate +sleep 5 +STATE=$(get_state) +if [ "$STATE" = "unloaded" ]; then + ok "T18: unload during loading → clean unloaded" +elif [ "$STATE" = "ready" ]; then + # Load completed before unload arrived — immediately unload + curl -sf -X POST "$BASE/model/unload" > /dev/null + sleep 3 + STATE=$(get_state) + [ "$STATE" = "unloaded" ] && ok "T18: load completed then unloaded (race condition OK)" \ + || fail "T18: state=$STATE after load+unload" +else + fail "T18: unexpected state after unload-during-load: $STATE" +fi + +# ── Summary ──────────────────────────────────────────────────────────────────── +echo "" +echo "==========================================" +echo " Results: ${PASS} passed, ${FAIL} failed" +echo "==========================================" +[ $FAIL -eq 0 ] && echo -e "${GREEN}ALL PASSED${NC}" || { echo -e "${RED}FAILURES: $FAIL${NC}"; exit 1; }