feat: dynamic model loading/unloading with GPU polling
All checks were successful
Build & Push Docker Image / build-and-push (push) Successful in 8m41s

- Model starts unloaded (lazy); loads on first job or POST /model/load
- Auto-unloads after IDLE_TIMEOUT_SECS (default 300) of inactivity
- POST /model/unload for immediate manual release
- GPU-busy detection: on VRAM OOM, enters WaitingForGpu and retries
  every GPU_POLL_INTERVAL_SECS (default 30) indefinitely
- POST /jobs when unloaded → 503 + Retry-After header, triggers load
- AppError::OutOfMemory and AppError::ModelNotReady variants
- WorkerCmd channel (SyncSender<WorkerCmd>) replaces bare tx_req channel
- Idle timer via recv_timeout(1s) tick inside OS thread (no extra thread)
- Model lifecycle events broadcast via tokio broadcast channel (SSE + webhooks)
- webhook_registry: all clients that ever submitted a webhook_url receive
  model_ready and model_unloaded webhooks
- GPU warmup retained on every (re)load

New routes:
  GET  /model/status  — current state + VRAM stats
  POST /model/load    — trigger load (idempotent)
  POST /model/unload  — immediate unload
  GET  /model/events  — SSE stream of model lifecycle events

New env vars:
  IDLE_TIMEOUT_SECS       (default 300)
  GPU_POLL_INTERVAL_SECS  (default 30)

Tests:
  tests/test_model_lifecycle.sh — 18 integration tests (full state machine,
    SSE events, webhooks, concurrency, unload-during-load)
  tests/test_idle_timeout.sh    — 5 tests with short IDLE_TIMEOUT_SECS=5
  test_all.sh updated: loads model before job submission, asserts
    model_state in /health, adds POST /model/unload at end

Docs:
  docs/USAGE.md: model lifecycle section, new env vars, 503 retry pattern,
    updated /health response shape

Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
This commit is contained in:
mozempk
2026-05-08 17:57:20 +02:00
parent 78c6fab81b
commit b191fbe200
13 changed files with 2053 additions and 148 deletions

View File

@@ -66,6 +66,8 @@ The bundled `docker-compose.yml` mounts named volumes for data and models and se
| `WHISPER_MODEL_PATH` | `/models/ggml-large-v3.bin` | Absolute path to GGML model file | | `WHISPER_MODEL_PATH` | `/models/ggml-large-v3.bin` | Absolute path to GGML model file |
| `WHISPER_MODEL` | `large-v3` | Model name reported by `/health` (display only) | | `WHISPER_MODEL` | `large-v3` | Model name reported by `/health` (display only) |
| `CUDA_DEVICE` | `0` | CUDA device index to use for inference | | `CUDA_DEVICE` | `0` | CUDA device index to use for inference |
| `IDLE_TIMEOUT_SECS` | `300` | Seconds of idle before the model is automatically unloaded from GPU memory. Set to `0` to disable auto-unload. |
| `GPU_POLL_INTERVAL_SECS` | `30` | Seconds between VRAM-availability retries when a load fails due to insufficient VRAM. |
### Note on CUDA device ordering ### Note on CUDA device ordering
Inside Docker, device ordering matches `nvidia-smi` (PCI bus order). On the host without Docker, ordering may differ. See [FINDINGS.md](FINDINGS.md#cuda-device-index-ordering-differs-between-host-and-docker) for details. Inside Docker, device ordering matches `nvidia-smi` (PCI bus order). On the host without Docker, ordering may differ. See [FINDINGS.md](FINDINGS.md#cuda-device-index-ordering-differs-between-host-and-docker) for details.
@@ -76,6 +78,194 @@ Inside Docker, device ordering matches `nvidia-smi` (PCI bus order). On the host
The interactive Swagger UI is available at `http://localhost:8080/docs`. The interactive Swagger UI is available at `http://localhost:8080/docs`.
---
## Model Lifecycle Management
The model starts **unloaded** on startup (lazy loading). It is loaded into GPU memory on the first job submission or via `POST /model/load`, and automatically unloaded after `IDLE_TIMEOUT_SECS` of inactivity.
### Model State Machine
```
Unloaded ──(job / POST /model/load)──► Loading ──(success)──► Ready
└──(VRAM full)──► WaitingForGpu ──(retry OK)──► Loading
Ready ──(idle timeout / POST /model/unload)──► Unloaded
WaitingForGpu ──(POST /model/unload)──► Unloaded
```
### `GET /model/status`
Returns the current model state and VRAM statistics.
```bash
curl http://localhost:8080/model/status
```
**When unloaded:**
```json
{ "state": "unloaded" }
```
**When loading:**
```json
{ "state": "loading" }
```
**When ready:**
```json
{
"state": "ready",
"loaded_at": "2026-05-10T14:00:00Z",
"vram_used_mb": 4096,
"vram_total_mb": 8192
}
```
**When waiting for VRAM:**
```json
{
"state": "waiting_for_gpu",
"vram_needed_mb": 3951,
"vram_free_mb": 512,
"retry_in_secs": 30
}
```
---
### `POST /model/load`
Request the model to be loaded. Idempotent — if already loading or ready, returns immediately.
```bash
curl -X POST http://localhost:8080/model/load
```
- Returns `202 Accepted` with `{"status":"load_initiated"}` when load is triggered
- Returns `200 OK` with `{"status":"already_ready"}` when model is already ready
- Poll `GET /model/status` or subscribe to `GET /model/events` to know when ready
---
### `POST /model/unload`
Unload the model from GPU memory immediately, freeing VRAM.
```bash
curl -X POST http://localhost:8080/model/unload
```
Returns `200 OK` regardless of current state.
---
### `GET /model/events` — Model SSE stream
Subscribe to model lifecycle events via Server-Sent Events.
```bash
curl -N http://localhost:8080/model/events
```
**Event types:**
```
event: model_loading
data: {"type":"model_loading"}
event: model_ready
data: {"type":"model_ready","loaded_at":"2026-05-10T14:00:00Z"}
event: model_unloaded
data: {"type":"model_unloaded"}
event: model_waiting_for_gpu
data: {"type":"model_waiting_for_gpu","vram_needed_mb":3951,"vram_free_mb":512,"retry_in_secs":30}
```
**JavaScript example:**
```javascript
const es = new EventSource('/model/events');
es.addEventListener('model_ready', () => {
console.log('Model loaded — ready to transcribe');
});
es.addEventListener('model_unloaded', () => {
console.log('Model freed GPU memory');
});
```
---
### Webhooks for model events
When any job is submitted with a `webhook_url`, that URL is registered to receive model lifecycle webhooks for the lifetime of the server process. The following events trigger a webhook POST:
| Event | Fired when |
|-------|-----------|
| `model_ready` | Model finishes loading (after GPU warmup) |
| `model_unloaded` | Model is freed from GPU memory |
**Webhook payload** (`Content-Type: application/json`):
```json
{ "type": "model_ready", "loaded_at": "2026-05-10T14:00:00Z" }
{ "type": "model_unloaded" }
```
Delivery is attempted up to 3 times with exponential backoff (1s, 2s).
---
### Handling 503 Model Not Ready
When you submit a job and the model is not yet loaded, you receive `503 Service Unavailable` with a `Retry-After` header:
```
HTTP/1.1 503 Service Unavailable
Retry-After: 30
Content-Type: application/json
{
"error": "model_not_ready",
"state": "unloaded",
"retry_after_secs": 30
}
```
| State at rejection | `retry_after_secs` | Meaning |
|---|---|---|
| `unloaded` | 30 | Load was triggered; retry after ~30s |
| `loading` | 10 | Check again in 10s |
| `waiting_for_gpu` | `GPU_POLL_INTERVAL_SECS` | VRAM contention; retry later |
A job rejection when the model is `unloaded` **automatically triggers a load** — you do not need to call `POST /model/load` separately.
**Recommended client pattern:**
```javascript
async function submitWithRetry(formData, maxAttempts = 10) {
for (let i = 0; i < maxAttempts; i++) {
const resp = await fetch('/jobs', { method: 'POST', body: formData });
if (resp.ok) return resp.json();
if (resp.status === 503) {
const retryAfter = parseInt(resp.headers.get('Retry-After') ?? '30');
const body = await resp.json();
console.log(`Model ${body.state} — retrying in ${retryAfter}s`);
await new Promise(r => setTimeout(r, retryAfter * 1000));
continue;
}
throw new Error(`Submit failed: ${resp.status}`);
}
throw new Error('Gave up after max attempts');
}
```
---
## API Reference
The interactive Swagger UI is available at `http://localhost:8080/docs`.
### `POST /jobs` — Submit a transcription job ### `POST /jobs` — Submit a transcription job
Accepts a multipart/form-data body. Accepts a multipart/form-data body.
@@ -249,11 +439,12 @@ curl http://localhost:8080/health
"gpu_name": "NVIDIA GeForce RTX 2080", "gpu_name": "NVIDIA GeForce RTX 2080",
"vram_total_mb": 8192, "vram_total_mb": 8192,
"model": "large-v3", "model": "large-v3",
"queue_depth": 0 "queue_depth": 0,
"model_state": "ready"
} }
``` ```
`queue_depth` is the number of jobs waiting to be processed (not counting the one currently running). `queue_depth` is the number of jobs waiting to be processed (not counting the one currently running). `model_state` reflects the current lifecycle state (`unloaded`, `loading`, `waiting_for_gpu`, `ready`).
--- ---
@@ -340,6 +531,11 @@ curl -X POST http://localhost:8080/jobs \
## Troubleshooting ## Troubleshooting
### Server returns `503 model_not_ready`
- The model starts unloaded. Call `POST /model/load` explicitly, or just retry the job submission — rejection automatically triggers a load.
- If state is `waiting_for_gpu`, another process is using the GPU's VRAM. The server will retry automatically every `GPU_POLL_INTERVAL_SECS` seconds.
- Monitor `GET /model/status` or subscribe to `GET /model/events` to know when the model is ready.
### Server returns 0 segments ### Server returns 0 segments
- Check that you are **not** setting `language` to an empty string — omit the field entirely for auto-detection - Check that you are **not** setting `language` to an empty string — omit the field entirely for auto-detection
- Verify the audio file is not corrupted: `ffprobe audio.mp3` - Verify the audio file is not corrupted: `ffprobe audio.mp3`

View File

@@ -1,6 +1,6 @@
use thiserror::Error; use thiserror::Error;
use axum::{ use axum::{
http::StatusCode, http::{StatusCode, HeaderValue, header},
response::{IntoResponse, Response}, response::{IntoResponse, Response},
Json, Json,
}; };
@@ -21,19 +21,138 @@ pub enum AppError {
#[error("internal error: {0}")] #[error("internal error: {0}")]
Internal(String), Internal(String),
/// Returned when `whisper_init_state` or `cudaMalloc` fails due to
/// insufficient VRAM. The worker uses this to distinguish a recoverable
/// VRAM-pressure failure from a hard internal error.
#[error("out of GPU memory: {0}")]
OutOfMemory(String),
/// Returned when a job is submitted but the model is not yet loaded.
/// Carries the current state tag and recommended Retry-After seconds.
#[error("model not ready: {state}")]
ModelNotReady { state: String, retry_after_secs: u64 },
}
impl AppError {
/// Returns true if the error string contains patterns emitted by
/// whisper.cpp / GGML when a CUDA memory allocation fails.
pub fn is_oom(msg: &str) -> bool {
msg.contains("cudaMalloc failed")
|| msg.contains("out of memory")
|| msg.contains("CUDA error: out of memory")
|| msg.contains("alloc_buffer")
}
} }
impl IntoResponse for AppError { impl IntoResponse for AppError {
fn into_response(self) -> Response { fn into_response(self) -> Response {
let (status, message) = match &self { match self {
AppError::NotFound(m) => (StatusCode::NOT_FOUND, m.clone()), AppError::NotFound(m) => {
AppError::BadRequest(m) => (StatusCode::BAD_REQUEST, m.clone()), (StatusCode::NOT_FOUND, Json(json!({ "error": m }))).into_response()
AppError::Conflict(m) => (StatusCode::CONFLICT, m.clone()), }
AppError::Internal(m) => (StatusCode::INTERNAL_SERVER_ERROR, m.clone()), AppError::BadRequest(m) => {
}; (StatusCode::BAD_REQUEST, Json(json!({ "error": m }))).into_response()
}
tracing::error!(status = status.as_u16(), error = %message); AppError::Conflict(m) => {
(StatusCode::CONFLICT, Json(json!({ "error": m }))).into_response()
(status, Json(json!({ "error": message }))).into_response() }
AppError::Internal(m) => {
tracing::error!(error = %m, "internal error");
(StatusCode::INTERNAL_SERVER_ERROR, Json(json!({ "error": m }))).into_response()
}
AppError::OutOfMemory(m) => {
tracing::warn!(error = %m, "GPU out of memory during model load");
(StatusCode::SERVICE_UNAVAILABLE, Json(json!({ "error": m }))).into_response()
}
AppError::ModelNotReady { state, retry_after_secs } => {
let body = Json(json!({
"error": "model_not_ready",
"state": state,
"retry_after_secs": retry_after_secs,
}));
let mut resp = (StatusCode::SERVICE_UNAVAILABLE, body).into_response();
resp.headers_mut().insert(
header::RETRY_AFTER,
HeaderValue::from_str(&retry_after_secs.to_string())
.unwrap_or(HeaderValue::from_static("30")),
);
resp
}
}
}
}
// ── Unit tests ───────────────────────────────────────────────────────────────
#[cfg(test)]
mod tests {
use super::*;
use axum::body::to_bytes;
#[test]
fn test_is_oom_cuda_malloc() {
assert!(AppError::is_oom("cudaMalloc failed: out of memory"));
}
#[test]
fn test_is_oom_alloc_buffer() {
// Exact message from ggml_backend_cuda_buffer_type_alloc_buffer
assert!(AppError::is_oom(
"ggml_backend_cuda_buffer_type_alloc_buffer: allocating 2951.01 MiB on device 0: cudaMalloc failed: out of memory"
));
}
#[test]
fn test_is_oom_generic_out_of_memory() {
assert!(AppError::is_oom("CUDA error: out of memory"));
}
#[test]
fn test_is_oom_other_error() {
assert!(!AppError::is_oom("failed to open model file"));
assert!(!AppError::is_oom("invalid model format"));
assert!(!AppError::is_oom(""));
}
#[tokio::test]
async fn test_model_not_ready_response_has_retry_after_header() {
let err = AppError::ModelNotReady { state: "loading".into(), retry_after_secs: 10 };
let resp = err.into_response();
assert_eq!(resp.status(), StatusCode::SERVICE_UNAVAILABLE);
let retry_after = resp.headers().get(header::RETRY_AFTER)
.expect("Retry-After header missing");
assert_eq!(retry_after, "10");
}
#[tokio::test]
async fn test_model_not_ready_response_body() {
let err = AppError::ModelNotReady { state: "unloaded".into(), retry_after_secs: 30 };
let resp = err.into_response();
let bytes = to_bytes(resp.into_body(), usize::MAX).await.unwrap();
let v: serde_json::Value = serde_json::from_slice(&bytes).unwrap();
assert_eq!(v["error"], "model_not_ready");
assert_eq!(v["state"], "unloaded");
assert_eq!(v["retry_after_secs"], 30);
}
#[tokio::test]
async fn test_model_not_ready_loading_retry_after_10() {
let err = AppError::ModelNotReady { state: "loading".into(), retry_after_secs: 10 };
let resp = err.into_response();
assert_eq!(
resp.headers().get(header::RETRY_AFTER).unwrap(),
"10"
);
}
#[tokio::test]
async fn test_model_not_ready_unloaded_retry_after_30() {
let err = AppError::ModelNotReady { state: "unloaded".into(), retry_after_secs: 30 };
let resp = err.into_response();
assert_eq!(
resp.headers().get(header::RETRY_AFTER).unwrap(),
"30"
);
} }
} }

