feat: dynamic model loading/unloading with GPU polling
All checks were successful
Build & Push Docker Image / build-and-push (push) Successful in 8m41s
All checks were successful
Build & Push Docker Image / build-and-push (push) Successful in 8m41s
- Model starts unloaded (lazy); loads on first job or POST /model/load
- Auto-unloads after IDLE_TIMEOUT_SECS (default 300) of inactivity
- POST /model/unload for immediate manual release
- GPU-busy detection: on VRAM OOM, enters WaitingForGpu and retries
every GPU_POLL_INTERVAL_SECS (default 30) indefinitely
- POST /jobs when unloaded → 503 + Retry-After header, triggers load
- AppError::OutOfMemory and AppError::ModelNotReady variants
- WorkerCmd channel (SyncSender<WorkerCmd>) replaces bare tx_req channel
- Idle timer via recv_timeout(1s) tick inside OS thread (no extra thread)
- Model lifecycle events broadcast via tokio broadcast channel (SSE + webhooks)
- webhook_registry: all clients that ever submitted a webhook_url receive
model_ready and model_unloaded webhooks
- GPU warmup retained on every (re)load
New routes:
GET /model/status — current state + VRAM stats
POST /model/load — trigger load (idempotent)
POST /model/unload — immediate unload
GET /model/events — SSE stream of model lifecycle events
New env vars:
IDLE_TIMEOUT_SECS (default 300)
GPU_POLL_INTERVAL_SECS (default 30)
Tests:
tests/test_model_lifecycle.sh — 18 integration tests (full state machine,
SSE events, webhooks, concurrency, unload-during-load)
tests/test_idle_timeout.sh — 5 tests with short IDLE_TIMEOUT_SECS=5
test_all.sh updated: loads model before job submission, asserts
model_state in /health, adds POST /model/unload at end
Docs:
docs/USAGE.md: model lifecycle section, new env vars, 503 retry pattern,
updated /health response shape
Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
This commit is contained in:
200
docs/USAGE.md
200
docs/USAGE.md
@@ -66,6 +66,8 @@ The bundled `docker-compose.yml` mounts named volumes for data and models and se
|
|||||||
| `WHISPER_MODEL_PATH` | `/models/ggml-large-v3.bin` | Absolute path to GGML model file |
|
| `WHISPER_MODEL_PATH` | `/models/ggml-large-v3.bin` | Absolute path to GGML model file |
|
||||||
| `WHISPER_MODEL` | `large-v3` | Model name reported by `/health` (display only) |
|
| `WHISPER_MODEL` | `large-v3` | Model name reported by `/health` (display only) |
|
||||||
| `CUDA_DEVICE` | `0` | CUDA device index to use for inference |
|
| `CUDA_DEVICE` | `0` | CUDA device index to use for inference |
|
||||||
|
| `IDLE_TIMEOUT_SECS` | `300` | Seconds of idle before the model is automatically unloaded from GPU memory. Set to `0` to disable auto-unload. |
|
||||||
|
| `GPU_POLL_INTERVAL_SECS` | `30` | Seconds between VRAM-availability retries when a load fails due to insufficient VRAM. |
|
||||||
|
|
||||||
### Note on CUDA device ordering
|
### Note on CUDA device ordering
|
||||||
Inside Docker, device ordering matches `nvidia-smi` (PCI bus order). On the host without Docker, ordering may differ. See [FINDINGS.md](FINDINGS.md#cuda-device-index-ordering-differs-between-host-and-docker) for details.
|
Inside Docker, device ordering matches `nvidia-smi` (PCI bus order). On the host without Docker, ordering may differ. See [FINDINGS.md](FINDINGS.md#cuda-device-index-ordering-differs-between-host-and-docker) for details.
|
||||||
@@ -76,6 +78,194 @@ Inside Docker, device ordering matches `nvidia-smi` (PCI bus order). On the host
|
|||||||
|
|
||||||
The interactive Swagger UI is available at `http://localhost:8080/docs`.
|
The interactive Swagger UI is available at `http://localhost:8080/docs`.
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Model Lifecycle Management
|
||||||
|
|
||||||
|
The model starts **unloaded** on startup (lazy loading). It is loaded into GPU memory on the first job submission or via `POST /model/load`, and automatically unloaded after `IDLE_TIMEOUT_SECS` of inactivity.
|
||||||
|
|
||||||
|
### Model State Machine
|
||||||
|
|
||||||
|
```
|
||||||
|
Unloaded ──(job / POST /model/load)──► Loading ──(success)──► Ready
|
||||||
|
└──(VRAM full)──► WaitingForGpu ──(retry OK)──► Loading
|
||||||
|
Ready ──(idle timeout / POST /model/unload)──► Unloaded
|
||||||
|
WaitingForGpu ──(POST /model/unload)──► Unloaded
|
||||||
|
```
|
||||||
|
|
||||||
|
### `GET /model/status`
|
||||||
|
|
||||||
|
Returns the current model state and VRAM statistics.
|
||||||
|
|
||||||
|
```bash
|
||||||
|
curl http://localhost:8080/model/status
|
||||||
|
```
|
||||||
|
|
||||||
|
**When unloaded:**
|
||||||
|
```json
|
||||||
|
{ "state": "unloaded" }
|
||||||
|
```
|
||||||
|
|
||||||
|
**When loading:**
|
||||||
|
```json
|
||||||
|
{ "state": "loading" }
|
||||||
|
```
|
||||||
|
|
||||||
|
**When ready:**
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"state": "ready",
|
||||||
|
"loaded_at": "2026-05-10T14:00:00Z",
|
||||||
|
"vram_used_mb": 4096,
|
||||||
|
"vram_total_mb": 8192
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
**When waiting for VRAM:**
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"state": "waiting_for_gpu",
|
||||||
|
"vram_needed_mb": 3951,
|
||||||
|
"vram_free_mb": 512,
|
||||||
|
"retry_in_secs": 30
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
### `POST /model/load`
|
||||||
|
|
||||||
|
Request the model to be loaded. Idempotent — if already loading or ready, returns immediately.
|
||||||
|
|
||||||
|
```bash
|
||||||
|
curl -X POST http://localhost:8080/model/load
|
||||||
|
```
|
||||||
|
|
||||||
|
- Returns `202 Accepted` with `{"status":"load_initiated"}` when load is triggered
|
||||||
|
- Returns `200 OK` with `{"status":"already_ready"}` when model is already ready
|
||||||
|
- Poll `GET /model/status` or subscribe to `GET /model/events` to know when ready
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
### `POST /model/unload`
|
||||||
|
|
||||||
|
Unload the model from GPU memory immediately, freeing VRAM.
|
||||||
|
|
||||||
|
```bash
|
||||||
|
curl -X POST http://localhost:8080/model/unload
|
||||||
|
```
|
||||||
|
|
||||||
|
Returns `200 OK` regardless of current state.
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
### `GET /model/events` — Model SSE stream
|
||||||
|
|
||||||
|
Subscribe to model lifecycle events via Server-Sent Events.
|
||||||
|
|
||||||
|
```bash
|
||||||
|
curl -N http://localhost:8080/model/events
|
||||||
|
```
|
||||||
|
|
||||||
|
**Event types:**
|
||||||
|
|
||||||
|
```
|
||||||
|
event: model_loading
|
||||||
|
data: {"type":"model_loading"}
|
||||||
|
|
||||||
|
event: model_ready
|
||||||
|
data: {"type":"model_ready","loaded_at":"2026-05-10T14:00:00Z"}
|
||||||
|
|
||||||
|
event: model_unloaded
|
||||||
|
data: {"type":"model_unloaded"}
|
||||||
|
|
||||||
|
event: model_waiting_for_gpu
|
||||||
|
data: {"type":"model_waiting_for_gpu","vram_needed_mb":3951,"vram_free_mb":512,"retry_in_secs":30}
|
||||||
|
```
|
||||||
|
|
||||||
|
**JavaScript example:**
|
||||||
|
```javascript
|
||||||
|
const es = new EventSource('/model/events');
|
||||||
|
|
||||||
|
es.addEventListener('model_ready', () => {
|
||||||
|
console.log('Model loaded — ready to transcribe');
|
||||||
|
});
|
||||||
|
|
||||||
|
es.addEventListener('model_unloaded', () => {
|
||||||
|
console.log('Model freed GPU memory');
|
||||||
|
});
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
### Webhooks for model events
|
||||||
|
|
||||||
|
When any job is submitted with a `webhook_url`, that URL is registered to receive model lifecycle webhooks for the lifetime of the server process. The following events trigger a webhook POST:
|
||||||
|
|
||||||
|
| Event | Fired when |
|
||||||
|
|-------|-----------|
|
||||||
|
| `model_ready` | Model finishes loading (after GPU warmup) |
|
||||||
|
| `model_unloaded` | Model is freed from GPU memory |
|
||||||
|
|
||||||
|
**Webhook payload** (`Content-Type: application/json`):
|
||||||
|
```json
|
||||||
|
{ "type": "model_ready", "loaded_at": "2026-05-10T14:00:00Z" }
|
||||||
|
{ "type": "model_unloaded" }
|
||||||
|
```
|
||||||
|
|
||||||
|
Delivery is attempted up to 3 times with exponential backoff (1s, 2s).
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
### Handling 503 Model Not Ready
|
||||||
|
|
||||||
|
When you submit a job and the model is not yet loaded, you receive `503 Service Unavailable` with a `Retry-After` header:
|
||||||
|
|
||||||
|
```
|
||||||
|
HTTP/1.1 503 Service Unavailable
|
||||||
|
Retry-After: 30
|
||||||
|
Content-Type: application/json
|
||||||
|
|
||||||
|
{
|
||||||
|
"error": "model_not_ready",
|
||||||
|
"state": "unloaded",
|
||||||
|
"retry_after_secs": 30
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
| State at rejection | `retry_after_secs` | Meaning |
|
||||||
|
|---|---|---|
|
||||||
|
| `unloaded` | 30 | Load was triggered; retry after ~30s |
|
||||||
|
| `loading` | 10 | Check again in 10s |
|
||||||
|
| `waiting_for_gpu` | `GPU_POLL_INTERVAL_SECS` | VRAM contention; retry later |
|
||||||
|
|
||||||
|
A job rejection when the model is `unloaded` **automatically triggers a load** — you do not need to call `POST /model/load` separately.
|
||||||
|
|
||||||
|
**Recommended client pattern:**
|
||||||
|
```javascript
|
||||||
|
async function submitWithRetry(formData, maxAttempts = 10) {
|
||||||
|
for (let i = 0; i < maxAttempts; i++) {
|
||||||
|
const resp = await fetch('/jobs', { method: 'POST', body: formData });
|
||||||
|
if (resp.ok) return resp.json();
|
||||||
|
if (resp.status === 503) {
|
||||||
|
const retryAfter = parseInt(resp.headers.get('Retry-After') ?? '30');
|
||||||
|
const body = await resp.json();
|
||||||
|
console.log(`Model ${body.state} — retrying in ${retryAfter}s`);
|
||||||
|
await new Promise(r => setTimeout(r, retryAfter * 1000));
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
throw new Error(`Submit failed: ${resp.status}`);
|
||||||
|
}
|
||||||
|
throw new Error('Gave up after max attempts');
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## API Reference
|
||||||
|
|
||||||
|
The interactive Swagger UI is available at `http://localhost:8080/docs`.
|
||||||
|
|
||||||
### `POST /jobs` — Submit a transcription job
|
### `POST /jobs` — Submit a transcription job
|
||||||
|
|
||||||
Accepts a multipart/form-data body.
|
Accepts a multipart/form-data body.
|
||||||
@@ -249,11 +439,12 @@ curl http://localhost:8080/health
|
|||||||
"gpu_name": "NVIDIA GeForce RTX 2080",
|
"gpu_name": "NVIDIA GeForce RTX 2080",
|
||||||
"vram_total_mb": 8192,
|
"vram_total_mb": 8192,
|
||||||
"model": "large-v3",
|
"model": "large-v3",
|
||||||
"queue_depth": 0
|
"queue_depth": 0,
|
||||||
|
"model_state": "ready"
|
||||||
}
|
}
|
||||||
```
|
```
|
||||||
|
|
||||||
`queue_depth` is the number of jobs waiting to be processed (not counting the one currently running).
|
`queue_depth` is the number of jobs waiting to be processed (not counting the one currently running). `model_state` reflects the current lifecycle state (`unloaded`, `loading`, `waiting_for_gpu`, `ready`).
|
||||||
|
|
||||||
---
|
---
|
||||||
|
|
||||||
@@ -340,6 +531,11 @@ curl -X POST http://localhost:8080/jobs \
|
|||||||
|
|
||||||
## Troubleshooting
|
## Troubleshooting
|
||||||
|
|
||||||
|
### Server returns `503 model_not_ready`
|
||||||
|
- The model starts unloaded. Call `POST /model/load` explicitly, or just retry the job submission — rejection automatically triggers a load.
|
||||||
|
- If state is `waiting_for_gpu`, another process is using the GPU's VRAM. The server will retry automatically every `GPU_POLL_INTERVAL_SECS` seconds.
|
||||||
|
- Monitor `GET /model/status` or subscribe to `GET /model/events` to know when the model is ready.
|
||||||
|
|
||||||
### Server returns 0 segments
|
### Server returns 0 segments
|
||||||
- Check that you are **not** setting `language` to an empty string — omit the field entirely for auto-detection
|
- Check that you are **not** setting `language` to an empty string — omit the field entirely for auto-detection
|
||||||
- Verify the audio file is not corrupted: `ffprobe audio.mp3`
|
- Verify the audio file is not corrupted: `ffprobe audio.mp3`
|
||||||
|
|||||||
141
src/error.rs
141
src/error.rs
@@ -1,6 +1,6 @@
|
|||||||
use thiserror::Error;
|
use thiserror::Error;
|
||||||
use axum::{
|
use axum::{
|
||||||
http::StatusCode,
|
http::{StatusCode, HeaderValue, header},
|
||||||
response::{IntoResponse, Response},
|
response::{IntoResponse, Response},
|
||||||
Json,
|
Json,
|
||||||
};
|
};
|
||||||
@@ -21,19 +21,138 @@ pub enum AppError {
|
|||||||
|
|
||||||
#[error("internal error: {0}")]
|
#[error("internal error: {0}")]
|
||||||
Internal(String),
|
Internal(String),
|
||||||
|
|
||||||
|
/// Returned when `whisper_init_state` or `cudaMalloc` fails due to
|
||||||
|
/// insufficient VRAM. The worker uses this to distinguish a recoverable
|
||||||
|
/// VRAM-pressure failure from a hard internal error.
|
||||||
|
#[error("out of GPU memory: {0}")]
|
||||||
|
OutOfMemory(String),
|
||||||
|
|
||||||
|
/// Returned when a job is submitted but the model is not yet loaded.
|
||||||
|
/// Carries the current state tag and recommended Retry-After seconds.
|
||||||
|
#[error("model not ready: {state}")]
|
||||||
|
ModelNotReady { state: String, retry_after_secs: u64 },
|
||||||
|
}
|
||||||
|
|
||||||
|
impl AppError {
|
||||||
|
/// Returns true if the error string contains patterns emitted by
|
||||||
|
/// whisper.cpp / GGML when a CUDA memory allocation fails.
|
||||||
|
pub fn is_oom(msg: &str) -> bool {
|
||||||
|
msg.contains("cudaMalloc failed")
|
||||||
|
|| msg.contains("out of memory")
|
||||||
|
|| msg.contains("CUDA error: out of memory")
|
||||||
|
|| msg.contains("alloc_buffer")
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl IntoResponse for AppError {
|
impl IntoResponse for AppError {
|
||||||
fn into_response(self) -> Response {
|
fn into_response(self) -> Response {
|
||||||
let (status, message) = match &self {
|
match self {
|
||||||
AppError::NotFound(m) => (StatusCode::NOT_FOUND, m.clone()),
|
AppError::NotFound(m) => {
|
||||||
AppError::BadRequest(m) => (StatusCode::BAD_REQUEST, m.clone()),
|
(StatusCode::NOT_FOUND, Json(json!({ "error": m }))).into_response()
|
||||||
AppError::Conflict(m) => (StatusCode::CONFLICT, m.clone()),
|
}
|
||||||
AppError::Internal(m) => (StatusCode::INTERNAL_SERVER_ERROR, m.clone()),
|
AppError::BadRequest(m) => {
|
||||||
};
|
(StatusCode::BAD_REQUEST, Json(json!({ "error": m }))).into_response()
|
||||||
|
}
|
||||||
tracing::error!(status = status.as_u16(), error = %message);
|
AppError::Conflict(m) => {
|
||||||
|
(StatusCode::CONFLICT, Json(json!({ "error": m }))).into_response()
|
||||||
(status, Json(json!({ "error": message }))).into_response()
|
}
|
||||||
|
AppError::Internal(m) => {
|
||||||
|
tracing::error!(error = %m, "internal error");
|
||||||
|
(StatusCode::INTERNAL_SERVER_ERROR, Json(json!({ "error": m }))).into_response()
|
||||||
|
}
|
||||||
|
AppError::OutOfMemory(m) => {
|
||||||
|
tracing::warn!(error = %m, "GPU out of memory during model load");
|
||||||
|
(StatusCode::SERVICE_UNAVAILABLE, Json(json!({ "error": m }))).into_response()
|
||||||
|
}
|
||||||
|
AppError::ModelNotReady { state, retry_after_secs } => {
|
||||||
|
let body = Json(json!({
|
||||||
|
"error": "model_not_ready",
|
||||||
|
"state": state,
|
||||||
|
"retry_after_secs": retry_after_secs,
|
||||||
|
}));
|
||||||
|
let mut resp = (StatusCode::SERVICE_UNAVAILABLE, body).into_response();
|
||||||
|
resp.headers_mut().insert(
|
||||||
|
header::RETRY_AFTER,
|
||||||
|
HeaderValue::from_str(&retry_after_secs.to_string())
|
||||||
|
.unwrap_or(HeaderValue::from_static("30")),
|
||||||
|
);
|
||||||
|
resp
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ── Unit tests ───────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
use super::*;
|
||||||
|
use axum::body::to_bytes;
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_is_oom_cuda_malloc() {
|
||||||
|
assert!(AppError::is_oom("cudaMalloc failed: out of memory"));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_is_oom_alloc_buffer() {
|
||||||
|
// Exact message from ggml_backend_cuda_buffer_type_alloc_buffer
|
||||||
|
assert!(AppError::is_oom(
|
||||||
|
"ggml_backend_cuda_buffer_type_alloc_buffer: allocating 2951.01 MiB on device 0: cudaMalloc failed: out of memory"
|
||||||
|
));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_is_oom_generic_out_of_memory() {
|
||||||
|
assert!(AppError::is_oom("CUDA error: out of memory"));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_is_oom_other_error() {
|
||||||
|
assert!(!AppError::is_oom("failed to open model file"));
|
||||||
|
assert!(!AppError::is_oom("invalid model format"));
|
||||||
|
assert!(!AppError::is_oom(""));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_model_not_ready_response_has_retry_after_header() {
|
||||||
|
let err = AppError::ModelNotReady { state: "loading".into(), retry_after_secs: 10 };
|
||||||
|
let resp = err.into_response();
|
||||||
|
assert_eq!(resp.status(), StatusCode::SERVICE_UNAVAILABLE);
|
||||||
|
let retry_after = resp.headers().get(header::RETRY_AFTER)
|
||||||
|
.expect("Retry-After header missing");
|
||||||
|
assert_eq!(retry_after, "10");
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_model_not_ready_response_body() {
|
||||||
|
let err = AppError::ModelNotReady { state: "unloaded".into(), retry_after_secs: 30 };
|
||||||
|
let resp = err.into_response();
|
||||||
|
let bytes = to_bytes(resp.into_body(), usize::MAX).await.unwrap();
|
||||||
|
let v: serde_json::Value = serde_json::from_slice(&bytes).unwrap();
|
||||||
|
assert_eq!(v["error"], "model_not_ready");
|
||||||
|
assert_eq!(v["state"], "unloaded");
|
||||||
|
assert_eq!(v["retry_after_secs"], 30);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_model_not_ready_loading_retry_after_10() {
|
||||||
|
let err = AppError::ModelNotReady { state: "loading".into(), retry_after_secs: 10 };
|
||||||
|
let resp = err.into_response();
|
||||||
|
assert_eq!(
|
||||||
|
resp.headers().get(header::RETRY_AFTER).unwrap(),
|
||||||
|
"10"
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_model_not_ready_unloaded_retry_after_30() {
|
||||||
|
let err = AppError::ModelNotReady { state: "unloaded".into(), retry_after_secs: 30 };
|
||||||
|
let resp = err.into_response();
|
||||||
|
assert_eq!(
|
||||||
|
resp.headers().get(header::RETRY_AFTER).unwrap(),
|
||||||
|
"30"
|
||||||
|
);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
60
src/main.rs
60
src/main.rs
@@ -1,7 +1,7 @@
|
|||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
|
|
||||||
use axum::Router;
|
use axum::Router;
|
||||||
use tokio::sync::mpsc;
|
use tokio::sync::{broadcast, mpsc, RwLock};
|
||||||
use tower_http::{cors::CorsLayer, trace::TraceLayer};
|
use tower_http::{cors::CorsLayer, trace::TraceLayer};
|
||||||
use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt, EnvFilter};
|
use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt, EnvFilter};
|
||||||
use utoipa::OpenApi;
|
use utoipa::OpenApi;
|
||||||
@@ -21,8 +21,10 @@ pub use error::{AppError, Result};
|
|||||||
|
|
||||||
#[derive(Clone)]
|
#[derive(Clone)]
|
||||||
pub struct AppState {
|
pub struct AppState {
|
||||||
/// Channel to submit jobs to the single GPU worker.
|
/// Channel to submit jobs to the single GPU worker (job IDs only).
|
||||||
pub job_tx: mpsc::UnboundedSender<models::JobId>,
|
pub job_tx: mpsc::UnboundedSender<models::JobId>,
|
||||||
|
/// Channel to send control commands to the worker OS thread.
|
||||||
|
pub cmd_tx: std::sync::mpsc::SyncSender<worker::WorkerCmd>,
|
||||||
/// Shared handle to the on-disk job store.
|
/// Shared handle to the on-disk job store.
|
||||||
pub storage: Arc<storage::Storage>,
|
pub storage: Arc<storage::Storage>,
|
||||||
/// SSE broadcast registry: job_id → sender.
|
/// SSE broadcast registry: job_id → sender.
|
||||||
@@ -33,6 +35,17 @@ pub struct AppState {
|
|||||||
pub queue_depth: Arc<std::sync::atomic::AtomicUsize>,
|
pub queue_depth: Arc<std::sync::atomic::AtomicUsize>,
|
||||||
/// CUDA device index used for inference.
|
/// CUDA device index used for inference.
|
||||||
pub gpu_device: u32,
|
pub gpu_device: u32,
|
||||||
|
/// Current state of the whisper model.
|
||||||
|
pub model_state: Arc<RwLock<models::ModelState>>,
|
||||||
|
/// Broadcast channel for model lifecycle events (SSE + webhooks).
|
||||||
|
pub model_event_tx: broadcast::Sender<models::ModelEvent>,
|
||||||
|
/// All webhook URLs ever registered via job submission.
|
||||||
|
/// Used to fire model_ready / model_unloaded notifications.
|
||||||
|
pub webhook_registry: Arc<std::sync::Mutex<std::collections::HashSet<String>>>,
|
||||||
|
/// How long the model stays loaded with no active jobs.
|
||||||
|
pub idle_timeout: std::time::Duration,
|
||||||
|
/// How often to retry loading when GPU is busy.
|
||||||
|
pub gpu_poll_interval: std::time::Duration,
|
||||||
}
|
}
|
||||||
|
|
||||||
// ── OpenAPI spec root ────────────────────────────────────────────────────────
|
// ── OpenAPI spec root ────────────────────────────────────────────────────────
|
||||||
@@ -50,6 +63,10 @@ pub struct AppState {
|
|||||||
routes::jobs::stream_job,
|
routes::jobs::stream_job,
|
||||||
routes::jobs::delete_job,
|
routes::jobs::delete_job,
|
||||||
routes::health::health,
|
routes::health::health,
|
||||||
|
routes::model::model_status,
|
||||||
|
routes::model::model_load,
|
||||||
|
routes::model::model_unload,
|
||||||
|
routes::model::model_events,
|
||||||
),
|
),
|
||||||
components(schemas(
|
components(schemas(
|
||||||
models::Job,
|
models::Job,
|
||||||
@@ -58,10 +75,14 @@ pub struct AppState {
|
|||||||
models::Word,
|
models::Word,
|
||||||
models::SubmitResponse,
|
models::SubmitResponse,
|
||||||
models::HealthResponse,
|
models::HealthResponse,
|
||||||
|
models::ModelState,
|
||||||
|
models::ModelEvent,
|
||||||
|
models::ModelStatusResponse,
|
||||||
)),
|
)),
|
||||||
tags(
|
tags(
|
||||||
(name = "jobs", description = "Transcription job management"),
|
(name = "jobs", description = "Transcription job management"),
|
||||||
(name = "system", description = "Service health"),
|
(name = "system", description = "Service health"),
|
||||||
|
(name = "model", description = "Model lifecycle management"),
|
||||||
)
|
)
|
||||||
)]
|
)]
|
||||||
struct ApiDoc;
|
struct ApiDoc;
|
||||||
@@ -85,6 +106,20 @@ async fn main() -> anyhow::Result<()> {
|
|||||||
.ok()
|
.ok()
|
||||||
.and_then(|s| s.parse().ok())
|
.and_then(|s| s.parse().ok())
|
||||||
.unwrap_or(0);
|
.unwrap_or(0);
|
||||||
|
let idle_timeout_secs: u64 = std::env::var("IDLE_TIMEOUT_SECS")
|
||||||
|
.ok()
|
||||||
|
.and_then(|s| s.parse().ok())
|
||||||
|
.unwrap_or(300);
|
||||||
|
let gpu_poll_interval_secs: u64 = std::env::var("GPU_POLL_INTERVAL_SECS")
|
||||||
|
.ok()
|
||||||
|
.and_then(|s| s.parse().ok())
|
||||||
|
.unwrap_or(30);
|
||||||
|
|
||||||
|
tracing::info!(
|
||||||
|
idle_timeout_secs,
|
||||||
|
gpu_poll_interval_secs,
|
||||||
|
"dynamic model loading configured"
|
||||||
|
);
|
||||||
|
|
||||||
let storage = Arc::new(storage::Storage::new(&data_dir).await?);
|
let storage = Arc::new(storage::Storage::new(&data_dir).await?);
|
||||||
|
|
||||||
@@ -94,28 +129,45 @@ async fn main() -> anyhow::Result<()> {
|
|||||||
let (job_tx, job_rx) = mpsc::unbounded_channel::<models::JobId>();
|
let (job_tx, job_rx) = mpsc::unbounded_channel::<models::JobId>();
|
||||||
let queue_depth = Arc::new(std::sync::atomic::AtomicUsize::new(0));
|
let queue_depth = Arc::new(std::sync::atomic::AtomicUsize::new(0));
|
||||||
|
|
||||||
// Spawn single GPU worker; get back the SSE broadcast registry.
|
// Model starts unloaded — lazy load on first job or POST /model/load.
|
||||||
let progress = worker::start(
|
let model_state = Arc::new(RwLock::new(models::ModelState::Unloaded));
|
||||||
|
let (model_event_tx, _) = broadcast::channel::<models::ModelEvent>(32);
|
||||||
|
let webhook_registry = Arc::new(std::sync::Mutex::new(std::collections::HashSet::<String>::new()));
|
||||||
|
|
||||||
|
// Spawn single GPU worker; get back the SSE broadcast registry and cmd channel.
|
||||||
|
let (progress, cmd_tx) = worker::start(
|
||||||
job_rx,
|
job_rx,
|
||||||
Arc::clone(&storage),
|
Arc::clone(&storage),
|
||||||
model_path.clone().into(),
|
model_path.clone().into(),
|
||||||
Arc::clone(&queue_depth),
|
Arc::clone(&queue_depth),
|
||||||
gpu_device,
|
gpu_device,
|
||||||
|
Arc::clone(&model_state),
|
||||||
|
model_event_tx.clone(),
|
||||||
|
Arc::clone(&webhook_registry),
|
||||||
|
std::time::Duration::from_secs(idle_timeout_secs),
|
||||||
|
std::time::Duration::from_secs(gpu_poll_interval_secs),
|
||||||
);
|
);
|
||||||
|
|
||||||
let state = AppState {
|
let state = AppState {
|
||||||
job_tx,
|
job_tx,
|
||||||
|
cmd_tx,
|
||||||
storage: Arc::clone(&storage),
|
storage: Arc::clone(&storage),
|
||||||
progress,
|
progress,
|
||||||
model_name: model_name.as_str().into(),
|
model_name: model_name.as_str().into(),
|
||||||
queue_depth: Arc::clone(&queue_depth),
|
queue_depth: Arc::clone(&queue_depth),
|
||||||
gpu_device,
|
gpu_device,
|
||||||
|
model_state,
|
||||||
|
model_event_tx,
|
||||||
|
webhook_registry,
|
||||||
|
idle_timeout: std::time::Duration::from_secs(idle_timeout_secs),
|
||||||
|
gpu_poll_interval: std::time::Duration::from_secs(gpu_poll_interval_secs),
|
||||||
};
|
};
|
||||||
|
|
||||||
let app = Router::new()
|
let app = Router::new()
|
||||||
.merge(SwaggerUi::new("/docs").url("/openapi.json", ApiDoc::openapi()))
|
.merge(SwaggerUi::new("/docs").url("/openapi.json", ApiDoc::openapi()))
|
||||||
.merge(routes::jobs_router())
|
.merge(routes::jobs_router())
|
||||||
.merge(routes::health_router())
|
.merge(routes::health_router())
|
||||||
|
.merge(routes::model_router())
|
||||||
.with_state(state)
|
.with_state(state)
|
||||||
.layer(CorsLayer::permissive())
|
.layer(CorsLayer::permissive())
|
||||||
.layer(TraceLayer::new_for_http());
|
.layer(TraceLayer::new_for_http());
|
||||||
|
|||||||
246
src/models.rs
246
src/models.rs
@@ -5,6 +5,116 @@ use uuid::Uuid;
|
|||||||
|
|
||||||
pub type JobId = Uuid;
|
pub type JobId = Uuid;
|
||||||
|
|
||||||
|
// ── Model lifecycle state ────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
/// Current state of the whisper model in memory.
|
||||||
|
///
|
||||||
|
/// State machine:
|
||||||
|
/// ```
|
||||||
|
/// Unloaded ──(load trigger)──► Loading ──(ok)──► Ready ──(idle/unload)──► Unloaded
|
||||||
|
/// └──(VRAM full)──► WaitingForGpu ──(retry)──► Loading
|
||||||
|
/// WaitingForGpu ──(unload cmd)──► Unloaded
|
||||||
|
/// ```
|
||||||
|
#[derive(Debug, Clone, Serialize, Deserialize, ToSchema)]
|
||||||
|
#[serde(tag = "state", rename_all = "snake_case")]
|
||||||
|
pub enum ModelState {
|
||||||
|
/// Model is not in memory. GPU is free.
|
||||||
|
Unloaded,
|
||||||
|
/// Model is being loaded (weights transferred to GPU).
|
||||||
|
Loading,
|
||||||
|
/// A previous load attempt failed due to insufficient VRAM. The worker is
|
||||||
|
/// polling at `retry_in_secs` intervals until enough memory is available.
|
||||||
|
WaitingForGpu {
|
||||||
|
/// VRAM required to load the model, in MiB.
|
||||||
|
vram_needed_mb: u64,
|
||||||
|
/// VRAM currently free on the device, in MiB.
|
||||||
|
vram_free_mb: u64,
|
||||||
|
/// How many seconds until the next load attempt.
|
||||||
|
retry_in_secs: u64,
|
||||||
|
},
|
||||||
|
/// Model is loaded and ready to accept inference jobs.
|
||||||
|
Ready {
|
||||||
|
/// UTC timestamp of when the model finished loading (post-warmup).
|
||||||
|
loaded_at: DateTime<Utc>,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
impl ModelState {
|
||||||
|
/// Returns true if the model can accept inference jobs right now.
|
||||||
|
pub fn is_ready(&self) -> bool {
|
||||||
|
matches!(self, ModelState::Ready { .. })
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Suggested `Retry-After` value (seconds) to include in 503 responses.
|
||||||
|
pub fn retry_after_secs(&self) -> u64 {
|
||||||
|
match self {
|
||||||
|
ModelState::Unloaded => 30, // conservative load estimate
|
||||||
|
ModelState::Loading => 10,
|
||||||
|
ModelState::WaitingForGpu { retry_in_secs, .. } => *retry_in_secs,
|
||||||
|
ModelState::Ready { .. } => 0, // shouldn't 503 if ready
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// String tag for use in error response bodies and log fields.
|
||||||
|
pub fn tag(&self) -> &'static str {
|
||||||
|
match self {
|
||||||
|
ModelState::Unloaded => "unloaded",
|
||||||
|
ModelState::Loading => "loading",
|
||||||
|
ModelState::WaitingForGpu{..} => "waiting_for_gpu",
|
||||||
|
ModelState::Ready{..} => "ready",
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ── Model events (SSE + webhooks) ────────────────────────────────────────────
|
||||||
|
|
||||||
|
/// Events broadcast over the `GET /model/events` SSE stream and fired as
|
||||||
|
/// webhooks to registered clients.
|
||||||
|
///
|
||||||
|
/// Webhook delivery: only `ModelReady` and `ModelUnloaded` are sent to
|
||||||
|
/// webhook URLs. All four are broadcast on the SSE stream.
|
||||||
|
#[derive(Debug, Clone, Serialize, Deserialize, ToSchema)]
|
||||||
|
#[serde(tag = "type", rename_all = "snake_case")]
|
||||||
|
pub enum ModelEvent {
|
||||||
|
/// Model finished loading and the GPU warmup completed — ready to accept jobs.
|
||||||
|
ModelReady {
|
||||||
|
loaded_at: DateTime<Utc>,
|
||||||
|
},
|
||||||
|
/// Model was unloaded from GPU memory (idle timeout or manual unload).
|
||||||
|
ModelUnloaded,
|
||||||
|
/// Model load initiated.
|
||||||
|
ModelLoading,
|
||||||
|
/// Load failed due to insufficient VRAM; retrying after `retry_in_secs`.
|
||||||
|
ModelWaitingForGpu {
|
||||||
|
vram_needed_mb: u64,
|
||||||
|
vram_free_mb: u64,
|
||||||
|
retry_in_secs: u64,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
impl ModelEvent {
|
||||||
|
/// Returns true if this event should be delivered via webhook.
|
||||||
|
pub fn is_webhook_event(&self) -> bool {
|
||||||
|
matches!(self, ModelEvent::ModelReady { .. } | ModelEvent::ModelUnloaded)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ── Model status response ────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
/// Response body for `GET /model/status`.
|
||||||
|
#[derive(Debug, Serialize, Deserialize, ToSchema)]
|
||||||
|
pub struct ModelStatusResponse {
|
||||||
|
/// Current model state (flattened from `ModelState`).
|
||||||
|
#[serde(flatten)]
|
||||||
|
pub state: ModelState,
|
||||||
|
/// VRAM currently used on the device, in MiB (from nvidia-smi).
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub vram_used_mb: Option<u64>,
|
||||||
|
/// VRAM total on the device, in MiB.
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub vram_total_mb: Option<u64>,
|
||||||
|
}
|
||||||
|
|
||||||
// ── Job status ───────────────────────────────────────────────────────────────
|
// ── Job status ───────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize, ToSchema)]
|
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize, ToSchema)]
|
||||||
@@ -130,6 +240,8 @@ pub struct HealthResponse {
|
|||||||
pub vram_total_mb: Option<u64>,
|
pub vram_total_mb: Option<u64>,
|
||||||
pub model: String,
|
pub model: String,
|
||||||
pub queue_depth: usize,
|
pub queue_depth: usize,
|
||||||
|
/// Current state of the whisper model.
|
||||||
|
pub model_state: String,
|
||||||
}
|
}
|
||||||
|
|
||||||
// ── SSE event payload ────────────────────────────────────────────────────────
|
// ── SSE event payload ────────────────────────────────────────────────────────
|
||||||
@@ -148,3 +260,137 @@ pub enum SsePayload {
|
|||||||
Done { job: Box<Job> },
|
Done { job: Box<Job> },
|
||||||
Error { message: String },
|
Error { message: String },
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ── Unit tests ───────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
use super::*;
|
||||||
|
use serde_json::Value;
|
||||||
|
|
||||||
|
// ── ModelState serialization ─────────────────────────────────────────────
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_model_state_unloaded_serializes() {
|
||||||
|
let v: Value = serde_json::to_value(ModelState::Unloaded).unwrap();
|
||||||
|
assert_eq!(v["state"], "unloaded");
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_model_state_loading_serializes() {
|
||||||
|
let v: Value = serde_json::to_value(ModelState::Loading).unwrap();
|
||||||
|
assert_eq!(v["state"], "loading");
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_model_state_waiting_serializes() {
|
||||||
|
let s = ModelState::WaitingForGpu { vram_needed_mb: 3000, vram_free_mb: 500, retry_in_secs: 30 };
|
||||||
|
let v: Value = serde_json::to_value(&s).unwrap();
|
||||||
|
assert_eq!(v["state"], "waiting_for_gpu");
|
||||||
|
assert_eq!(v["vram_needed_mb"], 3000);
|
||||||
|
assert_eq!(v["vram_free_mb"], 500);
|
||||||
|
assert_eq!(v["retry_in_secs"], 30);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_model_state_ready_serializes() {
|
||||||
|
let ts = Utc::now();
|
||||||
|
let s = ModelState::Ready { loaded_at: ts };
|
||||||
|
let v: Value = serde_json::to_value(&s).unwrap();
|
||||||
|
assert_eq!(v["state"], "ready");
|
||||||
|
assert!(v["loaded_at"].is_string());
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_model_state_is_ready() {
|
||||||
|
assert!(!ModelState::Unloaded.is_ready());
|
||||||
|
assert!(!ModelState::Loading.is_ready());
|
||||||
|
assert!(!ModelState::WaitingForGpu { vram_needed_mb: 0, vram_free_mb: 0, retry_in_secs: 30 }.is_ready());
|
||||||
|
assert!(ModelState::Ready { loaded_at: Utc::now() }.is_ready());
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_retry_after_unloaded() {
|
||||||
|
assert_eq!(ModelState::Unloaded.retry_after_secs(), 30);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_retry_after_loading() {
|
||||||
|
assert_eq!(ModelState::Loading.retry_after_secs(), 10);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_retry_after_waiting_for_gpu() {
|
||||||
|
let s = ModelState::WaitingForGpu { vram_needed_mb: 0, vram_free_mb: 0, retry_in_secs: 45 };
|
||||||
|
assert_eq!(s.retry_after_secs(), 45);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_retry_after_ready_is_zero() {
|
||||||
|
assert_eq!(ModelState::Ready { loaded_at: Utc::now() }.retry_after_secs(), 0);
|
||||||
|
}
|
||||||
|
|
||||||
|
// ── ModelEvent serialization ─────────────────────────────────────────────
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_model_event_ready_serializes() {
|
||||||
|
let ts = Utc::now();
|
||||||
|
let e = ModelEvent::ModelReady { loaded_at: ts };
|
||||||
|
let v: Value = serde_json::to_value(&e).unwrap();
|
||||||
|
assert_eq!(v["type"], "model_ready");
|
||||||
|
assert!(v["loaded_at"].is_string());
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_model_event_unloaded_serializes() {
|
||||||
|
let v: Value = serde_json::to_value(ModelEvent::ModelUnloaded).unwrap();
|
||||||
|
assert_eq!(v["type"], "model_unloaded");
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_model_event_loading_serializes() {
|
||||||
|
let v: Value = serde_json::to_value(ModelEvent::ModelLoading).unwrap();
|
||||||
|
assert_eq!(v["type"], "model_loading");
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_model_event_waiting_serializes() {
|
||||||
|
let e = ModelEvent::ModelWaitingForGpu { vram_needed_mb: 3000, vram_free_mb: 200, retry_in_secs: 30 };
|
||||||
|
let v: Value = serde_json::to_value(&e).unwrap();
|
||||||
|
assert_eq!(v["type"], "model_waiting_for_gpu");
|
||||||
|
assert_eq!(v["vram_needed_mb"], 3000);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_model_event_webhook_filter() {
|
||||||
|
assert!(ModelEvent::ModelReady { loaded_at: Utc::now() }.is_webhook_event());
|
||||||
|
assert!(ModelEvent::ModelUnloaded.is_webhook_event());
|
||||||
|
assert!(!ModelEvent::ModelLoading.is_webhook_event());
|
||||||
|
assert!(!ModelEvent::ModelWaitingForGpu { vram_needed_mb: 0, vram_free_mb: 0, retry_in_secs: 30 }.is_webhook_event());
|
||||||
|
}
|
||||||
|
|
||||||
|
// ── ModelStatusResponse ──────────────────────────────────────────────────
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_model_status_response_roundtrip() {
|
||||||
|
let r = ModelStatusResponse {
|
||||||
|
state: ModelState::Ready { loaded_at: Utc::now() },
|
||||||
|
vram_used_mb: Some(4096),
|
||||||
|
vram_total_mb: Some(8192),
|
||||||
|
};
|
||||||
|
let json_str = serde_json::to_string(&r).unwrap();
|
||||||
|
let v: Value = serde_json::from_str(&json_str).unwrap();
|
||||||
|
assert_eq!(v["state"], "ready");
|
||||||
|
assert_eq!(v["vram_used_mb"], 4096);
|
||||||
|
assert_eq!(v["vram_total_mb"], 8192);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_model_status_response_omits_nulls() {
|
||||||
|
let r = ModelStatusResponse { state: ModelState::Loading, vram_used_mb: None, vram_total_mb: None };
|
||||||
|
let v: Value = serde_json::to_value(&r).unwrap();
|
||||||
|
assert_eq!(v["state"], "loading");
|
||||||
|
assert!(v.get("vram_used_mb").is_none());
|
||||||
|
assert!(v.get("vram_total_mb").is_none());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
@@ -16,6 +16,7 @@ use crate::{models::HealthResponse, AppState, Result};
|
|||||||
)]
|
)]
|
||||||
pub async fn health(State(state): State<AppState>) -> Result<Json<HealthResponse>> {
|
pub async fn health(State(state): State<AppState>) -> Result<Json<HealthResponse>> {
|
||||||
let (gpu_name, vram_total_mb) = gpu_info(state.gpu_device);
|
let (gpu_name, vram_total_mb) = gpu_info(state.gpu_device);
|
||||||
|
let model_state_tag = state.model_state.read().await.tag().to_string();
|
||||||
|
|
||||||
Ok(Json(HealthResponse {
|
Ok(Json(HealthResponse {
|
||||||
status: "ok".into(),
|
status: "ok".into(),
|
||||||
@@ -23,6 +24,7 @@ pub async fn health(State(state): State<AppState>) -> Result<Json<HealthResponse
|
|||||||
vram_total_mb,
|
vram_total_mb,
|
||||||
model: state.model_name.to_string(),
|
model: state.model_name.to_string(),
|
||||||
queue_depth: state.queue_depth.load(Ordering::Relaxed),
|
queue_depth: state.queue_depth.load(Ordering::Relaxed),
|
||||||
|
model_state: model_state_tag,
|
||||||
}))
|
}))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -19,7 +19,7 @@ use uuid::Uuid;
|
|||||||
|
|
||||||
use crate::{
|
use crate::{
|
||||||
models::{Job, JobId, JobStatus, SubmitResponse},
|
models::{Job, JobId, JobStatus, SubmitResponse},
|
||||||
worker::{audio_path_for, ProgressEvent},
|
worker::{audio_path_for, ProgressEvent, WorkerCmd},
|
||||||
AppError, AppState, Result,
|
AppError, AppState, Result,
|
||||||
};
|
};
|
||||||
|
|
||||||
@@ -107,6 +107,36 @@ pub async fn submit_job(
|
|||||||
));
|
));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Check model state before accepting the job.
|
||||||
|
let (model_ready, retry_after_secs, state_tag) = {
|
||||||
|
let ms = state.model_state.read().await;
|
||||||
|
let ready = ms.is_ready();
|
||||||
|
let retry = ms.retry_after_secs();
|
||||||
|
let tag = ms.tag().to_string();
|
||||||
|
(ready, retry, tag)
|
||||||
|
};
|
||||||
|
|
||||||
|
// Register the webhook URL regardless of model state — so model lifecycle
|
||||||
|
// events are delivered even if the job itself is rejected.
|
||||||
|
if let Some(url) = &webhook_url {
|
||||||
|
state.webhook_registry.lock()
|
||||||
|
.unwrap_or_else(|e| e.into_inner())
|
||||||
|
.insert(url.clone());
|
||||||
|
}
|
||||||
|
|
||||||
|
if !model_ready {
|
||||||
|
// Trigger a load if the model is simply unloaded (not already loading).
|
||||||
|
if state_tag == "unloaded" {
|
||||||
|
let _ = state.cmd_tx.try_send(WorkerCmd::Load);
|
||||||
|
}
|
||||||
|
// Clean up the audio file we already wrote to disk.
|
||||||
|
let _ = tokio::fs::remove_file(&audio_path).await;
|
||||||
|
return Err(AppError::ModelNotReady {
|
||||||
|
state: state_tag,
|
||||||
|
retry_after_secs,
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
let mut job = Job::new(id, task, webhook_url, filename);
|
let mut job = Job::new(id, task, webhook_url, filename);
|
||||||
job.language = language;
|
job.language = language;
|
||||||
|
|
||||||
|
|||||||
@@ -1,5 +1,6 @@
|
|||||||
pub mod health;
|
pub mod health;
|
||||||
pub mod jobs;
|
pub mod jobs;
|
||||||
|
pub mod model;
|
||||||
|
|
||||||
use axum::{extract::DefaultBodyLimit, routing::{delete, get, post}, Router};
|
use axum::{extract::DefaultBodyLimit, routing::{delete, get, post}, Router};
|
||||||
use crate::AppState;
|
use crate::AppState;
|
||||||
@@ -17,3 +18,11 @@ pub fn health_router() -> Router<AppState> {
|
|||||||
Router::new()
|
Router::new()
|
||||||
.route("/health", get(health::health))
|
.route("/health", get(health::health))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub fn model_router() -> Router<AppState> {
|
||||||
|
Router::new()
|
||||||
|
.route("/model/status", get(model::model_status))
|
||||||
|
.route("/model/load", post(model::model_load))
|
||||||
|
.route("/model/unload", post(model::model_unload))
|
||||||
|
.route("/model/events", get(model::model_events))
|
||||||
|
}
|
||||||
|
|||||||
158
src/routes/model.rs
Normal file
158
src/routes/model.rs
Normal file
@@ -0,0 +1,158 @@
|
|||||||
|
use std::pin::Pin;
|
||||||
|
|
||||||
|
use axum::{
|
||||||
|
extract::State,
|
||||||
|
http::StatusCode,
|
||||||
|
response::{
|
||||||
|
sse::{Event, KeepAlive, Sse},
|
||||||
|
IntoResponse,
|
||||||
|
},
|
||||||
|
Json,
|
||||||
|
};
|
||||||
|
use futures::Stream;
|
||||||
|
use tokio_stream::wrappers::BroadcastStream;
|
||||||
|
use futures::StreamExt;
|
||||||
|
|
||||||
|
use crate::{
|
||||||
|
models::{ModelEvent, ModelStatusResponse},
|
||||||
|
worker::WorkerCmd,
|
||||||
|
AppState, Result,
|
||||||
|
};
|
||||||
|
|
||||||
|
type SseStream = Pin<Box<dyn Stream<Item = std::result::Result<Event, std::convert::Infallible>> + Send>>;
|
||||||
|
|
||||||
|
// ── GET /model/status ────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
/// Return the current model state and VRAM statistics.
|
||||||
|
#[utoipa::path(
|
||||||
|
get,
|
||||||
|
path = "/model/status",
|
||||||
|
tag = "model",
|
||||||
|
responses(
|
||||||
|
(status = 200, description = "Model status", body = ModelStatusResponse),
|
||||||
|
)
|
||||||
|
)]
|
||||||
|
pub async fn model_status(State(state): State<AppState>) -> Result<Json<ModelStatusResponse>> {
|
||||||
|
let model_state = state.model_state.read().await.clone();
|
||||||
|
let (vram_used_mb, vram_total_mb) = vram_stats(state.gpu_device);
|
||||||
|
|
||||||
|
Ok(Json(ModelStatusResponse {
|
||||||
|
state: model_state,
|
||||||
|
vram_used_mb,
|
||||||
|
vram_total_mb,
|
||||||
|
}))
|
||||||
|
}
|
||||||
|
|
||||||
|
// ── POST /model/load ─────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
/// Request the model to be loaded into GPU memory.
|
||||||
|
/// Idempotent: if the model is already loading or ready, this is a no-op.
|
||||||
|
/// Returns 202 Accepted; poll `GET /model/status` or subscribe to
|
||||||
|
/// `GET /model/events` to know when it is ready.
|
||||||
|
#[utoipa::path(
|
||||||
|
post,
|
||||||
|
path = "/model/load",
|
||||||
|
tag = "model",
|
||||||
|
responses(
|
||||||
|
(status = 202, description = "Load initiated or already in progress"),
|
||||||
|
(status = 200, description = "Model already ready"),
|
||||||
|
)
|
||||||
|
)]
|
||||||
|
pub async fn model_load(State(state): State<AppState>) -> impl IntoResponse {
|
||||||
|
let is_ready = state.model_state.read().await.is_ready();
|
||||||
|
if is_ready {
|
||||||
|
return (StatusCode::OK, Json(serde_json::json!({"status": "already_ready"})));
|
||||||
|
}
|
||||||
|
// Ignore send errors (channel full = load already in progress).
|
||||||
|
let _ = state.cmd_tx.try_send(WorkerCmd::Load);
|
||||||
|
(StatusCode::ACCEPTED, Json(serde_json::json!({"status": "load_initiated"})))
|
||||||
|
}
|
||||||
|
|
||||||
|
// ── POST /model/unload ───────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
/// Unload the model from GPU memory immediately.
|
||||||
|
/// Idempotent: if the model is already unloaded, returns 200 immediately.
|
||||||
|
#[utoipa::path(
|
||||||
|
post,
|
||||||
|
path = "/model/unload",
|
||||||
|
tag = "model",
|
||||||
|
responses(
|
||||||
|
(status = 200, description = "Model unloaded or was already unloaded"),
|
||||||
|
)
|
||||||
|
)]
|
||||||
|
pub async fn model_unload(State(state): State<AppState>) -> impl IntoResponse {
|
||||||
|
let _ = state.cmd_tx.try_send(WorkerCmd::Unload);
|
||||||
|
(StatusCode::OK, Json(serde_json::json!({"status": "unload_requested"})))
|
||||||
|
}
|
||||||
|
|
||||||
|
// ── GET /model/events ────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
/// Subscribe to model lifecycle events via Server-Sent Events.
|
||||||
|
///
|
||||||
|
/// Event types:
|
||||||
|
/// - `model_loading` — load initiated
|
||||||
|
/// - `model_ready` — model loaded and warmed up
|
||||||
|
/// - `model_unloaded` — model freed from GPU memory
|
||||||
|
/// - `model_waiting_for_gpu` — insufficient VRAM; retrying
|
||||||
|
#[utoipa::path(
|
||||||
|
get,
|
||||||
|
path = "/model/events",
|
||||||
|
tag = "model",
|
||||||
|
responses(
|
||||||
|
(status = 200, description = "SSE stream of model lifecycle events"),
|
||||||
|
)
|
||||||
|
)]
|
||||||
|
pub async fn model_events(State(state): State<AppState>) -> Sse<SseStream> {
|
||||||
|
let rx = state.model_event_tx.subscribe();
|
||||||
|
|
||||||
|
let stream: SseStream = Box::pin(
|
||||||
|
BroadcastStream::new(rx).filter_map(|msg| async move {
|
||||||
|
match msg {
|
||||||
|
Ok(event) => {
|
||||||
|
let event_type = match &event {
|
||||||
|
ModelEvent::ModelReady { .. } => "model_ready",
|
||||||
|
ModelEvent::ModelUnloaded => "model_unloaded",
|
||||||
|
ModelEvent::ModelLoading => "model_loading",
|
||||||
|
ModelEvent::ModelWaitingForGpu {..} => "model_waiting_for_gpu",
|
||||||
|
};
|
||||||
|
let data = serde_json::to_string(&event).ok()?;
|
||||||
|
Some(Ok(Event::default().event(event_type).data(data)))
|
||||||
|
}
|
||||||
|
Err(_) => None,
|
||||||
|
}
|
||||||
|
})
|
||||||
|
);
|
||||||
|
|
||||||
|
Sse::new(stream).keep_alive(KeepAlive::default())
|
||||||
|
}
|
||||||
|
|
||||||
|
// ── Helpers ───────────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
fn vram_stats(gpu_device: u32) -> (Option<u64>, Option<u64>) {
|
||||||
|
fn inner(gpu_device: u32) -> Option<(u64, u64)> {
|
||||||
|
let out = std::process::Command::new("nvidia-smi")
|
||||||
|
.args([
|
||||||
|
&format!("--id={gpu_device}"),
|
||||||
|
"--query-gpu=memory.used,memory.total",
|
||||||
|
"--format=csv,noheader,nounits",
|
||||||
|
])
|
||||||
|
.output()
|
||||||
|
.ok()?;
|
||||||
|
|
||||||
|
if !out.status.success() {
|
||||||
|
return None;
|
||||||
|
}
|
||||||
|
|
||||||
|
let line = String::from_utf8_lossy(&out.stdout);
|
||||||
|
let line = line.trim();
|
||||||
|
let mut parts = line.splitn(2, ',');
|
||||||
|
let used = parts.next().and_then(|s| s.trim().parse::<u64>().ok())?;
|
||||||
|
let total = parts.next().and_then(|s| s.trim().parse::<u64>().ok())?;
|
||||||
|
Some((used, total))
|
||||||
|
}
|
||||||
|
|
||||||
|
match inner(gpu_device) {
|
||||||
|
Some((u, t)) => (Some(u), Some(t)),
|
||||||
|
None => (None, None),
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -49,10 +49,24 @@ impl Transcriber {
|
|||||||
// params.flash_attn(true);
|
// params.flash_attn(true);
|
||||||
|
|
||||||
let ctx = WhisperContext::new_with_params(path, params)
|
let ctx = WhisperContext::new_with_params(path, params)
|
||||||
.map_err(|e| AppError::Internal(format!("failed to load model: {e}")))?;
|
.map_err(|e| {
|
||||||
|
let msg = format!("failed to load model: {e}");
|
||||||
|
if AppError::is_oom(&msg) {
|
||||||
|
AppError::OutOfMemory(msg)
|
||||||
|
} else {
|
||||||
|
AppError::Internal(msg)
|
||||||
|
}
|
||||||
|
})?;
|
||||||
|
|
||||||
let mut state = ctx.create_state()
|
let mut state = ctx.create_state()
|
||||||
.map_err(|e| AppError::Internal(format!("failed to create whisper state: {e}")))?;
|
.map_err(|e| {
|
||||||
|
let msg = format!("failed to create whisper state: {e}");
|
||||||
|
if AppError::is_oom(&msg) {
|
||||||
|
AppError::OutOfMemory(msg)
|
||||||
|
} else {
|
||||||
|
AppError::Internal(msg)
|
||||||
|
}
|
||||||
|
})?;
|
||||||
// ctx drops here; state holds Arc<WhisperInnerContext> so model stays loaded.
|
// ctx drops here; state holds Arc<WhisperInnerContext> so model stays loaded.
|
||||||
|
|
||||||
// ── GPU warmup ────────────────────────────────────────────────────────
|
// ── GPU warmup ────────────────────────────────────────────────────────
|
||||||
|
|||||||
482
src/worker.rs
482
src/worker.rs
@@ -1,20 +1,23 @@
|
|||||||
use std::{
|
use std::{
|
||||||
|
collections::HashSet,
|
||||||
path::PathBuf,
|
path::PathBuf,
|
||||||
sync::{
|
sync::{
|
||||||
atomic::{AtomicUsize, Ordering},
|
atomic::{AtomicUsize, Ordering},
|
||||||
Arc,
|
Arc, Mutex,
|
||||||
},
|
},
|
||||||
|
time::{Duration, Instant},
|
||||||
};
|
};
|
||||||
|
|
||||||
use chrono::Utc;
|
use chrono::Utc;
|
||||||
use reqwest::Client;
|
use reqwest::Client;
|
||||||
use tokio::sync::{broadcast, mpsc, oneshot};
|
use tokio::sync::{broadcast, mpsc, oneshot, RwLock};
|
||||||
|
|
||||||
use crate::{
|
use crate::{
|
||||||
models::{Job, JobId, JobStatus, Segment},
|
models::{Job, JobId, JobStatus, ModelEvent, ModelState, Segment},
|
||||||
storage::Storage,
|
storage::Storage,
|
||||||
transcriber::Transcriber,
|
transcriber::Transcriber,
|
||||||
webhook,
|
webhook,
|
||||||
|
AppError,
|
||||||
};
|
};
|
||||||
|
|
||||||
/// Per-job broadcast channel for SSE subscribers.
|
/// Per-job broadcast channel for SSE subscribers.
|
||||||
@@ -31,83 +34,383 @@ pub enum ProgressEvent {
|
|||||||
/// Global registry: job_id → broadcast sender.
|
/// Global registry: job_id → broadcast sender.
|
||||||
pub type ProgressRegistry = Arc<dashmap::DashMap<JobId, ProgressTx>>;
|
pub type ProgressRegistry = Arc<dashmap::DashMap<JobId, ProgressTx>>;
|
||||||
|
|
||||||
// ── Transcription request/response types for the blocking thread ─────────────
|
// ── Worker command channel ────────────────────────────────────────────────────
|
||||||
|
|
||||||
struct TranscribeRequest {
|
/// Commands sent to the GPU worker OS thread.
|
||||||
pcm: Vec<f32>,
|
#[derive(Debug)]
|
||||||
language: Option<String>,
|
pub enum WorkerCmd {
|
||||||
task: String,
|
/// Request a model load. Idempotent: if already loading/ready, ignored.
|
||||||
/// Per-chunk progress callback — receives 0–100 from whisper.cpp and can
|
Load,
|
||||||
/// scale/offset it before forwarding to the job's broadcast channel.
|
/// Unload the model immediately and free GPU memory.
|
||||||
on_progress: Box<dyn Fn(u8) + Send + 'static>,
|
Unload,
|
||||||
reply: oneshot::Sender<crate::Result<(Vec<Segment>, String)>>,
|
/// Internal: run a transcription chunk.
|
||||||
|
Transcribe(TranscribeRequest),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ── Transcription request/response types ─────────────────────────────────────
|
||||||
|
|
||||||
|
pub struct TranscribeRequest {
|
||||||
|
pub pcm: Vec<f32>,
|
||||||
|
pub language: Option<String>,
|
||||||
|
pub task: String,
|
||||||
|
pub on_progress: Box<dyn Fn(u8) + Send + 'static>,
|
||||||
|
pub reply: oneshot::Sender<crate::Result<(Vec<Segment>, String)>>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl std::fmt::Debug for TranscribeRequest {
|
||||||
|
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||||
|
f.debug_struct("TranscribeRequest")
|
||||||
|
.field("language", &self.language)
|
||||||
|
.field("task", &self.task)
|
||||||
|
.finish_non_exhaustive()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ── Public API ────────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
/// Spawn the single GPU worker.
|
/// Spawn the single GPU worker.
|
||||||
/// Returns the SSE progress registry.
|
///
|
||||||
|
/// Returns the SSE progress registry and a command sender for the worker thread.
|
||||||
|
/// The model starts **unloaded**; send `WorkerCmd::Load` or submit a job to
|
||||||
|
/// trigger loading.
|
||||||
|
#[allow(clippy::too_many_arguments)]
|
||||||
pub fn start(
|
pub fn start(
|
||||||
job_rx: mpsc::UnboundedReceiver<JobId>,
|
job_rx: mpsc::UnboundedReceiver<JobId>,
|
||||||
storage: Arc<Storage>,
|
storage: Arc<Storage>,
|
||||||
model_path: PathBuf,
|
model_path: PathBuf,
|
||||||
queue_depth: Arc<AtomicUsize>,
|
queue_depth: Arc<AtomicUsize>,
|
||||||
gpu_device: u32,
|
gpu_device: u32,
|
||||||
) -> ProgressRegistry {
|
model_state: Arc<RwLock<ModelState>>,
|
||||||
|
model_event_tx: broadcast::Sender<ModelEvent>,
|
||||||
|
webhook_registry: Arc<Mutex<HashSet<String>>>,
|
||||||
|
idle_timeout: Duration,
|
||||||
|
gpu_poll_interval: Duration,
|
||||||
|
) -> (ProgressRegistry, std::sync::mpsc::SyncSender<WorkerCmd>) {
|
||||||
let registry: ProgressRegistry = Arc::new(dashmap::DashMap::new());
|
let registry: ProgressRegistry = Arc::new(dashmap::DashMap::new());
|
||||||
let reg_clone = Arc::clone(®istry);
|
let reg_clone = Arc::clone(®istry);
|
||||||
|
|
||||||
let (tx_req, rx_req) = std::sync::mpsc::channel::<TranscribeRequest>();
|
// Bounded sync channel: capacity 8 is plenty (load/unload are rare).
|
||||||
|
let (cmd_tx, cmd_rx) = std::sync::mpsc::sync_channel::<WorkerCmd>(8);
|
||||||
|
let cmd_tx_clone = cmd_tx.clone();
|
||||||
|
|
||||||
|
// Capture Tokio runtime handle so the OS thread can spawn async tasks.
|
||||||
|
let rt_handle = tokio::runtime::Handle::current();
|
||||||
|
|
||||||
std::thread::Builder::new()
|
std::thread::Builder::new()
|
||||||
.name("whisper-gpu".into())
|
.name("whisper-gpu".into())
|
||||||
.spawn(move || transcriber_thread(rx_req, model_path, gpu_device))
|
.spawn(move || {
|
||||||
|
transcriber_thread(
|
||||||
|
cmd_rx,
|
||||||
|
model_path,
|
||||||
|
gpu_device,
|
||||||
|
model_state,
|
||||||
|
model_event_tx,
|
||||||
|
webhook_registry,
|
||||||
|
idle_timeout,
|
||||||
|
gpu_poll_interval,
|
||||||
|
rt_handle,
|
||||||
|
);
|
||||||
|
})
|
||||||
.expect("failed to spawn whisper-gpu thread");
|
.expect("failed to spawn whisper-gpu thread");
|
||||||
|
|
||||||
tokio::spawn(run(job_rx, storage, queue_depth, reg_clone, tx_req));
|
tokio::spawn(run(job_rx, storage, queue_depth, reg_clone, cmd_tx_clone));
|
||||||
|
|
||||||
registry
|
(registry, cmd_tx)
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Dedicated OS thread that owns the Transcriber (non-Send) and runs inference.
|
// ── GPU OS thread ─────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
/// The worker OS thread that owns the `Transcriber` (non-`Send`).
|
||||||
///
|
///
|
||||||
/// The Transcriber holds a single `WhisperState` that is reused for every chunk.
|
/// Uses `recv_timeout` with a 1-second tick to drive the idle timer without a
|
||||||
/// GPU compute buffers (~700 MB) are allocated once at startup rather than on
|
/// separate thread.
|
||||||
/// every call, eliminating per-chunk `whisper_init_state` overhead and the
|
#[allow(clippy::too_many_arguments)]
|
||||||
/// VRAM churn that caused intermittent 0-segment results.
|
|
||||||
fn transcriber_thread(
|
fn transcriber_thread(
|
||||||
rx: std::sync::mpsc::Receiver<TranscribeRequest>,
|
rx: std::sync::mpsc::Receiver<WorkerCmd>,
|
||||||
model_path: PathBuf,
|
model_path: PathBuf,
|
||||||
gpu_device: u32,
|
gpu_device: u32,
|
||||||
|
model_state: Arc<RwLock<ModelState>>,
|
||||||
|
model_event_tx: broadcast::Sender<ModelEvent>,
|
||||||
|
webhook_registry: Arc<Mutex<HashSet<String>>>,
|
||||||
|
idle_timeout: Duration,
|
||||||
|
gpu_poll_interval: Duration,
|
||||||
|
rt: tokio::runtime::Handle,
|
||||||
) {
|
) {
|
||||||
let mut transcriber = match Transcriber::load(&model_path, gpu_device) {
|
let mut transcriber: Option<Transcriber> = None;
|
||||||
Ok(t) => t,
|
let mut last_job = Instant::now();
|
||||||
Err(e) => {
|
|
||||||
tracing::error!(error = %e, "failed to load whisper model — transcriber thread exiting");
|
loop {
|
||||||
return;
|
match rx.recv_timeout(Duration::from_secs(1)) {
|
||||||
|
Ok(WorkerCmd::Load) => {
|
||||||
|
if transcriber.is_some() {
|
||||||
|
tracing::debug!("WorkerCmd::Load ignored — model already loaded");
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
transcriber = try_load_with_polling(
|
||||||
|
&rx,
|
||||||
|
&model_path,
|
||||||
|
gpu_device,
|
||||||
|
&model_state,
|
||||||
|
&model_event_tx,
|
||||||
|
&webhook_registry,
|
||||||
|
gpu_poll_interval,
|
||||||
|
&rt,
|
||||||
|
);
|
||||||
|
if transcriber.is_some() {
|
||||||
|
last_job = Instant::now();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(WorkerCmd::Unload) => {
|
||||||
|
do_unload(&mut transcriber, &model_state, &model_event_tx, &webhook_registry, &rt);
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(WorkerCmd::Transcribe(req)) => {
|
||||||
|
let t = match &mut transcriber {
|
||||||
|
Some(t) => t,
|
||||||
|
None => {
|
||||||
|
tracing::warn!("Transcribe cmd received but model is unloaded — failing job");
|
||||||
|
let _ = req.reply.send(Err(AppError::Internal(
|
||||||
|
"model unloaded before job could run".into(),
|
||||||
|
)));
|
||||||
|
continue;
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
tracing::info!(model = %model_path.display(), "GPU worker ready");
|
|
||||||
|
|
||||||
for req in rx {
|
let result = t.transcribe(
|
||||||
let on_progress = req.on_progress;
|
|
||||||
let result = transcriber.transcribe(
|
|
||||||
&req.pcm,
|
&req.pcm,
|
||||||
req.language.as_deref(),
|
req.language.as_deref(),
|
||||||
&req.task,
|
&req.task,
|
||||||
move |p| on_progress(p),
|
move |p| (req.on_progress)(p),
|
||||||
);
|
);
|
||||||
|
last_job = Instant::now();
|
||||||
let _ = req.reply.send(result);
|
let _ = req.reply.send(result);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Err(std::sync::mpsc::RecvTimeoutError::Timeout) => {
|
||||||
|
if transcriber.is_some() && last_job.elapsed() >= idle_timeout {
|
||||||
|
tracing::info!(
|
||||||
|
elapsed_secs = last_job.elapsed().as_secs(),
|
||||||
|
"idle timeout reached — unloading model"
|
||||||
|
);
|
||||||
|
do_unload(
|
||||||
|
&mut transcriber,
|
||||||
|
&model_state,
|
||||||
|
&model_event_tx,
|
||||||
|
&webhook_registry,
|
||||||
|
&rt,
|
||||||
|
);
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
Err(std::sync::mpsc::RecvTimeoutError::Disconnected) => {
|
||||||
|
tracing::info!("worker command channel closed — shutting down GPU thread");
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Attempt to load the model, polling on VRAM failures.
|
||||||
|
///
|
||||||
|
/// While waiting for GPU, drains `rx` so that `WorkerCmd::Unload` cancels the
|
||||||
|
/// load attempt and `WorkerCmd::Transcribe` commands get a "model not ready"
|
||||||
|
/// rejection. Returns `Some(Transcriber)` on success, `None` if cancelled.
|
||||||
|
#[allow(clippy::too_many_arguments)]
|
||||||
|
fn try_load_with_polling(
|
||||||
|
rx: &std::sync::mpsc::Receiver<WorkerCmd>,
|
||||||
|
model_path: &PathBuf,
|
||||||
|
gpu_device: u32,
|
||||||
|
model_state: &Arc<RwLock<ModelState>>,
|
||||||
|
model_event_tx: &broadcast::Sender<ModelEvent>,
|
||||||
|
webhook_registry: &Arc<Mutex<HashSet<String>>>,
|
||||||
|
gpu_poll_interval: Duration,
|
||||||
|
rt: &tokio::runtime::Handle,
|
||||||
|
) -> Option<Transcriber> {
|
||||||
|
loop {
|
||||||
|
set_state(model_state, ModelState::Loading);
|
||||||
|
broadcast_event(model_event_tx, ModelEvent::ModelLoading);
|
||||||
|
tracing::info!("loading whisper model...");
|
||||||
|
|
||||||
|
match Transcriber::load(model_path, gpu_device) {
|
||||||
|
Ok(t) => {
|
||||||
|
let loaded_at = Utc::now();
|
||||||
|
set_state(model_state, ModelState::Ready { loaded_at });
|
||||||
|
broadcast_event(model_event_tx, ModelEvent::ModelReady { loaded_at });
|
||||||
|
fire_webhooks(webhook_registry, ModelEvent::ModelReady { loaded_at }, rt);
|
||||||
|
tracing::info!("model loaded and ready");
|
||||||
|
return Some(t);
|
||||||
|
}
|
||||||
|
|
||||||
|
Err(AppError::OutOfMemory(msg)) => {
|
||||||
|
let (vram_needed_mb, vram_free_mb) = parse_oom_vram(&msg, gpu_device);
|
||||||
|
let retry_in_secs = gpu_poll_interval.as_secs();
|
||||||
|
|
||||||
|
tracing::warn!(
|
||||||
|
vram_needed_mb,
|
||||||
|
vram_free_mb,
|
||||||
|
retry_in_secs,
|
||||||
|
"insufficient VRAM — will retry"
|
||||||
|
);
|
||||||
|
|
||||||
|
set_state(model_state, ModelState::WaitingForGpu {
|
||||||
|
vram_needed_mb,
|
||||||
|
vram_free_mb,
|
||||||
|
retry_in_secs,
|
||||||
|
});
|
||||||
|
broadcast_event(model_event_tx, ModelEvent::ModelWaitingForGpu {
|
||||||
|
vram_needed_mb,
|
||||||
|
vram_free_mb,
|
||||||
|
retry_in_secs,
|
||||||
|
});
|
||||||
|
|
||||||
|
// Interruptible sleep: drain rx while waiting for gpu_poll_interval.
|
||||||
|
let deadline = Instant::now() + gpu_poll_interval;
|
||||||
|
loop {
|
||||||
|
let remaining = deadline.saturating_duration_since(Instant::now());
|
||||||
|
if remaining.is_zero() { break; }
|
||||||
|
match rx.recv_timeout(remaining.min(Duration::from_secs(1))) {
|
||||||
|
Ok(WorkerCmd::Unload) => {
|
||||||
|
tracing::info!("Unload received while waiting for GPU — cancelling load");
|
||||||
|
set_state(model_state, ModelState::Unloaded);
|
||||||
|
broadcast_event(model_event_tx, ModelEvent::ModelUnloaded);
|
||||||
|
fire_webhooks(webhook_registry, ModelEvent::ModelUnloaded, rt);
|
||||||
|
return None;
|
||||||
|
}
|
||||||
|
Ok(WorkerCmd::Load) => {} // idempotent
|
||||||
|
Ok(WorkerCmd::Transcribe(req)) => {
|
||||||
|
let _ = req.reply.send(Err(AppError::ModelNotReady {
|
||||||
|
state: "waiting_for_gpu".into(),
|
||||||
|
retry_after_secs: retry_in_secs,
|
||||||
|
}));
|
||||||
|
}
|
||||||
|
Err(std::sync::mpsc::RecvTimeoutError::Timeout) => {}
|
||||||
|
Err(std::sync::mpsc::RecvTimeoutError::Disconnected) => return None,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// Loop back to retry load
|
||||||
|
}
|
||||||
|
|
||||||
|
Err(e) => {
|
||||||
|
tracing::error!(error = %e, "model load failed with non-recoverable error");
|
||||||
|
set_state(model_state, ModelState::Unloaded);
|
||||||
|
return None;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn do_unload(
|
||||||
|
transcriber: &mut Option<Transcriber>,
|
||||||
|
model_state: &Arc<RwLock<ModelState>>,
|
||||||
|
model_event_tx: &broadcast::Sender<ModelEvent>,
|
||||||
|
webhook_registry: &Arc<Mutex<HashSet<String>>>,
|
||||||
|
rt: &tokio::runtime::Handle,
|
||||||
|
) {
|
||||||
|
*transcriber = None;
|
||||||
|
set_state(model_state, ModelState::Unloaded);
|
||||||
|
broadcast_event(model_event_tx, ModelEvent::ModelUnloaded);
|
||||||
|
fire_webhooks(webhook_registry, ModelEvent::ModelUnloaded, rt);
|
||||||
|
tracing::info!("model unloaded — GPU memory freed");
|
||||||
|
}
|
||||||
|
|
||||||
|
// ── Helpers ───────────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
fn set_state(arc: &Arc<RwLock<ModelState>>, state: ModelState) {
|
||||||
|
*arc.blocking_write() = state;
|
||||||
|
}
|
||||||
|
|
||||||
|
fn broadcast_event(tx: &broadcast::Sender<ModelEvent>, event: ModelEvent) {
|
||||||
|
let _ = tx.send(event);
|
||||||
|
}
|
||||||
|
|
||||||
|
fn fire_webhooks(
|
||||||
|
registry: &Arc<Mutex<HashSet<String>>>,
|
||||||
|
event: ModelEvent,
|
||||||
|
rt: &tokio::runtime::Handle,
|
||||||
|
) {
|
||||||
|
if !event.is_webhook_event() {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
let urls: Vec<String> = registry
|
||||||
|
.lock()
|
||||||
|
.unwrap_or_else(|e| e.into_inner())
|
||||||
|
.iter()
|
||||||
|
.cloned()
|
||||||
|
.collect();
|
||||||
|
|
||||||
|
if urls.is_empty() { return; }
|
||||||
|
|
||||||
|
let payload = match serde_json::to_string(&event) {
|
||||||
|
Ok(p) => p,
|
||||||
|
Err(e) => { tracing::error!(error = %e, "failed to serialize model event"); return; }
|
||||||
|
};
|
||||||
|
|
||||||
|
for url in urls {
|
||||||
|
let body = payload.clone();
|
||||||
|
rt.spawn(async move {
|
||||||
|
let http = Client::builder()
|
||||||
|
.timeout(Duration::from_secs(10))
|
||||||
|
.build()
|
||||||
|
.expect("http client");
|
||||||
|
for attempt in 0..3_u32 {
|
||||||
|
match http.post(&url)
|
||||||
|
.header("content-type", "application/json")
|
||||||
|
.body(body.clone())
|
||||||
|
.send()
|
||||||
|
.await
|
||||||
|
{
|
||||||
|
Ok(r) if r.status().is_success() => {
|
||||||
|
tracing::debug!(url, "model event webhook delivered");
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
Ok(r) => tracing::warn!(url, status = r.status().as_u16(), "webhook non-2xx"),
|
||||||
|
Err(e) => tracing::warn!(url, error = %e, attempt, "webhook delivery failed"),
|
||||||
|
}
|
||||||
|
if attempt < 2 {
|
||||||
|
tokio::time::sleep(Duration::from_secs(2u64.pow(attempt))).await;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
tracing::error!(url, "model event webhook failed after 3 attempts");
|
||||||
|
});
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn parse_oom_vram(msg: &str, gpu_device: u32) -> (u64, u64) {
|
||||||
|
let needed = msg
|
||||||
|
.split_whitespace()
|
||||||
|
.zip(msg.split_whitespace().skip(1))
|
||||||
|
.find(|(_, next)| *next == "MiB")
|
||||||
|
.and_then(|(n, _)| n.parse::<f64>().ok())
|
||||||
|
.map(|v| v as u64)
|
||||||
|
.unwrap_or(0);
|
||||||
|
|
||||||
|
let free = std::process::Command::new("nvidia-smi")
|
||||||
|
.args([
|
||||||
|
&format!("--id={gpu_device}"),
|
||||||
|
"--query-gpu=memory.free",
|
||||||
|
"--format=csv,noheader,nounits",
|
||||||
|
])
|
||||||
|
.output()
|
||||||
|
.ok()
|
||||||
|
.and_then(|o| String::from_utf8(o.stdout).ok())
|
||||||
|
.and_then(|s| s.trim().parse::<u64>().ok())
|
||||||
|
.unwrap_or(0);
|
||||||
|
|
||||||
|
(needed, free)
|
||||||
|
}
|
||||||
|
|
||||||
|
// ── Async job runner ──────────────────────────────────────────────────────────
|
||||||
|
|
||||||
async fn run(
|
async fn run(
|
||||||
mut job_rx: mpsc::UnboundedReceiver<JobId>,
|
mut job_rx: mpsc::UnboundedReceiver<JobId>,
|
||||||
storage: Arc<Storage>,
|
storage: Arc<Storage>,
|
||||||
queue_depth: Arc<AtomicUsize>,
|
queue_depth: Arc<AtomicUsize>,
|
||||||
registry: ProgressRegistry,
|
registry: ProgressRegistry,
|
||||||
tx_req: std::sync::mpsc::Sender<TranscribeRequest>,
|
cmd_tx: std::sync::mpsc::SyncSender<WorkerCmd>,
|
||||||
) {
|
) {
|
||||||
let http = Client::builder()
|
let http = Client::builder()
|
||||||
.timeout(std::time::Duration::from_secs(30))
|
.timeout(Duration::from_secs(30))
|
||||||
.build()
|
.build()
|
||||||
.expect("failed to build reqwest client");
|
.expect("failed to build reqwest client");
|
||||||
|
|
||||||
@@ -140,7 +443,7 @@ async fn run(
|
|||||||
|
|
||||||
let audio_path = audio_path_for(&job_id);
|
let audio_path = audio_path_for(&job_id);
|
||||||
|
|
||||||
let result = process_job(&job, &audio_path, &progress_tx, &tx_req, &storage).await;
|
let result = process_job(&job, &audio_path, &progress_tx, &cmd_tx, &storage).await;
|
||||||
|
|
||||||
let _ = tokio::fs::remove_file(&audio_path).await;
|
let _ = tokio::fs::remove_file(&audio_path).await;
|
||||||
|
|
||||||
@@ -175,26 +478,18 @@ async fn run(
|
|||||||
tokio::spawn(async move { webhook::fire(&http, &url, &job).await; });
|
tokio::spawn(async move { webhook::fire(&http, &url, &job).await; });
|
||||||
}
|
}
|
||||||
|
|
||||||
tokio::time::sleep(std::time::Duration::from_secs(30)).await;
|
tokio::time::sleep(Duration::from_secs(30)).await;
|
||||||
registry.remove(&job_id);
|
registry.remove(&job_id);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// ── Silence-based chunking ────────────────────────────────────────────────────
|
// ── Silence-based chunking ────────────────────────────────────────────────────
|
||||||
|
|
||||||
/// Target chunk length. 60s ≈ 2× whisper's native 30s window — short enough
|
|
||||||
/// that a hallucinated phrase can't compound beyond a single window.
|
|
||||||
const TARGET_CHUNK_SECS: f32 = 60.0;
|
const TARGET_CHUNK_SECS: f32 = 60.0;
|
||||||
/// How far from the target we'll snap to a silence midpoint.
|
|
||||||
const SNAP_WINDOW_SECS: f32 = 30.0;
|
const SNAP_WINDOW_SECS: f32 = 30.0;
|
||||||
/// Silence below this level (dB) counts as a split candidate.
|
|
||||||
const SILENCE_DB: &str = "-35dB";
|
const SILENCE_DB: &str = "-35dB";
|
||||||
/// Minimum silence duration to register as a candidate split.
|
|
||||||
const SILENCE_DUR: &str = "0.4";
|
const SILENCE_DUR: &str = "0.4";
|
||||||
|
|
||||||
/// Detect silence periods and return the midpoint (seconds) of each.
|
|
||||||
/// On any error (ffmpeg missing, binary format, etc.) returns an empty vec
|
|
||||||
/// so the caller can fall back to hard cuts.
|
|
||||||
async fn detect_silence_midpoints(path: &std::path::Path) -> Vec<f32> {
|
async fn detect_silence_midpoints(path: &std::path::Path) -> Vec<f32> {
|
||||||
use tokio::process::Command;
|
use tokio::process::Command;
|
||||||
|
|
||||||
@@ -217,7 +512,6 @@ async fn detect_silence_midpoints(path: &std::path::Path) -> Vec<f32> {
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
// silencedetect logs to stderr
|
|
||||||
let stderr = String::from_utf8_lossy(&output.stderr);
|
let stderr = String::from_utf8_lossy(&output.stderr);
|
||||||
let mut starts: Vec<f32> = Vec::new();
|
let mut starts: Vec<f32> = Vec::new();
|
||||||
let mut ends: Vec<f32> = Vec::new();
|
let mut ends: Vec<f32> = Vec::new();
|
||||||
@@ -228,7 +522,6 @@ async fn detect_silence_midpoints(path: &std::path::Path) -> Vec<f32> {
|
|||||||
starts.push(t);
|
starts.push(t);
|
||||||
}
|
}
|
||||||
} else if let Some(i) = line.find("silence_end: ") {
|
} else if let Some(i) = line.find("silence_end: ") {
|
||||||
// Format: "silence_end: 12.34 | silence_duration: 0.56"
|
|
||||||
let t_str = line[i + "silence_end: ".len()..]
|
let t_str = line[i + "silence_end: ".len()..]
|
||||||
.split(" |")
|
.split(" |")
|
||||||
.next()
|
.next()
|
||||||
@@ -248,10 +541,6 @@ async fn detect_silence_midpoints(path: &std::path::Path) -> Vec<f32> {
|
|||||||
mids
|
mids
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Build cut points every `target_secs`, snapping to the nearest silence
|
|
||||||
/// midpoint within `snap_window` when one exists; otherwise a hard cut.
|
|
||||||
/// Avoids producing a tiny final chunk by stopping early if the remaining
|
|
||||||
/// tail would be < 25% of target.
|
|
||||||
fn snap_to_silence(
|
fn snap_to_silence(
|
||||||
mids: &[f32],
|
mids: &[f32],
|
||||||
total_secs: f32,
|
total_secs: f32,
|
||||||
@@ -263,13 +552,9 @@ fn snap_to_silence(
|
|||||||
|
|
||||||
while pos < total_secs - target_secs * 0.25 {
|
while pos < total_secs - target_secs * 0.25 {
|
||||||
let prev_cut = cuts.last().copied().unwrap_or(0.0);
|
let prev_cut = cuts.last().copied().unwrap_or(0.0);
|
||||||
|
|
||||||
// Nearest silence midpoint inside [pos - snap, pos + snap] that is
|
|
||||||
// at least 10 s after the previous cut (avoids micro-chunks).
|
|
||||||
let best = mids.iter().copied()
|
let best = mids.iter().copied()
|
||||||
.filter(|&t| t > prev_cut + 10.0 && (t - pos).abs() <= snap_window)
|
.filter(|&t| t > prev_cut + 10.0 && (t - pos).abs() <= snap_window)
|
||||||
.min_by(|a, b| (a - pos).abs().partial_cmp(&(b - pos).abs()).unwrap());
|
.min_by(|a, b| (a - pos).abs().partial_cmp(&(b - pos).abs()).unwrap());
|
||||||
|
|
||||||
let cut = best.unwrap_or(pos);
|
let cut = best.unwrap_or(pos);
|
||||||
cuts.push(cut);
|
cuts.push(cut);
|
||||||
pos = cut + target_secs;
|
pos = cut + target_secs;
|
||||||
@@ -278,7 +563,6 @@ fn snap_to_silence(
|
|||||||
cuts
|
cuts
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Convert cut points into (start_secs, end_secs) chunk pairs.
|
|
||||||
fn to_chunk_ranges(cuts: &[f32], total_secs: f32) -> Vec<(f32, f32)> {
|
fn to_chunk_ranges(cuts: &[f32], total_secs: f32) -> Vec<(f32, f32)> {
|
||||||
let mut ranges = Vec::new();
|
let mut ranges = Vec::new();
|
||||||
let mut start = 0.0_f32;
|
let mut start = 0.0_f32;
|
||||||
@@ -289,7 +573,6 @@ fn to_chunk_ranges(cuts: &[f32], total_secs: f32) -> Vec<(f32, f32)> {
|
|||||||
start = cut;
|
start = cut;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
// Last chunk
|
|
||||||
if total_secs - start >= 1.0 {
|
if total_secs - start >= 1.0 {
|
||||||
ranges.push((start, total_secs));
|
ranges.push((start, total_secs));
|
||||||
}
|
}
|
||||||
@@ -302,17 +585,13 @@ async fn process_job(
|
|||||||
job: &Job,
|
job: &Job,
|
||||||
audio_path: &std::path::Path,
|
audio_path: &std::path::Path,
|
||||||
progress_tx: &ProgressTx,
|
progress_tx: &ProgressTx,
|
||||||
tx_req: &std::sync::mpsc::Sender<TranscribeRequest>,
|
cmd_tx: &std::sync::mpsc::SyncSender<WorkerCmd>,
|
||||||
storage: &Arc<Storage>,
|
storage: &Arc<Storage>,
|
||||||
) -> crate::Result<(Vec<Segment>, String, f32)> {
|
) -> crate::Result<(Vec<Segment>, String, f32)> {
|
||||||
// 1. Decode full audio to 16 kHz mono PCM.
|
|
||||||
let pcm = decode_audio(audio_path).await?;
|
let pcm = decode_audio(audio_path).await?;
|
||||||
let total_secs = pcm.len() as f32 / 16_000.0;
|
let total_secs = pcm.len() as f32 / 16_000.0;
|
||||||
|
|
||||||
// 2. Detect silence midpoints from original file.
|
|
||||||
let silence_mids = detect_silence_midpoints(audio_path).await;
|
let silence_mids = detect_silence_midpoints(audio_path).await;
|
||||||
|
|
||||||
// 3. Build silence-snapped chunk boundaries.
|
|
||||||
let cuts = snap_to_silence(&silence_mids, total_secs, TARGET_CHUNK_SECS, SNAP_WINDOW_SECS);
|
let cuts = snap_to_silence(&silence_mids, total_secs, TARGET_CHUNK_SECS, SNAP_WINDOW_SECS);
|
||||||
let chunks = to_chunk_ranges(&cuts, total_secs);
|
let chunks = to_chunk_ranges(&cuts, total_secs);
|
||||||
let n = chunks.len();
|
let n = chunks.len();
|
||||||
@@ -324,7 +603,6 @@ async fn process_job(
|
|||||||
"audio chunked by silence"
|
"audio chunked by silence"
|
||||||
);
|
);
|
||||||
|
|
||||||
// 4. Transcribe each chunk, applying a time offset to all timestamps.
|
|
||||||
let mut all_segments: Vec<Segment> = Vec::new();
|
let mut all_segments: Vec<Segment> = Vec::new();
|
||||||
let mut language = String::new();
|
let mut language = String::new();
|
||||||
|
|
||||||
@@ -334,11 +612,9 @@ async fn process_job(
|
|||||||
let mut chunk_pcm = pcm[s0..s1].to_vec();
|
let mut chunk_pcm = pcm[s0..s1].to_vec();
|
||||||
trim_trailing_silence(&mut chunk_pcm);
|
trim_trailing_silence(&mut chunk_pcm);
|
||||||
|
|
||||||
// Base percent this chunk starts at.
|
|
||||||
let base = (ci * 100 / n) as u8;
|
let base = (ci * 100 / n) as u8;
|
||||||
let span = (100usize / n).max(1) as u8;
|
let span = (100usize / n).max(1) as u8;
|
||||||
|
|
||||||
// Emit a progress event and persist it at the start of every chunk.
|
|
||||||
let _ = progress_tx.send(ProgressEvent::Progress {
|
let _ = progress_tx.send(ProgressEvent::Progress {
|
||||||
percent: base,
|
percent: base,
|
||||||
chunk: ci + 1,
|
chunk: ci + 1,
|
||||||
@@ -350,7 +626,6 @@ async fn process_job(
|
|||||||
tracing::warn!(error = %e, "failed to persist mid-job progress");
|
tracing::warn!(error = %e, "failed to persist mid-job progress");
|
||||||
}
|
}
|
||||||
|
|
||||||
// Scale whisper's per-chunk 0–100 into the job's overall range.
|
|
||||||
let tx = progress_tx.clone();
|
let tx = progress_tx.clone();
|
||||||
let chunk_num = ci + 1;
|
let chunk_num = ci + 1;
|
||||||
let on_progress = Box::new(move |p: u8| {
|
let on_progress = Box::new(move |p: u8| {
|
||||||
@@ -363,18 +638,17 @@ async fn process_job(
|
|||||||
});
|
});
|
||||||
|
|
||||||
let (reply_tx, reply_rx) = oneshot::channel();
|
let (reply_tx, reply_rx) = oneshot::channel();
|
||||||
tx_req.send(TranscribeRequest {
|
cmd_tx.send(WorkerCmd::Transcribe(TranscribeRequest {
|
||||||
pcm: chunk_pcm,
|
pcm: chunk_pcm,
|
||||||
language: job.language.clone(),
|
language: job.language.clone(),
|
||||||
task: job.task.clone(),
|
task: job.task.clone(),
|
||||||
on_progress,
|
on_progress,
|
||||||
reply: reply_tx,
|
reply: reply_tx,
|
||||||
}).map_err(|_| crate::AppError::Internal("transcriber thread gone".into()))?;
|
})).map_err(|_| AppError::Internal("worker command channel closed".into()))?;
|
||||||
|
|
||||||
let (mut segs, lang) = reply_rx.await
|
let (mut segs, lang) = reply_rx.await
|
||||||
.map_err(|_| crate::AppError::Internal("transcriber thread dropped reply".into()))??;
|
.map_err(|_| AppError::Internal("transcriber thread dropped reply".into()))??;
|
||||||
|
|
||||||
// Shift all timestamps by chunk offset.
|
|
||||||
let offset = *chunk_start;
|
let offset = *chunk_start;
|
||||||
for seg in &mut segs {
|
for seg in &mut segs {
|
||||||
seg.start += offset;
|
seg.start += offset;
|
||||||
@@ -400,7 +674,6 @@ async fn process_job(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Renumber segment indices across the merged output.
|
|
||||||
for (i, seg) in all_segments.iter_mut().enumerate() {
|
for (i, seg) in all_segments.iter_mut().enumerate() {
|
||||||
seg.index = i as i32;
|
seg.index = i as i32;
|
||||||
}
|
}
|
||||||
@@ -409,14 +682,9 @@ async fn process_job(
|
|||||||
Ok((all_segments, language, total_secs))
|
Ok((all_segments, language, total_secs))
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Trim trailing silence from a 16 kHz mono PCM buffer.
|
|
||||||
///
|
|
||||||
/// Scans backwards to find the last sample above −35 dB, then keeps
|
|
||||||
/// 0.5 s of padding after it. This prevents whisper from hallucinating
|
|
||||||
/// filler tokens into end-of-chunk silence.
|
|
||||||
fn trim_trailing_silence(pcm: &mut Vec<f32>) {
|
fn trim_trailing_silence(pcm: &mut Vec<f32>) {
|
||||||
const THRESHOLD: f32 = 0.017_8; // −35 dB (10^(−35/20))
|
const THRESHOLD: f32 = 0.017_8;
|
||||||
const PADDING: usize = 8_000; // 0.5 s at 16 kHz
|
const PADDING: usize = 8_000;
|
||||||
|
|
||||||
if let Some(last_loud) = pcm.iter().rposition(|&s| s.abs() > THRESHOLD) {
|
if let Some(last_loud) = pcm.iter().rposition(|&s| s.abs() > THRESHOLD) {
|
||||||
let new_len = (last_loud + 1 + PADDING).min(pcm.len());
|
let new_len = (last_loud + 1 + PADDING).min(pcm.len());
|
||||||
@@ -429,10 +697,8 @@ fn trim_trailing_silence(pcm: &mut Vec<f32>) {
|
|||||||
pcm.truncate(new_len);
|
pcm.truncate(new_len);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
// All-silent chunk: keep as-is — whisper will produce zero segments, which is correct.
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Decode any audio file to 16 kHz mono PCM f32 using ffmpeg.
|
|
||||||
async fn decode_audio(path: &std::path::Path) -> crate::Result<Vec<f32>> {
|
async fn decode_audio(path: &std::path::Path) -> crate::Result<Vec<f32>> {
|
||||||
use tokio::process::Command;
|
use tokio::process::Command;
|
||||||
|
|
||||||
@@ -447,11 +713,11 @@ async fn decode_audio(path: &std::path::Path) -> crate::Result<Vec<f32>> {
|
|||||||
])
|
])
|
||||||
.output()
|
.output()
|
||||||
.await
|
.await
|
||||||
.map_err(|e| crate::AppError::Internal(format!("ffmpeg spawn failed: {e}")))?;
|
.map_err(|e| AppError::Internal(format!("ffmpeg spawn failed: {e}")))?;
|
||||||
|
|
||||||
if !output.status.success() {
|
if !output.status.success() {
|
||||||
let stderr = String::from_utf8_lossy(&output.stderr);
|
let stderr = String::from_utf8_lossy(&output.stderr);
|
||||||
return Err(crate::AppError::Internal(format!(
|
return Err(AppError::Internal(format!(
|
||||||
"ffmpeg exited with {}: {}",
|
"ffmpeg exited with {}: {}",
|
||||||
output.status, stderr
|
output.status, stderr
|
||||||
)));
|
)));
|
||||||
@@ -459,7 +725,7 @@ async fn decode_audio(path: &std::path::Path) -> crate::Result<Vec<f32>> {
|
|||||||
|
|
||||||
let bytes = output.stdout;
|
let bytes = output.stdout;
|
||||||
if bytes.len() % 4 != 0 {
|
if bytes.len() % 4 != 0 {
|
||||||
return Err(crate::AppError::Internal(
|
return Err(AppError::Internal(
|
||||||
"ffmpeg output length not a multiple of 4".into(),
|
"ffmpeg output length not a multiple of 4".into(),
|
||||||
));
|
));
|
||||||
}
|
}
|
||||||
@@ -473,3 +739,51 @@ pub fn audio_path_for(id: &JobId) -> PathBuf {
|
|||||||
let data_dir = std::env::var("DATA_DIR").unwrap_or_else(|_| "/data".into());
|
let data_dir = std::env::var("DATA_DIR").unwrap_or_else(|_| "/data".into());
|
||||||
PathBuf::from(data_dir).join(format!("{id}.audio"))
|
PathBuf::from(data_dir).join(format!("{id}.audio"))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ── Unit tests ────────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
use super::*;
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_snap_to_silence_uses_nearest_midpoint() {
|
||||||
|
let mids = vec![55.0, 58.0, 62.0];
|
||||||
|
let cuts = snap_to_silence(&mids, 120.0, 60.0, 30.0);
|
||||||
|
assert!(!cuts.is_empty());
|
||||||
|
assert!((cuts[0] - 58.0).abs() < 0.01, "expected ~58.0, got {}", cuts[0]);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_snap_to_silence_hard_cut_when_no_silence() {
|
||||||
|
let cuts = snap_to_silence(&[], 120.0, 60.0, 30.0);
|
||||||
|
assert_eq!(cuts, vec![60.0]);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_to_chunk_ranges_single_chunk() {
|
||||||
|
let ranges = to_chunk_ranges(&[], 30.0);
|
||||||
|
assert_eq!(ranges, vec![(0.0, 30.0)]);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_to_chunk_ranges_two_chunks() {
|
||||||
|
let ranges = to_chunk_ranges(&[60.0], 120.0);
|
||||||
|
assert_eq!(ranges, vec![(0.0, 60.0), (60.0, 120.0)]);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_trim_trailing_silence_all_silent() {
|
||||||
|
let mut pcm = vec![0.0f32; 1000];
|
||||||
|
trim_trailing_silence(&mut pcm);
|
||||||
|
assert_eq!(pcm.len(), 1000);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_trim_trailing_silence_trims_to_padding() {
|
||||||
|
let mut pcm = vec![0.0f32; 32_000];
|
||||||
|
pcm[10_000] = 1.0;
|
||||||
|
trim_trailing_silence(&mut pcm);
|
||||||
|
assert_eq!(pcm.len(), (10_001 + 8_000).min(32_000));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
97
test_all.sh
Normal file → Executable file
97
test_all.sh
Normal file → Executable file
@@ -6,8 +6,9 @@ BASE="${WHISPER_BASE_URL:-http://localhost:8080}"
|
|||||||
AUDIO="${TEST_AUDIO:-/home/moze/Sources/youtube-transcriber/docker/tmp/audio-b2167046-a236-4fcd-b739-78177542fd23.wav}"
|
AUDIO="${TEST_AUDIO:-/home/moze/Sources/youtube-transcriber/docker/tmp/audio-b2167046-a236-4fcd-b739-78177542fd23.wav}"
|
||||||
|
|
||||||
GREEN='\033[0;32m'; RED='\033[0;31m'; NC='\033[0m'
|
GREEN='\033[0;32m'; RED='\033[0;31m'; NC='\033[0m'
|
||||||
|
FAILS=0
|
||||||
ok() { echo -e "${GREEN}[PASS]${NC} $*"; }
|
ok() { echo -e "${GREEN}[PASS]${NC} $*"; }
|
||||||
fail(){ echo -e "${RED}[FAIL]${NC} $*"; exit 1; }
|
fail(){ echo -e "${RED}[FAIL]${NC} $*"; FAILS=$((FAILS + 1)); }
|
||||||
|
|
||||||
echo "=== Whisper API test suite ==="
|
echo "=== Whisper API test suite ==="
|
||||||
echo " BASE : $BASE"
|
echo " BASE : $BASE"
|
||||||
@@ -17,11 +18,16 @@ echo ""
|
|||||||
echo "=== 1. GET /health ==="
|
echo "=== 1. GET /health ==="
|
||||||
HEALTH=$(curl -sf "$BASE/health")
|
HEALTH=$(curl -sf "$BASE/health")
|
||||||
echo "$HEALTH" | python3 -m json.tool
|
echo "$HEALTH" | python3 -m json.tool
|
||||||
echo "$HEALTH" | python3 -c "import sys,json; d=json.load(sys.stdin); assert d['status']=='ok', f'status={d[\"status\"]}'" && ok "health ok"
|
python3 -c "
|
||||||
|
import sys, json
|
||||||
|
d = json.loads('$HEALTH' if False else sys.stdin.read())
|
||||||
|
assert d['status'] == 'ok', f'status={d[\"status\"]}'
|
||||||
|
assert 'model_state' in d, 'model_state field missing from health response'
|
||||||
|
" <<< "$HEALTH" && ok "health ok + model_state present" || fail "health check"
|
||||||
|
|
||||||
echo ""
|
echo ""
|
||||||
echo "=== 2. GET /docs (Swagger UI reachable) ==="
|
echo "=== 2. GET /docs (Swagger UI reachable) ==="
|
||||||
curl -sf "$BASE/docs" | grep -qi "swagger" && ok "swagger UI reachable"
|
curl -sf "$BASE/docs" | grep -qi "swagger" && ok "swagger UI reachable" || fail "swagger UI"
|
||||||
|
|
||||||
echo ""
|
echo ""
|
||||||
echo "=== 3. Webhook receiver (background Python HTTP server) ==="
|
echo "=== 3. Webhook receiver (background Python HTTP server) ==="
|
||||||
@@ -33,7 +39,7 @@ class H(http.server.BaseHTTPRequestHandler):
|
|||||||
n = int(self.headers.get('Content-Length', 0))
|
n = int(self.headers.get('Content-Length', 0))
|
||||||
body = self.rfile.read(n)
|
body = self.rfile.read(n)
|
||||||
data = json.loads(body)
|
data = json.loads(body)
|
||||||
print(f"\n[WEBHOOK] status={data.get('status')} segments={len(data.get('segments', []))}")
|
print(f"\n[WEBHOOK] status={data.get('status')} segments={len(data.get('segments', []))}", flush=True)
|
||||||
self.send_response(200)
|
self.send_response(200)
|
||||||
self.end_headers()
|
self.end_headers()
|
||||||
def log_message(self, *a): pass
|
def log_message(self, *a): pass
|
||||||
@@ -48,40 +54,70 @@ sleep 1
|
|||||||
echo "Webhook receiver started (PID $WEBHOOK_PID)"
|
echo "Webhook receiver started (PID $WEBHOOK_PID)"
|
||||||
|
|
||||||
echo ""
|
echo ""
|
||||||
echo "=== 4. DELETE a non-existent job → 404 ==="
|
echo "=== 4. GET /model/status — expect unloaded on fresh start ==="
|
||||||
|
MODEL_STATUS=$(curl -sf "$BASE/model/status")
|
||||||
|
echo "$MODEL_STATUS" | python3 -m json.tool
|
||||||
|
echo "$MODEL_STATUS" | python3 -c "
|
||||||
|
import sys, json
|
||||||
|
d = json.load(sys.stdin)
|
||||||
|
assert 'state' in d, 'state field missing from /model/status'
|
||||||
|
print(f' model state: {d[\"state\"]}')
|
||||||
|
" && ok "/model/status has state field" || fail "/model/status schema"
|
||||||
|
|
||||||
|
echo ""
|
||||||
|
echo "=== 5. POST /model/load — trigger model load ==="
|
||||||
|
LOAD_RESP=$(curl -sf -X POST "$BASE/model/load")
|
||||||
|
echo "$LOAD_RESP"
|
||||||
|
ok "POST /model/load accepted"
|
||||||
|
|
||||||
|
echo ""
|
||||||
|
echo "=== 6. Poll /model/status until ready (max 3 min) ==="
|
||||||
|
LOAD_ELAPSED=0
|
||||||
|
while true; do
|
||||||
|
sleep 5
|
||||||
|
LOAD_ELAPSED=$((LOAD_ELAPSED + 5))
|
||||||
|
MS=$(curl -sf "$BASE/model/status")
|
||||||
|
STATE=$(echo "$MS" | python3 -c "import sys,json; print(json.load(sys.stdin)['state'])")
|
||||||
|
echo " [${LOAD_ELAPSED}s] model_state=${STATE}"
|
||||||
|
if [ "$STATE" = "ready" ]; then
|
||||||
|
ok "model loaded and ready in ${LOAD_ELAPSED}s"
|
||||||
|
break
|
||||||
|
fi
|
||||||
|
[ $LOAD_ELAPSED -gt 180 ] && { fail "model failed to load within 3 minutes"; break; }
|
||||||
|
done
|
||||||
|
|
||||||
|
echo ""
|
||||||
|
echo "=== 7. DELETE a non-existent job → 404 ==="
|
||||||
STATUS=$(curl -s -o /dev/null -w "%{http_code}" -X DELETE "$BASE/jobs/00000000-0000-0000-0000-000000000000")
|
STATUS=$(curl -s -o /dev/null -w "%{http_code}" -X DELETE "$BASE/jobs/00000000-0000-0000-0000-000000000000")
|
||||||
[ "$STATUS" = "404" ] && ok "DELETE unknown job → 404" || fail "expected 404, got $STATUS"
|
[ "$STATUS" = "404" ] && ok "DELETE unknown job → 404" || fail "expected 404, got $STATUS"
|
||||||
|
|
||||||
echo ""
|
echo ""
|
||||||
echo "=== 5. POST /jobs — submit audio ==="
|
echo "=== 8. POST /jobs — submit audio ==="
|
||||||
# language field omitted → auto-detection. Do NOT pass "auto" — it is not a
|
|
||||||
# valid ISO 639-1 code and whisper-rs will reject it or behave unexpectedly.
|
|
||||||
SUBMIT=$(curl -sf -X POST "$BASE/jobs" \
|
SUBMIT=$(curl -sf -X POST "$BASE/jobs" \
|
||||||
-F "audio=@${AUDIO};type=audio/wav" \
|
-F "audio=@${AUDIO};type=audio/wav" \
|
||||||
-F "task=transcribe" \
|
-F "task=transcribe" \
|
||||||
-F "webhook_url=http://localhost:9999/webhook")
|
-F "webhook_url=http://localhost:9999/webhook")
|
||||||
echo "$SUBMIT"
|
echo "$SUBMIT"
|
||||||
# Submit response: { "job_id": "<uuid>" } (field is "job_id", not "id")
|
|
||||||
JOB_ID=$(echo "$SUBMIT" | python3 -c "import sys,json; print(json.load(sys.stdin)['job_id'])")
|
JOB_ID=$(echo "$SUBMIT" | python3 -c "import sys,json; print(json.load(sys.stdin)['job_id'])")
|
||||||
ok "submitted job $JOB_ID"
|
ok "submitted job $JOB_ID"
|
||||||
|
|
||||||
echo ""
|
echo ""
|
||||||
echo "=== 6. GET /jobs/{id} immediately after submit ==="
|
echo "=== 9. GET /jobs/{id} immediately after submit ==="
|
||||||
JOB=$(curl -sf "$BASE/jobs/$JOB_ID")
|
JOB=$(curl -sf "$BASE/jobs/$JOB_ID")
|
||||||
echo "$JOB" | python3 -c "
|
echo "$JOB" | python3 -c "
|
||||||
import sys, json
|
import sys, json
|
||||||
d = json.load(sys.stdin)
|
d = json.load(sys.stdin)
|
||||||
assert d['status'] in ('queued', 'running'), f'unexpected status: {d[\"status\"]}'
|
assert d['status'] in ('queued', 'running'), f'unexpected status: {d[\"status\"]}'
|
||||||
" && ok "status is queued/running"
|
" && ok "status is queued/running" || fail "initial status check"
|
||||||
|
|
||||||
echo ""
|
echo ""
|
||||||
echo "=== 7. SSE stream (observe first 30 events then detach) ==="
|
echo "=== 10. SSE stream (observe first 30 events then detach) ==="
|
||||||
echo "Subscribing to SSE stream for $JOB_ID …"
|
echo "Subscribing to SSE stream for $JOB_ID …"
|
||||||
curl -sN --max-time 90 "$BASE/jobs/$JOB_ID/stream" | head -60 &
|
curl -sN --max-time 90 "$BASE/jobs/$JOB_ID/stream" | head -60 &
|
||||||
SSE_PID=$!
|
SSE_PID=$!
|
||||||
|
|
||||||
echo ""
|
echo ""
|
||||||
echo "=== 8. Poll until done (max 20 min) ==="
|
echo "=== 11. Poll until done (max 20 min) ==="
|
||||||
ELAPSED=0
|
ELAPSED=0
|
||||||
while true; do
|
while true; do
|
||||||
sleep 15
|
sleep 15
|
||||||
@@ -96,16 +132,15 @@ while true; do
|
|||||||
elif [ "$STATUS" = "failed" ]; then
|
elif [ "$STATUS" = "failed" ]; then
|
||||||
echo "$JOB" | python3 -m json.tool
|
echo "$JOB" | python3 -m json.tool
|
||||||
fail "job failed"
|
fail "job failed"
|
||||||
|
break
|
||||||
fi
|
fi
|
||||||
[ $ELAPSED -gt 1200 ] && fail "timeout after 20 minutes"
|
[ $ELAPSED -gt 1200 ] && { fail "timeout after 20 minutes"; break; }
|
||||||
done
|
done
|
||||||
kill $SSE_PID 2>/dev/null || true
|
kill $SSE_PID 2>/dev/null || true
|
||||||
|
|
||||||
echo ""
|
echo ""
|
||||||
echo "=== 9. Inspect transcription quality ==="
|
echo "=== 12. Inspect transcription quality ==="
|
||||||
RESULT=$(curl -sf "$BASE/jobs/$JOB_ID")
|
RESULT=$(curl -sf "$BASE/jobs/$JOB_ID")
|
||||||
# Note: can't pipe into a heredoc-driven python3 (heredoc takes stdin, pipe is ignored).
|
|
||||||
# Write to a temp file instead.
|
|
||||||
TMPJSON=$(mktemp /tmp/whisper_test_XXXXXX.json)
|
TMPJSON=$(mktemp /tmp/whisper_test_XXXXXX.json)
|
||||||
echo "$RESULT" > "$TMPJSON"
|
echo "$RESULT" > "$TMPJSON"
|
||||||
python3 - "$TMPJSON" << 'PYCHECK'
|
python3 - "$TMPJSON" << 'PYCHECK'
|
||||||
@@ -149,18 +184,16 @@ for seg in segments[:5]:
|
|||||||
PYCHECK
|
PYCHECK
|
||||||
PYEXIT=$?
|
PYEXIT=$?
|
||||||
rm -f "$TMPJSON"
|
rm -f "$TMPJSON"
|
||||||
[ $PYEXIT -eq 0 ] && ok "quality check passed" || { echo "[FAIL] quality check"; FAILS=$((FAILS+1)); }
|
[ $PYEXIT -eq 0 ] && ok "quality check passed" || fail "quality check"
|
||||||
|
|
||||||
echo ""
|
echo ""
|
||||||
echo "=== 10. DELETE completed job → 200 ==="
|
echo "=== 13. DELETE completed job → 409 Conflict ==="
|
||||||
# Completed jobs return 409 Conflict on DELETE (terminal state).
|
|
||||||
# Verify we get 409, not 200 (delete is only for cancellation of active jobs).
|
|
||||||
DEL_STATUS=$(curl -s -o /dev/null -w "%{http_code}" -X DELETE "$BASE/jobs/$JOB_ID")
|
DEL_STATUS=$(curl -s -o /dev/null -w "%{http_code}" -X DELETE "$BASE/jobs/$JOB_ID")
|
||||||
[ "$DEL_STATUS" = "409" ] && ok "DELETE completed job → 409 Conflict (expected)" \
|
[ "$DEL_STATUS" = "409" ] && ok "DELETE completed job → 409 Conflict (expected)" \
|
||||||
|| echo " [INFO] DELETE returned $DEL_STATUS"
|
|| echo " [INFO] DELETE returned $DEL_STATUS"
|
||||||
|
|
||||||
echo ""
|
echo ""
|
||||||
echo "=== 11. Submit + cancel a queued job ==="
|
echo "=== 14. Submit + cancel a queued job ==="
|
||||||
JOB2=$(curl -sf -X POST "$BASE/jobs" \
|
JOB2=$(curl -sf -X POST "$BASE/jobs" \
|
||||||
-F "audio=@${AUDIO};type=audio/wav" \
|
-F "audio=@${AUDIO};type=audio/wav" \
|
||||||
-F "language=en" \
|
-F "language=en" \
|
||||||
@@ -173,8 +206,24 @@ CANCEL_STATUS=$(curl -sf "$BASE/jobs/$JOB2_ID" | python3 -c "import sys,json; pr
|
|||||||
|| echo " [INFO] cancel status: $CANCEL_STATUS (may be running — worker ignores cancel mid-chunk)"
|
|| echo " [INFO] cancel status: $CANCEL_STATUS (may be running — worker ignores cancel mid-chunk)"
|
||||||
|
|
||||||
echo ""
|
echo ""
|
||||||
echo "=== 12. Verify webhook fired ==="
|
echo "=== 15. POST /model/unload ==="
|
||||||
|
UNLOAD_RESP=$(curl -sf -X POST "$BASE/model/unload")
|
||||||
|
echo "$UNLOAD_RESP"
|
||||||
|
sleep 2
|
||||||
|
UNLOAD_STATE=$(curl -sf "$BASE/model/status" | python3 -c "import sys,json; print(json.load(sys.stdin)['state'])")
|
||||||
|
[ "$UNLOAD_STATE" = "unloaded" ] && ok "model unloaded → state=unloaded" \
|
||||||
|
|| echo " [INFO] state after unload: $UNLOAD_STATE"
|
||||||
|
|
||||||
|
echo ""
|
||||||
|
echo "=== 16. Verify webhook fired ==="
|
||||||
sleep 3
|
sleep 3
|
||||||
kill $WEBHOOK_PID 2>/dev/null || true
|
kill $WEBHOOK_PID 2>/dev/null || true
|
||||||
ok "all tests complete"
|
ok "webhook server stopped"
|
||||||
|
|
||||||
|
echo ""
|
||||||
|
if [ $FAILS -eq 0 ]; then
|
||||||
|
echo -e "${GREEN}=== ALL TESTS PASSED ===${NC}"
|
||||||
|
else
|
||||||
|
echo -e "${RED}=== $FAILS TEST(S) FAILED ===${NC}"
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
|||||||
246
tests/test_idle_timeout.sh
Executable file
246
tests/test_idle_timeout.sh
Executable file
@@ -0,0 +1,246 @@
|
|||||||
|
#!/usr/bin/env bash
|
||||||
|
# tests/test_idle_timeout.sh
|
||||||
|
#
|
||||||
|
# Integration tests for the idle-timeout auto-unload feature.
|
||||||
|
# REQUIRES the server to be started with a short idle timeout:
|
||||||
|
#
|
||||||
|
# IDLE_TIMEOUT_SECS=5 ./whisper-server
|
||||||
|
# # or via Docker:
|
||||||
|
# docker run -e IDLE_TIMEOUT_SECS=5 ...
|
||||||
|
#
|
||||||
|
# The default idle timeout is 5 minutes; these tests use a 5-second window
|
||||||
|
# to keep the suite fast.
|
||||||
|
|
||||||
|
set -euo pipefail
|
||||||
|
|
||||||
|
BASE="${WHISPER_BASE_URL:-http://localhost:8080}"
|
||||||
|
IDLE_TIMEOUT="${EXPECTED_IDLE_TIMEOUT_SECS:-5}"
|
||||||
|
AUDIO="${TEST_AUDIO:-}"
|
||||||
|
|
||||||
|
GREEN='\033[0;32m'; RED='\033[0;31m'; YELLOW='\033[0;33m'; NC='\033[0m'
|
||||||
|
|
||||||
|
PASS=0; FAIL=0
|
||||||
|
|
||||||
|
ok() { echo -e "${GREEN}[PASS]${NC} $1"; PASS=$((PASS+1)); }
|
||||||
|
fail() { echo -e "${RED}[FAIL]${NC} $1"; FAIL=$((FAIL+1)); }
|
||||||
|
skip() { echo -e "${YELLOW}[SKIP]${NC} $1"; }
|
||||||
|
info() { echo " $1"; }
|
||||||
|
|
||||||
|
echo "=== Idle Timeout Tests ==="
|
||||||
|
echo " BASE: $BASE"
|
||||||
|
echo " IDLE_TIMEOUT_SECS: $IDLE_TIMEOUT (must be configured on the server)"
|
||||||
|
echo ""
|
||||||
|
echo "NOTE: These tests require the server to be running with IDLE_TIMEOUT_SECS=$IDLE_TIMEOUT"
|
||||||
|
echo ""
|
||||||
|
|
||||||
|
# ── Helpers ──────────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
get_state() {
|
||||||
|
curl -sf "$BASE/model/status" | python3 -c "import sys,json; print(json.load(sys.stdin)['state'])"
|
||||||
|
}
|
||||||
|
|
||||||
|
ensure_ready() {
|
||||||
|
local state
|
||||||
|
state=$(get_state)
|
||||||
|
if [ "$state" = "ready" ]; then return 0; fi
|
||||||
|
curl -sf -X POST "$BASE/model/load" > /dev/null
|
||||||
|
local elapsed=0
|
||||||
|
while true; do
|
||||||
|
sleep 3; elapsed=$((elapsed+3))
|
||||||
|
state=$(get_state)
|
||||||
|
[ "$state" = "ready" ] && return 0
|
||||||
|
[ $elapsed -gt 180 ] && return 1
|
||||||
|
done
|
||||||
|
}
|
||||||
|
|
||||||
|
ensure_unloaded() {
|
||||||
|
curl -sf -X POST "$BASE/model/unload" > /dev/null || true
|
||||||
|
sleep 2
|
||||||
|
}
|
||||||
|
|
||||||
|
# ── TEST 1: Load model, complete a job, then wait for idle unload ─────────────
|
||||||
|
echo "--- Test 1: Idle timeout triggers auto-unload ---"
|
||||||
|
|
||||||
|
ensure_unloaded
|
||||||
|
ensure_ready || { fail "T1: model load failed"; }
|
||||||
|
|
||||||
|
WAIT_SECS=$((IDLE_TIMEOUT + 3))
|
||||||
|
info "Model is ready. Waiting $WAIT_SECS seconds (idle timeout=$IDLE_TIMEOUT + 3s buffer)..."
|
||||||
|
sleep $WAIT_SECS
|
||||||
|
|
||||||
|
STATE=$(get_state)
|
||||||
|
if [ "$STATE" = "unloaded" ]; then
|
||||||
|
ok "T1: model auto-unloaded after ${IDLE_TIMEOUT}s idle"
|
||||||
|
else
|
||||||
|
fail "T1: expected unloaded after idle timeout, got $STATE"
|
||||||
|
info "Is the server running with IDLE_TIMEOUT_SECS=$IDLE_TIMEOUT?"
|
||||||
|
fi
|
||||||
|
|
||||||
|
# ── TEST 2: model_unloaded webhook fires on idle timeout ─────────────────────
|
||||||
|
echo ""
|
||||||
|
echo "--- Test 2: model_unloaded webhook fires on idle timeout ---"
|
||||||
|
|
||||||
|
ensure_unloaded
|
||||||
|
|
||||||
|
# Start webhook receiver
|
||||||
|
python3 - <<'PYEOF' &
|
||||||
|
import http.server, json, sys, signal
|
||||||
|
|
||||||
|
class H(http.server.BaseHTTPRequestHandler):
|
||||||
|
def do_POST(self):
|
||||||
|
n = int(self.headers.get('Content-Length', 0))
|
||||||
|
body = json.loads(self.rfile.read(n))
|
||||||
|
with open('/tmp/idle_wh_event.json', 'w') as f:
|
||||||
|
json.dump(body, f)
|
||||||
|
self.send_response(200); self.end_headers()
|
||||||
|
def log_message(self, *a): pass
|
||||||
|
|
||||||
|
signal.signal(signal.SIGTERM, lambda *_: sys.exit(0))
|
||||||
|
http.server.HTTPServer(('', 9995), H).serve_forever()
|
||||||
|
PYEOF
|
||||||
|
WH_PID=$!
|
||||||
|
sleep 1
|
||||||
|
|
||||||
|
# Register webhook via a job submission (will 503 since unloaded)
|
||||||
|
curl -sf -X POST "$BASE/jobs" \
|
||||||
|
-F "audio=@/dev/urandom;type=audio/wav" \
|
||||||
|
-F "webhook_url=http://localhost:9995/wh" \
|
||||||
|
--max-time 5 > /dev/null 2>&1 || true
|
||||||
|
|
||||||
|
# Load model
|
||||||
|
ensure_ready || { fail "T2: model load failed"; kill $WH_PID 2>/dev/null; }
|
||||||
|
|
||||||
|
# Wait for idle timeout
|
||||||
|
WAIT_SECS=$((IDLE_TIMEOUT + 5))
|
||||||
|
info "Waiting ${WAIT_SECS}s for idle timeout..."
|
||||||
|
sleep $WAIT_SECS
|
||||||
|
|
||||||
|
kill $WH_PID 2>/dev/null || true
|
||||||
|
wait $WH_PID 2>/dev/null || true
|
||||||
|
|
||||||
|
if [ -f /tmp/idle_wh_event.json ]; then
|
||||||
|
EVENT_TYPE=$(python3 -c "import json; print(json.load(open('/tmp/idle_wh_event.json')).get('type','?'))")
|
||||||
|
rm -f /tmp/idle_wh_event.json
|
||||||
|
[ "$EVENT_TYPE" = "model_unloaded" ] && ok "T2: model_unloaded webhook fired on idle timeout" \
|
||||||
|
|| fail "T2: webhook type=$EVENT_TYPE (expected model_unloaded)"
|
||||||
|
else
|
||||||
|
fail "T2: no webhook received within timeout"
|
||||||
|
fi
|
||||||
|
|
||||||
|
# ── TEST 3: Job submission after idle timeout → 503 → triggers reload ─────────
|
||||||
|
echo ""
|
||||||
|
echo "--- Test 3: Job triggers reload after idle unload ---"
|
||||||
|
|
||||||
|
ensure_unloaded
|
||||||
|
ensure_ready || { fail "T3: initial load failed"; }
|
||||||
|
|
||||||
|
# Wait for auto-unload
|
||||||
|
WAIT_SECS=$((IDLE_TIMEOUT + 3))
|
||||||
|
info "Waiting ${WAIT_SECS}s for idle unload..."
|
||||||
|
sleep $WAIT_SECS
|
||||||
|
|
||||||
|
STATE=$(get_state)
|
||||||
|
[ "$STATE" = "unloaded" ] || info "Note: state=$STATE (expected unloaded)"
|
||||||
|
|
||||||
|
# Submit job → 503, triggers reload
|
||||||
|
HTTP=$(curl -s -o /tmp/t3_body.json -w "%{http_code}" -X POST "$BASE/jobs" \
|
||||||
|
-F "audio=@/dev/urandom;type=audio/wav" \
|
||||||
|
--max-time 5 2>/dev/null || echo "000")
|
||||||
|
|
||||||
|
if [ "$HTTP" = "503" ]; then
|
||||||
|
ok "T3a: POST /jobs → 503 after idle unload"
|
||||||
|
else
|
||||||
|
skip "T3a: POST /jobs returned $HTTP (model may have reloaded)"
|
||||||
|
fi
|
||||||
|
|
||||||
|
# State should be loading or ready (reload triggered by job submission)
|
||||||
|
sleep 2
|
||||||
|
STATE=$(get_state)
|
||||||
|
if [ "$STATE" = "loading" ] || [ "$STATE" = "ready" ]; then
|
||||||
|
ok "T3b: reload triggered by job submission ($STATE)"
|
||||||
|
else
|
||||||
|
fail "T3b: expected loading/ready, got $STATE"
|
||||||
|
fi
|
||||||
|
|
||||||
|
rm -f /tmp/t3_body.json
|
||||||
|
|
||||||
|
# ── TEST 4: Idle timer resets per job (wait 60% of timeout → still ready) ─────
|
||||||
|
echo ""
|
||||||
|
echo "--- Test 4: Idle timer resets with each completed job ---"
|
||||||
|
|
||||||
|
ensure_unloaded
|
||||||
|
ensure_ready || { fail "T4: model load failed"; }
|
||||||
|
|
||||||
|
HALF_WAIT=$((IDLE_TIMEOUT - 1))
|
||||||
|
info "Waiting ${HALF_WAIT}s (less than idle timeout)..."
|
||||||
|
sleep $HALF_WAIT
|
||||||
|
|
||||||
|
STATE=$(get_state)
|
||||||
|
if [ "$STATE" = "ready" ]; then
|
||||||
|
ok "T4a: model still ready after ${HALF_WAIT}s (less than ${IDLE_TIMEOUT}s timeout)"
|
||||||
|
else
|
||||||
|
fail "T4a: model unexpectedly $STATE after only ${HALF_WAIT}s"
|
||||||
|
fi
|
||||||
|
|
||||||
|
# Wait for full unload
|
||||||
|
REMAINING=$((IDLE_TIMEOUT - HALF_WAIT + 3))
|
||||||
|
info "Waiting another ${REMAINING}s for full idle unload..."
|
||||||
|
sleep $REMAINING
|
||||||
|
STATE=$(get_state)
|
||||||
|
[ "$STATE" = "unloaded" ] && ok "T4b: model unloaded after total > ${IDLE_TIMEOUT}s idle" \
|
||||||
|
|| fail "T4b: expected unloaded, got $STATE"
|
||||||
|
|
||||||
|
# ── TEST 5: Job resets idle timer ─────────────────────────────────────────────
|
||||||
|
echo ""
|
||||||
|
echo "--- Test 5: Completing a job resets the idle timer ---"
|
||||||
|
|
||||||
|
if [ -z "$AUDIO" ]; then
|
||||||
|
skip "T5: TEST_AUDIO not set — skipping timer-reset test"
|
||||||
|
else
|
||||||
|
ensure_unloaded
|
||||||
|
ensure_ready || { fail "T5: model load failed"; }
|
||||||
|
|
||||||
|
# Submit a job
|
||||||
|
SUBMIT=$(curl -sf -X POST "$BASE/jobs" \
|
||||||
|
-F "audio=@${AUDIO};type=audio/wav" \
|
||||||
|
-F "task=transcribe" 2>&1)
|
||||||
|
JOB_ID=$(echo "$SUBMIT" | python3 -c "import sys,json; print(json.load(sys.stdin)['job_id'])" 2>/dev/null || echo "")
|
||||||
|
|
||||||
|
if [ -z "$JOB_ID" ]; then
|
||||||
|
fail "T5: job submission failed"
|
||||||
|
else
|
||||||
|
# Wait for job to finish
|
||||||
|
elapsed=0
|
||||||
|
while true; do
|
||||||
|
sleep 5; elapsed=$((elapsed+5))
|
||||||
|
STATUS=$(curl -sf "$BASE/jobs/$JOB_ID" | python3 -c "import sys,json; print(json.load(sys.stdin)['status'])")
|
||||||
|
[ "$STATUS" = "done" ] || [ "$STATUS" = "failed" ] && break
|
||||||
|
[ $elapsed -gt 300 ] && break
|
||||||
|
done
|
||||||
|
info "Job finished in ${elapsed}s with status=$STATUS"
|
||||||
|
|
||||||
|
# Now wait IDLE_TIMEOUT - 2 seconds — should still be ready
|
||||||
|
SAFE_WAIT=$((IDLE_TIMEOUT - 2))
|
||||||
|
[ $SAFE_WAIT -lt 1 ] && SAFE_WAIT=1
|
||||||
|
info "Waiting ${SAFE_WAIT}s after job completion (less than idle timeout)..."
|
||||||
|
sleep $SAFE_WAIT
|
||||||
|
STATE=$(get_state)
|
||||||
|
[ "$STATE" = "ready" ] && ok "T5a: model still ready ${SAFE_WAIT}s after job completion" \
|
||||||
|
|| fail "T5a: model unexpectedly $STATE after job"
|
||||||
|
|
||||||
|
# Wait for idle timeout
|
||||||
|
REMAINING=$((IDLE_TIMEOUT - SAFE_WAIT + 3))
|
||||||
|
info "Waiting ${REMAINING}s more for idle unload..."
|
||||||
|
sleep $REMAINING
|
||||||
|
STATE=$(get_state)
|
||||||
|
[ "$STATE" = "unloaded" ] && ok "T5b: model auto-unloaded after idle period post-job" \
|
||||||
|
|| fail "T5b: expected unloaded, got $STATE"
|
||||||
|
fi
|
||||||
|
fi
|
||||||
|
|
||||||
|
# ── Summary ────────────────────────────────────────────────────────────────────
|
||||||
|
echo ""
|
||||||
|
echo "=========================================="
|
||||||
|
echo " Results: ${PASS} passed, ${FAIL} failed"
|
||||||
|
echo "=========================================="
|
||||||
|
[ $FAIL -eq 0 ] && echo -e "${GREEN}ALL PASSED${NC}" || { echo -e "${RED}FAILURES: $FAIL${NC}"; exit 1; }
|
||||||
470
tests/test_model_lifecycle.sh
Executable file
470
tests/test_model_lifecycle.sh
Executable file
@@ -0,0 +1,470 @@
|
|||||||
|
#!/usr/bin/env bash
|
||||||
|
# tests/test_model_lifecycle.sh
|
||||||
|
#
|
||||||
|
# Integration tests for dynamic model loading/unloading.
|
||||||
|
# Requires a running whisper-server with GPU access.
|
||||||
|
#
|
||||||
|
# Usage:
|
||||||
|
# WHISPER_BASE_URL=http://localhost:8080 bash tests/test_model_lifecycle.sh
|
||||||
|
#
|
||||||
|
# Tests are designed to be independent; each section that needs a specific
|
||||||
|
# state resets it explicitly at the start.
|
||||||
|
|
||||||
|
set -euo pipefail
|
||||||
|
|
||||||
|
BASE="${WHISPER_BASE_URL:-http://localhost:8080}"
|
||||||
|
AUDIO="${TEST_AUDIO:-}"
|
||||||
|
|
||||||
|
GREEN='\033[0;32m'; RED='\033[0;31m'; YELLOW='\033[0;33m'; NC='\033[0m'
|
||||||
|
|
||||||
|
PASS=0; FAIL=0
|
||||||
|
|
||||||
|
ok() { echo -e "${GREEN}[PASS]${NC} $1"; PASS=$((PASS+1)); }
|
||||||
|
fail() { echo -e "${RED}[FAIL]${NC} $1"; FAIL=$((FAIL+1)); }
|
||||||
|
skip() { echo -e "${YELLOW}[SKIP]${NC} $1"; }
|
||||||
|
info() { echo " $1"; }
|
||||||
|
|
||||||
|
echo "=== Model Lifecycle Integration Tests ==="
|
||||||
|
echo " BASE: $BASE"
|
||||||
|
echo ""
|
||||||
|
|
||||||
|
# ── Helpers ──────────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
get_state() {
|
||||||
|
curl -sf "$BASE/model/status" | python3 -c "import sys,json; print(json.load(sys.stdin)['state'])"
|
||||||
|
}
|
||||||
|
|
||||||
|
ensure_unloaded() {
|
||||||
|
curl -sf -X POST "$BASE/model/unload" > /dev/null
|
||||||
|
sleep 2
|
||||||
|
local s
|
||||||
|
s=$(get_state)
|
||||||
|
if [ "$s" != "unloaded" ]; then
|
||||||
|
echo " WARNING: expected unloaded, got $s — waiting 5s"
|
||||||
|
sleep 5
|
||||||
|
fi
|
||||||
|
}
|
||||||
|
|
||||||
|
ensure_ready() {
|
||||||
|
local state
|
||||||
|
state=$(get_state)
|
||||||
|
if [ "$state" = "ready" ]; then return 0; fi
|
||||||
|
curl -sf -X POST "$BASE/model/load" > /dev/null
|
||||||
|
local elapsed=0
|
||||||
|
while true; do
|
||||||
|
sleep 5; elapsed=$((elapsed+5))
|
||||||
|
state=$(get_state)
|
||||||
|
[ "$state" = "ready" ] && return 0
|
||||||
|
[ $elapsed -gt 180 ] && echo " TIMEOUT: model did not become ready" && return 1
|
||||||
|
done
|
||||||
|
}
|
||||||
|
|
||||||
|
poll_state_transition() {
|
||||||
|
local target="$1" max_secs="${2:-120}"
|
||||||
|
local elapsed=0
|
||||||
|
while true; do
|
||||||
|
sleep 2; elapsed=$((elapsed+2))
|
||||||
|
local s
|
||||||
|
s=$(get_state)
|
||||||
|
[ "$s" = "$target" ] && return 0
|
||||||
|
[ $elapsed -ge $max_secs ] && return 1
|
||||||
|
done
|
||||||
|
}
|
||||||
|
|
||||||
|
# ── TEST 1: Startup state is unloaded ────────────────────────────────────────
|
||||||
|
echo "--- Test 1: Startup state is unloaded (or after explicit unload) ---"
|
||||||
|
ensure_unloaded
|
||||||
|
STATE=$(get_state)
|
||||||
|
if [ "$STATE" = "unloaded" ]; then
|
||||||
|
ok "T1: state=unloaded after explicit unload"
|
||||||
|
else
|
||||||
|
fail "T1: expected unloaded, got $STATE"
|
||||||
|
fi
|
||||||
|
|
||||||
|
# ── TEST 2: POST /model/load returns 202 ─────────────────────────────────────
|
||||||
|
echo ""
|
||||||
|
echo "--- Test 2: POST /model/load returns 202 ---"
|
||||||
|
ensure_unloaded
|
||||||
|
HTTP=$(curl -s -o /dev/null -w "%{http_code}" -X POST "$BASE/model/load")
|
||||||
|
if [ "$HTTP" = "202" ]; then
|
||||||
|
ok "T2: POST /model/load → 202 Accepted"
|
||||||
|
else
|
||||||
|
fail "T2: expected 202, got $HTTP"
|
||||||
|
fi
|
||||||
|
# Cancel the in-progress load to clean up
|
||||||
|
curl -sf -X POST "$BASE/model/unload" > /dev/null || true
|
||||||
|
sleep 2
|
||||||
|
|
||||||
|
# ── TEST 3: State transitions to loading/ready after load trigger ─────────────
|
||||||
|
echo ""
|
||||||
|
echo "--- Test 3: State transitions to loading (not stuck at unloaded) ---"
|
||||||
|
ensure_unloaded
|
||||||
|
curl -sf -X POST "$BASE/model/load" > /dev/null
|
||||||
|
sleep 1
|
||||||
|
STATE=$(get_state)
|
||||||
|
if [ "$STATE" = "loading" ] || [ "$STATE" = "ready" ]; then
|
||||||
|
ok "T3: state transitioned to $STATE (not stuck at unloaded)"
|
||||||
|
else
|
||||||
|
fail "T3: expected loading or ready, got $STATE"
|
||||||
|
fi
|
||||||
|
|
||||||
|
# ── TEST 4: Model reaches ready state and loaded_at is set ───────────────────
|
||||||
|
echo ""
|
||||||
|
echo "--- Test 4: Model reaches ready state with loaded_at timestamp ---"
|
||||||
|
# Already loading from T3 — wait for ready
|
||||||
|
if ! poll_state_transition "ready" 180; then
|
||||||
|
fail "T4: model did not become ready within 3 minutes"
|
||||||
|
else
|
||||||
|
STATUS_JSON=$(curl -sf "$BASE/model/status")
|
||||||
|
LOADED_AT=$(echo "$STATUS_JSON" | python3 -c "import sys,json; d=json.load(sys.stdin); print(d.get('loaded_at','MISSING'))" 2>/dev/null || echo "MISSING")
|
||||||
|
if [ "$LOADED_AT" != "MISSING" ] && [ "$LOADED_AT" != "null" ] && [ -n "$LOADED_AT" ]; then
|
||||||
|
ok "T4: model=ready, loaded_at=$LOADED_AT"
|
||||||
|
else
|
||||||
|
fail "T4: model ready but loaded_at is missing or null"
|
||||||
|
fi
|
||||||
|
fi
|
||||||
|
|
||||||
|
# ── TEST 5: Idempotent load — POST /model/load when ready returns 200 ─────────
|
||||||
|
echo ""
|
||||||
|
echo "--- Test 5: POST /model/load when already ready → 200 ---"
|
||||||
|
ensure_ready || { fail "T5: could not load model"; }
|
||||||
|
HTTP=$(curl -s -o /dev/null -w "%{http_code}" -X POST "$BASE/model/load")
|
||||||
|
STATE=$(get_state)
|
||||||
|
if [ "$HTTP" = "200" ] && [ "$STATE" = "ready" ]; then
|
||||||
|
ok "T5: idempotent load → 200, state stays ready"
|
||||||
|
elif [ "$HTTP" = "202" ] && [ "$STATE" = "ready" ]; then
|
||||||
|
ok "T5: idempotent load → 202, state stays ready"
|
||||||
|
else
|
||||||
|
fail "T5: expected 200 and ready, got HTTP=$HTTP state=$STATE"
|
||||||
|
fi
|
||||||
|
|
||||||
|
# ── TEST 6: Job accepted when ready (segments > 0) ────────────────────────────
|
||||||
|
echo ""
|
||||||
|
echo "--- Test 6: Job accepted when model is ready ---"
|
||||||
|
if [ -z "$AUDIO" ]; then
|
||||||
|
skip "T6: TEST_AUDIO not set — skipping job submission test"
|
||||||
|
else
|
||||||
|
ensure_ready || { fail "T6: model load failed"; }
|
||||||
|
SUBMIT=$(curl -sf -X POST "$BASE/jobs" -F "audio=@${AUDIO};type=audio/wav" -F "task=transcribe" 2>&1)
|
||||||
|
JOB_ID=$(echo "$SUBMIT" | python3 -c "import sys,json; print(json.load(sys.stdin)['job_id'])" 2>/dev/null || echo "")
|
||||||
|
if [ -n "$JOB_ID" ]; then
|
||||||
|
ok "T6: job accepted, id=$JOB_ID"
|
||||||
|
# Poll to done
|
||||||
|
elapsed=0
|
||||||
|
while true; do
|
||||||
|
sleep 10; elapsed=$((elapsed+10))
|
||||||
|
STATUS=$(curl -sf "$BASE/jobs/$JOB_ID" | python3 -c "import sys,json; print(json.load(sys.stdin)['status'])")
|
||||||
|
[ "$STATUS" = "done" ] && break
|
||||||
|
[ "$STATUS" = "failed" ] && break
|
||||||
|
[ $elapsed -gt 600 ] && break
|
||||||
|
done
|
||||||
|
SEGS=$(curl -sf "$BASE/jobs/$JOB_ID" | python3 -c "import sys,json; d=json.load(sys.stdin); print(len(d.get('segments',[])))")
|
||||||
|
[ "$SEGS" -gt 0 ] && ok "T6b: job done with $SEGS segments" || fail "T6b: job done but 0 segments"
|
||||||
|
else
|
||||||
|
fail "T6: job submission failed: $SUBMIT"
|
||||||
|
fi
|
||||||
|
fi
|
||||||
|
|
||||||
|
# ── TEST 7: POST /model/unload → state=unloaded ───────────────────────────────
|
||||||
|
echo ""
|
||||||
|
echo "--- Test 7: POST /model/unload ---"
|
||||||
|
ensure_ready || { fail "T7: model load failed"; }
|
||||||
|
curl -sf -X POST "$BASE/model/unload" > /dev/null
|
||||||
|
sleep 3
|
||||||
|
STATE=$(get_state)
|
||||||
|
if [ "$STATE" = "unloaded" ]; then
|
||||||
|
ok "T7: POST /model/unload → state=unloaded"
|
||||||
|
else
|
||||||
|
fail "T7: expected unloaded after unload, got $STATE"
|
||||||
|
fi
|
||||||
|
|
||||||
|
# ── TEST 8: POST /jobs when unloaded → 503 + Retry-After ─────────────────────
|
||||||
|
echo ""
|
||||||
|
echo "--- Test 8: POST /jobs when unloaded → 503 + Retry-After ---"
|
||||||
|
ensure_unloaded
|
||||||
|
# Submit a tiny dummy payload (won't be valid audio but that's ok for this test)
|
||||||
|
HTTP=$(curl -s -o /tmp/t8_body.json -w "%{http_code}" -X POST "$BASE/jobs" \
|
||||||
|
-F "audio=@/dev/urandom;type=audio/wav" \
|
||||||
|
--max-time 5 2>/dev/null || echo "000")
|
||||||
|
# If the model auto-loads it might start processing; check for 503 first
|
||||||
|
if [ "$HTTP" = "503" ]; then
|
||||||
|
RETRY_AFTER=$(curl -sI -X POST "$BASE/jobs" \
|
||||||
|
-F "audio=@/dev/urandom;type=audio/wav" \
|
||||||
|
--max-time 5 2>/dev/null | grep -i "retry-after" | awk '{print $2}' | tr -d '\r' || echo "")
|
||||||
|
BODY=$(cat /tmp/t8_body.json 2>/dev/null || echo "{}")
|
||||||
|
HAS_STATE=$(echo "$BODY" | python3 -c "import sys,json; d=json.load(sys.stdin); print('state' in d)" 2>/dev/null || echo "False")
|
||||||
|
HAS_RETRY=$(echo "$BODY" | python3 -c "import sys,json; d=json.load(sys.stdin); print('retry_after_secs' in d)" 2>/dev/null || echo "False")
|
||||||
|
if [ "$HAS_STATE" = "True" ] && [ "$HAS_RETRY" = "True" ]; then
|
||||||
|
ok "T8: 503 with state + retry_after_secs in body"
|
||||||
|
else
|
||||||
|
fail "T8: 503 but body missing state/retry_after_secs. body=$BODY"
|
||||||
|
fi
|
||||||
|
if [ -n "$RETRY_AFTER" ]; then
|
||||||
|
ok "T8b: Retry-After header present: $RETRY_AFTER"
|
||||||
|
else
|
||||||
|
fail "T8b: Retry-After header missing from 503 response"
|
||||||
|
fi
|
||||||
|
else
|
||||||
|
skip "T8: got HTTP $HTTP (model may have loaded before check) — skipping"
|
||||||
|
fi
|
||||||
|
|
||||||
|
# ── TEST 9: Rejected job triggers load ────────────────────────────────────────
|
||||||
|
echo ""
|
||||||
|
echo "--- Test 9: Job rejection triggers model load ---"
|
||||||
|
ensure_unloaded
|
||||||
|
# Send a job (we expect 503)
|
||||||
|
curl -sf -X POST "$BASE/jobs" \
|
||||||
|
-F "audio=@/dev/urandom;type=audio/wav" \
|
||||||
|
--max-time 5 > /dev/null 2>&1 || true
|
||||||
|
sleep 2
|
||||||
|
STATE=$(get_state)
|
||||||
|
if [ "$STATE" = "loading" ] || [ "$STATE" = "ready" ]; then
|
||||||
|
ok "T9: model started loading after job rejection ($STATE)"
|
||||||
|
else
|
||||||
|
fail "T9: expected loading/ready after job rejection, got $STATE"
|
||||||
|
fi
|
||||||
|
# Stop the load to clean up
|
||||||
|
curl -sf -X POST "$BASE/model/unload" > /dev/null || true
|
||||||
|
sleep 2
|
||||||
|
|
||||||
|
# ── TEST 10: Retry-After values ───────────────────────────────────────────────
|
||||||
|
echo ""
|
||||||
|
echo "--- Test 10: Retry-After values match state ---"
|
||||||
|
ensure_unloaded
|
||||||
|
# Unloaded → Retry-After: 30
|
||||||
|
RESP_UNLOADED=$(curl -si -X POST "$BASE/jobs" -F "audio=@/dev/urandom;type=audio/wav" --max-time 5 2>/dev/null || echo "")
|
||||||
|
RA_UNLOADED=$(echo "$RESP_UNLOADED" | grep -i "retry-after" | awk '{print $2}' | tr -d '\r' || echo "")
|
||||||
|
[ "$RA_UNLOADED" = "30" ] && ok "T10a: Retry-After=30 when unloaded" \
|
||||||
|
|| skip "T10a: Retry-After=$RA_UNLOADED (expected 30) — model may have started loading"
|
||||||
|
|
||||||
|
# ── TEST 11: Retry-After=10 during loading ────────────────────────────────────
|
||||||
|
echo ""
|
||||||
|
echo "--- Test 11: Retry-After=10 when loading ---"
|
||||||
|
ensure_unloaded
|
||||||
|
curl -sf -X POST "$BASE/model/load" > /dev/null
|
||||||
|
sleep 1 # In loading state
|
||||||
|
STATE=$(get_state)
|
||||||
|
if [ "$STATE" = "loading" ]; then
|
||||||
|
RESP_LOADING=$(curl -si -X POST "$BASE/jobs" -F "audio=@/dev/urandom;type=audio/wav" --max-time 5 2>/dev/null || echo "")
|
||||||
|
RA_LOADING=$(echo "$RESP_LOADING" | grep -i "retry-after" | awk '{print $2}' | tr -d '\r' || echo "")
|
||||||
|
[ "$RA_LOADING" = "10" ] && ok "T11: Retry-After=10 when loading" \
|
||||||
|
|| fail "T11: expected Retry-After=10, got '$RA_LOADING' (state=$STATE)"
|
||||||
|
else
|
||||||
|
skip "T11: model already $STATE — can't test loading state Retry-After"
|
||||||
|
fi
|
||||||
|
|
||||||
|
# ── TEST 12: 503 body schema validation ──────────────────────────────────────
|
||||||
|
echo ""
|
||||||
|
echo "--- Test 12: 503 body schema validation ---"
|
||||||
|
ensure_unloaded
|
||||||
|
BODY=$(curl -sf -X POST "$BASE/jobs" -F "audio=@/dev/urandom;type=audio/wav" --max-time 5 2>/dev/null || echo "{}")
|
||||||
|
python3 - <<PYCHECK
|
||||||
|
import json
|
||||||
|
body = json.loads('$BODY')
|
||||||
|
required = {'error', 'state', 'retry_after_secs'}
|
||||||
|
missing = required - set(body.keys())
|
||||||
|
if missing:
|
||||||
|
print(f"MISSING: {missing}")
|
||||||
|
exit(1)
|
||||||
|
assert body['error'] == 'model_not_ready', f"error={body['error']}"
|
||||||
|
assert isinstance(body['retry_after_secs'], int), f"retry_after_secs not int: {body['retry_after_secs']}"
|
||||||
|
print("schema ok")
|
||||||
|
PYCHECK
|
||||||
|
[ $? -eq 0 ] && ok "T12: 503 body has correct schema" || fail "T12: 503 body schema invalid"
|
||||||
|
|
||||||
|
# ── TEST 13: GET /health has model_state field ────────────────────────────────
|
||||||
|
echo ""
|
||||||
|
echo "--- Test 13: GET /health has model_state ---"
|
||||||
|
HEALTH=$(curl -sf "$BASE/health")
|
||||||
|
HAS_MODEL_STATE=$(echo "$HEALTH" | python3 -c "import sys,json; d=json.load(sys.stdin); print('model_state' in d)")
|
||||||
|
[ "$HAS_MODEL_STATE" = "True" ] && ok "T13: /health has model_state" || fail "T13: /health missing model_state"
|
||||||
|
|
||||||
|
# ── TEST 14: SSE /model/events delivers model_ready event ─────────────────────
|
||||||
|
echo ""
|
||||||
|
echo "--- Test 14: GET /model/events SSE delivers model_ready ---"
|
||||||
|
ensure_unloaded
|
||||||
|
|
||||||
|
# Collect SSE events for up to 3 minutes
|
||||||
|
SSE_LOG=$(mktemp /tmp/sse_events_XXXXXX.txt)
|
||||||
|
curl -sN --max-time 180 "$BASE/model/events" > "$SSE_LOG" &
|
||||||
|
SSE_PID=$!
|
||||||
|
sleep 1
|
||||||
|
|
||||||
|
# Trigger load
|
||||||
|
curl -sf -X POST "$BASE/model/load" > /dev/null
|
||||||
|
poll_state_transition "ready" 180 || true
|
||||||
|
|
||||||
|
sleep 2
|
||||||
|
kill $SSE_PID 2>/dev/null || true
|
||||||
|
wait $SSE_PID 2>/dev/null || true
|
||||||
|
|
||||||
|
if grep -q "model_loading" "$SSE_LOG" 2>/dev/null; then
|
||||||
|
ok "T14a: SSE received model_loading event"
|
||||||
|
else
|
||||||
|
fail "T14a: SSE did not receive model_loading event"
|
||||||
|
fi
|
||||||
|
|
||||||
|
if grep -q "model_ready" "$SSE_LOG" 2>/dev/null; then
|
||||||
|
ok "T14b: SSE received model_ready event"
|
||||||
|
else
|
||||||
|
fail "T14b: SSE did not receive model_ready event"
|
||||||
|
fi
|
||||||
|
|
||||||
|
# Now unload to get model_unloaded event
|
||||||
|
curl -sf -X POST "$BASE/model/unload" > /dev/null
|
||||||
|
sleep 1
|
||||||
|
|
||||||
|
SSE_LOG2=$(mktemp /tmp/sse_events_XXXXXX.txt)
|
||||||
|
curl -sN --max-time 10 "$BASE/model/events" > "$SSE_LOG2" &
|
||||||
|
SSE_PID2=$!
|
||||||
|
sleep 2
|
||||||
|
kill $SSE_PID2 2>/dev/null || true
|
||||||
|
wait $SSE_PID2 2>/dev/null || true
|
||||||
|
|
||||||
|
# model_unloaded fires immediately on unload command
|
||||||
|
if grep -q "model_unloaded" "$SSE_LOG" 2>/dev/null || grep -q "model_unloaded" "$SSE_LOG2" 2>/dev/null; then
|
||||||
|
ok "T14c: SSE received model_unloaded event"
|
||||||
|
else
|
||||||
|
fail "T14c: SSE did not receive model_unloaded event"
|
||||||
|
fi
|
||||||
|
|
||||||
|
rm -f "$SSE_LOG" "$SSE_LOG2"
|
||||||
|
|
||||||
|
# ── TEST 15: model_ready webhook fires after load ──────────────────────────────
|
||||||
|
echo ""
|
||||||
|
echo "--- Test 15: model_ready webhook ---"
|
||||||
|
ensure_unloaded
|
||||||
|
|
||||||
|
# Start webhook receiver
|
||||||
|
WEBHOOK_LOG=$(mktemp /tmp/webhook_log_XXXXXX.txt)
|
||||||
|
python3 - <<'PYEOF' &
|
||||||
|
import http.server, json, sys, signal, os
|
||||||
|
|
||||||
|
class H(http.server.BaseHTTPRequestHandler):
|
||||||
|
def do_POST(self):
|
||||||
|
n = int(self.headers.get('Content-Length', 0))
|
||||||
|
body = json.loads(self.rfile.read(n))
|
||||||
|
with open('/tmp/t15_webhook.json', 'w') as f:
|
||||||
|
json.dump(body, f)
|
||||||
|
self.send_response(200); self.end_headers()
|
||||||
|
def log_message(self, *a): pass
|
||||||
|
|
||||||
|
signal.signal(signal.SIGTERM, lambda *_: sys.exit(0))
|
||||||
|
http.server.HTTPServer(('', 9998), H).serve_forever()
|
||||||
|
PYEOF
|
||||||
|
WBOOK_PID=$!
|
||||||
|
sleep 1
|
||||||
|
|
||||||
|
# Register a webhook via a (doomed) job submission — this registers the URL
|
||||||
|
# even though the model is unloaded (and the job will 503)
|
||||||
|
curl -sf -X POST "$BASE/jobs" \
|
||||||
|
-F "audio=@/dev/urandom;type=audio/wav" \
|
||||||
|
-F "webhook_url=http://localhost:9998/wh" \
|
||||||
|
--max-time 5 > /dev/null 2>&1 || true
|
||||||
|
|
||||||
|
# Now load the model
|
||||||
|
curl -sf -X POST "$BASE/model/load" > /dev/null
|
||||||
|
poll_state_transition "ready" 180 || true
|
||||||
|
sleep 3
|
||||||
|
|
||||||
|
kill $WBOOK_PID 2>/dev/null || true
|
||||||
|
wait $WBOOK_PID 2>/dev/null || true
|
||||||
|
|
||||||
|
if [ -f /tmp/t15_webhook.json ]; then
|
||||||
|
EVENT_TYPE=$(python3 -c "import json; d=json.load(open('/tmp/t15_webhook.json')); print(d.get('type','?'))")
|
||||||
|
[ "$EVENT_TYPE" = "model_ready" ] && ok "T15: model_ready webhook fired" \
|
||||||
|
|| fail "T15: webhook fired but type=$EVENT_TYPE (expected model_ready)"
|
||||||
|
rm -f /tmp/t15_webhook.json
|
||||||
|
else
|
||||||
|
fail "T15: model_ready webhook not received within timeout"
|
||||||
|
fi
|
||||||
|
|
||||||
|
# ── TEST 16: model_unloaded webhook fires ─────────────────────────────────────
|
||||||
|
echo ""
|
||||||
|
echo "--- Test 16: model_unloaded webhook ---"
|
||||||
|
|
||||||
|
python3 - <<'PYEOF' &
|
||||||
|
import http.server, json, sys, signal
|
||||||
|
|
||||||
|
class H(http.server.BaseHTTPRequestHandler):
|
||||||
|
def do_POST(self):
|
||||||
|
n = int(self.headers.get('Content-Length', 0))
|
||||||
|
body = json.loads(self.rfile.read(n))
|
||||||
|
with open('/tmp/t16_webhook.json', 'w') as f:
|
||||||
|
json.dump(body, f)
|
||||||
|
self.send_response(200); self.end_headers()
|
||||||
|
def log_message(self, *a): pass
|
||||||
|
|
||||||
|
signal.signal(signal.SIGTERM, lambda *_: sys.exit(0))
|
||||||
|
http.server.HTTPServer(('', 9997), H).serve_forever()
|
||||||
|
PYEOF
|
||||||
|
WBOOK2_PID=$!
|
||||||
|
sleep 1
|
||||||
|
|
||||||
|
# Register webhook URL
|
||||||
|
curl -sf -X POST "$BASE/jobs" \
|
||||||
|
-F "audio=@/dev/urandom;type=audio/wav" \
|
||||||
|
-F "webhook_url=http://localhost:9997/wh" \
|
||||||
|
--max-time 5 > /dev/null 2>&1 || true
|
||||||
|
|
||||||
|
ensure_ready
|
||||||
|
|
||||||
|
# Unload
|
||||||
|
curl -sf -X POST "$BASE/model/unload" > /dev/null
|
||||||
|
sleep 5
|
||||||
|
|
||||||
|
kill $WBOOK2_PID 2>/dev/null || true
|
||||||
|
wait $WBOOK2_PID 2>/dev/null || true
|
||||||
|
|
||||||
|
if [ -f /tmp/t16_webhook.json ]; then
|
||||||
|
EVENT_TYPE=$(python3 -c "import json; d=json.load(open('/tmp/t16_webhook.json')); print(d.get('type','?'))")
|
||||||
|
[ "$EVENT_TYPE" = "model_unloaded" ] && ok "T16: model_unloaded webhook fired" \
|
||||||
|
|| fail "T16: webhook type=$EVENT_TYPE (expected model_unloaded)"
|
||||||
|
rm -f /tmp/t16_webhook.json
|
||||||
|
else
|
||||||
|
fail "T16: model_unloaded webhook not received"
|
||||||
|
fi
|
||||||
|
|
||||||
|
# ── TEST 17: Concurrent load requests — single load, stable ready ─────────────
|
||||||
|
echo ""
|
||||||
|
echo "--- Test 17: Concurrent POST /model/load requests ---"
|
||||||
|
ensure_unloaded
|
||||||
|
# Send 3 concurrent load requests
|
||||||
|
curl -sf -X POST "$BASE/model/load" > /dev/null &
|
||||||
|
curl -sf -X POST "$BASE/model/load" > /dev/null &
|
||||||
|
curl -sf -X POST "$BASE/model/load" > /dev/null &
|
||||||
|
wait
|
||||||
|
poll_state_transition "ready" 180 || true
|
||||||
|
STATE=$(get_state)
|
||||||
|
[ "$STATE" = "ready" ] && ok "T17: concurrent loads handled cleanly, state=ready" \
|
||||||
|
|| fail "T17: expected ready after concurrent loads, got $STATE"
|
||||||
|
|
||||||
|
# ── TEST 18: POST /model/unload during loading → clean unloaded ───────────────
|
||||||
|
echo ""
|
||||||
|
echo "--- Test 18: POST /model/unload during loading ---"
|
||||||
|
ensure_unloaded
|
||||||
|
curl -sf -X POST "$BASE/model/load" > /dev/null
|
||||||
|
sleep 1 # Hopefully still in loading state
|
||||||
|
curl -sf -X POST "$BASE/model/unload" > /dev/null
|
||||||
|
# Allow time for the unload to propagate
|
||||||
|
sleep 5
|
||||||
|
STATE=$(get_state)
|
||||||
|
if [ "$STATE" = "unloaded" ]; then
|
||||||
|
ok "T18: unload during loading → clean unloaded"
|
||||||
|
elif [ "$STATE" = "ready" ]; then
|
||||||
|
# Load completed before unload arrived — immediately unload
|
||||||
|
curl -sf -X POST "$BASE/model/unload" > /dev/null
|
||||||
|
sleep 3
|
||||||
|
STATE=$(get_state)
|
||||||
|
[ "$STATE" = "unloaded" ] && ok "T18: load completed then unloaded (race condition OK)" \
|
||||||
|
|| fail "T18: state=$STATE after load+unload"
|
||||||
|
else
|
||||||
|
fail "T18: unexpected state after unload-during-load: $STATE"
|
||||||
|
fi
|
||||||
|
|
||||||
|
# ── Summary ────────────────────────────────────────────────────────────────────
|
||||||
|
echo ""
|
||||||
|
echo "=========================================="
|
||||||
|
echo " Results: ${PASS} passed, ${FAIL} failed"
|
||||||
|
echo "=========================================="
|
||||||
|
[ $FAIL -eq 0 ] && echo -e "${GREEN}ALL PASSED${NC}" || { echo -e "${RED}FAILURES: $FAIL${NC}"; exit 1; }
|
||||||
Reference in New Issue
Block a user