feat: dynamic model loading/unloading with GPU polling
All checks were successful
Build & Push Docker Image / build-and-push (push) Successful in 8m41s
All checks were successful
Build & Push Docker Image / build-and-push (push) Successful in 8m41s
- Model starts unloaded (lazy); loads on first job or POST /model/load
- Auto-unloads after IDLE_TIMEOUT_SECS (default 300) of inactivity
- POST /model/unload for immediate manual release
- GPU-busy detection: on VRAM OOM, enters WaitingForGpu and retries
every GPU_POLL_INTERVAL_SECS (default 30) indefinitely
- POST /jobs when unloaded → 503 + Retry-After header, triggers load
- AppError::OutOfMemory and AppError::ModelNotReady variants
- WorkerCmd channel (SyncSender<WorkerCmd>) replaces bare tx_req channel
- Idle timer via recv_timeout(1s) tick inside OS thread (no extra thread)
- Model lifecycle events broadcast via tokio broadcast channel (SSE + webhooks)
- webhook_registry: all clients that ever submitted a webhook_url receive
model_ready and model_unloaded webhooks
- GPU warmup retained on every (re)load
New routes:
GET /model/status — current state + VRAM stats
POST /model/load — trigger load (idempotent)
POST /model/unload — immediate unload
GET /model/events — SSE stream of model lifecycle events
New env vars:
IDLE_TIMEOUT_SECS (default 300)
GPU_POLL_INTERVAL_SECS (default 30)
Tests:
tests/test_model_lifecycle.sh — 18 integration tests (full state machine,
SSE events, webhooks, concurrency, unload-during-load)
tests/test_idle_timeout.sh — 5 tests with short IDLE_TIMEOUT_SECS=5
test_all.sh updated: loads model before job submission, asserts
model_state in /health, adds POST /model/unload at end
Docs:
docs/USAGE.md: model lifecycle section, new env vars, 503 retry pattern,
updated /health response shape
Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
This commit is contained in:
200
docs/USAGE.md
200
docs/USAGE.md
@@ -66,6 +66,8 @@ The bundled `docker-compose.yml` mounts named volumes for data and models and se
|
||||
| `WHISPER_MODEL_PATH` | `/models/ggml-large-v3.bin` | Absolute path to GGML model file |
|
||||
| `WHISPER_MODEL` | `large-v3` | Model name reported by `/health` (display only) |
|
||||
| `CUDA_DEVICE` | `0` | CUDA device index to use for inference |
|
||||
| `IDLE_TIMEOUT_SECS` | `300` | Seconds of idle before the model is automatically unloaded from GPU memory. Set to `0` to disable auto-unload. |
|
||||
| `GPU_POLL_INTERVAL_SECS` | `30` | Seconds between VRAM-availability retries when a load fails due to insufficient VRAM. |
|
||||
|
||||
### Note on CUDA device ordering
|
||||
Inside Docker, device ordering matches `nvidia-smi` (PCI bus order). On the host without Docker, ordering may differ. See [FINDINGS.md](FINDINGS.md#cuda-device-index-ordering-differs-between-host-and-docker) for details.
|
||||
@@ -76,6 +78,194 @@ Inside Docker, device ordering matches `nvidia-smi` (PCI bus order). On the host
|
||||
|
||||
The interactive Swagger UI is available at `http://localhost:8080/docs`.
|
||||
|
||||
---
|
||||
|
||||
## Model Lifecycle Management
|
||||
|
||||
The model starts **unloaded** on startup (lazy loading). It is loaded into GPU memory on the first job submission or via `POST /model/load`, and automatically unloaded after `IDLE_TIMEOUT_SECS` of inactivity.
|
||||
|
||||
### Model State Machine
|
||||
|
||||
```
|
||||
Unloaded ──(job / POST /model/load)──► Loading ──(success)──► Ready
|
||||
└──(VRAM full)──► WaitingForGpu ──(retry OK)──► Loading
|
||||
Ready ──(idle timeout / POST /model/unload)──► Unloaded
|
||||
WaitingForGpu ──(POST /model/unload)──► Unloaded
|
||||
```
|
||||
|
||||
### `GET /model/status`
|
||||
|
||||
Returns the current model state and VRAM statistics.
|
||||
|
||||
```bash
|
||||
curl http://localhost:8080/model/status
|
||||
```
|
||||
|
||||
**When unloaded:**
|
||||
```json
|
||||
{ "state": "unloaded" }
|
||||
```
|
||||
|
||||
**When loading:**
|
||||
```json
|
||||
{ "state": "loading" }
|
||||
```
|
||||
|
||||
**When ready:**
|
||||
```json
|
||||
{
|
||||
"state": "ready",
|
||||
"loaded_at": "2026-05-10T14:00:00Z",
|
||||
"vram_used_mb": 4096,
|
||||
"vram_total_mb": 8192
|
||||
}
|
||||
```
|
||||
|
||||
**When waiting for VRAM:**
|
||||
```json
|
||||
{
|
||||
"state": "waiting_for_gpu",
|
||||
"vram_needed_mb": 3951,
|
||||
"vram_free_mb": 512,
|
||||
"retry_in_secs": 30
|
||||
}
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
### `POST /model/load`
|
||||
|
||||
Request the model to be loaded. Idempotent — if already loading or ready, returns immediately.
|
||||
|
||||
```bash
|
||||
curl -X POST http://localhost:8080/model/load
|
||||
```
|
||||
|
||||
- Returns `202 Accepted` with `{"status":"load_initiated"}` when load is triggered
|
||||
- Returns `200 OK` with `{"status":"already_ready"}` when model is already ready
|
||||
- Poll `GET /model/status` or subscribe to `GET /model/events` to know when ready
|
||||
|
||||
---
|
||||
|
||||
### `POST /model/unload`
|
||||
|
||||
Unload the model from GPU memory immediately, freeing VRAM.
|
||||
|
||||
```bash
|
||||
curl -X POST http://localhost:8080/model/unload
|
||||
```
|
||||
|
||||
Returns `200 OK` regardless of current state.
|
||||
|
||||
---
|
||||
|
||||
### `GET /model/events` — Model SSE stream
|
||||
|
||||
Subscribe to model lifecycle events via Server-Sent Events.
|
||||
|
||||
```bash
|
||||
curl -N http://localhost:8080/model/events
|
||||
```
|
||||
|
||||
**Event types:**
|
||||
|
||||
```
|
||||
event: model_loading
|
||||
data: {"type":"model_loading"}
|
||||
|
||||
event: model_ready
|
||||
data: {"type":"model_ready","loaded_at":"2026-05-10T14:00:00Z"}
|
||||
|
||||
event: model_unloaded
|
||||
data: {"type":"model_unloaded"}
|
||||
|
||||
event: model_waiting_for_gpu
|
||||
data: {"type":"model_waiting_for_gpu","vram_needed_mb":3951,"vram_free_mb":512,"retry_in_secs":30}
|
||||
```
|
||||
|
||||
**JavaScript example:**
|
||||
```javascript
|
||||
const es = new EventSource('/model/events');
|
||||
|
||||
es.addEventListener('model_ready', () => {
|
||||
console.log('Model loaded — ready to transcribe');
|
||||
});
|
||||
|
||||
es.addEventListener('model_unloaded', () => {
|
||||
console.log('Model freed GPU memory');
|
||||
});
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
### Webhooks for model events
|
||||
|
||||
When any job is submitted with a `webhook_url`, that URL is registered to receive model lifecycle webhooks for the lifetime of the server process. The following events trigger a webhook POST:
|
||||
|
||||
| Event | Fired when |
|
||||
|-------|-----------|
|
||||
| `model_ready` | Model finishes loading (after GPU warmup) |
|
||||
| `model_unloaded` | Model is freed from GPU memory |
|
||||
|
||||
**Webhook payload** (`Content-Type: application/json`):
|
||||
```json
|
||||
{ "type": "model_ready", "loaded_at": "2026-05-10T14:00:00Z" }
|
||||
{ "type": "model_unloaded" }
|
||||
```
|
||||
|
||||
Delivery is attempted up to 3 times with exponential backoff (1s, 2s).
|
||||
|
||||
---
|
||||
|
||||
### Handling 503 Model Not Ready
|
||||
|
||||
When you submit a job and the model is not yet loaded, you receive `503 Service Unavailable` with a `Retry-After` header:
|
||||
|
||||
```
|
||||
HTTP/1.1 503 Service Unavailable
|
||||
Retry-After: 30
|
||||
Content-Type: application/json
|
||||
|
||||
{
|
||||
"error": "model_not_ready",
|
||||
"state": "unloaded",
|
||||
"retry_after_secs": 30
|
||||
}
|
||||
```
|
||||
|
||||
| State at rejection | `retry_after_secs` | Meaning |
|
||||
|---|---|---|
|
||||
| `unloaded` | 30 | Load was triggered; retry after ~30s |
|
||||
| `loading` | 10 | Check again in 10s |
|
||||
| `waiting_for_gpu` | `GPU_POLL_INTERVAL_SECS` | VRAM contention; retry later |
|
||||
|
||||
A job rejection when the model is `unloaded` **automatically triggers a load** — you do not need to call `POST /model/load` separately.
|
||||
|
||||
**Recommended client pattern:**
|
||||
```javascript
|
||||
async function submitWithRetry(formData, maxAttempts = 10) {
|
||||
for (let i = 0; i < maxAttempts; i++) {
|
||||
const resp = await fetch('/jobs', { method: 'POST', body: formData });
|
||||
if (resp.ok) return resp.json();
|
||||
if (resp.status === 503) {
|
||||
const retryAfter = parseInt(resp.headers.get('Retry-After') ?? '30');
|
||||
const body = await resp.json();
|
||||
console.log(`Model ${body.state} — retrying in ${retryAfter}s`);
|
||||
await new Promise(r => setTimeout(r, retryAfter * 1000));
|
||||
continue;
|
||||
}
|
||||
throw new Error(`Submit failed: ${resp.status}`);
|
||||
}
|
||||
throw new Error('Gave up after max attempts');
|
||||
}
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## API Reference
|
||||
|
||||
The interactive Swagger UI is available at `http://localhost:8080/docs`.
|
||||
|
||||
### `POST /jobs` — Submit a transcription job
|
||||
|
||||
Accepts a multipart/form-data body.
|
||||
@@ -249,11 +439,12 @@ curl http://localhost:8080/health
|
||||
"gpu_name": "NVIDIA GeForce RTX 2080",
|
||||
"vram_total_mb": 8192,
|
||||
"model": "large-v3",
|
||||
"queue_depth": 0
|
||||
"queue_depth": 0,
|
||||
"model_state": "ready"
|
||||
}
|
||||
```
|
||||
|
||||
`queue_depth` is the number of jobs waiting to be processed (not counting the one currently running).
|
||||
`queue_depth` is the number of jobs waiting to be processed (not counting the one currently running). `model_state` reflects the current lifecycle state (`unloaded`, `loading`, `waiting_for_gpu`, `ready`).
|
||||
|
||||
---
|
||||
|
||||
@@ -340,6 +531,11 @@ curl -X POST http://localhost:8080/jobs \
|
||||
|
||||
## Troubleshooting
|
||||
|
||||
### Server returns `503 model_not_ready`
|
||||
- The model starts unloaded. Call `POST /model/load` explicitly, or just retry the job submission — rejection automatically triggers a load.
|
||||
- If state is `waiting_for_gpu`, another process is using the GPU's VRAM. The server will retry automatically every `GPU_POLL_INTERVAL_SECS` seconds.
|
||||
- Monitor `GET /model/status` or subscribe to `GET /model/events` to know when the model is ready.
|
||||
|
||||
### Server returns 0 segments
|
||||
- Check that you are **not** setting `language` to an empty string — omit the field entirely for auto-detection
|
||||
- Verify the audio file is not corrupted: `ffprobe audio.mp3`
|
||||
|
||||
141
src/error.rs
141
src/error.rs
@@ -1,6 +1,6 @@
|
||||
use thiserror::Error;
|
||||
use axum::{
|
||||
http::StatusCode,
|
||||
http::{StatusCode, HeaderValue, header},
|
||||
response::{IntoResponse, Response},
|
||||
Json,
|
||||
};
|
||||
@@ -21,19 +21,138 @@ pub enum AppError {
|
||||
|
||||
#[error("internal error: {0}")]
|
||||
Internal(String),
|
||||
|
||||
/// Returned when `whisper_init_state` or `cudaMalloc` fails due to
|
||||
/// insufficient VRAM. The worker uses this to distinguish a recoverable
|
||||
/// VRAM-pressure failure from a hard internal error.
|
||||
#[error("out of GPU memory: {0}")]
|
||||
OutOfMemory(String),
|
||||
|
||||
/// Returned when a job is submitted but the model is not yet loaded.
|
||||
/// Carries the current state tag and recommended Retry-After seconds.
|
||||
#[error("model not ready: {state}")]
|
||||
ModelNotReady { state: String, retry_after_secs: u64 },
|
||||
}
|
||||
|
||||
impl AppError {
|
||||
/// Returns true if the error string contains patterns emitted by
|
||||
/// whisper.cpp / GGML when a CUDA memory allocation fails.
|
||||
pub fn is_oom(msg: &str) -> bool {
|
||||
msg.contains("cudaMalloc failed")
|
||||
|| msg.contains("out of memory")
|
||||
|| msg.contains("CUDA error: out of memory")
|
||||
|| msg.contains("alloc_buffer")
|
||||
}
|
||||
}
|
||||
|
||||
impl IntoResponse for AppError {
|
||||
fn into_response(self) -> Response {
|
||||
let (status, message) = match &self {
|
||||
AppError::NotFound(m) => (StatusCode::NOT_FOUND, m.clone()),
|
||||
AppError::BadRequest(m) => (StatusCode::BAD_REQUEST, m.clone()),
|
||||
AppError::Conflict(m) => (StatusCode::CONFLICT, m.clone()),
|
||||
AppError::Internal(m) => (StatusCode::INTERNAL_SERVER_ERROR, m.clone()),
|
||||
};
|
||||
|
||||
tracing::error!(status = status.as_u16(), error = %message);
|
||||
|
||||
(status, Json(json!({ "error": message }))).into_response()
|
||||
match self {
|
||||
AppError::NotFound(m) => {
|
||||
(StatusCode::NOT_FOUND, Json(json!({ "error": m }))).into_response()
|
||||
}
|
||||
AppError::BadRequest(m) => {
|
||||
(StatusCode::BAD_REQUEST, Json(json!({ "error": m }))).into_response()
|
||||
}
|
||||
AppError::Conflict(m) => {
|
||||
(StatusCode::CONFLICT, Json(json!({ "error": m }))).into_response()
|
||||
}
|
||||
AppError::Internal(m) => {
|
||||
tracing::error!(error = %m, "internal error");
|
||||
(StatusCode::INTERNAL_SERVER_ERROR, Json(json!({ "error": m }))).into_response()
|
||||
}
|
||||
AppError::OutOfMemory(m) => {
|
||||
tracing::warn!(error = %m, "GPU out of memory during model load");
|
||||
(StatusCode::SERVICE_UNAVAILABLE, Json(json!({ "error": m }))).into_response()
|
||||
}
|
||||
AppError::ModelNotReady { state, retry_after_secs } => {
|
||||
let body = Json(json!({
|
||||
"error": "model_not_ready",
|
||||
"state": state,
|
||||
"retry_after_secs": retry_after_secs,
|
||||
}));
|
||||
let mut resp = (StatusCode::SERVICE_UNAVAILABLE, body).into_response();
|
||||
resp.headers_mut().insert(
|
||||
header::RETRY_AFTER,
|
||||
HeaderValue::from_str(&retry_after_secs.to_string())
|
||||
.unwrap_or(HeaderValue::from_static("30")),
|
||||
);
|
||||
resp
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ── Unit tests ───────────────────────────────────────────────────────────────
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use axum::body::to_bytes;
|
||||
|
||||
#[test]
|
||||
fn test_is_oom_cuda_malloc() {
|
||||
assert!(AppError::is_oom("cudaMalloc failed: out of memory"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_is_oom_alloc_buffer() {
|
||||
// Exact message from ggml_backend_cuda_buffer_type_alloc_buffer
|
||||
assert!(AppError::is_oom(
|
||||
"ggml_backend_cuda_buffer_type_alloc_buffer: allocating 2951.01 MiB on device 0: cudaMalloc failed: out of memory"
|
||||
));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_is_oom_generic_out_of_memory() {
|
||||
assert!(AppError::is_oom("CUDA error: out of memory"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_is_oom_other_error() {
|
||||
assert!(!AppError::is_oom("failed to open model file"));
|
||||
assert!(!AppError::is_oom("invalid model format"));
|
||||
assert!(!AppError::is_oom(""));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_model_not_ready_response_has_retry_after_header() {
|
||||
let err = AppError::ModelNotReady { state: "loading".into(), retry_after_secs: 10 };
|
||||
let resp = err.into_response();
|
||||
assert_eq!(resp.status(), StatusCode::SERVICE_UNAVAILABLE);
|
||||
let retry_after = resp.headers().get(header::RETRY_AFTER)
|
||||
.expect("Retry-After header missing");
|
||||
assert_eq!(retry_after, "10");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_model_not_ready_response_body() {
|
||||
let err = AppError::ModelNotReady { state: "unloaded".into(), retry_after_secs: 30 };
|
||||
let resp = err.into_response();
|
||||
let bytes = to_bytes(resp.into_body(), usize::MAX).await.unwrap();
|
||||
let v: serde_json::Value = serde_json::from_slice(&bytes).unwrap();
|
||||
assert_eq!(v["error"], "model_not_ready");
|
||||
assert_eq!(v["state"], "unloaded");
|
||||
assert_eq!(v["retry_after_secs"], 30);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_model_not_ready_loading_retry_after_10() {
|
||||
let err = AppError::ModelNotReady { state: "loading".into(), retry_after_secs: 10 };
|
||||
let resp = err.into_response();
|
||||
assert_eq!(
|
||||
resp.headers().get(header::RETRY_AFTER).unwrap(),
|
||||
"10"
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_model_not_ready_unloaded_retry_after_30() {
|
||||
let err = AppError::ModelNotReady { state: "unloaded".into(), retry_after_secs: 30 };
|
||||
let resp = err.into_response();
|
||||
assert_eq!(
|
||||
resp.headers().get(header::RETRY_AFTER).unwrap(),
|
||||
"30"
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
60
src/main.rs
60
src/main.rs
@@ -1,7 +1,7 @@
|
||||
use std::sync::Arc;
|
||||
|
||||
use axum::Router;
|
||||
use tokio::sync::mpsc;
|
||||
use tokio::sync::{broadcast, mpsc, RwLock};
|
||||
use tower_http::{cors::CorsLayer, trace::TraceLayer};
|
||||
use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt, EnvFilter};
|
||||
use utoipa::OpenApi;
|
||||
@@ -21,8 +21,10 @@ pub use error::{AppError, Result};
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct AppState {
|
||||
/// Channel to submit jobs to the single GPU worker.
|
||||
/// Channel to submit jobs to the single GPU worker (job IDs only).
|
||||
pub job_tx: mpsc::UnboundedSender<models::JobId>,
|
||||
/// Channel to send control commands to the worker OS thread.
|
||||
pub cmd_tx: std::sync::mpsc::SyncSender<worker::WorkerCmd>,
|
||||
/// Shared handle to the on-disk job store.
|
||||
pub storage: Arc<storage::Storage>,
|
||||
/// SSE broadcast registry: job_id → sender.
|
||||
@@ -33,6 +35,17 @@ pub struct AppState {
|
||||
pub queue_depth: Arc<std::sync::atomic::AtomicUsize>,
|
||||
/// CUDA device index used for inference.
|
||||
pub gpu_device: u32,
|
||||
/// Current state of the whisper model.
|
||||
pub model_state: Arc<RwLock<models::ModelState>>,
|
||||
/// Broadcast channel for model lifecycle events (SSE + webhooks).
|
||||
pub model_event_tx: broadcast::Sender<models::ModelEvent>,
|
||||
/// All webhook URLs ever registered via job submission.
|
||||
/// Used to fire model_ready / model_unloaded notifications.
|
||||
pub webhook_registry: Arc<std::sync::Mutex<std::collections::HashSet<String>>>,
|
||||
/// How long the model stays loaded with no active jobs.
|
||||
pub idle_timeout: std::time::Duration,
|
||||
/// How often to retry loading when GPU is busy.
|
||||
pub gpu_poll_interval: std::time::Duration,
|
||||
}
|
||||
|
||||
// ── OpenAPI spec root ────────────────────────────────────────────────────────
|
||||
@@ -50,6 +63,10 @@ pub struct AppState {
|
||||
routes::jobs::stream_job,
|
||||
routes::jobs::delete_job,
|
||||
routes::health::health,
|
||||
routes::model::model_status,
|
||||
routes::model::model_load,
|
||||
routes::model::model_unload,
|
||||
routes::model::model_events,
|
||||
),
|
||||
components(schemas(
|
||||
models::Job,
|
||||
@@ -58,10 +75,14 @@ pub struct AppState {
|
||||
models::Word,
|
||||
models::SubmitResponse,
|
||||
models::HealthResponse,
|
||||
models::ModelState,
|
||||
models::ModelEvent,
|
||||
models::ModelStatusResponse,
|
||||
)),
|
||||
tags(
|
||||
(name = "jobs", description = "Transcription job management"),
|
||||
(name = "system", description = "Service health"),
|
||||
(name = "model", description = "Model lifecycle management"),
|
||||
)
|
||||
)]
|
||||
struct ApiDoc;
|
||||
@@ -85,6 +106,20 @@ async fn main() -> anyhow::Result<()> {
|
||||
.ok()
|
||||
.and_then(|s| s.parse().ok())
|
||||
.unwrap_or(0);
|
||||
let idle_timeout_secs: u64 = std::env::var("IDLE_TIMEOUT_SECS")
|
||||
.ok()
|
||||
.and_then(|s| s.parse().ok())
|
||||
.unwrap_or(300);
|
||||
let gpu_poll_interval_secs: u64 = std::env::var("GPU_POLL_INTERVAL_SECS")
|
||||
.ok()
|
||||
.and_then(|s| s.parse().ok())
|
||||
.unwrap_or(30);
|
||||
|
||||
tracing::info!(
|
||||
idle_timeout_secs,
|
||||
gpu_poll_interval_secs,
|
||||
"dynamic model loading configured"
|
||||
);
|
||||
|
||||
let storage = Arc::new(storage::Storage::new(&data_dir).await?);
|
||||
|
||||
@@ -94,28 +129,45 @@ async fn main() -> anyhow::Result<()> {
|
||||
let (job_tx, job_rx) = mpsc::unbounded_channel::<models::JobId>();
|
||||
let queue_depth = Arc::new(std::sync::atomic::AtomicUsize::new(0));
|
||||
|
||||
// Spawn single GPU worker; get back the SSE broadcast registry.
|
||||
let progress = worker::start(
|
||||
// Model starts unloaded — lazy load on first job or POST /model/load.
|
||||
let model_state = Arc::new(RwLock::new(models::ModelState::Unloaded));
|
||||
let (model_event_tx, _) = broadcast::channel::<models::ModelEvent>(32);
|
||||
let webhook_registry = Arc::new(std::sync::Mutex::new(std::collections::HashSet::<String>::new()));
|
||||
|
||||
// Spawn single GPU worker; get back the SSE broadcast registry and cmd channel.
|
||||
let (progress, cmd_tx) = worker::start(
|
||||
job_rx,
|
||||
Arc::clone(&storage),
|
||||
model_path.clone().into(),
|
||||
Arc::clone(&queue_depth),
|
||||
gpu_device,
|
||||
Arc::clone(&model_state),
|
||||
model_event_tx.clone(),
|
||||
Arc::clone(&webhook_registry),
|
||||
std::time::Duration::from_secs(idle_timeout_secs),
|
||||
std::time::Duration::from_secs(gpu_poll_interval_secs),
|
||||
);
|
||||
|
||||
let state = AppState {
|
||||
job_tx,
|
||||
cmd_tx,
|
||||
storage: Arc::clone(&storage),
|
||||
progress,
|
||||
model_name: model_name.as_str().into(),
|
||||
queue_depth: Arc::clone(&queue_depth),
|
||||
gpu_device,
|
||||
model_state,
|
||||
model_event_tx,
|
||||
webhook_registry,
|
||||
idle_timeout: std::time::Duration::from_secs(idle_timeout_secs),
|
||||
gpu_poll_interval: std::time::Duration::from_secs(gpu_poll_interval_secs),
|
||||
};
|
||||
|
||||
let app = Router::new()
|
||||
.merge(SwaggerUi::new("/docs").url("/openapi.json", ApiDoc::openapi()))
|
||||
.merge(routes::jobs_router())
|
||||
.merge(routes::health_router())
|
||||
.merge(routes::model_router())
|
||||
.with_state(state)
|
||||
.layer(CorsLayer::permissive())
|
||||
.layer(TraceLayer::new_for_http());
|
||||
|
||||
246
src/models.rs
246
src/models.rs
@@ -5,6 +5,116 @@ use uuid::Uuid;
|
||||
|
||||
pub type JobId = Uuid;
|
||||
|
||||
// ── Model lifecycle state ────────────────────────────────────────────────────
|
||||
|
||||
/// Current state of the whisper model in memory.
|
||||
///
|
||||
/// State machine:
|
||||
/// ```
|
||||
/// Unloaded ──(load trigger)──► Loading ──(ok)──► Ready ──(idle/unload)──► Unloaded
|
||||
/// └──(VRAM full)──► WaitingForGpu ──(retry)──► Loading
|
||||
/// WaitingForGpu ──(unload cmd)──► Unloaded
|
||||
/// ```
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, ToSchema)]
|
||||
#[serde(tag = "state", rename_all = "snake_case")]
|
||||
pub enum ModelState {
|
||||
/// Model is not in memory. GPU is free.
|
||||
Unloaded,
|
||||
/// Model is being loaded (weights transferred to GPU).
|
||||
Loading,
|
||||
/// A previous load attempt failed due to insufficient VRAM. The worker is
|
||||
/// polling at `retry_in_secs` intervals until enough memory is available.
|
||||
WaitingForGpu {
|
||||
/// VRAM required to load the model, in MiB.
|
||||
vram_needed_mb: u64,
|
||||
/// VRAM currently free on the device, in MiB.
|
||||
vram_free_mb: u64,
|
||||
/// How many seconds until the next load attempt.
|
||||
retry_in_secs: u64,
|
||||
},
|
||||
/// Model is loaded and ready to accept inference jobs.
|
||||
Ready {
|
||||
/// UTC timestamp of when the model finished loading (post-warmup).
|
||||
loaded_at: DateTime<Utc>,
|
||||
},
|
||||
}
|
||||
|
||||
impl ModelState {
|
||||
/// Returns true if the model can accept inference jobs right now.
|
||||
pub fn is_ready(&self) -> bool {
|
||||
matches!(self, ModelState::Ready { .. })
|
||||
}
|
||||
|
||||
/// Suggested `Retry-After` value (seconds) to include in 503 responses.
|
||||
pub fn retry_after_secs(&self) -> u64 {
|
||||
match self {
|
||||
ModelState::Unloaded => 30, // conservative load estimate
|
||||
ModelState::Loading => 10,
|
||||
ModelState::WaitingForGpu { retry_in_secs, .. } => *retry_in_secs,
|
||||
ModelState::Ready { .. } => 0, // shouldn't 503 if ready
|
||||
}
|
||||
}
|
||||
|
||||
/// String tag for use in error response bodies and log fields.
|
||||
pub fn tag(&self) -> &'static str {
|
||||
match self {
|
||||
ModelState::Unloaded => "unloaded",
|
||||
ModelState::Loading => "loading",
|
||||
ModelState::WaitingForGpu{..} => "waiting_for_gpu",
|
||||
ModelState::Ready{..} => "ready",
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ── Model events (SSE + webhooks) ────────────────────────────────────────────
|
||||
|
||||
/// Events broadcast over the `GET /model/events` SSE stream and fired as
|
||||
/// webhooks to registered clients.
|
||||
///
|
||||
/// Webhook delivery: only `ModelReady` and `ModelUnloaded` are sent to
|
||||
/// webhook URLs. All four are broadcast on the SSE stream.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, ToSchema)]
|
||||
#[serde(tag = "type", rename_all = "snake_case")]
|
||||
pub enum ModelEvent {
|
||||
/// Model finished loading and the GPU warmup completed — ready to accept jobs.
|
||||
ModelReady {
|
||||
loaded_at: DateTime<Utc>,
|
||||
},
|
||||
/// Model was unloaded from GPU memory (idle timeout or manual unload).
|
||||
ModelUnloaded,
|
||||
/// Model load initiated.
|
||||
ModelLoading,
|
||||
/// Load failed due to insufficient VRAM; retrying after `retry_in_secs`.
|
||||
ModelWaitingForGpu {
|
||||
vram_needed_mb: u64,
|
||||
vram_free_mb: u64,
|
||||
retry_in_secs: u64,
|
||||
},
|
||||
}
|
||||
|
||||
impl ModelEvent {
|
||||
/// Returns true if this event should be delivered via webhook.
|
||||
pub fn is_webhook_event(&self) -> bool {
|
||||
matches!(self, ModelEvent::ModelReady { .. } | ModelEvent::ModelUnloaded)
|
||||
}
|
||||
}
|
||||
|
||||
// ── Model status response ────────────────────────────────────────────────────
|
||||
|
||||
/// Response body for `GET /model/status`.
|
||||
#[derive(Debug, Serialize, Deserialize, ToSchema)]
|
||||
pub struct ModelStatusResponse {
|
||||
/// Current model state (flattened from `ModelState`).
|
||||
#[serde(flatten)]
|
||||
pub state: ModelState,
|
||||
/// VRAM currently used on the device, in MiB (from nvidia-smi).
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub vram_used_mb: Option<u64>,
|
||||
/// VRAM total on the device, in MiB.
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub vram_total_mb: Option<u64>,
|
||||
}
|
||||
|
||||
// ── Job status ───────────────────────────────────────────────────────────────
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize, ToSchema)]
|
||||
@@ -130,6 +240,8 @@ pub struct HealthResponse {
|
||||
pub vram_total_mb: Option<u64>,
|
||||
pub model: String,
|
||||
pub queue_depth: usize,
|
||||
/// Current state of the whisper model.
|
||||
pub model_state: String,
|
||||
}
|
||||
|
||||
// ── SSE event payload ────────────────────────────────────────────────────────
|
||||
@@ -148,3 +260,137 @@ pub enum SsePayload {
|
||||
Done { job: Box<Job> },
|
||||
Error { message: String },
|
||||
}
|
||||
|
||||
// ── Unit tests ───────────────────────────────────────────────────────────────
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use serde_json::Value;
|
||||
|
||||
// ── ModelState serialization ─────────────────────────────────────────────
|
||||
|
||||
#[test]
|
||||
fn test_model_state_unloaded_serializes() {
|
||||
let v: Value = serde_json::to_value(ModelState::Unloaded).unwrap();
|
||||
assert_eq!(v["state"], "unloaded");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_model_state_loading_serializes() {
|
||||
let v: Value = serde_json::to_value(ModelState::Loading).unwrap();
|
||||
assert_eq!(v["state"], "loading");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_model_state_waiting_serializes() {
|
||||
let s = ModelState::WaitingForGpu { vram_needed_mb: 3000, vram_free_mb: 500, retry_in_secs: 30 };
|
||||
let v: Value = serde_json::to_value(&s).unwrap();
|
||||
assert_eq!(v["state"], "waiting_for_gpu");
|
||||
assert_eq!(v["vram_needed_mb"], 3000);
|
||||
assert_eq!(v["vram_free_mb"], 500);
|
||||
assert_eq!(v["retry_in_secs"], 30);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_model_state_ready_serializes() {
|
||||
let ts = Utc::now();
|
||||
let s = ModelState::Ready { loaded_at: ts };
|
||||
let v: Value = serde_json::to_value(&s).unwrap();
|
||||
assert_eq!(v["state"], "ready");
|
||||
assert!(v["loaded_at"].is_string());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_model_state_is_ready() {
|
||||
assert!(!ModelState::Unloaded.is_ready());
|
||||
assert!(!ModelState::Loading.is_ready());
|
||||
assert!(!ModelState::WaitingForGpu { vram_needed_mb: 0, vram_free_mb: 0, retry_in_secs: 30 }.is_ready());
|
||||
assert!(ModelState::Ready { loaded_at: Utc::now() }.is_ready());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_retry_after_unloaded() {
|
||||
assert_eq!(ModelState::Unloaded.retry_after_secs(), 30);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_retry_after_loading() {
|
||||
assert_eq!(ModelState::Loading.retry_after_secs(), 10);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_retry_after_waiting_for_gpu() {
|
||||
let s = ModelState::WaitingForGpu { vram_needed_mb: 0, vram_free_mb: 0, retry_in_secs: 45 };
|
||||
assert_eq!(s.retry_after_secs(), 45);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_retry_after_ready_is_zero() {
|
||||
assert_eq!(ModelState::Ready { loaded_at: Utc::now() }.retry_after_secs(), 0);
|
||||
}
|
||||
|
||||
// ── ModelEvent serialization ─────────────────────────────────────────────
|
||||
|
||||
#[test]
|
||||
fn test_model_event_ready_serializes() {
|
||||
let ts = Utc::now();
|
||||
let e = ModelEvent::ModelReady { loaded_at: ts };
|
||||
let v: Value = serde_json::to_value(&e).unwrap();
|
||||
assert_eq!(v["type"], "model_ready");
|
||||
assert!(v["loaded_at"].is_string());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_model_event_unloaded_serializes() {
|
||||
let v: Value = serde_json::to_value(ModelEvent::ModelUnloaded).unwrap();
|
||||
assert_eq!(v["type"], "model_unloaded");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_model_event_loading_serializes() {
|
||||
let v: Value = serde_json::to_value(ModelEvent::ModelLoading).unwrap();
|
||||
assert_eq!(v["type"], "model_loading");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_model_event_waiting_serializes() {
|
||||
let e = ModelEvent::ModelWaitingForGpu { vram_needed_mb: 3000, vram_free_mb: 200, retry_in_secs: 30 };
|
||||
let v: Value = serde_json::to_value(&e).unwrap();
|
||||
assert_eq!(v["type"], "model_waiting_for_gpu");
|
||||
assert_eq!(v["vram_needed_mb"], 3000);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_model_event_webhook_filter() {
|
||||
assert!(ModelEvent::ModelReady { loaded_at: Utc::now() }.is_webhook_event());
|
||||
assert!(ModelEvent::ModelUnloaded.is_webhook_event());
|
||||
assert!(!ModelEvent::ModelLoading.is_webhook_event());
|
||||
assert!(!ModelEvent::ModelWaitingForGpu { vram_needed_mb: 0, vram_free_mb: 0, retry_in_secs: 30 }.is_webhook_event());
|
||||
}
|
||||
|
||||
// ── ModelStatusResponse ──────────────────────────────────────────────────
|
||||
|
||||
#[test]
|
||||
fn test_model_status_response_roundtrip() {
|
||||
let r = ModelStatusResponse {
|
||||
state: ModelState::Ready { loaded_at: Utc::now() },
|
||||
vram_used_mb: Some(4096),
|
||||
vram_total_mb: Some(8192),
|
||||
};
|
||||
let json_str = serde_json::to_string(&r).unwrap();
|
||||
let v: Value = serde_json::from_str(&json_str).unwrap();
|
||||
assert_eq!(v["state"], "ready");
|
||||
assert_eq!(v["vram_used_mb"], 4096);
|
||||
assert_eq!(v["vram_total_mb"], 8192);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_model_status_response_omits_nulls() {
|
||||
let r = ModelStatusResponse { state: ModelState::Loading, vram_used_mb: None, vram_total_mb: None };
|
||||
let v: Value = serde_json::to_value(&r).unwrap();
|
||||
assert_eq!(v["state"], "loading");
|
||||
assert!(v.get("vram_used_mb").is_none());
|
||||
assert!(v.get("vram_total_mb").is_none());
|
||||
}
|
||||
}
|
||||
|
||||
@@ -16,6 +16,7 @@ use crate::{models::HealthResponse, AppState, Result};
|
||||
)]
|
||||
pub async fn health(State(state): State<AppState>) -> Result<Json<HealthResponse>> {
|
||||
let (gpu_name, vram_total_mb) = gpu_info(state.gpu_device);
|
||||
let model_state_tag = state.model_state.read().await.tag().to_string();
|
||||
|
||||
Ok(Json(HealthResponse {
|
||||
status: "ok".into(),
|
||||
@@ -23,6 +24,7 @@ pub async fn health(State(state): State<AppState>) -> Result<Json<HealthResponse
|
||||
vram_total_mb,
|
||||
model: state.model_name.to_string(),
|
||||
queue_depth: state.queue_depth.load(Ordering::Relaxed),
|
||||
model_state: model_state_tag,
|
||||
}))
|
||||
}
|
||||
|
||||
|
||||
@@ -19,7 +19,7 @@ use uuid::Uuid;
|
||||
|
||||
use crate::{
|
||||
models::{Job, JobId, JobStatus, SubmitResponse},
|
||||
worker::{audio_path_for, ProgressEvent},
|
||||
worker::{audio_path_for, ProgressEvent, WorkerCmd},
|
||||
AppError, AppState, Result,
|
||||
};
|
||||
|
||||
@@ -107,6 +107,36 @@ pub async fn submit_job(
|
||||
));
|
||||
}
|
||||
|
||||
// Check model state before accepting the job.
|
||||
let (model_ready, retry_after_secs, state_tag) = {
|
||||
let ms = state.model_state.read().await;
|
||||
let ready = ms.is_ready();
|
||||
let retry = ms.retry_after_secs();
|
||||
let tag = ms.tag().to_string();
|
||||
(ready, retry, tag)
|
||||
};
|
||||
|
||||
// Register the webhook URL regardless of model state — so model lifecycle
|
||||
// events are delivered even if the job itself is rejected.
|
||||
if let Some(url) = &webhook_url {
|
||||
state.webhook_registry.lock()
|
||||
.unwrap_or_else(|e| e.into_inner())
|
||||
.insert(url.clone());
|
||||
}
|
||||
|
||||
if !model_ready {
|
||||
// Trigger a load if the model is simply unloaded (not already loading).
|
||||
if state_tag == "unloaded" {
|
||||
let _ = state.cmd_tx.try_send(WorkerCmd::Load);
|
||||
}
|
||||
// Clean up the audio file we already wrote to disk.
|
||||
let _ = tokio::fs::remove_file(&audio_path).await;
|
||||
return Err(AppError::ModelNotReady {
|
||||
state: state_tag,
|
||||
retry_after_secs,
|
||||
});
|
||||
}
|
||||
|
||||
let mut job = Job::new(id, task, webhook_url, filename);
|
||||
job.language = language;
|
||||
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
pub mod health;
|
||||
pub mod jobs;
|
||||
pub mod model;
|
||||
|
||||
use axum::{extract::DefaultBodyLimit, routing::{delete, get, post}, Router};
|
||||
use crate::AppState;
|
||||
@@ -17,3 +18,11 @@ pub fn health_router() -> Router<AppState> {
|
||||
Router::new()
|
||||
.route("/health", get(health::health))
|
||||
}
|
||||
|
||||
pub fn model_router() -> Router<AppState> {
|
||||
Router::new()
|
||||
.route("/model/status", get(model::model_status))
|
||||
.route("/model/load", post(model::model_load))
|
||||
.route("/model/unload", post(model::model_unload))
|
||||
.route("/model/events", get(model::model_events))
|
||||
}
|
||||
|
||||
158
src/routes/model.rs
Normal file
158
src/routes/model.rs
Normal file
@@ -0,0 +1,158 @@
|
||||
use std::pin::Pin;
|
||||
|
||||
use axum::{
|
||||
extract::State,
|
||||
http::StatusCode,
|
||||
response::{
|
||||
sse::{Event, KeepAlive, Sse},
|
||||
IntoResponse,
|
||||
},
|
||||
Json,
|
||||
};
|
||||
use futures::Stream;
|
||||
use tokio_stream::wrappers::BroadcastStream;
|
||||
use futures::StreamExt;
|
||||
|
||||
use crate::{
|
||||
models::{ModelEvent, ModelStatusResponse},
|
||||
worker::WorkerCmd,
|
||||
AppState, Result,
|
||||
};
|
||||
|
||||
type SseStream = Pin<Box<dyn Stream<Item = std::result::Result<Event, std::convert::Infallible>> + Send>>;
|
||||
|
||||
// ── GET /model/status ────────────────────────────────────────────────────────
|
||||
|
||||
/// Return the current model state and VRAM statistics.
|
||||
#[utoipa::path(
|
||||
get,
|
||||
path = "/model/status",
|
||||
tag = "model",
|
||||
responses(
|
||||
(status = 200, description = "Model status", body = ModelStatusResponse),
|
||||
)
|
||||
)]
|
||||
pub async fn model_status(State(state): State<AppState>) -> Result<Json<ModelStatusResponse>> {
|
||||
let model_state = state.model_state.read().await.clone();
|
||||
let (vram_used_mb, vram_total_mb) = vram_stats(state.gpu_device);
|
||||
|
||||
Ok(Json(ModelStatusResponse {
|
||||
state: model_state,
|
||||
vram_used_mb,
|
||||
vram_total_mb,
|
||||
}))
|
||||
}
|
||||
|
||||
// ── POST /model/load ─────────────────────────────────────────────────────────
|
||||
|
||||
/// Request the model to be loaded into GPU memory.
|
||||
/// Idempotent: if the model is already loading or ready, this is a no-op.
|
||||
/// Returns 202 Accepted; poll `GET /model/status` or subscribe to
|
||||
/// `GET /model/events` to know when it is ready.
|
||||
#[utoipa::path(
|
||||
post,
|
||||
path = "/model/load",
|
||||
tag = "model",
|
||||
responses(
|
||||
(status = 202, description = "Load initiated or already in progress"),
|
||||
(status = 200, description = "Model already ready"),
|
||||
)
|
||||
)]
|
||||
pub async fn model_load(State(state): State<AppState>) -> impl IntoResponse {
|
||||
let is_ready = state.model_state.read().await.is_ready();
|
||||
if is_ready {
|
||||
return (StatusCode::OK, Json(serde_json::json!({"status": "already_ready"})));
|
||||
}
|
||||
// Ignore send errors (channel full = load already in progress).
|
||||
let _ = state.cmd_tx.try_send(WorkerCmd::Load);
|
||||
(StatusCode::ACCEPTED, Json(serde_json::json!({"status": "load_initiated"})))
|
||||
}
|
||||
|
||||
// ── POST /model/unload ───────────────────────────────────────────────────────
|
||||
|
||||
/// Unload the model from GPU memory immediately.
|
||||
/// Idempotent: if the model is already unloaded, returns 200 immediately.
|
||||
#[utoipa::path(
|
||||
post,
|
||||
path = "/model/unload",
|
||||
tag = "model",
|
||||
responses(
|
||||
(status = 200, description = "Model unloaded or was already unloaded"),
|
||||
)
|
||||
)]
|
||||
pub async fn model_unload(State(state): State<AppState>) -> impl IntoResponse {
|
||||
let _ = state.cmd_tx.try_send(WorkerCmd::Unload);
|
||||
(StatusCode::OK, Json(serde_json::json!({"status": "unload_requested"})))
|
||||
}
|
||||
|
||||
// ── GET /model/events ────────────────────────────────────────────────────────
|
||||
|
||||
/// Subscribe to model lifecycle events via Server-Sent Events.
|
||||
///
|
||||
/// Event types:
|
||||
/// - `model_loading` — load initiated
|
||||
/// - `model_ready` — model loaded and warmed up
|
||||
/// - `model_unloaded` — model freed from GPU memory
|
||||
/// - `model_waiting_for_gpu` — insufficient VRAM; retrying
|
||||
#[utoipa::path(
|
||||
get,
|
||||
path = "/model/events",
|
||||
tag = "model",
|
||||
responses(
|
||||
(status = 200, description = "SSE stream of model lifecycle events"),
|
||||
)
|
||||
)]
|
||||
pub async fn model_events(State(state): State<AppState>) -> Sse<SseStream> {
|
||||
let rx = state.model_event_tx.subscribe();
|
||||
|
||||
let stream: SseStream = Box::pin(
|
||||
BroadcastStream::new(rx).filter_map(|msg| async move {
|
||||
match msg {
|
||||
Ok(event) => {
|
||||
let event_type = match &event {
|
||||
ModelEvent::ModelReady { .. } => "model_ready",
|
||||
ModelEvent::ModelUnloaded => "model_unloaded",
|
||||
ModelEvent::ModelLoading => "model_loading",
|
||||
ModelEvent::ModelWaitingForGpu {..} => "model_waiting_for_gpu",
|
||||
};
|
||||
let data = serde_json::to_string(&event).ok()?;
|
||||
Some(Ok(Event::default().event(event_type).data(data)))
|
||||
}
|
||||
Err(_) => None,
|
||||
}
|
||||
})
|
||||
);
|
||||
|
||||
Sse::new(stream).keep_alive(KeepAlive::default())
|
||||
}
|
||||
|
||||
// ── Helpers ───────────────────────────────────────────────────────────────────
|
||||
|
||||
fn vram_stats(gpu_device: u32) -> (Option<u64>, Option<u64>) {
|
||||
fn inner(gpu_device: u32) -> Option<(u64, u64)> {
|
||||
let out = std::process::Command::new("nvidia-smi")
|
||||
.args([
|
||||
&format!("--id={gpu_device}"),
|
||||
"--query-gpu=memory.used,memory.total",
|
||||
"--format=csv,noheader,nounits",
|
||||
])
|
||||
.output()
|
||||
.ok()?;
|
||||
|
||||
if !out.status.success() {
|
||||
return None;
|
||||
}
|
||||
|
||||
let line = String::from_utf8_lossy(&out.stdout);
|
||||
let line = line.trim();
|
||||
let mut parts = line.splitn(2, ',');
|
||||
let used = parts.next().and_then(|s| s.trim().parse::<u64>().ok())?;
|
||||
let total = parts.next().and_then(|s| s.trim().parse::<u64>().ok())?;
|
||||
Some((used, total))
|
||||
}
|
||||
|
||||
match inner(gpu_device) {
|
||||
Some((u, t)) => (Some(u), Some(t)),
|
||||
None => (None, None),
|
||||
}
|
||||
}
|
||||
@@ -49,10 +49,24 @@ impl Transcriber {
|
||||
// params.flash_attn(true);
|
||||
|
||||
let ctx = WhisperContext::new_with_params(path, params)
|
||||
.map_err(|e| AppError::Internal(format!("failed to load model: {e}")))?;
|
||||
.map_err(|e| {
|
||||
let msg = format!("failed to load model: {e}");
|
||||
if AppError::is_oom(&msg) {
|
||||
AppError::OutOfMemory(msg)
|
||||
} else {
|
||||
AppError::Internal(msg)
|
||||
}
|
||||
})?;
|
||||
|
||||
let mut state = ctx.create_state()
|
||||
.map_err(|e| AppError::Internal(format!("failed to create whisper state: {e}")))?;
|
||||
.map_err(|e| {
|
||||
let msg = format!("failed to create whisper state: {e}");
|
||||
if AppError::is_oom(&msg) {
|
||||
AppError::OutOfMemory(msg)
|
||||
} else {
|
||||
AppError::Internal(msg)
|
||||
}
|
||||
})?;
|
||||
// ctx drops here; state holds Arc<WhisperInnerContext> so model stays loaded.
|
||||
|
||||
// ── GPU warmup ────────────────────────────────────────────────────────
|
||||
|
||||
482
src/worker.rs
482
src/worker.rs
@@ -1,20 +1,23 @@
|
||||
use std::{
|
||||
collections::HashSet,
|
||||
path::PathBuf,
|
||||
sync::{
|
||||
atomic::{AtomicUsize, Ordering},
|
||||
Arc,
|
||||
Arc, Mutex,
|
||||
},
|
||||
time::{Duration, Instant},
|
||||
};
|
||||
|
||||
use chrono::Utc;
|
||||
use reqwest::Client;
|
||||
use tokio::sync::{broadcast, mpsc, oneshot};
|
||||
use tokio::sync::{broadcast, mpsc, oneshot, RwLock};
|
||||
|
||||
use crate::{
|
||||
models::{Job, JobId, JobStatus, Segment},
|
||||
models::{Job, JobId, JobStatus, ModelEvent, ModelState, Segment},
|
||||
storage::Storage,
|
||||
transcriber::Transcriber,
|
||||
webhook,
|
||||
AppError,
|
||||
};
|
||||
|
||||
/// Per-job broadcast channel for SSE subscribers.
|
||||
@@ -31,83 +34,383 @@ pub enum ProgressEvent {
|
||||
/// Global registry: job_id → broadcast sender.
|
||||
pub type ProgressRegistry = Arc<dashmap::DashMap<JobId, ProgressTx>>;
|
||||
|
||||
// ── Transcription request/response types for the blocking thread ─────────────
|
||||
// ── Worker command channel ────────────────────────────────────────────────────
|
||||
|
||||
struct TranscribeRequest {
|
||||
pcm: Vec<f32>,
|
||||
language: Option<String>,
|
||||
task: String,
|
||||
/// Per-chunk progress callback — receives 0–100 from whisper.cpp and can
|
||||
/// scale/offset it before forwarding to the job's broadcast channel.
|
||||
on_progress: Box<dyn Fn(u8) + Send + 'static>,
|
||||
reply: oneshot::Sender<crate::Result<(Vec<Segment>, String)>>,
|
||||
/// Commands sent to the GPU worker OS thread.
|
||||
#[derive(Debug)]
|
||||
pub enum WorkerCmd {
|
||||
/// Request a model load. Idempotent: if already loading/ready, ignored.
|
||||
Load,
|
||||
/// Unload the model immediately and free GPU memory.
|
||||
Unload,
|
||||
/// Internal: run a transcription chunk.
|
||||
Transcribe(TranscribeRequest),
|
||||
}
|
||||
|
||||
// ── Transcription request/response types ─────────────────────────────────────
|
||||
|
||||
pub struct TranscribeRequest {
|
||||
pub pcm: Vec<f32>,
|
||||
pub language: Option<String>,
|
||||
pub task: String,
|
||||
pub on_progress: Box<dyn Fn(u8) + Send + 'static>,
|
||||
pub reply: oneshot::Sender<crate::Result<(Vec<Segment>, String)>>,
|
||||
}
|
||||
|
||||
impl std::fmt::Debug for TranscribeRequest {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
f.debug_struct("TranscribeRequest")
|
||||
.field("language", &self.language)
|
||||
.field("task", &self.task)
|
||||
.finish_non_exhaustive()
|
||||
}
|
||||
}
|
||||
|
||||
// ── Public API ────────────────────────────────────────────────────────────────
|
||||
|
||||
/// Spawn the single GPU worker.
|
||||
/// Returns the SSE progress registry.
|
||||
///
|
||||
/// Returns the SSE progress registry and a command sender for the worker thread.
|
||||
/// The model starts **unloaded**; send `WorkerCmd::Load` or submit a job to
|
||||
/// trigger loading.
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
pub fn start(
|
||||
job_rx: mpsc::UnboundedReceiver<JobId>,
|
||||
storage: Arc<Storage>,
|
||||
model_path: PathBuf,
|
||||
queue_depth: Arc<AtomicUsize>,
|
||||
gpu_device: u32,
|
||||
) -> ProgressRegistry {
|
||||
model_state: Arc<RwLock<ModelState>>,
|
||||
model_event_tx: broadcast::Sender<ModelEvent>,
|
||||
webhook_registry: Arc<Mutex<HashSet<String>>>,
|
||||
idle_timeout: Duration,
|
||||
gpu_poll_interval: Duration,
|
||||
) -> (ProgressRegistry, std::sync::mpsc::SyncSender<WorkerCmd>) {
|
||||
let registry: ProgressRegistry = Arc::new(dashmap::DashMap::new());
|
||||
let reg_clone = Arc::clone(®istry);
|
||||
|
||||
let (tx_req, rx_req) = std::sync::mpsc::channel::<TranscribeRequest>();
|
||||
// Bounded sync channel: capacity 8 is plenty (load/unload are rare).
|
||||
let (cmd_tx, cmd_rx) = std::sync::mpsc::sync_channel::<WorkerCmd>(8);
|
||||
let cmd_tx_clone = cmd_tx.clone();
|
||||
|
||||
// Capture Tokio runtime handle so the OS thread can spawn async tasks.
|
||||
let rt_handle = tokio::runtime::Handle::current();
|
||||
|
||||
std::thread::Builder::new()
|
||||
.name("whisper-gpu".into())
|
||||
.spawn(move || transcriber_thread(rx_req, model_path, gpu_device))
|
||||
.spawn(move || {
|
||||
transcriber_thread(
|
||||
cmd_rx,
|
||||
model_path,
|
||||
gpu_device,
|
||||
model_state,
|
||||
model_event_tx,
|
||||
webhook_registry,
|
||||
idle_timeout,
|
||||
gpu_poll_interval,
|
||||
rt_handle,
|
||||
);
|
||||
})
|
||||
.expect("failed to spawn whisper-gpu thread");
|
||||
|
||||
tokio::spawn(run(job_rx, storage, queue_depth, reg_clone, tx_req));
|
||||
tokio::spawn(run(job_rx, storage, queue_depth, reg_clone, cmd_tx_clone));
|
||||
|
||||
registry
|
||||
(registry, cmd_tx)
|
||||
}
|
||||
|
||||
/// Dedicated OS thread that owns the Transcriber (non-Send) and runs inference.
|
||||
// ── GPU OS thread ─────────────────────────────────────────────────────────────
|
||||
|
||||
/// The worker OS thread that owns the `Transcriber` (non-`Send`).
|
||||
///
|
||||
/// The Transcriber holds a single `WhisperState` that is reused for every chunk.
|
||||
/// GPU compute buffers (~700 MB) are allocated once at startup rather than on
|
||||
/// every call, eliminating per-chunk `whisper_init_state` overhead and the
|
||||
/// VRAM churn that caused intermittent 0-segment results.
|
||||
/// Uses `recv_timeout` with a 1-second tick to drive the idle timer without a
|
||||
/// separate thread.
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
fn transcriber_thread(
|
||||
rx: std::sync::mpsc::Receiver<TranscribeRequest>,
|
||||
rx: std::sync::mpsc::Receiver<WorkerCmd>,
|
||||
model_path: PathBuf,
|
||||
gpu_device: u32,
|
||||
model_state: Arc<RwLock<ModelState>>,
|
||||
model_event_tx: broadcast::Sender<ModelEvent>,
|
||||
webhook_registry: Arc<Mutex<HashSet<String>>>,
|
||||
idle_timeout: Duration,
|
||||
gpu_poll_interval: Duration,
|
||||
rt: tokio::runtime::Handle,
|
||||
) {
|
||||
let mut transcriber = match Transcriber::load(&model_path, gpu_device) {
|
||||
Ok(t) => t,
|
||||
Err(e) => {
|
||||
tracing::error!(error = %e, "failed to load whisper model — transcriber thread exiting");
|
||||
return;
|
||||
let mut transcriber: Option<Transcriber> = None;
|
||||
let mut last_job = Instant::now();
|
||||
|
||||
loop {
|
||||
match rx.recv_timeout(Duration::from_secs(1)) {
|
||||
Ok(WorkerCmd::Load) => {
|
||||
if transcriber.is_some() {
|
||||
tracing::debug!("WorkerCmd::Load ignored — model already loaded");
|
||||
continue;
|
||||
}
|
||||
transcriber = try_load_with_polling(
|
||||
&rx,
|
||||
&model_path,
|
||||
gpu_device,
|
||||
&model_state,
|
||||
&model_event_tx,
|
||||
&webhook_registry,
|
||||
gpu_poll_interval,
|
||||
&rt,
|
||||
);
|
||||
if transcriber.is_some() {
|
||||
last_job = Instant::now();
|
||||
}
|
||||
}
|
||||
|
||||
Ok(WorkerCmd::Unload) => {
|
||||
do_unload(&mut transcriber, &model_state, &model_event_tx, &webhook_registry, &rt);
|
||||
}
|
||||
|
||||
Ok(WorkerCmd::Transcribe(req)) => {
|
||||
let t = match &mut transcriber {
|
||||
Some(t) => t,
|
||||
None => {
|
||||
tracing::warn!("Transcribe cmd received but model is unloaded — failing job");
|
||||
let _ = req.reply.send(Err(AppError::Internal(
|
||||
"model unloaded before job could run".into(),
|
||||
)));
|
||||
continue;
|
||||
}
|
||||
};
|
||||
tracing::info!(model = %model_path.display(), "GPU worker ready");
|
||||
|
||||
for req in rx {
|
||||
let on_progress = req.on_progress;
|
||||
let result = transcriber.transcribe(
|
||||
let result = t.transcribe(
|
||||
&req.pcm,
|
||||
req.language.as_deref(),
|
||||
&req.task,
|
||||
move |p| on_progress(p),
|
||||
move |p| (req.on_progress)(p),
|
||||
);
|
||||
last_job = Instant::now();
|
||||
let _ = req.reply.send(result);
|
||||
}
|
||||
|
||||
Err(std::sync::mpsc::RecvTimeoutError::Timeout) => {
|
||||
if transcriber.is_some() && last_job.elapsed() >= idle_timeout {
|
||||
tracing::info!(
|
||||
elapsed_secs = last_job.elapsed().as_secs(),
|
||||
"idle timeout reached — unloading model"
|
||||
);
|
||||
do_unload(
|
||||
&mut transcriber,
|
||||
&model_state,
|
||||
&model_event_tx,
|
||||
&webhook_registry,
|
||||
&rt,
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
Err(std::sync::mpsc::RecvTimeoutError::Disconnected) => {
|
||||
tracing::info!("worker command channel closed — shutting down GPU thread");
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Attempt to load the model, polling on VRAM failures.
|
||||
///
|
||||
/// While waiting for GPU, drains `rx` so that `WorkerCmd::Unload` cancels the
|
||||
/// load attempt and `WorkerCmd::Transcribe` commands get a "model not ready"
|
||||
/// rejection. Returns `Some(Transcriber)` on success, `None` if cancelled.
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
fn try_load_with_polling(
|
||||
rx: &std::sync::mpsc::Receiver<WorkerCmd>,
|
||||
model_path: &PathBuf,
|
||||
gpu_device: u32,
|
||||
model_state: &Arc<RwLock<ModelState>>,
|
||||
model_event_tx: &broadcast::Sender<ModelEvent>,
|
||||
webhook_registry: &Arc<Mutex<HashSet<String>>>,
|
||||
gpu_poll_interval: Duration,
|
||||
rt: &tokio::runtime::Handle,
|
||||
) -> Option<Transcriber> {
|
||||
loop {
|
||||
set_state(model_state, ModelState::Loading);
|
||||
broadcast_event(model_event_tx, ModelEvent::ModelLoading);
|
||||
tracing::info!("loading whisper model...");
|
||||
|
||||
match Transcriber::load(model_path, gpu_device) {
|
||||
Ok(t) => {
|
||||
let loaded_at = Utc::now();
|
||||
set_state(model_state, ModelState::Ready { loaded_at });
|
||||
broadcast_event(model_event_tx, ModelEvent::ModelReady { loaded_at });
|
||||
fire_webhooks(webhook_registry, ModelEvent::ModelReady { loaded_at }, rt);
|
||||
tracing::info!("model loaded and ready");
|
||||
return Some(t);
|
||||
}
|
||||
|
||||
Err(AppError::OutOfMemory(msg)) => {
|
||||
let (vram_needed_mb, vram_free_mb) = parse_oom_vram(&msg, gpu_device);
|
||||
let retry_in_secs = gpu_poll_interval.as_secs();
|
||||
|
||||
tracing::warn!(
|
||||
vram_needed_mb,
|
||||
vram_free_mb,
|
||||
retry_in_secs,
|
||||
"insufficient VRAM — will retry"
|
||||
);
|
||||
|
||||
set_state(model_state, ModelState::WaitingForGpu {
|
||||
vram_needed_mb,
|
||||
vram_free_mb,
|
||||
retry_in_secs,
|
||||
});
|
||||
broadcast_event(model_event_tx, ModelEvent::ModelWaitingForGpu {
|
||||
vram_needed_mb,
|
||||
vram_free_mb,
|
||||
retry_in_secs,
|
||||
});
|
||||
|
||||
// Interruptible sleep: drain rx while waiting for gpu_poll_interval.
|
||||
let deadline = Instant::now() + gpu_poll_interval;
|
||||
loop {
|
||||
let remaining = deadline.saturating_duration_since(Instant::now());
|
||||
if remaining.is_zero() { break; }
|
||||
match rx.recv_timeout(remaining.min(Duration::from_secs(1))) {
|
||||
Ok(WorkerCmd::Unload) => {
|
||||
tracing::info!("Unload received while waiting for GPU — cancelling load");
|
||||
set_state(model_state, ModelState::Unloaded);
|
||||
broadcast_event(model_event_tx, ModelEvent::ModelUnloaded);
|
||||
fire_webhooks(webhook_registry, ModelEvent::ModelUnloaded, rt);
|
||||
return None;
|
||||
}
|
||||
Ok(WorkerCmd::Load) => {} // idempotent
|
||||
Ok(WorkerCmd::Transcribe(req)) => {
|
||||
let _ = req.reply.send(Err(AppError::ModelNotReady {
|
||||
state: "waiting_for_gpu".into(),
|
||||
retry_after_secs: retry_in_secs,
|
||||
}));
|
||||
}
|
||||
Err(std::sync::mpsc::RecvTimeoutError::Timeout) => {}
|
||||
Err(std::sync::mpsc::RecvTimeoutError::Disconnected) => return None,
|
||||
}
|
||||
}
|
||||
// Loop back to retry load
|
||||
}
|
||||
|
||||
Err(e) => {
|
||||
tracing::error!(error = %e, "model load failed with non-recoverable error");
|
||||
set_state(model_state, ModelState::Unloaded);
|
||||
return None;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn do_unload(
|
||||
transcriber: &mut Option<Transcriber>,
|
||||
model_state: &Arc<RwLock<ModelState>>,
|
||||
model_event_tx: &broadcast::Sender<ModelEvent>,
|
||||
webhook_registry: &Arc<Mutex<HashSet<String>>>,
|
||||
rt: &tokio::runtime::Handle,
|
||||
) {
|
||||
*transcriber = None;
|
||||
set_state(model_state, ModelState::Unloaded);
|
||||
broadcast_event(model_event_tx, ModelEvent::ModelUnloaded);
|
||||
fire_webhooks(webhook_registry, ModelEvent::ModelUnloaded, rt);
|
||||
tracing::info!("model unloaded — GPU memory freed");
|
||||
}
|
||||
|
||||
// ── Helpers ───────────────────────────────────────────────────────────────────
|
||||
|
||||
fn set_state(arc: &Arc<RwLock<ModelState>>, state: ModelState) {
|
||||
*arc.blocking_write() = state;
|
||||
}
|
||||
|
||||
fn broadcast_event(tx: &broadcast::Sender<ModelEvent>, event: ModelEvent) {
|
||||
let _ = tx.send(event);
|
||||
}
|
||||
|
||||
fn fire_webhooks(
|
||||
registry: &Arc<Mutex<HashSet<String>>>,
|
||||
event: ModelEvent,
|
||||
rt: &tokio::runtime::Handle,
|
||||
) {
|
||||
if !event.is_webhook_event() {
|
||||
return;
|
||||
}
|
||||
let urls: Vec<String> = registry
|
||||
.lock()
|
||||
.unwrap_or_else(|e| e.into_inner())
|
||||
.iter()
|
||||
.cloned()
|
||||
.collect();
|
||||
|
||||
if urls.is_empty() { return; }
|
||||
|
||||
let payload = match serde_json::to_string(&event) {
|
||||
Ok(p) => p,
|
||||
Err(e) => { tracing::error!(error = %e, "failed to serialize model event"); return; }
|
||||
};
|
||||
|
||||
for url in urls {
|
||||
let body = payload.clone();
|
||||
rt.spawn(async move {
|
||||
let http = Client::builder()
|
||||
.timeout(Duration::from_secs(10))
|
||||
.build()
|
||||
.expect("http client");
|
||||
for attempt in 0..3_u32 {
|
||||
match http.post(&url)
|
||||
.header("content-type", "application/json")
|
||||
.body(body.clone())
|
||||
.send()
|
||||
.await
|
||||
{
|
||||
Ok(r) if r.status().is_success() => {
|
||||
tracing::debug!(url, "model event webhook delivered");
|
||||
return;
|
||||
}
|
||||
Ok(r) => tracing::warn!(url, status = r.status().as_u16(), "webhook non-2xx"),
|
||||
Err(e) => tracing::warn!(url, error = %e, attempt, "webhook delivery failed"),
|
||||
}
|
||||
if attempt < 2 {
|
||||
tokio::time::sleep(Duration::from_secs(2u64.pow(attempt))).await;
|
||||
}
|
||||
}
|
||||
tracing::error!(url, "model event webhook failed after 3 attempts");
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
fn parse_oom_vram(msg: &str, gpu_device: u32) -> (u64, u64) {
|
||||
let needed = msg
|
||||
.split_whitespace()
|
||||
.zip(msg.split_whitespace().skip(1))
|
||||
.find(|(_, next)| *next == "MiB")
|
||||
.and_then(|(n, _)| n.parse::<f64>().ok())
|
||||
.map(|v| v as u64)
|
||||
.unwrap_or(0);
|
||||
|
||||
let free = std::process::Command::new("nvidia-smi")
|
||||
.args([
|
||||
&format!("--id={gpu_device}"),
|
||||
"--query-gpu=memory.free",
|
||||
"--format=csv,noheader,nounits",
|
||||
])
|
||||
.output()
|
||||
.ok()
|
||||
.and_then(|o| String::from_utf8(o.stdout).ok())
|
||||
.and_then(|s| s.trim().parse::<u64>().ok())
|
||||
.unwrap_or(0);
|
||||
|
||||
(needed, free)
|
||||
}
|
||||
|
||||
// ── Async job runner ──────────────────────────────────────────────────────────
|
||||
|
||||
async fn run(
|
||||
mut job_rx: mpsc::UnboundedReceiver<JobId>,
|
||||
storage: Arc<Storage>,
|
||||
queue_depth: Arc<AtomicUsize>,
|
||||
registry: ProgressRegistry,
|
||||
tx_req: std::sync::mpsc::Sender<TranscribeRequest>,
|
||||
cmd_tx: std::sync::mpsc::SyncSender<WorkerCmd>,
|
||||
) {
|
||||
let http = Client::builder()
|
||||
.timeout(std::time::Duration::from_secs(30))
|
||||
.timeout(Duration::from_secs(30))
|
||||
.build()
|
||||
.expect("failed to build reqwest client");
|
||||
|
||||
@@ -140,7 +443,7 @@ async fn run(
|
||||
|
||||
let audio_path = audio_path_for(&job_id);
|
||||
|
||||
let result = process_job(&job, &audio_path, &progress_tx, &tx_req, &storage).await;
|
||||
let result = process_job(&job, &audio_path, &progress_tx, &cmd_tx, &storage).await;
|
||||
|
||||
let _ = tokio::fs::remove_file(&audio_path).await;
|
||||
|
||||
@@ -175,26 +478,18 @@ async fn run(
|
||||
tokio::spawn(async move { webhook::fire(&http, &url, &job).await; });
|
||||
}
|
||||
|
||||
tokio::time::sleep(std::time::Duration::from_secs(30)).await;
|
||||
tokio::time::sleep(Duration::from_secs(30)).await;
|
||||
registry.remove(&job_id);
|
||||
}
|
||||
}
|
||||
|
||||
// ── Silence-based chunking ────────────────────────────────────────────────────
|
||||
|
||||
/// Target chunk length. 60s ≈ 2× whisper's native 30s window — short enough
|
||||
/// that a hallucinated phrase can't compound beyond a single window.
|
||||
const TARGET_CHUNK_SECS: f32 = 60.0;
|
||||
/// How far from the target we'll snap to a silence midpoint.
|
||||
const SNAP_WINDOW_SECS: f32 = 30.0;
|
||||
/// Silence below this level (dB) counts as a split candidate.
|
||||
const SILENCE_DB: &str = "-35dB";
|
||||
/// Minimum silence duration to register as a candidate split.
|
||||
const SILENCE_DUR: &str = "0.4";
|
||||
|
||||
/// Detect silence periods and return the midpoint (seconds) of each.
|
||||
/// On any error (ffmpeg missing, binary format, etc.) returns an empty vec
|
||||
/// so the caller can fall back to hard cuts.
|
||||
async fn detect_silence_midpoints(path: &std::path::Path) -> Vec<f32> {
|
||||
use tokio::process::Command;
|
||||
|
||||
@@ -217,7 +512,6 @@ async fn detect_silence_midpoints(path: &std::path::Path) -> Vec<f32> {
|
||||
}
|
||||
};
|
||||
|
||||
// silencedetect logs to stderr
|
||||
let stderr = String::from_utf8_lossy(&output.stderr);
|
||||
let mut starts: Vec<f32> = Vec::new();
|
||||
let mut ends: Vec<f32> = Vec::new();
|
||||
@@ -228,7 +522,6 @@ async fn detect_silence_midpoints(path: &std::path::Path) -> Vec<f32> {
|
||||
starts.push(t);
|
||||
}
|
||||
} else if let Some(i) = line.find("silence_end: ") {
|
||||
// Format: "silence_end: 12.34 | silence_duration: 0.56"
|
||||
let t_str = line[i + "silence_end: ".len()..]
|
||||
.split(" |")
|
||||
.next()
|
||||
@@ -248,10 +541,6 @@ async fn detect_silence_midpoints(path: &std::path::Path) -> Vec<f32> {
|
||||
mids
|
||||
}
|
||||
|
||||
/// Build cut points every `target_secs`, snapping to the nearest silence
|
||||
/// midpoint within `snap_window` when one exists; otherwise a hard cut.
|
||||
/// Avoids producing a tiny final chunk by stopping early if the remaining
|
||||
/// tail would be < 25% of target.
|
||||
fn snap_to_silence(
|
||||
mids: &[f32],
|
||||
total_secs: f32,
|
||||
@@ -263,13 +552,9 @@ fn snap_to_silence(
|
||||
|
||||
while pos < total_secs - target_secs * 0.25 {
|
||||
let prev_cut = cuts.last().copied().unwrap_or(0.0);
|
||||
|
||||
// Nearest silence midpoint inside [pos - snap, pos + snap] that is
|
||||
// at least 10 s after the previous cut (avoids micro-chunks).
|
||||
let best = mids.iter().copied()
|
||||
.filter(|&t| t > prev_cut + 10.0 && (t - pos).abs() <= snap_window)
|
||||
.min_by(|a, b| (a - pos).abs().partial_cmp(&(b - pos).abs()).unwrap());
|
||||
|
||||
let cut = best.unwrap_or(pos);
|
||||
cuts.push(cut);
|
||||
pos = cut + target_secs;
|
||||
@@ -278,7 +563,6 @@ fn snap_to_silence(
|
||||
cuts
|
||||
}
|
||||
|
||||
/// Convert cut points into (start_secs, end_secs) chunk pairs.
|
||||
fn to_chunk_ranges(cuts: &[f32], total_secs: f32) -> Vec<(f32, f32)> {
|
||||
let mut ranges = Vec::new();
|
||||
let mut start = 0.0_f32;
|
||||
@@ -289,7 +573,6 @@ fn to_chunk_ranges(cuts: &[f32], total_secs: f32) -> Vec<(f32, f32)> {
|
||||
start = cut;
|
||||
}
|
||||
}
|
||||
// Last chunk
|
||||
if total_secs - start >= 1.0 {
|
||||
ranges.push((start, total_secs));
|
||||
}
|
||||
@@ -302,17 +585,13 @@ async fn process_job(
|
||||
job: &Job,
|
||||
audio_path: &std::path::Path,
|
||||
progress_tx: &ProgressTx,
|
||||
tx_req: &std::sync::mpsc::Sender<TranscribeRequest>,
|
||||
cmd_tx: &std::sync::mpsc::SyncSender<WorkerCmd>,
|
||||
storage: &Arc<Storage>,
|
||||
) -> crate::Result<(Vec<Segment>, String, f32)> {
|
||||
// 1. Decode full audio to 16 kHz mono PCM.
|
||||
let pcm = decode_audio(audio_path).await?;
|
||||
let total_secs = pcm.len() as f32 / 16_000.0;
|
||||
|
||||
// 2. Detect silence midpoints from original file.
|
||||
let silence_mids = detect_silence_midpoints(audio_path).await;
|
||||
|
||||
// 3. Build silence-snapped chunk boundaries.
|
||||
let cuts = snap_to_silence(&silence_mids, total_secs, TARGET_CHUNK_SECS, SNAP_WINDOW_SECS);
|
||||
let chunks = to_chunk_ranges(&cuts, total_secs);
|
||||
let n = chunks.len();
|
||||
@@ -324,7 +603,6 @@ async fn process_job(
|
||||
"audio chunked by silence"
|
||||
);
|
||||
|
||||
// 4. Transcribe each chunk, applying a time offset to all timestamps.
|
||||
let mut all_segments: Vec<Segment> = Vec::new();
|
||||
let mut language = String::new();
|
||||
|
||||
@@ -334,11 +612,9 @@ async fn process_job(
|
||||
let mut chunk_pcm = pcm[s0..s1].to_vec();
|
||||
trim_trailing_silence(&mut chunk_pcm);
|
||||
|
||||
// Base percent this chunk starts at.
|
||||
let base = (ci * 100 / n) as u8;
|
||||
let span = (100usize / n).max(1) as u8;
|
||||
|
||||
// Emit a progress event and persist it at the start of every chunk.
|
||||
let _ = progress_tx.send(ProgressEvent::Progress {
|
||||
percent: base,
|
||||
chunk: ci + 1,
|
||||
@@ -350,7 +626,6 @@ async fn process_job(
|
||||
tracing::warn!(error = %e, "failed to persist mid-job progress");
|
||||
}
|
||||
|
||||
// Scale whisper's per-chunk 0–100 into the job's overall range.
|
||||
let tx = progress_tx.clone();
|
||||
let chunk_num = ci + 1;
|
||||
let on_progress = Box::new(move |p: u8| {
|
||||
@@ -363,18 +638,17 @@ async fn process_job(
|
||||
});
|
||||
|
||||
let (reply_tx, reply_rx) = oneshot::channel();
|
||||
tx_req.send(TranscribeRequest {
|
||||
cmd_tx.send(WorkerCmd::Transcribe(TranscribeRequest {
|
||||
pcm: chunk_pcm,
|
||||
language: job.language.clone(),
|
||||
task: job.task.clone(),
|
||||
on_progress,
|
||||
reply: reply_tx,
|
||||
}).map_err(|_| crate::AppError::Internal("transcriber thread gone".into()))?;
|
||||
})).map_err(|_| AppError::Internal("worker command channel closed".into()))?;
|
||||
|
||||
let (mut segs, lang) = reply_rx.await
|
||||
.map_err(|_| crate::AppError::Internal("transcriber thread dropped reply".into()))??;
|
||||
.map_err(|_| AppError::Internal("transcriber thread dropped reply".into()))??;
|
||||
|
||||
// Shift all timestamps by chunk offset.
|
||||
let offset = *chunk_start;
|
||||
for seg in &mut segs {
|
||||
seg.start += offset;
|
||||
@@ -400,7 +674,6 @@ async fn process_job(
|
||||
}
|
||||
}
|
||||
|
||||
// Renumber segment indices across the merged output.
|
||||
for (i, seg) in all_segments.iter_mut().enumerate() {
|
||||
seg.index = i as i32;
|
||||
}
|
||||
@@ -409,14 +682,9 @@ async fn process_job(
|
||||
Ok((all_segments, language, total_secs))
|
||||
}
|
||||
|
||||
/// Trim trailing silence from a 16 kHz mono PCM buffer.
|
||||
///
|
||||
/// Scans backwards to find the last sample above −35 dB, then keeps
|
||||
/// 0.5 s of padding after it. This prevents whisper from hallucinating
|
||||
/// filler tokens into end-of-chunk silence.
|
||||
fn trim_trailing_silence(pcm: &mut Vec<f32>) {
|
||||
const THRESHOLD: f32 = 0.017_8; // −35 dB (10^(−35/20))
|
||||
const PADDING: usize = 8_000; // 0.5 s at 16 kHz
|
||||
const THRESHOLD: f32 = 0.017_8;
|
||||
const PADDING: usize = 8_000;
|
||||
|
||||
if let Some(last_loud) = pcm.iter().rposition(|&s| s.abs() > THRESHOLD) {
|
||||
let new_len = (last_loud + 1 + PADDING).min(pcm.len());
|
||||
@@ -429,10 +697,8 @@ fn trim_trailing_silence(pcm: &mut Vec<f32>) {
|
||||
pcm.truncate(new_len);
|
||||
}
|
||||
}
|
||||
// All-silent chunk: keep as-is — whisper will produce zero segments, which is correct.
|
||||
}
|
||||
|
||||
/// Decode any audio file to 16 kHz mono PCM f32 using ffmpeg.
|
||||
async fn decode_audio(path: &std::path::Path) -> crate::Result<Vec<f32>> {
|
||||
use tokio::process::Command;
|
||||
|
||||
@@ -447,11 +713,11 @@ async fn decode_audio(path: &std::path::Path) -> crate::Result<Vec<f32>> {
|
||||
])
|
||||
.output()
|
||||
.await
|
||||
.map_err(|e| crate::AppError::Internal(format!("ffmpeg spawn failed: {e}")))?;
|
||||
.map_err(|e| AppError::Internal(format!("ffmpeg spawn failed: {e}")))?;
|
||||
|
||||
if !output.status.success() {
|
||||
let stderr = String::from_utf8_lossy(&output.stderr);
|
||||
return Err(crate::AppError::Internal(format!(
|
||||
return Err(AppError::Internal(format!(
|
||||
"ffmpeg exited with {}: {}",
|
||||
output.status, stderr
|
||||
)));
|
||||
@@ -459,7 +725,7 @@ async fn decode_audio(path: &std::path::Path) -> crate::Result<Vec<f32>> {
|
||||
|
||||
let bytes = output.stdout;
|
||||
if bytes.len() % 4 != 0 {
|
||||
return Err(crate::AppError::Internal(
|
||||
return Err(AppError::Internal(
|
||||
"ffmpeg output length not a multiple of 4".into(),
|
||||
));
|
||||
}
|
||||
@@ -473,3 +739,51 @@ pub fn audio_path_for(id: &JobId) -> PathBuf {
|
||||
let data_dir = std::env::var("DATA_DIR").unwrap_or_else(|_| "/data".into());
|
||||
PathBuf::from(data_dir).join(format!("{id}.audio"))
|
||||
}
|
||||
|
||||
// ── Unit tests ────────────────────────────────────────────────────────────────
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_snap_to_silence_uses_nearest_midpoint() {
|
||||
let mids = vec![55.0, 58.0, 62.0];
|
||||
let cuts = snap_to_silence(&mids, 120.0, 60.0, 30.0);
|
||||
assert!(!cuts.is_empty());
|
||||
assert!((cuts[0] - 58.0).abs() < 0.01, "expected ~58.0, got {}", cuts[0]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_snap_to_silence_hard_cut_when_no_silence() {
|
||||
let cuts = snap_to_silence(&[], 120.0, 60.0, 30.0);
|
||||
assert_eq!(cuts, vec![60.0]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_to_chunk_ranges_single_chunk() {
|
||||
let ranges = to_chunk_ranges(&[], 30.0);
|
||||
assert_eq!(ranges, vec![(0.0, 30.0)]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_to_chunk_ranges_two_chunks() {
|
||||
let ranges = to_chunk_ranges(&[60.0], 120.0);
|
||||
assert_eq!(ranges, vec![(0.0, 60.0), (60.0, 120.0)]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_trim_trailing_silence_all_silent() {
|
||||
let mut pcm = vec![0.0f32; 1000];
|
||||
trim_trailing_silence(&mut pcm);
|
||||
assert_eq!(pcm.len(), 1000);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_trim_trailing_silence_trims_to_padding() {
|
||||
let mut pcm = vec![0.0f32; 32_000];
|
||||
pcm[10_000] = 1.0;
|
||||
trim_trailing_silence(&mut pcm);
|
||||
assert_eq!(pcm.len(), (10_001 + 8_000).min(32_000));
|
||||
}
|
||||
}
|
||||
|
||||
97
test_all.sh
Normal file → Executable file
97
test_all.sh
Normal file → Executable file
@@ -6,8 +6,9 @@ BASE="${WHISPER_BASE_URL:-http://localhost:8080}"
|
||||
AUDIO="${TEST_AUDIO:-/home/moze/Sources/youtube-transcriber/docker/tmp/audio-b2167046-a236-4fcd-b739-78177542fd23.wav}"
|
||||
|
||||
GREEN='\033[0;32m'; RED='\033[0;31m'; NC='\033[0m'
|
||||
FAILS=0
|
||||
ok() { echo -e "${GREEN}[PASS]${NC} $*"; }
|
||||
fail(){ echo -e "${RED}[FAIL]${NC} $*"; exit 1; }
|
||||
fail(){ echo -e "${RED}[FAIL]${NC} $*"; FAILS=$((FAILS + 1)); }
|
||||
|
||||
echo "=== Whisper API test suite ==="
|
||||
echo " BASE : $BASE"
|
||||
@@ -17,11 +18,16 @@ echo ""
|
||||
echo "=== 1. GET /health ==="
|
||||
HEALTH=$(curl -sf "$BASE/health")
|
||||
echo "$HEALTH" | python3 -m json.tool
|
||||
echo "$HEALTH" | python3 -c "import sys,json; d=json.load(sys.stdin); assert d['status']=='ok', f'status={d[\"status\"]}'" && ok "health ok"
|
||||
python3 -c "
|
||||
import sys, json
|
||||
d = json.loads('$HEALTH' if False else sys.stdin.read())
|
||||
assert d['status'] == 'ok', f'status={d[\"status\"]}'
|
||||
assert 'model_state' in d, 'model_state field missing from health response'
|
||||
" <<< "$HEALTH" && ok "health ok + model_state present" || fail "health check"
|
||||
|
||||
echo ""
|
||||
echo "=== 2. GET /docs (Swagger UI reachable) ==="
|
||||
curl -sf "$BASE/docs" | grep -qi "swagger" && ok "swagger UI reachable"
|
||||
curl -sf "$BASE/docs" | grep -qi "swagger" && ok "swagger UI reachable" || fail "swagger UI"
|
||||
|
||||
echo ""
|
||||
echo "=== 3. Webhook receiver (background Python HTTP server) ==="
|
||||
@@ -33,7 +39,7 @@ class H(http.server.BaseHTTPRequestHandler):
|
||||
n = int(self.headers.get('Content-Length', 0))
|
||||
body = self.rfile.read(n)
|
||||
data = json.loads(body)
|
||||
print(f"\n[WEBHOOK] status={data.get('status')} segments={len(data.get('segments', []))}")
|
||||
print(f"\n[WEBHOOK] status={data.get('status')} segments={len(data.get('segments', []))}", flush=True)
|
||||
self.send_response(200)
|
||||
self.end_headers()
|
||||
def log_message(self, *a): pass
|
||||
@@ -48,40 +54,70 @@ sleep 1
|
||||
echo "Webhook receiver started (PID $WEBHOOK_PID)"
|
||||
|
||||
echo ""
|
||||
echo "=== 4. DELETE a non-existent job → 404 ==="
|
||||
echo "=== 4. GET /model/status — expect unloaded on fresh start ==="
|
||||
MODEL_STATUS=$(curl -sf "$BASE/model/status")
|
||||
echo "$MODEL_STATUS" | python3 -m json.tool
|
||||
echo "$MODEL_STATUS" | python3 -c "
|
||||
import sys, json
|
||||
d = json.load(sys.stdin)
|
||||
assert 'state' in d, 'state field missing from /model/status'
|
||||
print(f' model state: {d[\"state\"]}')
|
||||
" && ok "/model/status has state field" || fail "/model/status schema"
|
||||
|
||||
echo ""
|
||||
echo "=== 5. POST /model/load — trigger model load ==="
|
||||
LOAD_RESP=$(curl -sf -X POST "$BASE/model/load")
|
||||
echo "$LOAD_RESP"
|
||||
ok "POST /model/load accepted"
|
||||
|
||||
echo ""
|
||||
echo "=== 6. Poll /model/status until ready (max 3 min) ==="
|
||||
LOAD_ELAPSED=0
|
||||
while true; do
|
||||
sleep 5
|
||||
LOAD_ELAPSED=$((LOAD_ELAPSED + 5))
|
||||
MS=$(curl -sf "$BASE/model/status")
|
||||
STATE=$(echo "$MS" | python3 -c "import sys,json; print(json.load(sys.stdin)['state'])")
|
||||
echo " [${LOAD_ELAPSED}s] model_state=${STATE}"
|
||||
if [ "$STATE" = "ready" ]; then
|
||||
ok "model loaded and ready in ${LOAD_ELAPSED}s"
|
||||
break
|
||||
fi
|
||||
[ $LOAD_ELAPSED -gt 180 ] && { fail "model failed to load within 3 minutes"; break; }
|
||||
done
|
||||
|
||||
echo ""
|
||||
echo "=== 7. DELETE a non-existent job → 404 ==="
|
||||
STATUS=$(curl -s -o /dev/null -w "%{http_code}" -X DELETE "$BASE/jobs/00000000-0000-0000-0000-000000000000")
|
||||
[ "$STATUS" = "404" ] && ok "DELETE unknown job → 404" || fail "expected 404, got $STATUS"
|
||||
|
||||
echo ""
|
||||
echo "=== 5. POST /jobs — submit audio ==="
|
||||
# language field omitted → auto-detection. Do NOT pass "auto" — it is not a
|
||||
# valid ISO 639-1 code and whisper-rs will reject it or behave unexpectedly.
|
||||
echo "=== 8. POST /jobs — submit audio ==="
|
||||
SUBMIT=$(curl -sf -X POST "$BASE/jobs" \
|
||||
-F "audio=@${AUDIO};type=audio/wav" \
|
||||
-F "task=transcribe" \
|
||||
-F "webhook_url=http://localhost:9999/webhook")
|
||||
echo "$SUBMIT"
|
||||
# Submit response: { "job_id": "<uuid>" } (field is "job_id", not "id")
|
||||
JOB_ID=$(echo "$SUBMIT" | python3 -c "import sys,json; print(json.load(sys.stdin)['job_id'])")
|
||||
ok "submitted job $JOB_ID"
|
||||
|
||||
echo ""
|
||||
echo "=== 6. GET /jobs/{id} immediately after submit ==="
|
||||
echo "=== 9. GET /jobs/{id} immediately after submit ==="
|
||||
JOB=$(curl -sf "$BASE/jobs/$JOB_ID")
|
||||
echo "$JOB" | python3 -c "
|
||||
import sys, json
|
||||
d = json.load(sys.stdin)
|
||||
assert d['status'] in ('queued', 'running'), f'unexpected status: {d[\"status\"]}'
|
||||
" && ok "status is queued/running"
|
||||
" && ok "status is queued/running" || fail "initial status check"
|
||||
|
||||
echo ""
|
||||
echo "=== 7. SSE stream (observe first 30 events then detach) ==="
|
||||
echo "=== 10. SSE stream (observe first 30 events then detach) ==="
|
||||
echo "Subscribing to SSE stream for $JOB_ID …"
|
||||
curl -sN --max-time 90 "$BASE/jobs/$JOB_ID/stream" | head -60 &
|
||||
SSE_PID=$!
|
||||
|
||||
echo ""
|
||||
echo "=== 8. Poll until done (max 20 min) ==="
|
||||
echo "=== 11. Poll until done (max 20 min) ==="
|
||||
ELAPSED=0
|
||||
while true; do
|
||||
sleep 15
|
||||
@@ -96,16 +132,15 @@ while true; do
|
||||
elif [ "$STATUS" = "failed" ]; then
|
||||
echo "$JOB" | python3 -m json.tool
|
||||
fail "job failed"
|
||||
break
|
||||
fi
|
||||
[ $ELAPSED -gt 1200 ] && fail "timeout after 20 minutes"
|
||||
[ $ELAPSED -gt 1200 ] && { fail "timeout after 20 minutes"; break; }
|
||||
done
|
||||
kill $SSE_PID 2>/dev/null || true
|
||||
|
||||
echo ""
|
||||
echo "=== 9. Inspect transcription quality ==="
|
||||
echo "=== 12. Inspect transcription quality ==="
|
||||
RESULT=$(curl -sf "$BASE/jobs/$JOB_ID")
|
||||
# Note: can't pipe into a heredoc-driven python3 (heredoc takes stdin, pipe is ignored).
|
||||
# Write to a temp file instead.
|
||||
TMPJSON=$(mktemp /tmp/whisper_test_XXXXXX.json)
|
||||
echo "$RESULT" > "$TMPJSON"
|
||||
python3 - "$TMPJSON" << 'PYCHECK'
|
||||
@@ -149,18 +184,16 @@ for seg in segments[:5]:
|
||||
PYCHECK
|
||||
PYEXIT=$?
|
||||
rm -f "$TMPJSON"
|
||||
[ $PYEXIT -eq 0 ] && ok "quality check passed" || { echo "[FAIL] quality check"; FAILS=$((FAILS+1)); }
|
||||
[ $PYEXIT -eq 0 ] && ok "quality check passed" || fail "quality check"
|
||||
|
||||
echo ""
|
||||
echo "=== 10. DELETE completed job → 200 ==="
|
||||
# Completed jobs return 409 Conflict on DELETE (terminal state).
|
||||
# Verify we get 409, not 200 (delete is only for cancellation of active jobs).
|
||||
echo "=== 13. DELETE completed job → 409 Conflict ==="
|
||||
DEL_STATUS=$(curl -s -o /dev/null -w "%{http_code}" -X DELETE "$BASE/jobs/$JOB_ID")
|
||||
[ "$DEL_STATUS" = "409" ] && ok "DELETE completed job → 409 Conflict (expected)" \
|
||||
|| echo " [INFO] DELETE returned $DEL_STATUS"
|
||||
|
||||
echo ""
|
||||
echo "=== 11. Submit + cancel a queued job ==="
|
||||
echo "=== 14. Submit + cancel a queued job ==="
|
||||
JOB2=$(curl -sf -X POST "$BASE/jobs" \
|
||||
-F "audio=@${AUDIO};type=audio/wav" \
|
||||
-F "language=en" \
|
||||
@@ -173,8 +206,24 @@ CANCEL_STATUS=$(curl -sf "$BASE/jobs/$JOB2_ID" | python3 -c "import sys,json; pr
|
||||
|| echo " [INFO] cancel status: $CANCEL_STATUS (may be running — worker ignores cancel mid-chunk)"
|
||||
|
||||
echo ""
|
||||
echo "=== 12. Verify webhook fired ==="
|
||||
echo "=== 15. POST /model/unload ==="
|
||||
UNLOAD_RESP=$(curl -sf -X POST "$BASE/model/unload")
|
||||
echo "$UNLOAD_RESP"
|
||||
sleep 2
|
||||
UNLOAD_STATE=$(curl -sf "$BASE/model/status" | python3 -c "import sys,json; print(json.load(sys.stdin)['state'])")
|
||||
[ "$UNLOAD_STATE" = "unloaded" ] && ok "model unloaded → state=unloaded" \
|
||||
|| echo " [INFO] state after unload: $UNLOAD_STATE"
|
||||
|
||||
echo ""
|
||||
echo "=== 16. Verify webhook fired ==="
|
||||
sleep 3
|
||||
kill $WEBHOOK_PID 2>/dev/null || true
|
||||
ok "all tests complete"
|
||||
ok "webhook server stopped"
|
||||
|
||||
echo ""
|
||||
if [ $FAILS -eq 0 ]; then
|
||||
echo -e "${GREEN}=== ALL TESTS PASSED ===${NC}"
|
||||
else
|
||||
echo -e "${RED}=== $FAILS TEST(S) FAILED ===${NC}"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
246
tests/test_idle_timeout.sh
Executable file
246
tests/test_idle_timeout.sh
Executable file
@@ -0,0 +1,246 @@
|
||||
#!/usr/bin/env bash
|
||||
# tests/test_idle_timeout.sh
|
||||
#
|
||||
# Integration tests for the idle-timeout auto-unload feature.
|
||||
# REQUIRES the server to be started with a short idle timeout:
|
||||
#
|
||||
# IDLE_TIMEOUT_SECS=5 ./whisper-server
|
||||
# # or via Docker:
|
||||
# docker run -e IDLE_TIMEOUT_SECS=5 ...
|
||||
#
|
||||
# The default idle timeout is 5 minutes; these tests use a 5-second window
|
||||
# to keep the suite fast.
|
||||
|
||||
set -euo pipefail
|
||||
|
||||
BASE="${WHISPER_BASE_URL:-http://localhost:8080}"
|
||||
IDLE_TIMEOUT="${EXPECTED_IDLE_TIMEOUT_SECS:-5}"
|
||||
AUDIO="${TEST_AUDIO:-}"
|
||||
|
||||
GREEN='\033[0;32m'; RED='\033[0;31m'; YELLOW='\033[0;33m'; NC='\033[0m'
|
||||
|
||||
PASS=0; FAIL=0
|
||||
|
||||
ok() { echo -e "${GREEN}[PASS]${NC} $1"; PASS=$((PASS+1)); }
|
||||
fail() { echo -e "${RED}[FAIL]${NC} $1"; FAIL=$((FAIL+1)); }
|
||||
skip() { echo -e "${YELLOW}[SKIP]${NC} $1"; }
|
||||
info() { echo " $1"; }
|
||||
|
||||
echo "=== Idle Timeout Tests ==="
|
||||
echo " BASE: $BASE"
|
||||
echo " IDLE_TIMEOUT_SECS: $IDLE_TIMEOUT (must be configured on the server)"
|
||||
echo ""
|
||||
echo "NOTE: These tests require the server to be running with IDLE_TIMEOUT_SECS=$IDLE_TIMEOUT"
|
||||
echo ""
|
||||
|
||||
# ── Helpers ──────────────────────────────────────────────────────────────────
|
||||
|
||||
get_state() {
|
||||
curl -sf "$BASE/model/status" | python3 -c "import sys,json; print(json.load(sys.stdin)['state'])"
|
||||
}
|
||||
|
||||
ensure_ready() {
|
||||
local state
|
||||
state=$(get_state)
|
||||
if [ "$state" = "ready" ]; then return 0; fi
|
||||
curl -sf -X POST "$BASE/model/load" > /dev/null
|
||||
local elapsed=0
|
||||
while true; do
|
||||
sleep 3; elapsed=$((elapsed+3))
|
||||
state=$(get_state)
|
||||
[ "$state" = "ready" ] && return 0
|
||||
[ $elapsed -gt 180 ] && return 1
|
||||
done
|
||||
}
|
||||
|
||||
ensure_unloaded() {
|
||||
curl -sf -X POST "$BASE/model/unload" > /dev/null || true
|
||||
sleep 2
|
||||
}
|
||||
|
||||
# ── TEST 1: Load model, complete a job, then wait for idle unload ─────────────
|
||||
echo "--- Test 1: Idle timeout triggers auto-unload ---"
|
||||
|
||||
ensure_unloaded
|
||||
ensure_ready || { fail "T1: model load failed"; }
|
||||
|
||||
WAIT_SECS=$((IDLE_TIMEOUT + 3))
|
||||
info "Model is ready. Waiting $WAIT_SECS seconds (idle timeout=$IDLE_TIMEOUT + 3s buffer)..."
|
||||
sleep $WAIT_SECS
|
||||
|
||||
STATE=$(get_state)
|
||||
if [ "$STATE" = "unloaded" ]; then
|
||||
ok "T1: model auto-unloaded after ${IDLE_TIMEOUT}s idle"
|
||||
else
|
||||
fail "T1: expected unloaded after idle timeout, got $STATE"
|
||||
info "Is the server running with IDLE_TIMEOUT_SECS=$IDLE_TIMEOUT?"
|
||||
fi
|
||||
|
||||
# ── TEST 2: model_unloaded webhook fires on idle timeout ─────────────────────
|
||||
echo ""
|
||||
echo "--- Test 2: model_unloaded webhook fires on idle timeout ---"
|
||||
|
||||
ensure_unloaded
|
||||
|
||||
# Start webhook receiver
|
||||
python3 - <<'PYEOF' &
|
||||
import http.server, json, sys, signal
|
||||
|
||||
class H(http.server.BaseHTTPRequestHandler):
|
||||
def do_POST(self):
|
||||
n = int(self.headers.get('Content-Length', 0))
|
||||
body = json.loads(self.rfile.read(n))
|
||||
with open('/tmp/idle_wh_event.json', 'w') as f:
|
||||
json.dump(body, f)
|
||||
self.send_response(200); self.end_headers()
|
||||
def log_message(self, *a): pass
|
||||
|
||||
signal.signal(signal.SIGTERM, lambda *_: sys.exit(0))
|
||||
http.server.HTTPServer(('', 9995), H).serve_forever()
|
||||
PYEOF
|
||||
WH_PID=$!
|
||||
sleep 1
|
||||
|
||||
# Register webhook via a job submission (will 503 since unloaded)
|
||||
curl -sf -X POST "$BASE/jobs" \
|
||||
-F "audio=@/dev/urandom;type=audio/wav" \
|
||||
-F "webhook_url=http://localhost:9995/wh" \
|
||||
--max-time 5 > /dev/null 2>&1 || true
|
||||
|
||||
# Load model
|
||||
ensure_ready || { fail "T2: model load failed"; kill $WH_PID 2>/dev/null; }
|
||||
|
||||
# Wait for idle timeout
|
||||
WAIT_SECS=$((IDLE_TIMEOUT + 5))
|
||||
info "Waiting ${WAIT_SECS}s for idle timeout..."
|
||||
sleep $WAIT_SECS
|
||||
|
||||
kill $WH_PID 2>/dev/null || true
|
||||
wait $WH_PID 2>/dev/null || true
|
||||
|
||||
if [ -f /tmp/idle_wh_event.json ]; then
|
||||
EVENT_TYPE=$(python3 -c "import json; print(json.load(open('/tmp/idle_wh_event.json')).get('type','?'))")
|
||||
rm -f /tmp/idle_wh_event.json
|
||||
[ "$EVENT_TYPE" = "model_unloaded" ] && ok "T2: model_unloaded webhook fired on idle timeout" \
|
||||
|| fail "T2: webhook type=$EVENT_TYPE (expected model_unloaded)"
|
||||
else
|
||||
fail "T2: no webhook received within timeout"
|
||||
fi
|
||||
|
||||
# ── TEST 3: Job submission after idle timeout → 503 → triggers reload ─────────
|
||||
echo ""
|
||||
echo "--- Test 3: Job triggers reload after idle unload ---"
|
||||
|
||||
ensure_unloaded
|
||||
ensure_ready || { fail "T3: initial load failed"; }
|
||||
|
||||
# Wait for auto-unload
|
||||
WAIT_SECS=$((IDLE_TIMEOUT + 3))
|
||||
info "Waiting ${WAIT_SECS}s for idle unload..."
|
||||
sleep $WAIT_SECS
|
||||
|
||||
STATE=$(get_state)
|
||||
[ "$STATE" = "unloaded" ] || info "Note: state=$STATE (expected unloaded)"
|
||||
|
||||
# Submit job → 503, triggers reload
|
||||
HTTP=$(curl -s -o /tmp/t3_body.json -w "%{http_code}" -X POST "$BASE/jobs" \
|
||||
-F "audio=@/dev/urandom;type=audio/wav" \
|
||||
--max-time 5 2>/dev/null || echo "000")
|
||||
|
||||
if [ "$HTTP" = "503" ]; then
|
||||
ok "T3a: POST /jobs → 503 after idle unload"
|
||||
else
|
||||
skip "T3a: POST /jobs returned $HTTP (model may have reloaded)"
|
||||
fi
|
||||
|
||||
# State should be loading or ready (reload triggered by job submission)
|
||||
sleep 2
|
||||
STATE=$(get_state)
|
||||
if [ "$STATE" = "loading" ] || [ "$STATE" = "ready" ]; then
|
||||
ok "T3b: reload triggered by job submission ($STATE)"
|
||||
else
|
||||
fail "T3b: expected loading/ready, got $STATE"
|
||||
fi
|
||||
|
||||
rm -f /tmp/t3_body.json
|
||||
|
||||
# ── TEST 4: Idle timer resets per job (wait 60% of timeout → still ready) ─────
|
||||
echo ""
|
||||
echo "--- Test 4: Idle timer resets with each completed job ---"
|
||||
|
||||
ensure_unloaded
|
||||
ensure_ready || { fail "T4: model load failed"; }
|
||||
|
||||
HALF_WAIT=$((IDLE_TIMEOUT - 1))
|
||||
info "Waiting ${HALF_WAIT}s (less than idle timeout)..."
|
||||
sleep $HALF_WAIT
|
||||
|
||||
STATE=$(get_state)
|
||||
if [ "$STATE" = "ready" ]; then
|
||||
ok "T4a: model still ready after ${HALF_WAIT}s (less than ${IDLE_TIMEOUT}s timeout)"
|
||||
else
|
||||
fail "T4a: model unexpectedly $STATE after only ${HALF_WAIT}s"
|
||||
fi
|
||||
|
||||
# Wait for full unload
|
||||
REMAINING=$((IDLE_TIMEOUT - HALF_WAIT + 3))
|
||||
info "Waiting another ${REMAINING}s for full idle unload..."
|
||||
sleep $REMAINING
|
||||
STATE=$(get_state)
|
||||
[ "$STATE" = "unloaded" ] && ok "T4b: model unloaded after total > ${IDLE_TIMEOUT}s idle" \
|
||||
|| fail "T4b: expected unloaded, got $STATE"
|
||||
|
||||
# ── TEST 5: Job resets idle timer ─────────────────────────────────────────────
|
||||
echo ""
|
||||
echo "--- Test 5: Completing a job resets the idle timer ---"
|
||||
|
||||
if [ -z "$AUDIO" ]; then
|
||||
skip "T5: TEST_AUDIO not set — skipping timer-reset test"
|
||||
else
|
||||
ensure_unloaded
|
||||
ensure_ready || { fail "T5: model load failed"; }
|
||||
|
||||
# Submit a job
|
||||
SUBMIT=$(curl -sf -X POST "$BASE/jobs" \
|
||||
-F "audio=@${AUDIO};type=audio/wav" \
|
||||
-F "task=transcribe" 2>&1)
|
||||
JOB_ID=$(echo "$SUBMIT" | python3 -c "import sys,json; print(json.load(sys.stdin)['job_id'])" 2>/dev/null || echo "")
|
||||
|
||||
if [ -z "$JOB_ID" ]; then
|
||||
fail "T5: job submission failed"
|
||||
else
|
||||
# Wait for job to finish
|
||||
elapsed=0
|
||||
while true; do
|
||||
sleep 5; elapsed=$((elapsed+5))
|
||||
STATUS=$(curl -sf "$BASE/jobs/$JOB_ID" | python3 -c "import sys,json; print(json.load(sys.stdin)['status'])")
|
||||
[ "$STATUS" = "done" ] || [ "$STATUS" = "failed" ] && break
|
||||
[ $elapsed -gt 300 ] && break
|
||||
done
|
||||
info "Job finished in ${elapsed}s with status=$STATUS"
|
||||
|
||||
# Now wait IDLE_TIMEOUT - 2 seconds — should still be ready
|
||||
SAFE_WAIT=$((IDLE_TIMEOUT - 2))
|
||||
[ $SAFE_WAIT -lt 1 ] && SAFE_WAIT=1
|
||||
info "Waiting ${SAFE_WAIT}s after job completion (less than idle timeout)..."
|
||||
sleep $SAFE_WAIT
|
||||
STATE=$(get_state)
|
||||
[ "$STATE" = "ready" ] && ok "T5a: model still ready ${SAFE_WAIT}s after job completion" \
|
||||
|| fail "T5a: model unexpectedly $STATE after job"
|
||||
|
||||
# Wait for idle timeout
|
||||
REMAINING=$((IDLE_TIMEOUT - SAFE_WAIT + 3))
|
||||
info "Waiting ${REMAINING}s more for idle unload..."
|
||||
sleep $REMAINING
|
||||
STATE=$(get_state)
|
||||
[ "$STATE" = "unloaded" ] && ok "T5b: model auto-unloaded after idle period post-job" \
|
||||
|| fail "T5b: expected unloaded, got $STATE"
|
||||
fi
|
||||
fi
|
||||
|
||||
# ── Summary ────────────────────────────────────────────────────────────────────
|
||||
echo ""
|
||||
echo "=========================================="
|
||||
echo " Results: ${PASS} passed, ${FAIL} failed"
|
||||
echo "=========================================="
|
||||
[ $FAIL -eq 0 ] && echo -e "${GREEN}ALL PASSED${NC}" || { echo -e "${RED}FAILURES: $FAIL${NC}"; exit 1; }
|
||||
470
tests/test_model_lifecycle.sh
Executable file
470
tests/test_model_lifecycle.sh
Executable file
@@ -0,0 +1,470 @@
|
||||
#!/usr/bin/env bash
|
||||
# tests/test_model_lifecycle.sh
|
||||
#
|
||||
# Integration tests for dynamic model loading/unloading.
|
||||
# Requires a running whisper-server with GPU access.
|
||||
#
|
||||
# Usage:
|
||||
# WHISPER_BASE_URL=http://localhost:8080 bash tests/test_model_lifecycle.sh
|
||||
#
|
||||
# Tests are designed to be independent; each section that needs a specific
|
||||
# state resets it explicitly at the start.
|
||||
|
||||
set -euo pipefail
|
||||
|
||||
BASE="${WHISPER_BASE_URL:-http://localhost:8080}"
|
||||
AUDIO="${TEST_AUDIO:-}"
|
||||
|
||||
GREEN='\033[0;32m'; RED='\033[0;31m'; YELLOW='\033[0;33m'; NC='\033[0m'
|
||||
|
||||
PASS=0; FAIL=0
|
||||
|
||||
ok() { echo -e "${GREEN}[PASS]${NC} $1"; PASS=$((PASS+1)); }
|
||||
fail() { echo -e "${RED}[FAIL]${NC} $1"; FAIL=$((FAIL+1)); }
|
||||
skip() { echo -e "${YELLOW}[SKIP]${NC} $1"; }
|
||||
info() { echo " $1"; }
|
||||
|
||||
echo "=== Model Lifecycle Integration Tests ==="
|
||||
echo " BASE: $BASE"
|
||||
echo ""
|
||||
|
||||
# ── Helpers ──────────────────────────────────────────────────────────────────
|
||||
|
||||
get_state() {
|
||||
curl -sf "$BASE/model/status" | python3 -c "import sys,json; print(json.load(sys.stdin)['state'])"
|
||||
}
|
||||
|
||||
ensure_unloaded() {
|
||||
curl -sf -X POST "$BASE/model/unload" > /dev/null
|
||||
sleep 2
|
||||
local s
|
||||
s=$(get_state)
|
||||
if [ "$s" != "unloaded" ]; then
|
||||
echo " WARNING: expected unloaded, got $s — waiting 5s"
|
||||
sleep 5
|
||||
fi
|
||||
}
|
||||
|
||||
ensure_ready() {
|
||||
local state
|
||||
state=$(get_state)
|
||||
if [ "$state" = "ready" ]; then return 0; fi
|
||||
curl -sf -X POST "$BASE/model/load" > /dev/null
|
||||
local elapsed=0
|
||||
while true; do
|
||||
sleep 5; elapsed=$((elapsed+5))
|
||||
state=$(get_state)
|
||||
[ "$state" = "ready" ] && return 0
|
||||
[ $elapsed -gt 180 ] && echo " TIMEOUT: model did not become ready" && return 1
|
||||
done
|
||||
}
|
||||
|
||||
poll_state_transition() {
|
||||
local target="$1" max_secs="${2:-120}"
|
||||
local elapsed=0
|
||||
while true; do
|
||||
sleep 2; elapsed=$((elapsed+2))
|
||||
local s
|
||||
s=$(get_state)
|
||||
[ "$s" = "$target" ] && return 0
|
||||
[ $elapsed -ge $max_secs ] && return 1
|
||||
done
|
||||
}
|
||||
|
||||
# ── TEST 1: Startup state is unloaded ────────────────────────────────────────
|
||||
echo "--- Test 1: Startup state is unloaded (or after explicit unload) ---"
|
||||
ensure_unloaded
|
||||
STATE=$(get_state)
|
||||
if [ "$STATE" = "unloaded" ]; then
|
||||
ok "T1: state=unloaded after explicit unload"
|
||||
else
|
||||
fail "T1: expected unloaded, got $STATE"
|
||||
fi
|
||||
|
||||
# ── TEST 2: POST /model/load returns 202 ─────────────────────────────────────
|
||||
echo ""
|
||||
echo "--- Test 2: POST /model/load returns 202 ---"
|
||||
ensure_unloaded
|
||||
HTTP=$(curl -s -o /dev/null -w "%{http_code}" -X POST "$BASE/model/load")
|
||||
if [ "$HTTP" = "202" ]; then
|
||||
ok "T2: POST /model/load → 202 Accepted"
|
||||
else
|
||||
fail "T2: expected 202, got $HTTP"
|
||||
fi
|
||||
# Cancel the in-progress load to clean up
|
||||
curl -sf -X POST "$BASE/model/unload" > /dev/null || true
|
||||
sleep 2
|
||||
|
||||
# ── TEST 3: State transitions to loading/ready after load trigger ─────────────
|
||||
echo ""
|
||||
echo "--- Test 3: State transitions to loading (not stuck at unloaded) ---"
|
||||
ensure_unloaded
|
||||
curl -sf -X POST "$BASE/model/load" > /dev/null
|
||||
sleep 1
|
||||
STATE=$(get_state)
|
||||
if [ "$STATE" = "loading" ] || [ "$STATE" = "ready" ]; then
|
||||
ok "T3: state transitioned to $STATE (not stuck at unloaded)"
|
||||
else
|
||||
fail "T3: expected loading or ready, got $STATE"
|
||||
fi
|
||||
|
||||
# ── TEST 4: Model reaches ready state and loaded_at is set ───────────────────
|
||||
echo ""
|
||||
echo "--- Test 4: Model reaches ready state with loaded_at timestamp ---"
|
||||
# Already loading from T3 — wait for ready
|
||||
if ! poll_state_transition "ready" 180; then
|
||||
fail "T4: model did not become ready within 3 minutes"
|
||||
else
|
||||
STATUS_JSON=$(curl -sf "$BASE/model/status")
|
||||
LOADED_AT=$(echo "$STATUS_JSON" | python3 -c "import sys,json; d=json.load(sys.stdin); print(d.get('loaded_at','MISSING'))" 2>/dev/null || echo "MISSING")
|
||||
if [ "$LOADED_AT" != "MISSING" ] && [ "$LOADED_AT" != "null" ] && [ -n "$LOADED_AT" ]; then
|
||||
ok "T4: model=ready, loaded_at=$LOADED_AT"
|
||||
else
|
||||
fail "T4: model ready but loaded_at is missing or null"
|
||||
fi
|
||||
fi
|
||||
|
||||
# ── TEST 5: Idempotent load — POST /model/load when ready returns 200 ─────────
|
||||
echo ""
|
||||
echo "--- Test 5: POST /model/load when already ready → 200 ---"
|
||||
ensure_ready || { fail "T5: could not load model"; }
|
||||
HTTP=$(curl -s -o /dev/null -w "%{http_code}" -X POST "$BASE/model/load")
|
||||
STATE=$(get_state)
|
||||
if [ "$HTTP" = "200" ] && [ "$STATE" = "ready" ]; then
|
||||
ok "T5: idempotent load → 200, state stays ready"
|
||||
elif [ "$HTTP" = "202" ] && [ "$STATE" = "ready" ]; then
|
||||
ok "T5: idempotent load → 202, state stays ready"
|
||||
else
|
||||
fail "T5: expected 200 and ready, got HTTP=$HTTP state=$STATE"
|
||||
fi
|
||||
|
||||
# ── TEST 6: Job accepted when ready (segments > 0) ────────────────────────────
|
||||
echo ""
|
||||
echo "--- Test 6: Job accepted when model is ready ---"
|
||||
if [ -z "$AUDIO" ]; then
|
||||
skip "T6: TEST_AUDIO not set — skipping job submission test"
|
||||
else
|
||||
ensure_ready || { fail "T6: model load failed"; }
|
||||
SUBMIT=$(curl -sf -X POST "$BASE/jobs" -F "audio=@${AUDIO};type=audio/wav" -F "task=transcribe" 2>&1)
|
||||
JOB_ID=$(echo "$SUBMIT" | python3 -c "import sys,json; print(json.load(sys.stdin)['job_id'])" 2>/dev/null || echo "")
|
||||
if [ -n "$JOB_ID" ]; then
|
||||
ok "T6: job accepted, id=$JOB_ID"
|
||||
# Poll to done
|
||||
elapsed=0
|
||||
while true; do
|
||||
sleep 10; elapsed=$((elapsed+10))
|
||||
STATUS=$(curl -sf "$BASE/jobs/$JOB_ID" | python3 -c "import sys,json; print(json.load(sys.stdin)['status'])")
|
||||
[ "$STATUS" = "done" ] && break
|
||||
[ "$STATUS" = "failed" ] && break
|
||||
[ $elapsed -gt 600 ] && break
|
||||
done
|
||||
SEGS=$(curl -sf "$BASE/jobs/$JOB_ID" | python3 -c "import sys,json; d=json.load(sys.stdin); print(len(d.get('segments',[])))")
|
||||
[ "$SEGS" -gt 0 ] && ok "T6b: job done with $SEGS segments" || fail "T6b: job done but 0 segments"
|
||||
else
|
||||
fail "T6: job submission failed: $SUBMIT"
|
||||
fi
|
||||
fi
|
||||
|
||||
# ── TEST 7: POST /model/unload → state=unloaded ───────────────────────────────
|
||||
echo ""
|
||||
echo "--- Test 7: POST /model/unload ---"
|
||||
ensure_ready || { fail "T7: model load failed"; }
|
||||
curl -sf -X POST "$BASE/model/unload" > /dev/null
|
||||
sleep 3
|
||||
STATE=$(get_state)
|
||||
if [ "$STATE" = "unloaded" ]; then
|
||||
ok "T7: POST /model/unload → state=unloaded"
|
||||
else
|
||||
fail "T7: expected unloaded after unload, got $STATE"
|
||||
fi
|
||||
|
||||
# ── TEST 8: POST /jobs when unloaded → 503 + Retry-After ─────────────────────
|
||||
echo ""
|
||||
echo "--- Test 8: POST /jobs when unloaded → 503 + Retry-After ---"
|
||||
ensure_unloaded
|
||||
# Submit a tiny dummy payload (won't be valid audio but that's ok for this test)
|
||||
HTTP=$(curl -s -o /tmp/t8_body.json -w "%{http_code}" -X POST "$BASE/jobs" \
|
||||
-F "audio=@/dev/urandom;type=audio/wav" \
|
||||
--max-time 5 2>/dev/null || echo "000")
|
||||
# If the model auto-loads it might start processing; check for 503 first
|
||||
if [ "$HTTP" = "503" ]; then
|
||||
RETRY_AFTER=$(curl -sI -X POST "$BASE/jobs" \
|
||||
-F "audio=@/dev/urandom;type=audio/wav" \
|
||||
--max-time 5 2>/dev/null | grep -i "retry-after" | awk '{print $2}' | tr -d '\r' || echo "")
|
||||
BODY=$(cat /tmp/t8_body.json 2>/dev/null || echo "{}")
|
||||
HAS_STATE=$(echo "$BODY" | python3 -c "import sys,json; d=json.load(sys.stdin); print('state' in d)" 2>/dev/null || echo "False")
|
||||
HAS_RETRY=$(echo "$BODY" | python3 -c "import sys,json; d=json.load(sys.stdin); print('retry_after_secs' in d)" 2>/dev/null || echo "False")
|
||||
if [ "$HAS_STATE" = "True" ] && [ "$HAS_RETRY" = "True" ]; then
|
||||
ok "T8: 503 with state + retry_after_secs in body"
|
||||
else
|
||||
fail "T8: 503 but body missing state/retry_after_secs. body=$BODY"
|
||||
fi
|
||||
if [ -n "$RETRY_AFTER" ]; then
|
||||
ok "T8b: Retry-After header present: $RETRY_AFTER"
|
||||
else
|
||||
fail "T8b: Retry-After header missing from 503 response"
|
||||
fi
|
||||
else
|
||||
skip "T8: got HTTP $HTTP (model may have loaded before check) — skipping"
|
||||
fi
|
||||
|
||||
# ── TEST 9: Rejected job triggers load ────────────────────────────────────────
|
||||
echo ""
|
||||
echo "--- Test 9: Job rejection triggers model load ---"
|
||||
ensure_unloaded
|
||||
# Send a job (we expect 503)
|
||||
curl -sf -X POST "$BASE/jobs" \
|
||||
-F "audio=@/dev/urandom;type=audio/wav" \
|
||||
--max-time 5 > /dev/null 2>&1 || true
|
||||
sleep 2
|
||||
STATE=$(get_state)
|
||||
if [ "$STATE" = "loading" ] || [ "$STATE" = "ready" ]; then
|
||||
ok "T9: model started loading after job rejection ($STATE)"
|
||||
else
|
||||
fail "T9: expected loading/ready after job rejection, got $STATE"
|
||||
fi
|
||||
# Stop the load to clean up
|
||||
curl -sf -X POST "$BASE/model/unload" > /dev/null || true
|
||||
sleep 2
|
||||
|
||||
# ── TEST 10: Retry-After values ───────────────────────────────────────────────
|
||||
echo ""
|
||||
echo "--- Test 10: Retry-After values match state ---"
|
||||
ensure_unloaded
|
||||
# Unloaded → Retry-After: 30
|
||||
RESP_UNLOADED=$(curl -si -X POST "$BASE/jobs" -F "audio=@/dev/urandom;type=audio/wav" --max-time 5 2>/dev/null || echo "")
|
||||
RA_UNLOADED=$(echo "$RESP_UNLOADED" | grep -i "retry-after" | awk '{print $2}' | tr -d '\r' || echo "")
|
||||
[ "$RA_UNLOADED" = "30" ] && ok "T10a: Retry-After=30 when unloaded" \
|
||||
|| skip "T10a: Retry-After=$RA_UNLOADED (expected 30) — model may have started loading"
|
||||
|
||||
# ── TEST 11: Retry-After=10 during loading ────────────────────────────────────
|
||||
echo ""
|
||||
echo "--- Test 11: Retry-After=10 when loading ---"
|
||||
ensure_unloaded
|
||||
curl -sf -X POST "$BASE/model/load" > /dev/null
|
||||
sleep 1 # In loading state
|
||||
STATE=$(get_state)
|
||||
if [ "$STATE" = "loading" ]; then
|
||||
RESP_LOADING=$(curl -si -X POST "$BASE/jobs" -F "audio=@/dev/urandom;type=audio/wav" --max-time 5 2>/dev/null || echo "")
|
||||
RA_LOADING=$(echo "$RESP_LOADING" | grep -i "retry-after" | awk '{print $2}' | tr -d '\r' || echo "")
|
||||
[ "$RA_LOADING" = "10" ] && ok "T11: Retry-After=10 when loading" \
|
||||
|| fail "T11: expected Retry-After=10, got '$RA_LOADING' (state=$STATE)"
|
||||
else
|
||||
skip "T11: model already $STATE — can't test loading state Retry-After"
|
||||
fi
|
||||
|
||||
# ── TEST 12: 503 body schema validation ──────────────────────────────────────
|
||||
echo ""
|
||||
echo "--- Test 12: 503 body schema validation ---"
|
||||
ensure_unloaded
|
||||
BODY=$(curl -sf -X POST "$BASE/jobs" -F "audio=@/dev/urandom;type=audio/wav" --max-time 5 2>/dev/null || echo "{}")
|
||||
python3 - <<PYCHECK
|
||||
import json
|
||||
body = json.loads('$BODY')
|
||||
required = {'error', 'state', 'retry_after_secs'}
|
||||
missing = required - set(body.keys())
|
||||
if missing:
|
||||
print(f"MISSING: {missing}")
|
||||
exit(1)
|
||||
assert body['error'] == 'model_not_ready', f"error={body['error']}"
|
||||
assert isinstance(body['retry_after_secs'], int), f"retry_after_secs not int: {body['retry_after_secs']}"
|
||||
print("schema ok")
|
||||
PYCHECK
|
||||
[ $? -eq 0 ] && ok "T12: 503 body has correct schema" || fail "T12: 503 body schema invalid"
|
||||
|
||||
# ── TEST 13: GET /health has model_state field ────────────────────────────────
|
||||
echo ""
|
||||
echo "--- Test 13: GET /health has model_state ---"
|
||||
HEALTH=$(curl -sf "$BASE/health")
|
||||
HAS_MODEL_STATE=$(echo "$HEALTH" | python3 -c "import sys,json; d=json.load(sys.stdin); print('model_state' in d)")
|
||||
[ "$HAS_MODEL_STATE" = "True" ] && ok "T13: /health has model_state" || fail "T13: /health missing model_state"
|
||||
|
||||
# ── TEST 14: SSE /model/events delivers model_ready event ─────────────────────
|
||||
echo ""
|
||||
echo "--- Test 14: GET /model/events SSE delivers model_ready ---"
|
||||
ensure_unloaded
|
||||
|
||||
# Collect SSE events for up to 3 minutes
|
||||
SSE_LOG=$(mktemp /tmp/sse_events_XXXXXX.txt)
|
||||
curl -sN --max-time 180 "$BASE/model/events" > "$SSE_LOG" &
|
||||
SSE_PID=$!
|
||||
sleep 1
|
||||
|
||||
# Trigger load
|
||||
curl -sf -X POST "$BASE/model/load" > /dev/null
|
||||
poll_state_transition "ready" 180 || true
|
||||
|
||||
sleep 2
|
||||
kill $SSE_PID 2>/dev/null || true
|
||||
wait $SSE_PID 2>/dev/null || true
|
||||
|
||||
if grep -q "model_loading" "$SSE_LOG" 2>/dev/null; then
|
||||
ok "T14a: SSE received model_loading event"
|
||||
else
|
||||
fail "T14a: SSE did not receive model_loading event"
|
||||
fi
|
||||
|
||||
if grep -q "model_ready" "$SSE_LOG" 2>/dev/null; then
|
||||
ok "T14b: SSE received model_ready event"
|
||||
else
|
||||
fail "T14b: SSE did not receive model_ready event"
|
||||
fi
|
||||
|
||||
# Now unload to get model_unloaded event
|
||||
curl -sf -X POST "$BASE/model/unload" > /dev/null
|
||||
sleep 1
|
||||
|
||||
SSE_LOG2=$(mktemp /tmp/sse_events_XXXXXX.txt)
|
||||
curl -sN --max-time 10 "$BASE/model/events" > "$SSE_LOG2" &
|
||||
SSE_PID2=$!
|
||||
sleep 2
|
||||
kill $SSE_PID2 2>/dev/null || true
|
||||
wait $SSE_PID2 2>/dev/null || true
|
||||
|
||||
# model_unloaded fires immediately on unload command
|
||||
if grep -q "model_unloaded" "$SSE_LOG" 2>/dev/null || grep -q "model_unloaded" "$SSE_LOG2" 2>/dev/null; then
|
||||
ok "T14c: SSE received model_unloaded event"
|
||||
else
|
||||
fail "T14c: SSE did not receive model_unloaded event"
|
||||
fi
|
||||
|
||||
rm -f "$SSE_LOG" "$SSE_LOG2"
|
||||
|
||||
# ── TEST 15: model_ready webhook fires after load ──────────────────────────────
|
||||
echo ""
|
||||
echo "--- Test 15: model_ready webhook ---"
|
||||
ensure_unloaded
|
||||
|
||||
# Start webhook receiver
|
||||
WEBHOOK_LOG=$(mktemp /tmp/webhook_log_XXXXXX.txt)
|
||||
python3 - <<'PYEOF' &
|
||||
import http.server, json, sys, signal, os
|
||||
|
||||
class H(http.server.BaseHTTPRequestHandler):
|
||||
def do_POST(self):
|
||||
n = int(self.headers.get('Content-Length', 0))
|
||||
body = json.loads(self.rfile.read(n))
|
||||
with open('/tmp/t15_webhook.json', 'w') as f:
|
||||
json.dump(body, f)
|
||||
self.send_response(200); self.end_headers()
|
||||
def log_message(self, *a): pass
|
||||
|
||||
signal.signal(signal.SIGTERM, lambda *_: sys.exit(0))
|
||||
http.server.HTTPServer(('', 9998), H).serve_forever()
|
||||
PYEOF
|
||||
WBOOK_PID=$!
|
||||
sleep 1
|
||||
|
||||
# Register a webhook via a (doomed) job submission — this registers the URL
|
||||
# even though the model is unloaded (and the job will 503)
|
||||
curl -sf -X POST "$BASE/jobs" \
|
||||
-F "audio=@/dev/urandom;type=audio/wav" \
|
||||
-F "webhook_url=http://localhost:9998/wh" \
|
||||
--max-time 5 > /dev/null 2>&1 || true
|
||||
|
||||
# Now load the model
|
||||
curl -sf -X POST "$BASE/model/load" > /dev/null
|
||||
poll_state_transition "ready" 180 || true
|
||||
sleep 3
|
||||
|
||||
kill $WBOOK_PID 2>/dev/null || true
|
||||
wait $WBOOK_PID 2>/dev/null || true
|
||||
|
||||
if [ -f /tmp/t15_webhook.json ]; then
|
||||
EVENT_TYPE=$(python3 -c "import json; d=json.load(open('/tmp/t15_webhook.json')); print(d.get('type','?'))")
|
||||
[ "$EVENT_TYPE" = "model_ready" ] && ok "T15: model_ready webhook fired" \
|
||||
|| fail "T15: webhook fired but type=$EVENT_TYPE (expected model_ready)"
|
||||
rm -f /tmp/t15_webhook.json
|
||||
else
|
||||
fail "T15: model_ready webhook not received within timeout"
|
||||
fi
|
||||
|
||||
# ── TEST 16: model_unloaded webhook fires ─────────────────────────────────────
|
||||
echo ""
|
||||
echo "--- Test 16: model_unloaded webhook ---"
|
||||
|
||||
python3 - <<'PYEOF' &
|
||||
import http.server, json, sys, signal
|
||||
|
||||
class H(http.server.BaseHTTPRequestHandler):
|
||||
def do_POST(self):
|
||||
n = int(self.headers.get('Content-Length', 0))
|
||||
body = json.loads(self.rfile.read(n))
|
||||
with open('/tmp/t16_webhook.json', 'w') as f:
|
||||
json.dump(body, f)
|
||||
self.send_response(200); self.end_headers()
|
||||
def log_message(self, *a): pass
|
||||
|
||||
signal.signal(signal.SIGTERM, lambda *_: sys.exit(0))
|
||||
http.server.HTTPServer(('', 9997), H).serve_forever()
|
||||
PYEOF
|
||||
WBOOK2_PID=$!
|
||||
sleep 1
|
||||
|
||||
# Register webhook URL
|
||||
curl -sf -X POST "$BASE/jobs" \
|
||||
-F "audio=@/dev/urandom;type=audio/wav" \
|
||||
-F "webhook_url=http://localhost:9997/wh" \
|
||||
--max-time 5 > /dev/null 2>&1 || true
|
||||
|
||||
ensure_ready
|
||||
|
||||
# Unload
|
||||
curl -sf -X POST "$BASE/model/unload" > /dev/null
|
||||
sleep 5
|
||||
|
||||
kill $WBOOK2_PID 2>/dev/null || true
|
||||
wait $WBOOK2_PID 2>/dev/null || true
|
||||
|
||||
if [ -f /tmp/t16_webhook.json ]; then
|
||||
EVENT_TYPE=$(python3 -c "import json; d=json.load(open('/tmp/t16_webhook.json')); print(d.get('type','?'))")
|
||||
[ "$EVENT_TYPE" = "model_unloaded" ] && ok "T16: model_unloaded webhook fired" \
|
||||
|| fail "T16: webhook type=$EVENT_TYPE (expected model_unloaded)"
|
||||
rm -f /tmp/t16_webhook.json
|
||||
else
|
||||
fail "T16: model_unloaded webhook not received"
|
||||
fi
|
||||
|
||||
# ── TEST 17: Concurrent load requests — single load, stable ready ─────────────
|
||||
echo ""
|
||||
echo "--- Test 17: Concurrent POST /model/load requests ---"
|
||||
ensure_unloaded
|
||||
# Send 3 concurrent load requests
|
||||
curl -sf -X POST "$BASE/model/load" > /dev/null &
|
||||
curl -sf -X POST "$BASE/model/load" > /dev/null &
|
||||
curl -sf -X POST "$BASE/model/load" > /dev/null &
|
||||
wait
|
||||
poll_state_transition "ready" 180 || true
|
||||
STATE=$(get_state)
|
||||
[ "$STATE" = "ready" ] && ok "T17: concurrent loads handled cleanly, state=ready" \
|
||||
|| fail "T17: expected ready after concurrent loads, got $STATE"
|
||||
|
||||
# ── TEST 18: POST /model/unload during loading → clean unloaded ───────────────
|
||||
echo ""
|
||||
echo "--- Test 18: POST /model/unload during loading ---"
|
||||
ensure_unloaded
|
||||
curl -sf -X POST "$BASE/model/load" > /dev/null
|
||||
sleep 1 # Hopefully still in loading state
|
||||
curl -sf -X POST "$BASE/model/unload" > /dev/null
|
||||
# Allow time for the unload to propagate
|
||||
sleep 5
|
||||
STATE=$(get_state)
|
||||
if [ "$STATE" = "unloaded" ]; then
|
||||
ok "T18: unload during loading → clean unloaded"
|
||||
elif [ "$STATE" = "ready" ]; then
|
||||
# Load completed before unload arrived — immediately unload
|
||||
curl -sf -X POST "$BASE/model/unload" > /dev/null
|
||||
sleep 3
|
||||
STATE=$(get_state)
|
||||
[ "$STATE" = "unloaded" ] && ok "T18: load completed then unloaded (race condition OK)" \
|
||||
|| fail "T18: state=$STATE after load+unload"
|
||||
else
|
||||
fail "T18: unexpected state after unload-during-load: $STATE"
|
||||
fi
|
||||
|
||||
# ── Summary ────────────────────────────────────────────────────────────────────
|
||||
echo ""
|
||||
echo "=========================================="
|
||||
echo " Results: ${PASS} passed, ${FAIL} failed"
|
||||
echo "=========================================="
|
||||
[ $FAIL -eq 0 ] && echo -e "${GREEN}ALL PASSED${NC}" || { echo -e "${RED}FAILURES: $FAIL${NC}"; exit 1; }
|
||||
Reference in New Issue
Block a user