View File

@@ -1,7 +1,7 @@
use std::sync::Arc; use std::sync::Arc;
use axum::Router; use axum::Router;
use tokio::sync::mpsc; use tokio::sync::{broadcast, mpsc, RwLock};
use tower_http::{cors::CorsLayer, trace::TraceLayer}; use tower_http::{cors::CorsLayer, trace::TraceLayer};
use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt, EnvFilter}; use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt, EnvFilter};
use utoipa::OpenApi; use utoipa::OpenApi;
@@ -21,8 +21,10 @@ pub use error::{AppError, Result};
#[derive(Clone)] #[derive(Clone)]
pub struct AppState { pub struct AppState {
/// Channel to submit jobs to the single GPU worker. /// Channel to submit jobs to the single GPU worker (job IDs only).
pub job_tx: mpsc::UnboundedSender<models::JobId>, pub job_tx: mpsc::UnboundedSender<models::JobId>,
/// Channel to send control commands to the worker OS thread.
pub cmd_tx: std::sync::mpsc::SyncSender<worker::WorkerCmd>,
/// Shared handle to the on-disk job store. /// Shared handle to the on-disk job store.
pub storage: Arc<storage::Storage>, pub storage: Arc<storage::Storage>,
/// SSE broadcast registry: job_id → sender. /// SSE broadcast registry: job_id → sender.
@@ -33,6 +35,17 @@ pub struct AppState {
pub queue_depth: Arc<std::sync::atomic::AtomicUsize>, pub queue_depth: Arc<std::sync::atomic::AtomicUsize>,
/// CUDA device index used for inference. /// CUDA device index used for inference.
pub gpu_device: u32, pub gpu_device: u32,
/// Current state of the whisper model.
pub model_state: Arc<RwLock<models::ModelState>>,
/// Broadcast channel for model lifecycle events (SSE + webhooks).
pub model_event_tx: broadcast::Sender<models::ModelEvent>,
/// All webhook URLs ever registered via job submission.
/// Used to fire model_ready / model_unloaded notifications.
pub webhook_registry: Arc<std::sync::Mutex<std::collections::HashSet<String>>>,
/// How long the model stays loaded with no active jobs.
pub idle_timeout: std::time::Duration,
/// How often to retry loading when GPU is busy.
pub gpu_poll_interval: std::time::Duration,
} }
// ── OpenAPI spec root ──────────────────────────────────────────────────────── // ── OpenAPI spec root ────────────────────────────────────────────────────────
@@ -50,6 +63,10 @@ pub struct AppState {
routes::jobs::stream_job, routes::jobs::stream_job,
routes::jobs::delete_job, routes::jobs::delete_job,
routes::health::health, routes::health::health,
routes::model::model_status,
routes::model::model_load,
routes::model::model_unload,
routes::model::model_events,
), ),
components(schemas( components(schemas(
models::Job, models::Job,
@@ -58,10 +75,14 @@ pub struct AppState {
models::Word, models::Word,
models::SubmitResponse, models::SubmitResponse,
models::HealthResponse, models::HealthResponse,
models::ModelState,
models::ModelEvent,
models::ModelStatusResponse,
)), )),
tags( tags(
(name = "jobs", description = "Transcription job management"), (name = "jobs", description = "Transcription job management"),
(name = "system", description = "Service health"), (name = "system", description = "Service health"),
(name = "model", description = "Model lifecycle management"),
) )
)] )]
struct ApiDoc; struct ApiDoc;
@@ -85,6 +106,20 @@ async fn main() -> anyhow::Result<()> {
.ok() .ok()
.and_then(|s| s.parse().ok()) .and_then(|s| s.parse().ok())
.unwrap_or(0); .unwrap_or(0);
let idle_timeout_secs: u64 = std::env::var("IDLE_TIMEOUT_SECS")
.ok()
.and_then(|s| s.parse().ok())
.unwrap_or(300);
let gpu_poll_interval_secs: u64 = std::env::var("GPU_POLL_INTERVAL_SECS")
.ok()
.and_then(|s| s.parse().ok())
.unwrap_or(30);
tracing::info!(
idle_timeout_secs,
gpu_poll_interval_secs,
"dynamic model loading configured"
);
let storage = Arc::new(storage::Storage::new(&data_dir).await?); let storage = Arc::new(storage::Storage::new(&data_dir).await?);
@@ -94,28 +129,45 @@ async fn main() -> anyhow::Result<()> {
let (job_tx, job_rx) = mpsc::unbounded_channel::<models::JobId>(); let (job_tx, job_rx) = mpsc::unbounded_channel::<models::JobId>();
let queue_depth = Arc::new(std::sync::atomic::AtomicUsize::new(0)); let queue_depth = Arc::new(std::sync::atomic::AtomicUsize::new(0));
// Spawn single GPU worker; get back the SSE broadcast registry. // Model starts unloaded — lazy load on first job or POST /model/load.
let progress = worker::start( let model_state = Arc::new(RwLock::new(models::ModelState::Unloaded));
let (model_event_tx, _) = broadcast::channel::<models::ModelEvent>(32);
let webhook_registry = Arc::new(std::sync::Mutex::new(std::collections::HashSet::<String>::new()));
// Spawn single GPU worker; get back the SSE broadcast registry and cmd channel.
let (progress, cmd_tx) = worker::start(
job_rx, job_rx,
Arc::clone(&storage), Arc::clone(&storage),
model_path.clone().into(), model_path.clone().into(),
Arc::clone(&queue_depth), Arc::clone(&queue_depth),
gpu_device, gpu_device,
Arc::clone(&model_state),
model_event_tx.clone(),
Arc::clone(&webhook_registry),
std::time::Duration::from_secs(idle_timeout_secs),
std::time::Duration::from_secs(gpu_poll_interval_secs),
); );
let state = AppState { let state = AppState {
job_tx, job_tx,
cmd_tx,
storage: Arc::clone(&storage), storage: Arc::clone(&storage),
progress, progress,
model_name: model_name.as_str().into(), model_name: model_name.as_str().into(),
queue_depth: Arc::clone(&queue_depth), queue_depth: Arc::clone(&queue_depth),
gpu_device, gpu_device,
model_state,
model_event_tx,
webhook_registry,
idle_timeout: std::time::Duration::from_secs(idle_timeout_secs),
gpu_poll_interval: std::time::Duration::from_secs(gpu_poll_interval_secs),
}; };
let app = Router::new() let app = Router::new()
.merge(SwaggerUi::new("/docs").url("/openapi.json", ApiDoc::openapi())) .merge(SwaggerUi::new("/docs").url("/openapi.json", ApiDoc::openapi()))
.merge(routes::jobs_router()) .merge(routes::jobs_router())
.merge(routes::health_router()) .merge(routes::health_router())
.merge(routes::model_router())
.with_state(state) .with_state(state)
.layer(CorsLayer::permissive()) .layer(CorsLayer::permissive())
.layer(TraceLayer::new_for_http()); .layer(TraceLayer::new_for_http());

View File

@@ -5,6 +5,116 @@ use uuid::Uuid;
pub type JobId = Uuid; pub type JobId = Uuid;
// ── Model lifecycle state ────────────────────────────────────────────────────
/// Current state of the whisper model in memory.
///
/// State machine:
/// ```
/// Unloaded ──(load trigger)──► Loading ──(ok)──► Ready ──(idle/unload)──► Unloaded
/// └──(VRAM full)──► WaitingForGpu ──(retry)──► Loading
/// WaitingForGpu ──(unload cmd)──► Unloaded
/// ```
#[derive(Debug, Clone, Serialize, Deserialize, ToSchema)]
#[serde(tag = "state", rename_all = "snake_case")]
pub enum ModelState {
/// Model is not in memory. GPU is free.
Unloaded,
/// Model is being loaded (weights transferred to GPU).
Loading,
/// A previous load attempt failed due to insufficient VRAM. The worker is
/// polling at `retry_in_secs` intervals until enough memory is available.
WaitingForGpu {
/// VRAM required to load the model, in MiB.
vram_needed_mb: u64,
/// VRAM currently free on the device, in MiB.
vram_free_mb: u64,
/// How many seconds until the next load attempt.
retry_in_secs: u64,
},
/// Model is loaded and ready to accept inference jobs.
Ready {
/// UTC timestamp of when the model finished loading (post-warmup).
loaded_at: DateTime<Utc>,
},
}
impl ModelState {
/// Returns true if the model can accept inference jobs right now.
pub fn is_ready(&self) -> bool {
matches!(self, ModelState::Ready { .. })
}
/// Suggested `Retry-After` value (seconds) to include in 503 responses.
pub fn retry_after_secs(&self) -> u64 {
match self {
ModelState::Unloaded => 30, // conservative load estimate
ModelState::Loading => 10,
ModelState::WaitingForGpu { retry_in_secs, .. } => *retry_in_secs,
ModelState::Ready { .. } => 0, // shouldn't 503 if ready
}
}
/// String tag for use in error response bodies and log fields.
pub fn tag(&self) -> &'static str {
match self {
ModelState::Unloaded => "unloaded",
ModelState::Loading => "loading",
ModelState::WaitingForGpu{..} => "waiting_for_gpu",
ModelState::Ready{..} => "ready",
}
}
}
// ── Model events (SSE + webhooks) ────────────────────────────────────────────
/// Events broadcast over the `GET /model/events` SSE stream and fired as
/// webhooks to registered clients.
///
/// Webhook delivery: only `ModelReady` and `ModelUnloaded` are sent to
/// webhook URLs. All four are broadcast on the SSE stream.
#[derive(Debug, Clone, Serialize, Deserialize, ToSchema)]
#[serde(tag = "type", rename_all = "snake_case")]
pub enum ModelEvent {
/// Model finished loading and the GPU warmup completed — ready to accept jobs.
ModelReady {
loaded_at: DateTime<Utc>,
},
/// Model was unloaded from GPU memory (idle timeout or manual unload).
ModelUnloaded,
/// Model load initiated.
ModelLoading,
/// Load failed due to insufficient VRAM; retrying after `retry_in_secs`.
ModelWaitingForGpu {
vram_needed_mb: u64,
vram_free_mb: u64,
retry_in_secs: u64,
},
}
impl ModelEvent {
/// Returns true if this event should be delivered via webhook.
pub fn is_webhook_event(&self) -> bool {
matches!(self, ModelEvent::ModelReady { .. } | ModelEvent::ModelUnloaded)
}
}
// ── Model status response ────────────────────────────────────────────────────
/// Response body for `GET /model/status`.
#[derive(Debug, Serialize, Deserialize, ToSchema)]
pub struct ModelStatusResponse {
/// Current model state (flattened from `ModelState`).
#[serde(flatten)]
pub state: ModelState,
/// VRAM currently used on the device, in MiB (from nvidia-smi).
#[serde(skip_serializing_if = "Option::is_none")]
pub vram_used_mb: Option<u64>,
/// VRAM total on the device, in MiB.
#[serde(skip_serializing_if = "Option::is_none")]
pub vram_total_mb: Option<u64>,
}
// ── Job status ─────────────────────────────────────────────────────────────── // ── Job status ───────────────────────────────────────────────────────────────
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize, ToSchema)] #[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize, ToSchema)]
@@ -130,6 +240,8 @@ pub struct HealthResponse {
pub vram_total_mb: Option<u64>, pub vram_total_mb: Option<u64>,
pub model: String, pub model: String,
pub queue_depth: usize, pub queue_depth: usize,
/// Current state of the whisper model.
pub model_state: String,
} }
// ── SSE event payload ──────────────────────────────────────────────────────── // ── SSE event payload ────────────────────────────────────────────────────────
@@ -148,3 +260,137 @@ pub enum SsePayload {
Done { job: Box<Job> }, Done { job: Box<Job> },
Error { message: String }, Error { message: String },
} }
// ── Unit tests ───────────────────────────────────────────────────────────────
#[cfg(test)]
mod tests {
use super::*;
use serde_json::Value;
// ── ModelState serialization ─────────────────────────────────────────────
#[test]
fn test_model_state_unloaded_serializes() {
let v: Value = serde_json::to_value(ModelState::Unloaded).unwrap();
assert_eq!(v["state"], "unloaded");
}
#[test]
fn test_model_state_loading_serializes() {
let v: Value = serde_json::to_value(ModelState::Loading).unwrap();
assert_eq!(v["state"], "loading");
}
#[test]
fn test_model_state_waiting_serializes() {
let s = ModelState::WaitingForGpu { vram_needed_mb: 3000, vram_free_mb: 500, retry_in_secs: 30 };
let v: Value = serde_json::to_value(&s).unwrap();
assert_eq!(v["state"], "waiting_for_gpu");
assert_eq!(v["vram_needed_mb"], 3000);
assert_eq!(v["vram_free_mb"], 500);
assert_eq!(v["retry_in_secs"], 30);
}
#[test]
fn test_model_state_ready_serializes() {
let ts = Utc::now();
let s = ModelState::Ready { loaded_at: ts };
let v: Value = serde_json::to_value(&s).unwrap();
assert_eq!(v["state"], "ready");
assert!(v["loaded_at"].is_string());
}
#[test]
fn test_model_state_is_ready() {
assert!(!ModelState::Unloaded.is_ready());
assert!(!ModelState::Loading.is_ready());
assert!(!ModelState::WaitingForGpu { vram_needed_mb: 0, vram_free_mb: 0, retry_in_secs: 30 }.is_ready());
assert!(ModelState::Ready { loaded_at: Utc::now() }.is_ready());
}
#[test]
fn test_retry_after_unloaded() {
assert_eq!(ModelState::Unloaded.retry_after_secs(), 30);
}
#[test]
fn test_retry_after_loading() {
assert_eq!(ModelState::Loading.retry_after_secs(), 10);
}
#[test]
fn test_retry_after_waiting_for_gpu() {
let s = ModelState::WaitingForGpu { vram_needed_mb: 0, vram_free_mb: 0, retry_in_secs: 45 };
assert_eq!(s.retry_after_secs(), 45);
}
#[test]
fn test_retry_after_ready_is_zero() {
assert_eq!(ModelState::Ready { loaded_at: Utc::now() }.retry_after_secs(), 0);
}
// ── ModelEvent serialization ─────────────────────────────────────────────
#[test]
fn test_model_event_ready_serializes() {
let ts = Utc::now();
let e = ModelEvent::ModelReady { loaded_at: ts };
let v: Value = serde_json::to_value(&e).unwrap();
assert_eq!(v["type"], "model_ready");
assert!(v["loaded_at"].is_string());
}
#[test]
fn test_model_event_unloaded_serializes() {
let v: Value = serde_json::to_value(ModelEvent::ModelUnloaded).unwrap();
assert_eq!(v["type"], "model_unloaded");
}
#[test]
fn test_model_event_loading_serializes() {
let v: Value = serde_json::to_value(ModelEvent::ModelLoading).unwrap();
assert_eq!(v["type"], "model_loading");
}
#[test]
fn test_model_event_waiting_serializes() {
let e = ModelEvent::ModelWaitingForGpu { vram_needed_mb: 3000, vram_free_mb: 200, retry_in_secs: 30 };
let v: Value = serde_json::to_value(&e).unwrap();
assert_eq!(v["type"], "model_waiting_for_gpu");
assert_eq!(v["vram_needed_mb"], 3000);
}
#[test]
fn test_model_event_webhook_filter() {
assert!(ModelEvent::ModelReady { loaded_at: Utc::now() }.is_webhook_event());
assert!(ModelEvent::ModelUnloaded.is_webhook_event());
assert!(!ModelEvent::ModelLoading.is_webhook_event());
assert!(!ModelEvent::ModelWaitingForGpu { vram_needed_mb: 0, vram_free_mb: 0, retry_in_secs: 30 }.is_webhook_event());
}
// ── ModelStatusResponse ──────────────────────────────────────────────────
#[test]
fn test_model_status_response_roundtrip() {
let r = ModelStatusResponse {
state: ModelState::Ready { loaded_at: Utc::now() },
vram_used_mb: Some(4096),
vram_total_mb: Some(8192),
};
let json_str = serde_json::to_string(&r).unwrap();
let v: Value = serde_json::from_str(&json_str).unwrap();
assert_eq!(v["state"], "ready");
assert_eq!(v["vram_used_mb"], 4096);
assert_eq!(v["vram_total_mb"], 8192);
}
#[test]
fn test_model_status_response_omits_nulls() {
let r = ModelStatusResponse { state: ModelState::Loading, vram_used_mb: None, vram_total_mb: None };
let v: Value = serde_json::to_value(&r).unwrap();
assert_eq!(v["state"], "loading");
assert!(v.get("vram_used_mb").is_none());
assert!(v.get("vram_total_mb").is_none());
}
}

