Files
whisper-rtx2080/src/routes/model.rs
mozempk b191fbe200
All checks were successful
Build & Push Docker Image / build-and-push (push) Successful in 8m41s
feat: dynamic model loading/unloading with GPU polling
- Model starts unloaded (lazy); loads on first job or POST /model/load
- Auto-unloads after IDLE_TIMEOUT_SECS (default 300) of inactivity
- POST /model/unload for immediate manual release
- GPU-busy detection: on VRAM OOM, enters WaitingForGpu and retries
  every GPU_POLL_INTERVAL_SECS (default 30) indefinitely
- POST /jobs when unloaded → 503 + Retry-After header, triggers load
- AppError::OutOfMemory and AppError::ModelNotReady variants
- WorkerCmd channel (SyncSender<WorkerCmd>) replaces bare tx_req channel
- Idle timer via recv_timeout(1s) tick inside OS thread (no extra thread)
- Model lifecycle events broadcast via tokio broadcast channel (SSE + webhooks)
- webhook_registry: all clients that ever submitted a webhook_url receive
  model_ready and model_unloaded webhooks
- GPU warmup retained on every (re)load

New routes:
  GET  /model/status  — current state + VRAM stats
  POST /model/load    — trigger load (idempotent)
  POST /model/unload  — immediate unload
  GET  /model/events  — SSE stream of model lifecycle events

New env vars:
  IDLE_TIMEOUT_SECS       (default 300)
  GPU_POLL_INTERVAL_SECS  (default 30)

Tests:
  tests/test_model_lifecycle.sh — 18 integration tests (full state machine,
    SSE events, webhooks, concurrency, unload-during-load)
  tests/test_idle_timeout.sh    — 5 tests with short IDLE_TIMEOUT_SECS=5
  test_all.sh updated: loads model before job submission, asserts
    model_state in /health, adds POST /model/unload at end

Docs:
  docs/USAGE.md: model lifecycle section, new env vars, 503 retry pattern,
    updated /health response shape

Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
2026-05-08 17:57:20 +02:00

159 lines
5.8 KiB
Rust

use std::pin::Pin;
use axum::{
extract::State,
http::StatusCode,
response::{
sse::{Event, KeepAlive, Sse},
IntoResponse,
},
Json,
};
use futures::Stream;
use tokio_stream::wrappers::BroadcastStream;
use futures::StreamExt;
use crate::{
models::{ModelEvent, ModelStatusResponse},
worker::WorkerCmd,
AppState, Result,
};
type SseStream = Pin<Box<dyn Stream<Item = std::result::Result<Event, std::convert::Infallible>> + Send>>;
// ── GET /model/status ────────────────────────────────────────────────────────
/// Return the current model state and VRAM statistics.
#[utoipa::path(
get,
path = "/model/status",
tag = "model",
responses(
(status = 200, description = "Model status", body = ModelStatusResponse),
)
)]
pub async fn model_status(State(state): State<AppState>) -> Result<Json<ModelStatusResponse>> {
let model_state = state.model_state.read().await.clone();
let (vram_used_mb, vram_total_mb) = vram_stats(state.gpu_device);
Ok(Json(ModelStatusResponse {
state: model_state,
vram_used_mb,
vram_total_mb,
}))
}
// ── POST /model/load ─────────────────────────────────────────────────────────
/// Request the model to be loaded into GPU memory.
/// Idempotent: if the model is already loading or ready, this is a no-op.
/// Returns 202 Accepted; poll `GET /model/status` or subscribe to
/// `GET /model/events` to know when it is ready.
#[utoipa::path(
post,
path = "/model/load",
tag = "model",
responses(
(status = 202, description = "Load initiated or already in progress"),
(status = 200, description = "Model already ready"),
)
)]
pub async fn model_load(State(state): State<AppState>) -> impl IntoResponse {
let is_ready = state.model_state.read().await.is_ready();
if is_ready {
return (StatusCode::OK, Json(serde_json::json!({"status": "already_ready"})));
}
// Ignore send errors (channel full = load already in progress).
let _ = state.cmd_tx.try_send(WorkerCmd::Load);
(StatusCode::ACCEPTED, Json(serde_json::json!({"status": "load_initiated"})))
}
// ── POST /model/unload ───────────────────────────────────────────────────────
/// Unload the model from GPU memory immediately.
/// Idempotent: if the model is already unloaded, returns 200 immediately.
#[utoipa::path(
post,
path = "/model/unload",
tag = "model",
responses(
(status = 200, description = "Model unloaded or was already unloaded"),
)
)]
pub async fn model_unload(State(state): State<AppState>) -> impl IntoResponse {
let _ = state.cmd_tx.try_send(WorkerCmd::Unload);
(StatusCode::OK, Json(serde_json::json!({"status": "unload_requested"})))
}
// ── GET /model/events ────────────────────────────────────────────────────────
/// Subscribe to model lifecycle events via Server-Sent Events.
///
/// Event types:
/// - `model_loading` — load initiated
/// - `model_ready` — model loaded and warmed up
/// - `model_unloaded` — model freed from GPU memory
/// - `model_waiting_for_gpu` — insufficient VRAM; retrying
#[utoipa::path(
get,
path = "/model/events",
tag = "model",
responses(
(status = 200, description = "SSE stream of model lifecycle events"),
)
)]
pub async fn model_events(State(state): State<AppState>) -> Sse<SseStream> {
let rx = state.model_event_tx.subscribe();
let stream: SseStream = Box::pin(
BroadcastStream::new(rx).filter_map(|msg| async move {
match msg {
Ok(event) => {
let event_type = match &event {
ModelEvent::ModelReady { .. } => "model_ready",
ModelEvent::ModelUnloaded => "model_unloaded",
ModelEvent::ModelLoading => "model_loading",
ModelEvent::ModelWaitingForGpu {..} => "model_waiting_for_gpu",
};
let data = serde_json::to_string(&event).ok()?;
Some(Ok(Event::default().event(event_type).data(data)))
}
Err(_) => None,
}
})
);
Sse::new(stream).keep_alive(KeepAlive::default())
}
// ── Helpers ───────────────────────────────────────────────────────────────────
fn vram_stats(gpu_device: u32) -> (Option<u64>, Option<u64>) {
fn inner(gpu_device: u32) -> Option<(u64, u64)> {
let out = std::process::Command::new("nvidia-smi")
.args([
&format!("--id={gpu_device}"),
"--query-gpu=memory.used,memory.total",
"--format=csv,noheader,nounits",
])
.output()
.ok()?;
if !out.status.success() {
return None;
}
let line = String::from_utf8_lossy(&out.stdout);
let line = line.trim();
let mut parts = line.splitn(2, ',');
let used = parts.next().and_then(|s| s.trim().parse::<u64>().ok())?;
let total = parts.next().and_then(|s| s.trim().parse::<u64>().ok())?;
Some((used, total))
}
match inner(gpu_device) {
Some((u, t)) => (Some(u), Some(t)),
None => (None, None),
}
}