feat: dynamic model loading/unloading with GPU polling
All checks were successful
Build & Push Docker Image / build-and-push (push) Successful in 8m41s

- Model starts unloaded (lazy); loads on first job or POST /model/load
- Auto-unloads after IDLE_TIMEOUT_SECS (default 300) of inactivity
- POST /model/unload for immediate manual release
- GPU-busy detection: on VRAM OOM, enters WaitingForGpu and retries
  every GPU_POLL_INTERVAL_SECS (default 30) indefinitely
- POST /jobs when unloaded → 503 + Retry-After header, triggers load
- AppError::OutOfMemory and AppError::ModelNotReady variants
- WorkerCmd channel (SyncSender<WorkerCmd>) replaces bare tx_req channel
- Idle timer via recv_timeout(1s) tick inside OS thread (no extra thread)
- Model lifecycle events broadcast via tokio broadcast channel (SSE + webhooks)
- webhook_registry: all clients that ever submitted a webhook_url receive
  model_ready and model_unloaded webhooks
- GPU warmup retained on every (re)load

New routes:
  GET  /model/status  — current state + VRAM stats
  POST /model/load    — trigger load (idempotent)
  POST /model/unload  — immediate unload
  GET  /model/events  — SSE stream of model lifecycle events

New env vars:
  IDLE_TIMEOUT_SECS       (default 300)
  GPU_POLL_INTERVAL_SECS  (default 30)

Tests:
  tests/test_model_lifecycle.sh — 18 integration tests (full state machine,
    SSE events, webhooks, concurrency, unload-during-load)
  tests/test_idle_timeout.sh    — 5 tests with short IDLE_TIMEOUT_SECS=5
  test_all.sh updated: loads model before job submission, asserts
    model_state in /health, adds POST /model/unload at end

Docs:
  docs/USAGE.md: model lifecycle section, new env vars, 503 retry pattern,
    updated /health response shape

Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
This commit is contained in:
mozempk
2026-05-08 17:57:20 +02:00
parent 78c6fab81b
commit b191fbe200
13 changed files with 2053 additions and 148 deletions

View File

@@ -1,7 +1,7 @@
use std::sync::Arc;
use axum::Router;
use tokio::sync::mpsc;
use tokio::sync::{broadcast, mpsc, RwLock};
use tower_http::{cors::CorsLayer, trace::TraceLayer};
use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt, EnvFilter};
use utoipa::OpenApi;
@@ -21,8 +21,10 @@ pub use error::{AppError, Result};
#[derive(Clone)]
pub struct AppState {
/// Channel to submit jobs to the single GPU worker.
/// Channel to submit jobs to the single GPU worker (job IDs only).
pub job_tx: mpsc::UnboundedSender<models::JobId>,
/// Channel to send control commands to the worker OS thread.
pub cmd_tx: std::sync::mpsc::SyncSender<worker::WorkerCmd>,
/// Shared handle to the on-disk job store.
pub storage: Arc<storage::Storage>,
/// SSE broadcast registry: job_id → sender.
@@ -33,6 +35,17 @@ pub struct AppState {
pub queue_depth: Arc<std::sync::atomic::AtomicUsize>,
/// CUDA device index used for inference.
pub gpu_device: u32,
/// Current state of the whisper model.
pub model_state: Arc<RwLock<models::ModelState>>,
/// Broadcast channel for model lifecycle events (SSE + webhooks).
pub model_event_tx: broadcast::Sender<models::ModelEvent>,
/// All webhook URLs ever registered via job submission.
/// Used to fire model_ready / model_unloaded notifications.
pub webhook_registry: Arc<std::sync::Mutex<std::collections::HashSet<String>>>,
/// How long the model stays loaded with no active jobs.
pub idle_timeout: std::time::Duration,
/// How often to retry loading when GPU is busy.
pub gpu_poll_interval: std::time::Duration,
}
// ── OpenAPI spec root ────────────────────────────────────────────────────────
@@ -50,6 +63,10 @@ pub struct AppState {
routes::jobs::stream_job,
routes::jobs::delete_job,
routes::health::health,
routes::model::model_status,
routes::model::model_load,
routes::model::model_unload,
routes::model::model_events,
),
components(schemas(
models::Job,
@@ -58,10 +75,14 @@ pub struct AppState {
models::Word,
models::SubmitResponse,
models::HealthResponse,
models::ModelState,
models::ModelEvent,
models::ModelStatusResponse,
)),
tags(
(name = "jobs", description = "Transcription job management"),
(name = "system", description = "Service health"),
(name = "model", description = "Model lifecycle management"),
)
)]
struct ApiDoc;
@@ -85,6 +106,20 @@ async fn main() -> anyhow::Result<()> {
.ok()
.and_then(|s| s.parse().ok())
.unwrap_or(0);
let idle_timeout_secs: u64 = std::env::var("IDLE_TIMEOUT_SECS")
.ok()
.and_then(|s| s.parse().ok())
.unwrap_or(300);
let gpu_poll_interval_secs: u64 = std::env::var("GPU_POLL_INTERVAL_SECS")
.ok()
.and_then(|s| s.parse().ok())
.unwrap_or(30);
tracing::info!(
idle_timeout_secs,
gpu_poll_interval_secs,
"dynamic model loading configured"
);
let storage = Arc::new(storage::Storage::new(&data_dir).await?);
@@ -94,28 +129,45 @@ async fn main() -> anyhow::Result<()> {
let (job_tx, job_rx) = mpsc::unbounded_channel::<models::JobId>();
let queue_depth = Arc::new(std::sync::atomic::AtomicUsize::new(0));
// Spawn single GPU worker; get back the SSE broadcast registry.
let progress = worker::start(
// Model starts unloaded — lazy load on first job or POST /model/load.
let model_state = Arc::new(RwLock::new(models::ModelState::Unloaded));
let (model_event_tx, _) = broadcast::channel::<models::ModelEvent>(32);
let webhook_registry = Arc::new(std::sync::Mutex::new(std::collections::HashSet::<String>::new()));
// Spawn single GPU worker; get back the SSE broadcast registry and cmd channel.
let (progress, cmd_tx) = worker::start(
job_rx,
Arc::clone(&storage),
model_path.clone().into(),
Arc::clone(&queue_depth),
gpu_device,
Arc::clone(&model_state),
model_event_tx.clone(),
Arc::clone(&webhook_registry),
std::time::Duration::from_secs(idle_timeout_secs),
std::time::Duration::from_secs(gpu_poll_interval_secs),
);
let state = AppState {
job_tx,
cmd_tx,
storage: Arc::clone(&storage),
progress,
model_name: model_name.as_str().into(),
queue_depth: Arc::clone(&queue_depth),
model_name: model_name.as_str().into(),
queue_depth: Arc::clone(&queue_depth),
gpu_device,
model_state,
model_event_tx,
webhook_registry,
idle_timeout: std::time::Duration::from_secs(idle_timeout_secs),
gpu_poll_interval: std::time::Duration::from_secs(gpu_poll_interval_secs),
};
let app = Router::new()
.merge(SwaggerUi::new("/docs").url("/openapi.json", ApiDoc::openapi()))
.merge(routes::jobs_router())
.merge(routes::health_router())
.merge(routes::model_router())
.with_state(state)
.layer(CorsLayer::permissive())
.layer(TraceLayer::new_for_http());