View File

@@ -16,6 +16,7 @@ use crate::{models::HealthResponse, AppState, Result};
)] )]
pub async fn health(State(state): State<AppState>) -> Result<Json<HealthResponse>> { pub async fn health(State(state): State<AppState>) -> Result<Json<HealthResponse>> {
let (gpu_name, vram_total_mb) = gpu_info(state.gpu_device); let (gpu_name, vram_total_mb) = gpu_info(state.gpu_device);
let model_state_tag = state.model_state.read().await.tag().to_string();
Ok(Json(HealthResponse { Ok(Json(HealthResponse {
status: "ok".into(), status: "ok".into(),
@@ -23,6 +24,7 @@ pub async fn health(State(state): State<AppState>) -> Result<Json<HealthResponse
vram_total_mb, vram_total_mb,
model: state.model_name.to_string(), model: state.model_name.to_string(),
queue_depth: state.queue_depth.load(Ordering::Relaxed), queue_depth: state.queue_depth.load(Ordering::Relaxed),
model_state: model_state_tag,
})) }))
} }

View File

@@ -19,7 +19,7 @@ use uuid::Uuid;
use crate::{ use crate::{
models::{Job, JobId, JobStatus, SubmitResponse}, models::{Job, JobId, JobStatus, SubmitResponse},
worker::{audio_path_for, ProgressEvent}, worker::{audio_path_for, ProgressEvent, WorkerCmd},
AppError, AppState, Result, AppError, AppState, Result,
}; };
@@ -107,6 +107,36 @@ pub async fn submit_job(
)); ));
} }
// Check model state before accepting the job.
let (model_ready, retry_after_secs, state_tag) = {
let ms = state.model_state.read().await;
let ready = ms.is_ready();
let retry = ms.retry_after_secs();
let tag = ms.tag().to_string();
(ready, retry, tag)
};
// Register the webhook URL regardless of model state — so model lifecycle
// events are delivered even if the job itself is rejected.
if let Some(url) = &webhook_url {
state.webhook_registry.lock()
.unwrap_or_else(|e| e.into_inner())
.insert(url.clone());
}
if !model_ready {
// Trigger a load if the model is simply unloaded (not already loading).
if state_tag == "unloaded" {
let _ = state.cmd_tx.try_send(WorkerCmd::Load);
}
// Clean up the audio file we already wrote to disk.
let _ = tokio::fs::remove_file(&audio_path).await;
return Err(AppError::ModelNotReady {
state: state_tag,
retry_after_secs,
});
}
let mut job = Job::new(id, task, webhook_url, filename); let mut job = Job::new(id, task, webhook_url, filename);
job.language = language; job.language = language;

View File

@@ -1,5 +1,6 @@
pub mod health; pub mod health;
pub mod jobs; pub mod jobs;
pub mod model;
use axum::{extract::DefaultBodyLimit, routing::{delete, get, post}, Router}; use axum::{extract::DefaultBodyLimit, routing::{delete, get, post}, Router};
use crate::AppState; use crate::AppState;
@@ -17,3 +18,11 @@ pub fn health_router() -> Router<AppState> {
Router::new() Router::new()
.route("/health", get(health::health)) .route("/health", get(health::health))
} }
pub fn model_router() -> Router<AppState> {
Router::new()
.route("/model/status", get(model::model_status))
.route("/model/load", post(model::model_load))
.route("/model/unload", post(model::model_unload))
.route("/model/events", get(model::model_events))
}

158
src/routes/model.rs Normal file
View File

@@ -0,0 +1,158 @@
use std::pin::Pin;
use axum::{
extract::State,
http::StatusCode,
response::{
sse::{Event, KeepAlive, Sse},
IntoResponse,
},
Json,
};
use futures::Stream;
use tokio_stream::wrappers::BroadcastStream;
use futures::StreamExt;
use crate::{
models::{ModelEvent, ModelStatusResponse},
worker::WorkerCmd,
AppState, Result,
};
type SseStream = Pin<Box<dyn Stream<Item = std::result::Result<Event, std::convert::Infallible>> + Send>>;
// ── GET /model/status ────────────────────────────────────────────────────────
/// Return the current model state and VRAM statistics.
#[utoipa::path(
get,
path = "/model/status",
tag = "model",
responses(
(status = 200, description = "Model status", body = ModelStatusResponse),
)
)]
pub async fn model_status(State(state): State<AppState>) -> Result<Json<ModelStatusResponse>> {
let model_state = state.model_state.read().await.clone();
let (vram_used_mb, vram_total_mb) = vram_stats(state.gpu_device);
Ok(Json(ModelStatusResponse {
state: model_state,
vram_used_mb,
vram_total_mb,
}))
}
// ── POST /model/load ─────────────────────────────────────────────────────────
/// Request the model to be loaded into GPU memory.
/// Idempotent: if the model is already loading or ready, this is a no-op.
/// Returns 202 Accepted; poll `GET /model/status` or subscribe to
/// `GET /model/events` to know when it is ready.
#[utoipa::path(
post,
path = "/model/load",
tag = "model",
responses(
(status = 202, description = "Load initiated or already in progress"),
(status = 200, description = "Model already ready"),
)
)]
pub async fn model_load(State(state): State<AppState>) -> impl IntoResponse {
let is_ready = state.model_state.read().await.is_ready();
if is_ready {
return (StatusCode::OK, Json(serde_json::json!({"status": "already_ready"})));
}
// Ignore send errors (channel full = load already in progress).
let _ = state.cmd_tx.try_send(WorkerCmd::Load);
(StatusCode::ACCEPTED, Json(serde_json::json!({"status": "load_initiated"})))
}
// ── POST /model/unload ───────────────────────────────────────────────────────
/// Unload the model from GPU memory immediately.
/// Idempotent: if the model is already unloaded, returns 200 immediately.
#[utoipa::path(
post,
path = "/model/unload",
tag = "model",
responses(
(status = 200, description = "Model unloaded or was already unloaded"),
)
)]
pub async fn model_unload(State(state): State<AppState>) -> impl IntoResponse {
let _ = state.cmd_tx.try_send(WorkerCmd::Unload);
(StatusCode::OK, Json(serde_json::json!({"status": "unload_requested"})))
}
// ── GET /model/events ────────────────────────────────────────────────────────
/// Subscribe to model lifecycle events via Server-Sent Events.
///
/// Event types:
/// - `model_loading` — load initiated
/// - `model_ready` — model loaded and warmed up
/// - `model_unloaded` — model freed from GPU memory
/// - `model_waiting_for_gpu` — insufficient VRAM; retrying
#[utoipa::path(
get,
path = "/model/events",
tag = "model",
responses(
(status = 200, description = "SSE stream of model lifecycle events"),
)
)]
pub async fn model_events(State(state): State<AppState>) -> Sse<SseStream> {
let rx = state.model_event_tx.subscribe();
let stream: SseStream = Box::pin(
BroadcastStream::new(rx).filter_map(|msg| async move {
match msg {
Ok(event) => {
let event_type = match &event {
ModelEvent::ModelReady { .. } => "model_ready",
ModelEvent::ModelUnloaded => "model_unloaded",
ModelEvent::ModelLoading => "model_loading",
ModelEvent::ModelWaitingForGpu {..} => "model_waiting_for_gpu",
};
let data = serde_json::to_string(&event).ok()?;
Some(Ok(Event::default().event(event_type).data(data)))
}
Err(_) => None,
}
})
);
Sse::new(stream).keep_alive(KeepAlive::default())
}
// ── Helpers ───────────────────────────────────────────────────────────────────
fn vram_stats(gpu_device: u32) -> (Option<u64>, Option<u64>) {
fn inner(gpu_device: u32) -> Option<(u64, u64)> {
let out = std::process::Command::new("nvidia-smi")
.args([
&format!("--id={gpu_device}"),
"--query-gpu=memory.used,memory.total",
"--format=csv,noheader,nounits",
])
.output()
.ok()?;
if !out.status.success() {
return None;
}
let line = String::from_utf8_lossy(&out.stdout);
let line = line.trim();
let mut parts = line.splitn(2, ',');
let used = parts.next().and_then(|s| s.trim().parse::<u64>().ok())?;
let total = parts.next().and_then(|s| s.trim().parse::<u64>().ok())?;
Some((used, total))
}
match inner(gpu_device) {
Some((u, t)) => (Some(u), Some(t)),
None => (None, None),
}
}

View File

