feat: dynamic model loading/unloading with GPU polling
All checks were successful
Build & Push Docker Image / build-and-push (push) Successful in 8m41s
All checks were successful
Build & Push Docker Image / build-and-push (push) Successful in 8m41s
- Model starts unloaded (lazy); loads on first job or POST /model/load
- Auto-unloads after IDLE_TIMEOUT_SECS (default 300) of inactivity
- POST /model/unload for immediate manual release
- GPU-busy detection: on VRAM OOM, enters WaitingForGpu and retries
every GPU_POLL_INTERVAL_SECS (default 30) indefinitely
- POST /jobs when unloaded → 503 + Retry-After header, triggers load
- AppError::OutOfMemory and AppError::ModelNotReady variants
- WorkerCmd channel (SyncSender<WorkerCmd>) replaces bare tx_req channel
- Idle timer via recv_timeout(1s) tick inside OS thread (no extra thread)
- Model lifecycle events broadcast via tokio broadcast channel (SSE + webhooks)
- webhook_registry: all clients that ever submitted a webhook_url receive
model_ready and model_unloaded webhooks
- GPU warmup retained on every (re)load
New routes:
GET /model/status — current state + VRAM stats
POST /model/load — trigger load (idempotent)
POST /model/unload — immediate unload
GET /model/events — SSE stream of model lifecycle events
New env vars:
IDLE_TIMEOUT_SECS (default 300)
GPU_POLL_INTERVAL_SECS (default 30)
Tests:
tests/test_model_lifecycle.sh — 18 integration tests (full state machine,
SSE events, webhooks, concurrency, unload-during-load)
tests/test_idle_timeout.sh — 5 tests with short IDLE_TIMEOUT_SECS=5
test_all.sh updated: loads model before job submission, asserts
model_state in /health, adds POST /model/unload at end
Docs:
docs/USAGE.md: model lifecycle section, new env vars, 503 retry pattern,
updated /health response shape
Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
This commit is contained in:
64
src/main.rs
64
src/main.rs
@@ -1,7 +1,7 @@
|
||||
use std::sync::Arc;
|
||||
|
||||
use axum::Router;
|
||||
use tokio::sync::mpsc;
|
||||
use tokio::sync::{broadcast, mpsc, RwLock};
|
||||
use tower_http::{cors::CorsLayer, trace::TraceLayer};
|
||||
use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt, EnvFilter};
|
||||
use utoipa::OpenApi;
|
||||
@@ -21,8 +21,10 @@ pub use error::{AppError, Result};
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct AppState {
|
||||
/// Channel to submit jobs to the single GPU worker.
|
||||
/// Channel to submit jobs to the single GPU worker (job IDs only).
|
||||
pub job_tx: mpsc::UnboundedSender<models::JobId>,
|
||||
/// Channel to send control commands to the worker OS thread.
|
||||
pub cmd_tx: std::sync::mpsc::SyncSender<worker::WorkerCmd>,
|
||||
/// Shared handle to the on-disk job store.
|
||||
pub storage: Arc<storage::Storage>,
|
||||
/// SSE broadcast registry: job_id → sender.
|
||||
@@ -33,6 +35,17 @@ pub struct AppState {
|
||||
pub queue_depth: Arc<std::sync::atomic::AtomicUsize>,
|
||||
/// CUDA device index used for inference.
|
||||
pub gpu_device: u32,
|
||||
/// Current state of the whisper model.
|
||||
pub model_state: Arc<RwLock<models::ModelState>>,
|
||||
/// Broadcast channel for model lifecycle events (SSE + webhooks).
|
||||
pub model_event_tx: broadcast::Sender<models::ModelEvent>,
|
||||
/// All webhook URLs ever registered via job submission.
|
||||
/// Used to fire model_ready / model_unloaded notifications.
|
||||
pub webhook_registry: Arc<std::sync::Mutex<std::collections::HashSet<String>>>,
|
||||
/// How long the model stays loaded with no active jobs.
|
||||
pub idle_timeout: std::time::Duration,
|
||||
/// How often to retry loading when GPU is busy.
|
||||
pub gpu_poll_interval: std::time::Duration,
|
||||
}
|
||||
|
||||
// ── OpenAPI spec root ────────────────────────────────────────────────────────
|
||||
@@ -50,6 +63,10 @@ pub struct AppState {
|
||||
routes::jobs::stream_job,
|
||||
routes::jobs::delete_job,
|
||||
routes::health::health,
|
||||
routes::model::model_status,
|
||||
routes::model::model_load,
|
||||
routes::model::model_unload,
|
||||
routes::model::model_events,
|
||||
),
|
||||
components(schemas(
|
||||
models::Job,
|
||||
@@ -58,10 +75,14 @@ pub struct AppState {
|
||||
models::Word,
|
||||
models::SubmitResponse,
|
||||
models::HealthResponse,
|
||||
models::ModelState,
|
||||
models::ModelEvent,
|
||||
models::ModelStatusResponse,
|
||||
)),
|
||||
tags(
|
||||
(name = "jobs", description = "Transcription job management"),
|
||||
(name = "system", description = "Service health"),
|
||||
(name = "model", description = "Model lifecycle management"),
|
||||
)
|
||||
)]
|
||||
struct ApiDoc;
|
||||
@@ -85,6 +106,20 @@ async fn main() -> anyhow::Result<()> {
|
||||
.ok()
|
||||
.and_then(|s| s.parse().ok())
|
||||
.unwrap_or(0);
|
||||
let idle_timeout_secs: u64 = std::env::var("IDLE_TIMEOUT_SECS")
|
||||
.ok()
|
||||
.and_then(|s| s.parse().ok())
|
||||
.unwrap_or(300);
|
||||
let gpu_poll_interval_secs: u64 = std::env::var("GPU_POLL_INTERVAL_SECS")
|
||||
.ok()
|
||||
.and_then(|s| s.parse().ok())
|
||||
.unwrap_or(30);
|
||||
|
||||
tracing::info!(
|
||||
idle_timeout_secs,
|
||||
gpu_poll_interval_secs,
|
||||
"dynamic model loading configured"
|
||||
);
|
||||
|
||||
let storage = Arc::new(storage::Storage::new(&data_dir).await?);
|
||||
|
||||
@@ -94,28 +129,45 @@ async fn main() -> anyhow::Result<()> {
|
||||
let (job_tx, job_rx) = mpsc::unbounded_channel::<models::JobId>();
|
||||
let queue_depth = Arc::new(std::sync::atomic::AtomicUsize::new(0));
|
||||
|
||||
// Spawn single GPU worker; get back the SSE broadcast registry.
|
||||
let progress = worker::start(
|
||||
// Model starts unloaded — lazy load on first job or POST /model/load.
|
||||
let model_state = Arc::new(RwLock::new(models::ModelState::Unloaded));
|
||||
let (model_event_tx, _) = broadcast::channel::<models::ModelEvent>(32);
|
||||
let webhook_registry = Arc::new(std::sync::Mutex::new(std::collections::HashSet::<String>::new()));
|
||||
|
||||
// Spawn single GPU worker; get back the SSE broadcast registry and cmd channel.
|
||||
let (progress, cmd_tx) = worker::start(
|
||||
job_rx,
|
||||
Arc::clone(&storage),
|
||||
model_path.clone().into(),
|
||||
Arc::clone(&queue_depth),
|
||||
gpu_device,
|
||||
Arc::clone(&model_state),
|
||||
model_event_tx.clone(),
|
||||
Arc::clone(&webhook_registry),
|
||||
std::time::Duration::from_secs(idle_timeout_secs),
|
||||
std::time::Duration::from_secs(gpu_poll_interval_secs),
|
||||
);
|
||||
|
||||
let state = AppState {
|
||||
job_tx,
|
||||
cmd_tx,
|
||||
storage: Arc::clone(&storage),
|
||||
progress,
|
||||
model_name: model_name.as_str().into(),
|
||||
queue_depth: Arc::clone(&queue_depth),
|
||||
model_name: model_name.as_str().into(),
|
||||
queue_depth: Arc::clone(&queue_depth),
|
||||
gpu_device,
|
||||
model_state,
|
||||
model_event_tx,
|
||||
webhook_registry,
|
||||
idle_timeout: std::time::Duration::from_secs(idle_timeout_secs),
|
||||
gpu_poll_interval: std::time::Duration::from_secs(gpu_poll_interval_secs),
|
||||
};
|
||||
|
||||
let app = Router::new()
|
||||
.merge(SwaggerUi::new("/docs").url("/openapi.json", ApiDoc::openapi()))
|
||||
.merge(routes::jobs_router())
|
||||
.merge(routes::health_router())
|
||||
.merge(routes::model_router())
|
||||
.with_state(state)
|
||||
.layer(CorsLayer::permissive())
|
||||
.layer(TraceLayer::new_for_http());
|
||||
|
||||
Reference in New Issue
Block a user