@@ -49,10 +49,24 @@ impl Transcriber {
// params.flash_attn(true); // params.flash_attn(true);
let ctx = WhisperContext::new_with_params(path, params) let ctx = WhisperContext::new_with_params(path, params)
.map_err(|e| AppError::Internal(format!("failed to load model: {e}")))?; .map_err(|e| {
let msg = format!("failed to load model: {e}");
if AppError::is_oom(&msg) {
AppError::OutOfMemory(msg)
} else {
AppError::Internal(msg)
}
})?;
let mut state = ctx.create_state() let mut state = ctx.create_state()
.map_err(|e| AppError::Internal(format!("failed to create whisper state: {e}")))?; .map_err(|e| {
let msg = format!("failed to create whisper state: {e}");
if AppError::is_oom(&msg) {
AppError::OutOfMemory(msg)
} else {
AppError::Internal(msg)
}
})?;
// ctx drops here; state holds Arc<WhisperInnerContext> so model stays loaded. // ctx drops here; state holds Arc<WhisperInnerContext> so model stays loaded.
// ── GPU warmup ──────────────────────────────────────────────────────── // ── GPU warmup ────────────────────────────────────────────────────────

View File

@@ -1,20 +1,23 @@
use std::{ use std::{
collections::HashSet,
path::PathBuf, path::PathBuf,
sync::{ sync::{
atomic::{AtomicUsize, Ordering}, atomic::{AtomicUsize, Ordering},
Arc, Arc, Mutex,
}, },
time::{Duration, Instant},
}; };
use chrono::Utc; use chrono::Utc;
use reqwest::Client; use reqwest::Client;
use tokio::sync::{broadcast, mpsc, oneshot}; use tokio::sync::{broadcast, mpsc, oneshot, RwLock};
use crate::{ use crate::{
models::{Job, JobId, JobStatus, Segment}, models::{Job, JobId, JobStatus, ModelEvent, ModelState, Segment},
storage::Storage, storage::Storage,
transcriber::Transcriber, transcriber::Transcriber,
webhook, webhook,
AppError,
}; };
/// Per-job broadcast channel for SSE subscribers. /// Per-job broadcast channel for SSE subscribers.
@@ -31,83 +34,383 @@ pub enum ProgressEvent {
/// Global registry: job_id → broadcast sender. /// Global registry: job_id → broadcast sender.
pub type ProgressRegistry = Arc<dashmap::DashMap<JobId, ProgressTx>>; pub type ProgressRegistry = Arc<dashmap::DashMap<JobId, ProgressTx>>;
// ── Transcription request/response types for the blocking thread ───────────── // ── Worker command channel ────────────────────────────────────────────────────
struct TranscribeRequest { /// Commands sent to the GPU worker OS thread.
pcm: Vec<f32>, #[derive(Debug)]
language: Option<String>, pub enum WorkerCmd {
task: String, /// Request a model load. Idempotent: if already loading/ready, ignored.
/// Per-chunk progress callback — receives 0100 from whisper.cpp and can Load,
/// scale/offset it before forwarding to the job's broadcast channel. /// Unload the model immediately and free GPU memory.
on_progress: Box<dyn Fn(u8) + Send + 'static>, Unload,
reply: oneshot::Sender<crate::Result<(Vec<Segment>, String)>>, /// Internal: run a transcription chunk.
Transcribe(TranscribeRequest),
} }
// ── Transcription request/response types ─────────────────────────────────────
pub struct TranscribeRequest {
pub pcm: Vec<f32>,
pub language: Option<String>,
pub task: String,
pub on_progress: Box<dyn Fn(u8) + Send + 'static>,
pub reply: oneshot::Sender<crate::Result<(Vec<Segment>, String)>>,
}
impl std::fmt::Debug for TranscribeRequest {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("TranscribeRequest")
.field("language", &self.language)
.field("task", &self.task)
.finish_non_exhaustive()
}
}
// ── Public API ────────────────────────────────────────────────────────────────
/// Spawn the single GPU worker. /// Spawn the single GPU worker.
/// Returns the SSE progress registry. ///
/// Returns the SSE progress registry and a command sender for the worker thread.
/// The model starts **unloaded**; send `WorkerCmd::Load` or submit a job to
/// trigger loading.
#[allow(clippy::too_many_arguments)]
pub fn start( pub fn start(
job_rx: mpsc::UnboundedReceiver<JobId>, job_rx: mpsc::UnboundedReceiver<JobId>,
storage: Arc<Storage>, storage: Arc<Storage>,
model_path: PathBuf, model_path: PathBuf,
queue_depth: Arc<AtomicUsize>, queue_depth: Arc<AtomicUsize>,
gpu_device: u32, gpu_device: u32,
) -> ProgressRegistry { model_state: Arc<RwLock<ModelState>>,
model_event_tx: broadcast::Sender<ModelEvent>,
webhook_registry: Arc<Mutex<HashSet<String>>>,
idle_timeout: Duration,
gpu_poll_interval: Duration,
) -> (ProgressRegistry, std::sync::mpsc::SyncSender<WorkerCmd>) {
let registry: ProgressRegistry = Arc::new(dashmap::DashMap::new()); let registry: ProgressRegistry = Arc::new(dashmap::DashMap::new());
let reg_clone = Arc::clone(&registry); let reg_clone = Arc::clone(&registry);
let (tx_req, rx_req) = std::sync::mpsc::channel::<TranscribeRequest>(); // Bounded sync channel: capacity 8 is plenty (load/unload are rare).
let (cmd_tx, cmd_rx) = std::sync::mpsc::sync_channel::<WorkerCmd>(8);
let cmd_tx_clone = cmd_tx.clone();
// Capture Tokio runtime handle so the OS thread can spawn async tasks.
let rt_handle = tokio::runtime::Handle::current();
std::thread::Builder::new() std::thread::Builder::new()
.name("whisper-gpu".into()) .name("whisper-gpu".into())
.spawn(move || transcriber_thread(rx_req, model_path, gpu_device)) .spawn(move || {
transcriber_thread(
cmd_rx,
model_path,
gpu_device,
model_state,
model_event_tx,
webhook_registry,
idle_timeout,
gpu_poll_interval,
rt_handle,
);
})
.expect("failed to spawn whisper-gpu thread"); .expect("failed to spawn whisper-gpu thread");
tokio::spawn(run(job_rx, storage, queue_depth, reg_clone, tx_req)); tokio::spawn(run(job_rx, storage, queue_depth, reg_clone, cmd_tx_clone));
registry (registry, cmd_tx)
} }
/// Dedicated OS thread that owns the Transcriber (non-Send) and runs inference. // ── GPU OS thread ─────────────────────────────────────────────────────────────
///
/// The Transcriber holds a single `WhisperState` that is reused for every chunk.
/// GPU compute buffers (~700 MB) are allocated once at startup rather than on
/// every call, eliminating per-chunk `whisper_init_state` overhead and the
/// VRAM churn that caused intermittent 0-segment results.
fn transcriber_thread(
rx: std::sync::mpsc::Receiver<TranscribeRequest>,
model_path: PathBuf,
gpu_device: u32,
) {
let mut transcriber = match Transcriber::load(&model_path, gpu_device) {
Ok(t) => t,
Err(e) => {
tracing::error!(error = %e, "failed to load whisper model — transcriber thread exiting");
return;
}
};
tracing::info!(model = %model_path.display(), "GPU worker ready");
for req in rx { /// The worker OS thread that owns the `Transcriber` (non-`Send`).
let on_progress = req.on_progress; ///
let result = transcriber.transcribe( /// Uses `recv_timeout` with a 1-second tick to drive the idle timer without a
&req.pcm, /// separate thread.
req.language.as_deref(), #[allow(clippy::too_many_arguments)]
&req.task, fn transcriber_thread(
move |p| on_progress(p), rx: std::sync::mpsc::Receiver<WorkerCmd>,
); model_path: PathBuf,
let _ = req.reply.send(result); gpu_device: u32,
model_state: Arc<RwLock<ModelState>>,
model_event_tx: broadcast::Sender<ModelEvent>,
webhook_registry: Arc<Mutex<HashSet<String>>>,
idle_timeout: Duration,
gpu_poll_interval: Duration,
rt: tokio::runtime::Handle,
) {
let mut transcriber: Option<Transcriber> = None;
let mut last_job = Instant::now();
loop {
match rx.recv_timeout(Duration::from_secs(1)) {
Ok(WorkerCmd::Load) => {
if transcriber.is_some() {
tracing::debug!("WorkerCmd::Load ignored — model already loaded");
continue;
}
transcriber = try_load_with_polling(
&rx,
&model_path,
gpu_device,
&model_state,
&model_event_tx,
&webhook_registry,
gpu_poll_interval,
&rt,
);
if transcriber.is_some() {
last_job = Instant::now();
}
}
Ok(WorkerCmd::Unload) => {
do_unload(&mut transcriber, &model_state, &model_event_tx, &webhook_registry, &rt);
}
Ok(WorkerCmd::Transcribe(req)) => {
let t = match &mut transcriber {
Some(t) => t,
None => {
tracing::warn!("Transcribe cmd received but model is unloaded — failing job");
let _ = req.reply.send(Err(AppError::Internal(
"model unloaded before job could run".into(),
)));
continue;
}
};
let result = t.transcribe(
&req.pcm,
req.language.as_deref(),
&req.task,
move |p| (req.on_progress)(p),
);
last_job = Instant::now();
let _ = req.reply.send(result);
}
Err(std::sync::mpsc::RecvTimeoutError::Timeout) => {
if transcriber.is_some() && last_job.elapsed() >= idle_timeout {
tracing::info!(
elapsed_secs = last_job.elapsed().as_secs(),
"idle timeout reached — unloading model"
);
do_unload(
&mut transcriber,
&model_state,
&model_event_tx,
&webhook_registry,
&rt,
);
}
}
Err(std::sync::mpsc::RecvTimeoutError::Disconnected) => {
tracing::info!("worker command channel closed — shutting down GPU thread");
break;
}
}
} }
} }
/// Attempt to load the model, polling on VRAM failures.
///
/// While waiting for GPU, drains `rx` so that `WorkerCmd::Unload` cancels the
/// load attempt and `WorkerCmd::Transcribe` commands get a "model not ready"
/// rejection. Returns `Some(Transcriber)` on success, `None` if cancelled.
#[allow(clippy::too_many_arguments)]
fn try_load_with_polling(
rx: &std::sync::mpsc::Receiver<WorkerCmd>,
model_path: &PathBuf,
gpu_device: u32,
model_state: &Arc<RwLock<ModelState>>,
model_event_tx: &broadcast::Sender<ModelEvent>,
webhook_registry: &Arc<Mutex<HashSet<String>>>,
gpu_poll_interval: Duration,
rt: &tokio::runtime::Handle,
) -> Option<Transcriber> {
loop {
set_state(model_state, ModelState::Loading);
broadcast_event(model_event_tx, ModelEvent::ModelLoading);
tracing::info!("loading whisper model...");
match Transcriber::load(model_path, gpu_device) {
Ok(t) => {
let loaded_at = Utc::now();
set_state(model_state, ModelState::Ready { loaded_at });
broadcast_event(model_event_tx, ModelEvent::ModelReady { loaded_at });
fire_webhooks(webhook_registry, ModelEvent::ModelReady { loaded_at }, rt);
tracing::info!("model loaded and ready");
return Some(t);
}
Err(AppError::OutOfMemory(msg)) => {
let (vram_needed_mb, vram_free_mb) = parse_oom_vram(&msg, gpu_device);
let retry_in_secs = gpu_poll_interval.as_secs();
tracing::warn!(
vram_needed_mb,
vram_free_mb,
retry_in_secs,
"insufficient VRAM — will retry"
);
set_state(model_state, ModelState::WaitingForGpu {
vram_needed_mb,
vram_free_mb,
retry_in_secs,
});
broadcast_event(model_event_tx, ModelEvent::ModelWaitingForGpu {
vram_needed_mb,
vram_free_mb,
retry_in_secs,
});
// Interruptible sleep: drain rx while waiting for gpu_poll_interval.
let deadline = Instant::now() + gpu_poll_interval;
loop {
let remaining = deadline.saturating_duration_since(Instant::now());
if remaining.is_zero() { break; }
match rx.recv_timeout(remaining.min(Duration::from_secs(1))) {
Ok(WorkerCmd::Unload) => {
tracing::info!("Unload received while waiting for GPU — cancelling load");
set_state(model_state, ModelState::Unloaded);
broadcast_event(model_event_tx, ModelEvent::ModelUnloaded);
fire_webhooks(webhook_registry, ModelEvent::ModelUnloaded, rt);
return None;
}
Ok(WorkerCmd::Load) => {} // idempotent
Ok(WorkerCmd::Transcribe(req)) => {
let _ = req.reply.send(Err(AppError::ModelNotReady {
state: "waiting_for_gpu".into(),
retry_after_secs: retry_in_secs,
}));
}
Err(std::sync::mpsc::RecvTimeoutError::Timeout) => {}
Err(std::sync::mpsc::RecvTimeoutError::Disconnected) => return None,
}
}
// Loop back to retry load
}
Err(e) => {
tracing::error!(error = %e, "model load failed with non-recoverable error");
set_state(model_state, ModelState::Unloaded);
return None;
}
}
}
}
fn do_unload(
transcriber: &mut Option<Transcriber>,
model_state: &Arc<RwLock<ModelState>>,
model_event_tx: &broadcast::Sender<ModelEvent>,
webhook_registry: &Arc<Mutex<HashSet<String>>>,
rt: &tokio::runtime::Handle,
) {
*transcriber = None;
set_state(model_state, ModelState::Unloaded);
broadcast_event(model_event_tx, ModelEvent::ModelUnloaded);
fire_webhooks(webhook_registry, ModelEvent::ModelUnloaded, rt);
tracing::info!("model unloaded — GPU memory freed");
}
// ── Helpers ───────────────────────────────────────────────────────────────────
fn set_state(arc: &Arc<RwLock<ModelState>>, state: ModelState) {
*arc.blocking_write() = state;
}
fn broadcast_event(tx: &broadcast::Sender<ModelEvent>, event: ModelEvent) {
let _ = tx.send(event);
}
fn fire_webhooks(
registry: &Arc<Mutex<HashSet<String>>>,
event: ModelEvent,
rt: &tokio::runtime::Handle,
) {
if !event.is_webhook_event() {
return;
}
let urls: Vec<String> = registry
.lock()
.unwrap_or_else(|e| e.into_inner())
.iter()
.cloned()
.collect();
if urls.is_empty() { return; }
let payload = match serde_json::to_string(&event) {
Ok(p) => p,
Err(e) => { tracing::error!(error = %e, "failed to serialize model event"); return; }
};
for url in urls {
let body = payload.clone();
rt.spawn(async move {
let http = Client::builder()
.timeout(Duration::from_secs(10))
.build()
.expect("http client");
for attempt in 0..3_u32 {
match http.post(&url)
.header("content-type", "application/json")
.body(body.clone())
.send()
.await
{
Ok(r) if r.status().is_success() => {
tracing::debug!(url, "model event webhook delivered");
return;
}
Ok(r) => tracing::warn!(url, status = r.status().as_u16(), "webhook non-2xx"),
Err(e) => tracing::warn!(url, error = %e, attempt, "webhook delivery failed"),
}
if attempt < 2 {
tokio::time::sleep(Duration::from_secs(2u64.pow(attempt))).await;
}
}
tracing::error!(url, "model event webhook failed after 3 attempts");
});
}
}
fn parse_oom_vram(msg: &str, gpu_device: u32) -> (u64, u64) {
let needed = msg
.split_whitespace()
.zip(msg.split_whitespace().skip(1))
.find(|(_, next)| *next == "MiB")
.and_then(|(n, _)| n.parse::<f64>().ok())
.map(|v| v as u64)
.unwrap_or(0);
let free = std::process::Command::new("nvidia-smi")
.args([
&format!("--id={gpu_device}"),
"--query-gpu=memory.free",
"--format=csv,noheader,nounits",
])
.output()
.ok()
.and_then(|o| String::from_utf8(o.stdout).ok())
.and_then(|s| s.trim().parse::<u64>().ok())
.unwrap_or(0);
(needed, free)
}
// ── Async job runner ──────────────────────────────────────────────────────────
async fn run( async fn run(
mut job_rx: mpsc::UnboundedReceiver<JobId>, mut job_rx: mpsc::UnboundedReceiver<JobId>,
storage: Arc<Storage>, storage: Arc<Storage>,
queue_depth: Arc<AtomicUsize>, queue_depth: Arc<AtomicUsize>,
registry: ProgressRegistry, registry: ProgressRegistry,
tx_req: std::sync::mpsc::Sender<TranscribeRequest>, cmd_tx: std::sync::mpsc::SyncSender<WorkerCmd>,
) { ) {
let http = Client::builder() let http = Client::builder()
.timeout(std::time::Duration::from_secs(30)) .timeout(Duration::from_secs(30))
.build() .build()
.expect("failed to build reqwest client"); .expect("failed to build reqwest client");
@@ -140,7 +443,7 @@ async fn run(
let audio_path = audio_path_for(&job_id); let audio_path = audio_path_for(&job_id);
let result = process_job(&job, &audio_path, &progress_tx, &tx_req, &storage).await; let result = process_job(&job, &audio_path, &progress_tx, &cmd_tx, &storage).await;
let _ = tokio::fs::remove_file(&audio_path).await; let _ = tokio::fs::remove_file(&audio_path).await;
@@ -175,26 +478,18 @@ async fn run(
tokio::spawn(async move { webhook::fire(&http, &url, &job).await; }); tokio::spawn(async move { webhook::fire(&http, &url, &job).await; });
} }
tokio::time::sleep(std::time::Duration::from_secs(30)).await; tokio::time::sleep(Duration::from_secs(30)).await;
registry.remove(&job_id); registry.remove(&job_id);
} }
} }
// ── Silence-based chunking ──────────────────────────────────────────────────── // ── Silence-based chunking ────────────────────────────────────────────────────
/// Target chunk length. 60s ≈ 2× whisper's native 30s window — short enough
/// that a hallucinated phrase can't compound beyond a single window.
const TARGET_CHUNK_SECS: f32 = 60.0; const TARGET_CHUNK_SECS: f32 = 60.0;
/// How far from the target we'll snap to a silence midpoint.
const SNAP_WINDOW_SECS: f32 = 30.0; const SNAP_WINDOW_SECS: f32 = 30.0;
/// Silence below this level (dB) counts as a split candidate.
const SILENCE_DB: &str = "-35dB"; const SILENCE_DB: &str = "-35dB";
/// Minimum silence duration to register as a candidate split.
const SILENCE_DUR: &str = "0.4"; const SILENCE_DUR: &str = "0.4";
/// Detect silence periods and return the midpoint (seconds) of each.
/// On any error (ffmpeg missing, binary format, etc.) returns an empty vec
/// so the caller can fall back to hard cuts.
async fn detect_silence_midpoints(path: &std::path::Path) -> Vec<f32> { async fn detect_silence_midpoints(path: &std::path::Path) -> Vec<f32> {
use tokio::process::Command; use tokio::process::Command;
@@ -217,7 +512,6 @@ async fn detect_silence_midpoints(path: &std::path::Path) -> Vec<f32> {
} }
}; };
// silencedetect logs to stderr
let stderr = String::from_utf8_lossy(&output.stderr); let stderr = String::from_utf8_lossy(&output.stderr);
let mut starts: Vec<f32> = Vec::new(); let mut starts: Vec<f32> = Vec::new();
let mut ends: Vec<f32> = Vec::new(); let mut ends: Vec<f32> = Vec::new();
@@ -228,7 +522,6 @@ async fn detect_silence_midpoints(path: &std::path::Path) -> Vec<f32> {
starts.push(t); starts.push(t);
} }
} else if let Some(i) = line.find("silence_end: ") { } else if let Some(i) = line.find("silence_end: ") {
// Format: "silence_end: 12.34 | silence_duration: 0.56"
let t_str = line[i + "silence_end: ".len()..] let t_str = line[i + "silence_end: ".len()..]
.split(" |") .split(" |")
.next() .next()
@@ -248,10 +541,6 @@ async fn detect_silence_midpoints(path: &std::path::Path) -> Vec<f32> {
mids mids
} }
/// Build cut points every `target_secs`, snapping to the nearest silence
/// midpoint within `snap_window` when one exists; otherwise a hard cut.
/// Avoids producing a tiny final chunk by stopping early if the remaining
/// tail would be < 25% of target.
fn snap_to_silence( fn snap_to_silence(
mids: &[f32], mids: &[f32],
total_secs: f32, total_secs: f32,
@@ -263,13 +552,9 @@ fn snap_to_silence(
while pos < total_secs - target_secs * 0.25 { while pos < total_secs - target_secs * 0.25 {
let prev_cut = cuts.last().copied().unwrap_or(0.0); let prev_cut = cuts.last().copied().unwrap_or(0.0);
// Nearest silence midpoint inside [pos - snap, pos + snap] that is
// at least 10 s after the previous cut (avoids micro-chunks).
let best = mids.iter().copied() let best = mids.iter().copied()
.filter(|&t| t > prev_cut + 10.0 && (t - pos).abs() <= snap_window) .filter(|&t| t > prev_cut + 10.0 && (t - pos).abs() <= snap_window)
.min_by(|a, b| (a - pos).abs().partial_cmp(&(b - pos).abs()).unwrap()); .min_by(|a, b| (a - pos).abs().partial_cmp(&(b - pos).abs()).unwrap());
let cut = best.unwrap_or(pos); let cut = best.unwrap_or(pos);
cuts.push(cut); cuts.push(cut);
pos = cut + target_secs; pos = cut + target_secs;
@@ -278,7 +563,6 @@ fn snap_to_silence(
cuts cuts
} }
/// Convert cut points into (start_secs, end_secs) chunk pairs.
fn to_chunk_ranges(cuts: &[f32], total_secs: f32) -> Vec<(f32, f32)> { fn to_chunk_ranges(cuts: &[f32], total_secs: f32) -> Vec<(f32, f32)> {
let mut ranges = Vec::new(); let mut ranges = Vec::new();
let mut start = 0.0_f32; let mut start = 0.0_f32;
@@ -289,7 +573,6 @@ fn to_chunk_ranges(cuts: &[f32], total_secs: f32) -> Vec<(f32, f32)> {
start = cut; start = cut;
} }
} }
// Last chunk
if total_secs - start >= 1.0 { if total_secs - start >= 1.0 {
ranges.push((start, total_secs)); ranges.push((start, total_secs));
} }
@@ -302,17 +585,13 @@ async fn process_job(
job: &Job, job: &Job,
audio_path: &std::path::Path, audio_path: &std::path::Path,
progress_tx: &ProgressTx, progress_tx: &ProgressTx,
tx_req: &std::sync::mpsc::Sender<TranscribeRequest>, cmd_tx: &std::sync::mpsc::SyncSender<WorkerCmd>,
storage: &Arc<Storage>, storage: &Arc<Storage>,
) -> crate::Result<(Vec<Segment>, String, f32)> { ) -> crate::Result<(Vec<Segment>, String, f32)> {
// 1. Decode full audio to 16 kHz mono PCM.
let pcm = decode_audio(audio_path).await?; let pcm = decode_audio(audio_path).await?;
let total_secs = pcm.len() as f32 / 16_000.0; let total_secs = pcm.len() as f32 / 16_000.0;
// 2. Detect silence midpoints from original file.
let silence_mids = detect_silence_midpoints(audio_path).await; let silence_mids = detect_silence_midpoints(audio_path).await;
// 3. Build silence-snapped chunk boundaries.
let cuts = snap_to_silence(&silence_mids, total_secs, TARGET_CHUNK_SECS, SNAP_WINDOW_SECS); let cuts = snap_to_silence(&silence_mids, total_secs, TARGET_CHUNK_SECS, SNAP_WINDOW_SECS);
let chunks = to_chunk_ranges(&cuts, total_secs); let chunks = to_chunk_ranges(&cuts, total_secs);
let n = chunks.len(); let n = chunks.len();
@@ -324,7 +603,6 @@ async fn process_job(
"audio chunked by silence" "audio chunked by silence"
); );
// 4. Transcribe each chunk, applying a time offset to all timestamps.
let mut all_segments: Vec<Segment> = Vec::new(); let mut all_segments: Vec<Segment> = Vec::new();
let mut language = String::new(); let mut language = String::new();
@@ -334,11 +612,9 @@ async fn process_job(
let mut chunk_pcm = pcm[s0..s1].to_vec(); let mut chunk_pcm = pcm[s0..s1].to_vec();
trim_trailing_silence(&mut chunk_pcm); trim_trailing_silence(&mut chunk_pcm);
// Base percent this chunk starts at.
let base = (ci * 100 / n) as u8; let base = (ci * 100 / n) as u8;
let span = (100usize / n).max(1) as u8; let span = (100usize / n).max(1) as u8;
// Emit a progress event and persist it at the start of every chunk.
let _ = progress_tx.send(ProgressEvent::Progress { let _ = progress_tx.send(ProgressEvent::Progress {
percent: base, percent: base,
chunk: ci + 1, chunk: ci + 1,
@@ -350,8 +626,7 @@ async fn process_job(
tracing::warn!(error = %e, "failed to persist mid-job progress"); tracing::warn!(error = %e, "failed to persist mid-job progress");
} }
// Scale whisper's per-chunk 0100 into the job's overall range. let tx = progress_tx.clone();
let tx = progress_tx.clone();
let chunk_num = ci + 1; let chunk_num = ci + 1;
let on_progress = Box::new(move |p: u8| { let on_progress = Box::new(move |p: u8| {
let overall = base.saturating_add(p.saturating_mul(span) / 100); let overall = base.saturating_add(p.saturating_mul(span) / 100);
@@ -363,18 +638,17 @@ async fn process_job(
}); });
let (reply_tx, reply_rx) = oneshot::channel(); let (reply_tx, reply_rx) = oneshot::channel();
tx_req.send(TranscribeRequest { cmd_tx.send(WorkerCmd::Transcribe(TranscribeRequest {
pcm: chunk_pcm, pcm: chunk_pcm,
language: job.language.clone(), language: job.language.clone(),
task: job.task.clone(), task: job.task.clone(),
on_progress, on_progress,
reply: reply_tx, reply: reply_tx,
}).map_err(|_| crate::AppError::Internal("transcriber thread gone".into()))?; })).map_err(|_| AppError::Internal("worker command channel closed".into()))?;
let (mut segs, lang) = reply_rx.await let (mut segs, lang) = reply_rx.await
.map_err(|_| crate::AppError::Internal("transcriber thread dropped reply".into()))??; .map_err(|_| AppError::Internal("transcriber thread dropped reply".into()))??;
// Shift all timestamps by chunk offset.
let offset = *chunk_start; let offset = *chunk_start;
for seg in &mut segs { for seg in &mut segs {
seg.start += offset; seg.start += offset;
@@ -400,7 +674,6 @@ async fn process_job(
} }
} }
// Renumber segment indices across the merged output.
for (i, seg) in all_segments.iter_mut().enumerate() { for (i, seg) in all_segments.iter_mut().enumerate() {
seg.index = i as i32; seg.index = i as i32;
} }
@@ -409,14 +682,9 @@ async fn process_job(
Ok((all_segments, language, total_secs)) Ok((all_segments, language, total_secs))
} }
/// Trim trailing silence from a 16 kHz mono PCM buffer.
///
/// Scans backwards to find the last sample above 35 dB, then keeps
/// 0.5 s of padding after it. This prevents whisper from hallucinating
/// filler tokens into end-of-chunk silence.
fn trim_trailing_silence(pcm: &mut Vec<f32>) { fn trim_trailing_silence(pcm: &mut Vec<f32>) {
const THRESHOLD: f32 = 0.017_8; // 35 dB (10^(35/20)) const THRESHOLD: f32 = 0.017_8;
const PADDING: usize = 8_000; // 0.5 s at 16 kHz const PADDING: usize = 8_000;
if let Some(last_loud) = pcm.iter().rposition(|&s| s.abs() > THRESHOLD) { if let Some(last_loud) = pcm.iter().rposition(|&s| s.abs() > THRESHOLD) {
let new_len = (last_loud + 1 + PADDING).min(pcm.len()); let new_len = (last_loud + 1 + PADDING).min(pcm.len());
@@ -429,10 +697,8 @@ fn trim_trailing_silence(pcm: &mut Vec<f32>) {
pcm.truncate(new_len); pcm.truncate(new_len);
} }
} }
// All-silent chunk: keep as-is — whisper will produce zero segments, which is correct.
} }
/// Decode any audio file to 16 kHz mono PCM f32 using ffmpeg.
async fn decode_audio(path: &std::path::Path) -> crate::Result<Vec<f32>> { async fn decode_audio(path: &std::path::Path) -> crate::Result<Vec<f32>> {
use tokio::process::Command; use tokio::process::Command;
@@ -447,11 +713,11 @@ async fn decode_audio(path: &std::path::Path) -> crate::Result<Vec<f32>> {
]) ])
.output() .output()
.await .await
.map_err(|e| crate::AppError::Internal(format!("ffmpeg spawn failed: {e}")))?; .map_err(|e| AppError::Internal(format!("ffmpeg spawn failed: {e}")))?;
if !output.status.success() { if !output.status.success() {
let stderr = String::from_utf8_lossy(&output.stderr); let stderr = String::from_utf8_lossy(&output.stderr);
return Err(crate::AppError::Internal(format!( return Err(AppError::Internal(format!(
"ffmpeg exited with {}: {}", "ffmpeg exited with {}: {}",
output.status, stderr output.status, stderr
))); )));
@@ -459,7 +725,7 @@ async fn decode_audio(path: &std::path::Path) -> crate::Result<Vec<f32>> {
let bytes = output.stdout; let bytes = output.stdout;
if bytes.len() % 4 != 0 { if bytes.len() % 4 != 0 {
return Err(crate::AppError::Internal( return Err(AppError::Internal(
"ffmpeg output length not a multiple of 4".into(), "ffmpeg output length not a multiple of 4".into(),
)); ));
} }
@@ -473,3 +739,51 @@ pub fn audio_path_for(id: &JobId) -> PathBuf {
let data_dir = std::env::var("DATA_DIR").unwrap_or_else(|_| "/data".into()); let data_dir = std::env::var("DATA_DIR").unwrap_or_else(|_| "/data".into());
PathBuf::from(data_dir).join(format!("{id}.audio")) PathBuf::from(data_dir).join(format!("{id}.audio"))
} }
// ── Unit tests ────────────────────────────────────────────────────────────────
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_snap_to_silence_uses_nearest_midpoint() {
let mids = vec![55.0, 58.0, 62.0];
let cuts = snap_to_silence(&mids, 120.0, 60.0, 30.0);
assert!(!cuts.is_empty());
assert!((cuts[0] - 58.0).abs() < 0.01, "expected ~58.0, got {}", cuts[0]);
}
#[test]
fn test_snap_to_silence_hard_cut_when_no_silence() {
let cuts = snap_to_silence(&[], 120.0, 60.0, 30.0);
assert_eq!(cuts, vec![60.0]);
}
#[test]
fn test_to_chunk_ranges_single_chunk() {
let ranges = to_chunk_ranges(&[], 30.0);
assert_eq!(ranges, vec![(0.0, 30.0)]);
}
#[test]
fn test_to_chunk_ranges_two_chunks() {
let ranges = to_chunk_ranges(&[60.0], 120.0);
assert_eq!(ranges, vec![(0.0, 60.0), (60.0, 120.0)]);
}
#[test]
fn test_trim_trailing_silence_all_silent() {
let mut pcm = vec![0.0f32; 1000];
trim_trailing_silence(&mut pcm);
assert_eq!(pcm.len(), 1000);
}
#[test]
fn test_trim_trailing_silence_trims_to_padding() {
let mut pcm = vec![0.0f32; 32_000];
pcm[10_000] = 1.0;
trim_trailing_silence(&mut pcm);
assert_eq!(pcm.len(), (10_001 + 8_000).min(32_000));
}
}

97
test_all.sh Normal file → Executable file
View File

@@ -6,8 +6,9 @@ BASE="${WHISPER_BASE_URL:-http://localhost:8080}"
AUDIO="${TEST_AUDIO:-/home/moze/Sources/youtube-transcriber/docker/tmp/audio-b2167046-a236-4fcd-b739-78177542fd23.wav}" AUDIO="${TEST_AUDIO:-/home/moze/Sources/youtube-transcriber/docker/tmp/audio-b2167046-a236-4fcd-b739-78177542fd23.wav}"
GREEN='\033[0;32m'; RED='\033[0;31m'; NC='\033[0m' GREEN='\033[0;32m'; RED='\033[0;31m'; NC='\033[0m'
FAILS=0
ok() { echo -e "${GREEN}[PASS]${NC} $*"; } ok() { echo -e "${GREEN}[PASS]${NC} $*"; }
fail(){ echo -e "${RED}[FAIL]${NC} $*"; exit 1; } fail(){ echo -e "${RED}[FAIL]${NC} $*"; FAILS=$((FAILS + 1)); }
echo "=== Whisper API test suite ===" echo "=== Whisper API test suite ==="
echo " BASE : $BASE" echo " BASE : $BASE"
@@ -17,11 +18,16 @@ echo ""
echo "=== 1. GET /health ===" echo "=== 1. GET /health ==="
HEALTH=$(curl -sf "$BASE/health") HEALTH=$(curl -sf "$BASE/health")
echo "$HEALTH" | python3 -m json.tool echo "$HEALTH" | python3 -m json.tool
echo "$HEALTH" | python3 -c "import sys,json; d=json.load(sys.stdin); assert d['status']=='ok', f'status={d[\"status\"]}'" && ok "health ok" python3 -c "
import sys, json
d = json.loads('$HEALTH' if False else sys.stdin.read())
assert d['status'] == 'ok', f'status={d[\"status\"]}'
assert 'model_state' in d, 'model_state field missing from health response'
" <<< "$HEALTH" && ok "health ok + model_state present" || fail "health check"
echo "" echo ""
echo "=== 2. GET /docs (Swagger UI reachable) ===" echo "=== 2. GET /docs (Swagger UI reachable) ==="
curl -sf "$BASE/docs" | grep -qi "swagger" && ok "swagger UI reachable" curl -sf "$BASE/docs" | grep -qi "swagger" && ok "swagger UI reachable" || fail "swagger UI"
echo "" echo ""
echo "=== 3. Webhook receiver (background Python HTTP server) ===" echo "=== 3. Webhook receiver (background Python HTTP server) ==="
@@ -33,7 +39,7 @@ class H(http.server.BaseHTTPRequestHandler):
n = int(self.headers.get('Content-Length', 0)) n = int(self.headers.get('Content-Length', 0))
body = self.rfile.read(n) body = self.rfile.read(n)
data = json.loads(body) data = json.loads(body)
print(f"\n[WEBHOOK] status={data.get('status')} segments={len(data.get('segments', []))}") print(f"\n[WEBHOOK] status={data.get('status')} segments={len(data.get('segments', []))}", flush=True)
self.send_response(200) self.send_response(200)
self.end_headers() self.end_headers()
def log_message(self, *a): pass def log_message(self, *a): pass
@@ -48,40 +54,70 @@ sleep 1
echo "Webhook receiver started (PID $WEBHOOK_PID)" echo "Webhook receiver started (PID $WEBHOOK_PID)"
echo "" echo ""
echo "=== 4. DELETE a non-existent job → 404 ===" echo "=== 4. GET /model/status — expect unloaded on fresh start ==="
MODEL_STATUS=$(curl -sf "$BASE/model/status")
echo "$MODEL_STATUS" | python3 -m json.tool
echo "$MODEL_STATUS" | python3 -c "
import sys, json
d = json.load(sys.stdin)
assert 'state' in d, 'state field missing from /model/status'
print(f' model state: {d[\"state\"]}')
" && ok "/model/status has state field" || fail "/model/status schema"
echo ""
echo "=== 5. POST /model/load — trigger model load ==="
LOAD_RESP=$(curl -sf -X POST "$BASE/model/load")
echo "$LOAD_RESP"
ok "POST /model/load accepted"
echo ""
echo "=== 6. Poll /model/status until ready (max 3 min) ==="
LOAD_ELAPSED=0
while true; do
sleep 5
LOAD_ELAPSED=$((LOAD_ELAPSED + 5))
MS=$(curl -sf "$BASE/model/status")
STATE=$(echo "$MS" | python3 -c "import sys,json; print(json.load(sys.stdin)['state'])")
echo " [${LOAD_ELAPSED}s] model_state=${STATE}"
if [ "$STATE" = "ready" ]; then
ok "model loaded and ready in ${LOAD_ELAPSED}s"
break
fi
[ $LOAD_ELAPSED -gt 180 ] && { fail "model failed to load within 3 minutes"; break; }
done
echo ""
echo "=== 7. DELETE a non-existent job → 404 ==="
STATUS=$(curl -s -o /dev/null -w "%{http_code}" -X DELETE "$BASE/jobs/00000000-0000-0000-0000-000000000000") STATUS=$(curl -s -o /dev/null -w "%{http_code}" -X DELETE "$BASE/jobs/00000000-0000-0000-0000-000000000000")
[ "$STATUS" = "404" ] && ok "DELETE unknown job → 404" || fail "expected 404, got $STATUS" [ "$STATUS" = "404" ] && ok "DELETE unknown job → 404" || fail "expected 404, got $STATUS"
echo "" echo ""
echo "=== 5. POST /jobs — submit audio ===" echo "=== 8. POST /jobs — submit audio ==="
# language field omitted → auto-detection. Do NOT pass "auto" — it is not a
# valid ISO 639-1 code and whisper-rs will reject it or behave unexpectedly.
SUBMIT=$(curl -sf -X POST "$BASE/jobs" \ SUBMIT=$(curl -sf -X POST "$BASE/jobs" \
-F "audio=@${AUDIO};type=audio/wav" \ -F "audio=@${AUDIO};type=audio/wav" \
-F "task=transcribe" \ -F "task=transcribe" \
-F "webhook_url=http://localhost:9999/webhook") -F "webhook_url=http://localhost:9999/webhook")
echo "$SUBMIT" echo "$SUBMIT"
# Submit response: { "job_id": "<uuid>" } (field is "job_id", not "id")
JOB_ID=$(echo "$SUBMIT" | python3 -c "import sys,json; print(json.load(sys.stdin)['job_id'])") JOB_ID=$(echo "$SUBMIT" | python3 -c "import sys,json; print(json.load(sys.stdin)['job_id'])")
ok "submitted job $JOB_ID" ok "submitted job $JOB_ID"
echo "" echo ""
echo "=== 6. GET /jobs/{id} immediately after submit ===" echo "=== 9. GET /jobs/{id} immediately after submit ==="
JOB=$(curl -sf "$BASE/jobs/$JOB_ID") JOB=$(curl -sf "$BASE/jobs/$JOB_ID")
echo "$JOB" | python3 -c " echo "$JOB" | python3 -c "
import sys, json import sys, json
d = json.load(sys.stdin) d = json.load(sys.stdin)
assert d['status'] in ('queued', 'running'), f'unexpected status: {d[\"status\"]}' assert d['status'] in ('queued', 'running'), f'unexpected status: {d[\"status\"]}'
" && ok "status is queued/running" " && ok "status is queued/running" || fail "initial status check"
echo "" echo ""
echo "=== 7. SSE stream (observe first 30 events then detach) ===" echo "=== 10. SSE stream (observe first 30 events then detach) ==="
echo "Subscribing to SSE stream for $JOB_ID" echo "Subscribing to SSE stream for $JOB_ID"
curl -sN --max-time 90 "$BASE/jobs/$JOB_ID/stream" | head -60 & curl -sN --max-time 90 "$BASE/jobs/$JOB_ID/stream" | head -60 &
SSE_PID=$! SSE_PID=$!
echo "" echo ""
echo "=== 8. Poll until done (max 20 min) ===" echo "=== 11. Poll until done (max 20 min) ==="
ELAPSED=0 ELAPSED=0
while true; do while true; do
sleep 15 sleep 15
@@ -96,16 +132,15 @@ while true; do
elif [ "$STATUS" = "failed" ]; then elif [ "$STATUS" = "failed" ]; then
echo "$JOB" | python3 -m json.tool echo "$JOB" | python3 -m json.tool
fail "job failed" fail "job failed"
break
fi fi
[ $ELAPSED -gt 1200 ] && fail "timeout after 20 minutes" [ $ELAPSED -gt 1200 ] && { fail "timeout after 20 minutes"; break; }
done done
kill $SSE_PID 2>/dev/null || true kill $SSE_PID 2>/dev/null || true
echo "" echo ""
echo "=== 9. Inspect transcription quality ===" echo "=== 12. Inspect transcription quality ==="
RESULT=$(curl -sf "$BASE/jobs/$JOB_ID") RESULT=$(curl -sf "$BASE/jobs/$JOB_ID")
# Note: can't pipe into a heredoc-driven python3 (heredoc takes stdin, pipe is ignored).
# Write to a temp file instead.
TMPJSON=$(mktemp /tmp/whisper_test_XXXXXX.json) TMPJSON=$(mktemp /tmp/whisper_test_XXXXXX.json)
echo "$RESULT" > "$TMPJSON" echo "$RESULT" > "$TMPJSON"
python3 - "$TMPJSON" << 'PYCHECK' python3 - "$TMPJSON" << 'PYCHECK'
@@ -149,18 +184,16 @@ for seg in segments[:5]:
PYCHECK PYCHECK
PYEXIT=$? PYEXIT=$?
rm -f "$TMPJSON" rm -f "$TMPJSON"
[ $PYEXIT -eq 0 ] && ok "quality check passed" || { echo "[FAIL] quality check"; FAILS=$((FAILS+1)); } [ $PYEXIT -eq 0 ] && ok "quality check passed" || fail "quality check"
echo "" echo ""
echo "=== 10. DELETE completed job → 200 ===" echo "=== 13. DELETE completed job → 409 Conflict ==="
# Completed jobs return 409 Conflict on DELETE (terminal state).
# Verify we get 409, not 200 (delete is only for cancellation of active jobs).
DEL_STATUS=$(curl -s -o /dev/null -w "%{http_code}" -X DELETE "$BASE/jobs/$JOB_ID") DEL_STATUS=$(curl -s -o /dev/null -w "%{http_code}" -X DELETE "$BASE/jobs/$JOB_ID")
[ "$DEL_STATUS" = "409" ] && ok "DELETE completed job → 409 Conflict (expected)" \ [ "$DEL_STATUS" = "409" ] && ok "DELETE completed job → 409 Conflict (expected)" \
|| echo " [INFO] DELETE returned $DEL_STATUS" || echo " [INFO] DELETE returned $DEL_STATUS"
echo "" echo ""
echo "=== 11. Submit + cancel a queued job ===" echo "=== 14. Submit + cancel a queued job ==="
JOB2=$(curl -sf -X POST "$BASE/jobs" \ JOB2=$(curl -sf -X POST "$BASE/jobs" \
-F "audio=@${AUDIO};type=audio/wav" \ -F "audio=@${AUDIO};type=audio/wav" \
-F "language=en" \ -F "language=en" \
@@ -173,8 +206,24 @@ CANCEL_STATUS=$(curl -sf "$BASE/jobs/$JOB2_ID" | python3 -c "import sys,json; pr
|| echo " [INFO] cancel status: $CANCEL_STATUS (may be running — worker ignores cancel mid-chunk)" || echo " [INFO] cancel status: $CANCEL_STATUS (may be running — worker ignores cancel mid-chunk)"
echo "" echo ""
echo "=== 12. Verify webhook fired ===" echo "=== 15. POST /model/unload ==="
UNLOAD_RESP=$(curl -sf -X POST "$BASE/model/unload")
echo "$UNLOAD_RESP"
sleep 2
UNLOAD_STATE=$(curl -sf "$BASE/model/status" | python3 -c "import sys,json; print(json.load(sys.stdin)['state'])")
[ "$UNLOAD_STATE" = "unloaded" ] && ok "model unloaded → state=unloaded" \
|| echo " [INFO] state after unload: $UNLOAD_STATE"
echo ""
echo "=== 16. Verify webhook fired ==="
sleep 3 sleep 3
kill $WEBHOOK_PID 2>/dev/null || true kill $WEBHOOK_PID 2>/dev/null || true
ok "all tests complete" ok "webhook server stopped"
echo ""
if [ $FAILS -eq 0 ]; then
echo -e "${GREEN}=== ALL TESTS PASSED ===${NC}"
else
echo -e "${RED}=== $FAILS TEST(S) FAILED ===${NC}"
exit 1
fi

246
tests/test_idle_timeout.sh Executable file
View File

@@ -0,0 +1,246 @@
#!/usr/bin/env bash
# tests/test_idle_timeout.sh
#
# Integration tests for the idle-timeout auto-unload feature.
# REQUIRES the server to be started with a short idle timeout:
#
# IDLE_TIMEOUT_SECS=5 ./whisper-server
# # or via Docker:
# docker run -e IDLE_TIMEOUT_SECS=5 ...
#
# The default idle timeout is 5 minutes; these tests use a 5-second window
# to keep the suite fast.
set -euo pipefail
BASE="${WHISPER_BASE_URL:-http://localhost:8080}"
IDLE_TIMEOUT="${EXPECTED_IDLE_TIMEOUT_SECS:-5}"
AUDIO="${TEST_AUDIO:-}"
GREEN='\033[0;32m'; RED='\033[0;31m'; YELLOW='\033[0;33m'; NC='\033[0m'
PASS=0; FAIL=0
ok() { echo -e "${GREEN}[PASS]${NC} $1"; PASS=$((PASS+1)); }
fail() { echo -e "${RED}[FAIL]${NC} $1"; FAIL=$((FAIL+1)); }
skip() { echo -e "${YELLOW}[SKIP]${NC} $1"; }
info() { echo " $1"; }
echo "=== Idle Timeout Tests ==="
echo " BASE: $BASE"
echo " IDLE_TIMEOUT_SECS: $IDLE_TIMEOUT (must be configured on the server)"
echo ""
echo "NOTE: These tests require the server to be running with IDLE_TIMEOUT_SECS=$IDLE_TIMEOUT"
echo ""
# ── Helpers ──────────────────────────────────────────────────────────────────
get_state() {
curl -sf "$BASE/model/status" | python3 -c "import sys,json; print(json.load(sys.stdin)['state'])"
}
ensure_ready() {
local state
state=$(get_state)
if [ "$state" = "ready" ]; then return 0; fi
curl -sf -X POST "$BASE/model/load" > /dev/null
local elapsed=0
while true; do
sleep 3; elapsed=$((elapsed+3))
state=$(get_state)
[ "$state" = "ready" ] && return 0
[ $elapsed -gt 180 ] && return 1
done
}
ensure_unloaded() {
curl -sf -X POST "$BASE/model/unload" > /dev/null || true
sleep 2
}
# ── TEST 1: Load model, complete a job, then wait for idle unload ─────────────
echo "--- Test 1: Idle timeout triggers auto-unload ---"
ensure_unloaded
ensure_ready || { fail "T1: model load failed"; }
WAIT_SECS=$((IDLE_TIMEOUT + 3))
info "Model is ready. Waiting $WAIT_SECS seconds (idle timeout=$IDLE_TIMEOUT + 3s buffer)..."
sleep $WAIT_SECS
STATE=$(get_state)
if [ "$STATE" = "unloaded" ]; then
ok "T1: model auto-unloaded after ${IDLE_TIMEOUT}s idle"
else
fail "T1: expected unloaded after idle timeout, got $STATE"
info "Is the server running with IDLE_TIMEOUT_SECS=$IDLE_TIMEOUT?"
fi
# ── TEST 2: model_unloaded webhook fires on idle timeout ─────────────────────
echo ""
echo "--- Test 2: model_unloaded webhook fires on idle timeout ---"
ensure_unloaded
# Start webhook receiver
python3 - <<'PYEOF' &
import http.server, json, sys, signal
class H(http.server.BaseHTTPRequestHandler):
def do_POST(self):
n = int(self.headers.get('Content-Length', 0))
body = json.loads(self.rfile.read(n))
with open('/tmp/idle_wh_event.json', 'w') as f:
json.dump(body, f)
self.send_response(200); self.end_headers()
def log_message(self, *a): pass
signal.signal(signal.SIGTERM, lambda *_: sys.exit(0))
http.server.HTTPServer(('', 9995), H).serve_forever()
PYEOF
WH_PID=$!
sleep 1
# Register webhook via a job submission (will 503 since unloaded)
curl -sf -X POST "$BASE/jobs" \
-F "audio=@/dev/urandom;type=audio/wav" \
-F "webhook_url=http://localhost:9995/wh" \
--max-time 5 > /dev/null 2>&1 || true
# Load model
ensure_ready || { fail "T2: model load failed"; kill $WH_PID 2>/dev/null; }
# Wait for idle timeout
WAIT_SECS=$((IDLE_TIMEOUT + 5))
info "Waiting ${WAIT_SECS}s for idle timeout..."
sleep $WAIT_SECS
kill $WH_PID 2>/dev/null || true
wait $WH_PID 2>/dev/null || true
if [ -f /tmp/idle_wh_event.json ]; then
EVENT_TYPE=$(python3 -c "import json; print(json.load(open('/tmp/idle_wh_event.json')).get('type','?'))")
rm -f /tmp/idle_wh_event.json
[ "$EVENT_TYPE" = "model_unloaded" ] && ok "T2: model_unloaded webhook fired on idle timeout" \
|| fail "T2: webhook type=$EVENT_TYPE (expected model_unloaded)"
else
fail "T2: no webhook received within timeout"
fi
# ── TEST 3: Job submission after idle timeout → 503 → triggers reload ─────────
echo ""
echo "--- Test 3: Job triggers reload after idle unload ---"
ensure_unloaded
ensure_ready || { fail "T3: initial load failed"; }
# Wait for auto-unload
WAIT_SECS=$((IDLE_TIMEOUT + 3))
info "Waiting ${WAIT_SECS}s for idle unload..."
sleep $WAIT_SECS
STATE=$(get_state)
[ "$STATE" = "unloaded" ] || info "Note: state=$STATE (expected unloaded)"
# Submit job → 503, triggers reload
HTTP=$(curl -s -o /tmp/t3_body.json -w "%{http_code}" -X POST "$BASE/jobs" \
-F "audio=@/dev/urandom;type=audio/wav" \
--max-time 5 2>/dev/null || echo "000")
if [ "$HTTP" = "503" ]; then
ok "T3a: POST /jobs → 503 after idle unload"
else
skip "T3a: POST /jobs returned $HTTP (model may have reloaded)"
fi
# State should be loading or ready (reload triggered by job submission)
sleep 2
STATE=$(get_state)
if [ "$STATE" = "loading" ] || [ "$STATE" = "ready" ]; then
ok "T3b: reload triggered by job submission ($STATE)"
else
fail "T3b: expected loading/ready, got $STATE"
fi
rm -f /tmp/t3_body.json
# ── TEST 4: Idle timer resets per job (wait 60% of timeout → still ready) ─────
echo ""
echo "--- Test 4: Idle timer resets with each completed job ---"
ensure_unloaded
ensure_ready || { fail "T4: model load failed"; }
HALF_WAIT=$((IDLE_TIMEOUT - 1))
info "Waiting ${HALF_WAIT}s (less than idle timeout)..."
sleep $HALF_WAIT
STATE=$(get_state)
if [ "$STATE" = "ready" ]; then
ok "T4a: model still ready after ${HALF_WAIT}s (less than ${IDLE_TIMEOUT}s timeout)"
else
fail "T4a: model unexpectedly $STATE after only ${HALF_WAIT}s"
fi
# Wait for full unload
REMAINING=$((IDLE_TIMEOUT - HALF_WAIT + 3))
info "Waiting another ${REMAINING}s for full idle unload..."
sleep $REMAINING
STATE=$(get_state)
[ "$STATE" = "unloaded" ] && ok "T4b: model unloaded after total > ${IDLE_TIMEOUT}s idle" \
|| fail "T4b: expected unloaded, got $STATE"
# ── TEST 5: Job resets idle timer ─────────────────────────────────────────────
echo ""
echo "--- Test 5: Completing a job resets the idle timer ---"
if [ -z "$AUDIO" ]; then
skip "T5: TEST_AUDIO not set — skipping timer-reset test"
else
ensure_unloaded
ensure_ready || { fail "T5: model load failed"; }
# Submit a job
SUBMIT=$(curl -sf -X POST "$BASE/jobs" \
-F "audio=@${AUDIO};type=audio/wav" \
-F "task=transcribe" 2>&1)
JOB_ID=$(echo "$SUBMIT" | python3 -c "import sys,json; print(json.load(sys.stdin)['job_id'])" 2>/dev/null || echo "")
if [ -z "$JOB_ID" ]; then
fail "T5: job submission failed"
else
# Wait for job to finish
elapsed=0
while true; do
sleep 5; elapsed=$((elapsed+5))
STATUS=$(curl -sf "$BASE/jobs/$JOB_ID" | python3 -c "import sys,json; print(json.load(sys.stdin)['status'])")
[ "$STATUS" = "done" ] || [ "$STATUS" = "failed" ] && break
[ $elapsed -gt 300 ] && break
done
info "Job finished in ${elapsed}s with status=$STATUS"
# Now wait IDLE_TIMEOUT - 2 seconds — should still be ready
SAFE_WAIT=$((IDLE_TIMEOUT - 2))
[ $SAFE_WAIT -lt 1 ] && SAFE_WAIT=1
info "Waiting ${SAFE_WAIT}s after job completion (less than idle timeout)..."
sleep $SAFE_WAIT
STATE=$(get_state)
[ "$STATE" = "ready" ] && ok "T5a: model still ready ${SAFE_WAIT}s after job completion" \
|| fail "T5a: model unexpectedly $STATE after job"
# Wait for idle timeout
REMAINING=$((IDLE_TIMEOUT - SAFE_WAIT + 3))
info "Waiting ${REMAINING}s more for idle unload..."
sleep $REMAINING
STATE=$(get_state)
[ "$STATE" = "unloaded" ] && ok "T5b: model auto-unloaded after idle period post-job" \
|| fail "T5b: expected unloaded, got $STATE"
fi
fi
# ── Summary ────────────────────────────────────────────────────────────────────
echo ""
echo "=========================================="
echo " Results: ${PASS} passed, ${FAIL} failed"
echo "=========================================="
[ $FAIL -eq 0 ] && echo -e "${GREEN}ALL PASSED${NC}" || { echo -e "${RED}FAILURES: $FAIL${NC}"; exit 1; }

470
tests/test_model_lifecycle.sh Executable file
View File

@@ -0,0 +1,470 @@
#!/usr/bin/env bash
# tests/test_model_lifecycle.sh
#
# Integration tests for dynamic model loading/unloading.
# Requires a running whisper-server with GPU access.
#
# Usage:
# WHISPER_BASE_URL=http://localhost:8080 bash tests/test_model_lifecycle.sh
#
# Tests are designed to be independent; each section that needs a specific
# state resets it explicitly at the start.
set -euo pipefail
BASE="${WHISPER_BASE_URL:-http://localhost:8080}"
AUDIO="${TEST_AUDIO:-}"
GREEN='\033[0;32m'; RED='\033[0;31m'; YELLOW='\033[0;33m'; NC='\033[0m'
PASS=0; FAIL=0
ok() { echo -e "${GREEN}[PASS]${NC} $1"; PASS=$((PASS+1)); }
fail() { echo -e "${RED}[FAIL]${NC} $1"; FAIL=$((FAIL+1)); }
skip() { echo -e "${YELLOW}[SKIP]${NC} $1"; }
info() { echo " $1"; }
echo "=== Model Lifecycle Integration Tests ==="
echo " BASE: $BASE"
echo ""
# ── Helpers ──────────────────────────────────────────────────────────────────
get_state() {
curl -sf "$BASE/model/status" | python3 -c "import sys,json; print(json.load(sys.stdin)['state'])"
}
ensure_unloaded() {
curl -sf -X POST "$BASE/model/unload" > /dev/null
sleep 2
local s
s=$(get_state)
if [ "$s" != "unloaded" ]; then
echo " WARNING: expected unloaded, got $s — waiting 5s"
sleep 5
fi
}
ensure_ready() {
local state
state=$(get_state)
if [ "$state" = "ready" ]; then return 0; fi
curl -sf -X POST "$BASE/model/load" > /dev/null
local elapsed=0
while true; do
sleep 5; elapsed=$((elapsed+5))
state=$(get_state)
[ "$state" = "ready" ] && return 0
[ $elapsed -gt 180 ] && echo " TIMEOUT: model did not become ready" && return 1
done
}
poll_state_transition() {
local target="$1" max_secs="${2:-120}"
local elapsed=0
while true; do
sleep 2; elapsed=$((elapsed+2))
local s
s=$(get_state)
[ "$s" = "$target" ] && return 0
[ $elapsed -ge $max_secs ] && return 1
done
}
# ── TEST 1: Startup state is unloaded ────────────────────────────────────────
echo "--- Test 1: Startup state is unloaded (or after explicit unload) ---"
ensure_unloaded
STATE=$(get_state)
if [ "$STATE" = "unloaded" ]; then
ok "T1: state=unloaded after explicit unload"
else
fail "T1: expected unloaded, got $STATE"
fi
# ── TEST 2: POST /model/load returns 202 ─────────────────────────────────────
echo ""
echo "--- Test 2: POST /model/load returns 202 ---"
ensure_unloaded
HTTP=$(curl -s -o /dev/null -w "%{http_code}" -X POST "$BASE/model/load")
if [ "$HTTP" = "202" ]; then
ok "T2: POST /model/load → 202 Accepted"
else
fail "T2: expected 202, got $HTTP"
fi
# Cancel the in-progress load to clean up
curl -sf -X POST "$BASE/model/unload" > /dev/null || true
sleep 2
# ── TEST 3: State transitions to loading/ready after load trigger ─────────────
echo ""
echo "--- Test 3: State transitions to loading (not stuck at unloaded) ---"
ensure_unloaded
curl -sf -X POST "$BASE/model/load" > /dev/null
sleep 1
STATE=$(get_state)
if [ "$STATE" = "loading" ] || [ "$STATE" = "ready" ]; then
ok "T3: state transitioned to $STATE (not stuck at unloaded)"
else
fail "T3: expected loading or ready, got $STATE"
fi
# ── TEST 4: Model reaches ready state and loaded_at is set ───────────────────
echo ""
echo "--- Test 4: Model reaches ready state with loaded_at timestamp ---"
# Already loading from T3 — wait for ready
if ! poll_state_transition "ready" 180; then
fail "T4: model did not become ready within 3 minutes"
else
STATUS_JSON=$(curl -sf "$BASE/model/status")
LOADED_AT=$(echo "$STATUS_JSON" | python3 -c "import sys,json; d=json.load(sys.stdin); print(d.get('loaded_at','MISSING'))" 2>/dev/null || echo "MISSING")
if [ "$LOADED_AT" != "MISSING" ] && [ "$LOADED_AT" != "null" ] && [ -n "$LOADED_AT" ]; then
ok "T4: model=ready, loaded_at=$LOADED_AT"
else
fail "T4: model ready but loaded_at is missing or null"
fi
fi
# ── TEST 5: Idempotent load — POST /model/load when ready returns 200 ─────────
echo ""
echo "--- Test 5: POST /model/load when already ready → 200 ---"
ensure_ready || { fail "T5: could not load model"; }
HTTP=$(curl -s -o /dev/null -w "%{http_code}" -X POST "$BASE/model/load")
STATE=$(get_state)
if [ "$HTTP" = "200" ] && [ "$STATE" = "ready" ]; then
ok "T5: idempotent load → 200, state stays ready"
elif [ "$HTTP" = "202" ] && [ "$STATE" = "ready" ]; then
ok "T5: idempotent load → 202, state stays ready"
else
fail "T5: expected 200 and ready, got HTTP=$HTTP state=$STATE"
fi
# ── TEST 6: Job accepted when ready (segments > 0) ────────────────────────────
echo ""
echo "--- Test 6: Job accepted when model is ready ---"
if [ -z "$AUDIO" ]; then
skip "T6: TEST_AUDIO not set — skipping job submission test"
else
ensure_ready || { fail "T6: model load failed"; }
SUBMIT=$(curl -sf -X POST "$BASE/jobs" -F "audio=@${AUDIO};type=audio/wav" -F "task=transcribe" 2>&1)
JOB_ID=$(echo "$SUBMIT" | python3 -c "import sys,json; print(json.load(sys.stdin)['job_id'])" 2>/dev/null || echo "")
if [ -n "$JOB_ID" ]; then
ok "T6: job accepted, id=$JOB_ID"
# Poll to done
elapsed=0
while true; do
sleep 10; elapsed=$((elapsed+10))
STATUS=$(curl -sf "$BASE/jobs/$JOB_ID" | python3 -c "import sys,json; print(json.load(sys.stdin)['status'])")
[ "$STATUS" = "done" ] && break
[ "$STATUS" = "failed" ] && break
[ $elapsed -gt 600 ] && break
done
SEGS=$(curl -sf "$BASE/jobs/$JOB_ID" | python3 -c "import sys,json; d=json.load(sys.stdin); print(len(d.get('segments',[])))")
[ "$SEGS" -gt 0 ] && ok "T6b: job done with $SEGS segments" || fail "T6b: job done but 0 segments"
else
fail "T6: job submission failed: $SUBMIT"
fi
fi
# ── TEST 7: POST /model/unload → state=unloaded ───────────────────────────────
echo ""
echo "--- Test 7: POST /model/unload ---"
ensure_ready || { fail "T7: model load failed"; }
curl -sf -X POST "$BASE/model/unload" > /dev/null
sleep 3
STATE=$(get_state)
if [ "$STATE" = "unloaded" ]; then
ok "T7: POST /model/unload → state=unloaded"
else
fail "T7: expected unloaded after unload, got $STATE"
fi
# ── TEST 8: POST /jobs when unloaded → 503 + Retry-After ─────────────────────
echo ""
echo "--- Test 8: POST /jobs when unloaded → 503 + Retry-After ---"
ensure_unloaded
# Submit a tiny dummy payload (won't be valid audio but that's ok for this test)
HTTP=$(curl -s -o /tmp/t8_body.json -w "%{http_code}" -X POST "$BASE/jobs" \
-F "audio=@/dev/urandom;type=audio/wav" \
--max-time 5 2>/dev/null || echo "000")
# If the model auto-loads it might start processing; check for 503 first
if [ "$HTTP" = "503" ]; then
RETRY_AFTER=$(curl -sI -X POST "$BASE/jobs" \
-F "audio=@/dev/urandom;type=audio/wav" \
--max-time 5 2>/dev/null | grep -i "retry-after" | awk '{print $2}' | tr -d '\r' || echo "")
BODY=$(cat /tmp/t8_body.json 2>/dev/null || echo "{}")
HAS_STATE=$(echo "$BODY" | python3 -c "import sys,json; d=json.load(sys.stdin); print('state' in d)" 2>/dev/null || echo "False")
HAS_RETRY=$(echo "$BODY" | python3 -c "import sys,json; d=json.load(sys.stdin); print('retry_after_secs' in d)" 2>/dev/null || echo "False")
if [ "$HAS_STATE" = "True" ] && [ "$HAS_RETRY" = "True" ]; then
ok "T8: 503 with state + retry_after_secs in body"
else
fail "T8: 503 but body missing state/retry_after_secs. body=$BODY"
fi
if [ -n "$RETRY_AFTER" ]; then
ok "T8b: Retry-After header present: $RETRY_AFTER"
else
fail "T8b: Retry-After header missing from 503 response"
fi
else
skip "T8: got HTTP $HTTP (model may have loaded before check) — skipping"
fi
# ── TEST 9: Rejected job triggers load ────────────────────────────────────────
echo ""
echo "--- Test 9: Job rejection triggers model load ---"
ensure_unloaded
# Send a job (we expect 503)
curl -sf -X POST "$BASE/jobs" \
-F "audio=@/dev/urandom;type=audio/wav" \
--max-time 5 > /dev/null 2>&1 || true
sleep 2
STATE=$(get_state)
if [ "$STATE" = "loading" ] || [ "$STATE" = "ready" ]; then
ok "T9: model started loading after job rejection ($STATE)"
else
fail "T9: expected loading/ready after job rejection, got $STATE"
fi
# Stop the load to clean up
curl -sf -X POST "$BASE/model/unload" > /dev/null || true
sleep 2
# ── TEST 10: Retry-After values ───────────────────────────────────────────────
echo ""
echo "--- Test 10: Retry-After values match state ---"
ensure_unloaded
# Unloaded → Retry-After: 30
RESP_UNLOADED=$(curl -si -X POST "$BASE/jobs" -F "audio=@/dev/urandom;type=audio/wav" --max-time 5 2>/dev/null || echo "")
RA_UNLOADED=$(echo "$RESP_UNLOADED" | grep -i "retry-after" | awk '{print $2}' | tr -d '\r' || echo "")
[ "$RA_UNLOADED" = "30" ] && ok "T10a: Retry-After=30 when unloaded" \
|| skip "T10a: Retry-After=$RA_UNLOADED (expected 30) — model may have started loading"
# ── TEST 11: Retry-After=10 during loading ────────────────────────────────────
echo ""
echo "--- Test 11: Retry-After=10 when loading ---"
ensure_unloaded
curl -sf -X POST "$BASE/model/load" > /dev/null
sleep 1 # In loading state
STATE=$(get_state)
if [ "$STATE" = "loading" ]; then
RESP_LOADING=$(curl -si -X POST "$BASE/jobs" -F "audio=@/dev/urandom;type=audio/wav" --max-time 5 2>/dev/null || echo "")
RA_LOADING=$(echo "$RESP_LOADING" | grep -i "retry-after" | awk '{print $2}' | tr -d '\r' || echo "")
[ "$RA_LOADING" = "10" ] && ok "T11: Retry-After=10 when loading" \
|| fail "T11: expected Retry-After=10, got '$RA_LOADING' (state=$STATE)"
else
skip "T11: model already $STATE — can't test loading state Retry-After"
fi
# ── TEST 12: 503 body schema validation ──────────────────────────────────────
echo ""
echo "--- Test 12: 503 body schema validation ---"
ensure_unloaded
BODY=$(curl -sf -X POST "$BASE/jobs" -F "audio=@/dev/urandom;type=audio/wav" --max-time 5 2>/dev/null || echo "{}")
python3 - <<PYCHECK
import json
body = json.loads('$BODY')
required = {'error', 'state', 'retry_after_secs'}
missing = required - set(body.keys())
if missing:
print(f"MISSING: {missing}")
exit(1)
assert body['error'] == 'model_not_ready', f"error={body['error']}"
assert isinstance(body['retry_after_secs'], int), f"retry_after_secs not int: {body['retry_after_secs']}"
print("schema ok")
PYCHECK
[ $? -eq 0 ] && ok "T12: 503 body has correct schema" || fail "T12: 503 body schema invalid"
# ── TEST 13: GET /health has model_state field ────────────────────────────────
echo ""
echo "--- Test 13: GET /health has model_state ---"
HEALTH=$(curl -sf "$BASE/health")
HAS_MODEL_STATE=$(echo "$HEALTH" | python3 -c "import sys,json; d=json.load(sys.stdin); print('model_state' in d)")
[ "$HAS_MODEL_STATE" = "True" ] && ok "T13: /health has model_state" || fail "T13: /health missing model_state"
# ── TEST 14: SSE /model/events delivers model_ready event ─────────────────────
echo ""
echo "--- Test 14: GET /model/events SSE delivers model_ready ---"
ensure_unloaded
# Collect SSE events for up to 3 minutes
SSE_LOG=$(mktemp /tmp/sse_events_XXXXXX.txt)
curl -sN --max-time 180 "$BASE/model/events" > "$SSE_LOG" &
SSE_PID=$!
sleep 1
# Trigger load
curl -sf -X POST "$BASE/model/load" > /dev/null
poll_state_transition "ready" 180 || true
sleep 2
kill $SSE_PID 2>/dev/null || true
wait $SSE_PID 2>/dev/null || true
if grep -q "model_loading" "$SSE_LOG" 2>/dev/null; then
ok "T14a: SSE received model_loading event"
else
fail "T14a: SSE did not receive model_loading event"
fi
if grep -q "model_ready" "$SSE_LOG" 2>/dev/null; then
ok "T14b: SSE received model_ready event"
else
fail "T14b: SSE did not receive model_ready event"
fi
# Now unload to get model_unloaded event
curl -sf -X POST "$BASE/model/unload" > /dev/null
sleep 1
SSE_LOG2=$(mktemp /tmp/sse_events_XXXXXX.txt)
curl -sN --max-time 10 "$BASE/model/events" > "$SSE_LOG2" &
SSE_PID2=$!
sleep 2
kill $SSE_PID2 2>/dev/null || true
wait $SSE_PID2 2>/dev/null || true
# model_unloaded fires immediately on unload command
if grep -q "model_unloaded" "$SSE_LOG" 2>/dev/null || grep -q "model_unloaded" "$SSE_LOG2" 2>/dev/null; then
ok "T14c: SSE received model_unloaded event"
else
fail "T14c: SSE did not receive model_unloaded event"
fi
rm -f "$SSE_LOG" "$SSE_LOG2"
# ── TEST 15: model_ready webhook fires after load ──────────────────────────────
echo ""
echo "--- Test 15: model_ready webhook ---"
ensure_unloaded
# Start webhook receiver
WEBHOOK_LOG=$(mktemp /tmp/webhook_log_XXXXXX.txt)
python3 - <<'PYEOF' &
import http.server, json, sys, signal, os
class H(http.server.BaseHTTPRequestHandler):
def do_POST(self):
n = int(self.headers.get('Content-Length', 0))
body = json.loads(self.rfile.read(n))
with open('/tmp/t15_webhook.json', 'w') as f:
json.dump(body, f)
self.send_response(200); self.end_headers()
def log_message(self, *a): pass
signal.signal(signal.SIGTERM, lambda *_: sys.exit(0))
http.server.HTTPServer(('', 9998), H).serve_forever()
PYEOF
WBOOK_PID=$!
sleep 1
# Register a webhook via a (doomed) job submission — this registers the URL
# even though the model is unloaded (and the job will 503)
curl -sf -X POST "$BASE/jobs" \
-F "audio=@/dev/urandom;type=audio/wav" \
-F "webhook_url=http://localhost:9998/wh" \
--max-time 5 > /dev/null 2>&1 || true
# Now load the model
curl -sf -X POST "$BASE/model/load" > /dev/null
poll_state_transition "ready" 180 || true
sleep 3
kill $WBOOK_PID 2>/dev/null || true
wait $WBOOK_PID 2>/dev/null || true
if [ -f /tmp/t15_webhook.json ]; then
EVENT_TYPE=$(python3 -c "import json; d=json.load(open('/tmp/t15_webhook.json')); print(d.get('type','?'))")
[ "$EVENT_TYPE" = "model_ready" ] && ok "T15: model_ready webhook fired" \
|| fail "T15: webhook fired but type=$EVENT_TYPE (expected model_ready)"
rm -f /tmp/t15_webhook.json
else
fail "T15: model_ready webhook not received within timeout"
fi
# ── TEST 16: model_unloaded webhook fires ─────────────────────────────────────
echo ""
echo "--- Test 16: model_unloaded webhook ---"
python3 - <<'PYEOF' &
import http.server, json, sys, signal
class H(http.server.BaseHTTPRequestHandler):
def do_POST(self):
n = int(self.headers.get('Content-Length', 0))
body = json.loads(self.rfile.read(n))
with open('/tmp/t16_webhook.json', 'w') as f:
json.dump(body, f)
self.send_response(200); self.end_headers()
def log_message(self, *a): pass
signal.signal(signal.SIGTERM, lambda *_: sys.exit(0))
http.server.HTTPServer(('', 9997), H).serve_forever()
PYEOF
WBOOK2_PID=$!
sleep 1
# Register webhook URL
curl -sf -X POST "$BASE/jobs" \
-F "audio=@/dev/urandom;type=audio/wav" \
-F "webhook_url=http://localhost:9997/wh" \
--max-time 5 > /dev/null 2>&1 || true
ensure_ready
# Unload
curl -sf -X POST "$BASE/model/unload" > /dev/null
sleep 5
kill $WBOOK2_PID 2>/dev/null || true
wait $WBOOK2_PID 2>/dev/null || true
if [ -f /tmp/t16_webhook.json ]; then
EVENT_TYPE=$(python3 -c "import json; d=json.load(open('/tmp/t16_webhook.json')); print(d.get('type','?'))")
[ "$EVENT_TYPE" = "model_unloaded" ] && ok "T16: model_unloaded webhook fired" \
|| fail "T16: webhook type=$EVENT_TYPE (expected model_unloaded)"
rm -f /tmp/t16_webhook.json
else
fail "T16: model_unloaded webhook not received"
fi
# ── TEST 17: Concurrent load requests — single load, stable ready ─────────────
echo ""
echo "--- Test 17: Concurrent POST /model/load requests ---"
ensure_unloaded
# Send 3 concurrent load requests
curl -sf -X POST "$BASE/model/load" > /dev/null &
curl -sf -X POST "$BASE/model/load" > /dev/null &
curl -sf -X POST "$BASE/model/load" > /dev/null &
wait
poll_state_transition "ready" 180 || true
STATE=$(get_state)
[ "$STATE" = "ready" ] && ok "T17: concurrent loads handled cleanly, state=ready" \
|| fail "T17: expected ready after concurrent loads, got $STATE"
# ── TEST 18: POST /model/unload during loading → clean unloaded ───────────────
echo ""
echo "--- Test 18: POST /model/unload during loading ---"
ensure_unloaded
curl -sf -X POST "$BASE/model/load" > /dev/null
sleep 1 # Hopefully still in loading state
curl -sf -X POST "$BASE/model/unload" > /dev/null
# Allow time for the unload to propagate
sleep 5
STATE=$(get_state)
if [ "$STATE" = "unloaded" ]; then
ok "T18: unload during loading → clean unloaded"
elif [ "$STATE" = "ready" ]; then
# Load completed before unload arrived — immediately unload
curl -sf -X POST "$BASE/model/unload" > /dev/null
sleep 3
STATE=$(get_state)
[ "$STATE" = "unloaded" ] && ok "T18: load completed then unloaded (race condition OK)" \
|| fail "T18: state=$STATE after load+unload"
else
fail "T18: unexpected state after unload-during-load: $STATE"
fi
# ── Summary ────────────────────────────────────────────────────────────────────
echo ""
echo "=========================================="
echo " Results: ${PASS} passed, ${FAIL} failed"
echo "=========================================="
[ $FAIL -eq 0 ] && echo -e "${GREEN}ALL PASSED${NC}" || { echo -e "${RED}FAILURES: $FAIL${NC}"; exit 1; }