feat: dynamic model loading/unloading with GPU polling
All checks were successful
Build & Push Docker Image / build-and-push (push) Successful in 8m41s
All checks were successful
Build & Push Docker Image / build-and-push (push) Successful in 8m41s
- Model starts unloaded (lazy); loads on first job or POST /model/load
- Auto-unloads after IDLE_TIMEOUT_SECS (default 300) of inactivity
- POST /model/unload for immediate manual release
- GPU-busy detection: on VRAM OOM, enters WaitingForGpu and retries
every GPU_POLL_INTERVAL_SECS (default 30) indefinitely
- POST /jobs when unloaded → 503 + Retry-After header, triggers load
- AppError::OutOfMemory and AppError::ModelNotReady variants
- WorkerCmd channel (SyncSender<WorkerCmd>) replaces bare tx_req channel
- Idle timer via recv_timeout(1s) tick inside OS thread (no extra thread)
- Model lifecycle events broadcast via tokio broadcast channel (SSE + webhooks)
- webhook_registry: all clients that ever submitted a webhook_url receive
model_ready and model_unloaded webhooks
- GPU warmup retained on every (re)load
New routes:
GET /model/status — current state + VRAM stats
POST /model/load — trigger load (idempotent)
POST /model/unload — immediate unload
GET /model/events — SSE stream of model lifecycle events
New env vars:
IDLE_TIMEOUT_SECS (default 300)
GPU_POLL_INTERVAL_SECS (default 30)
Tests:
tests/test_model_lifecycle.sh — 18 integration tests (full state machine,
SSE events, webhooks, concurrency, unload-during-load)
tests/test_idle_timeout.sh — 5 tests with short IDLE_TIMEOUT_SECS=5
test_all.sh updated: loads model before job submission, asserts
model_state in /health, adds POST /model/unload at end
Docs:
docs/USAGE.md: model lifecycle section, new env vars, 503 retry pattern,
updated /health response shape
Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
This commit is contained in:
158
src/routes/model.rs
Normal file
158
src/routes/model.rs
Normal file
@@ -0,0 +1,158 @@
|
||||
use std::pin::Pin;
|
||||
|
||||
use axum::{
|
||||
extract::State,
|
||||
http::StatusCode,
|
||||
response::{
|
||||
sse::{Event, KeepAlive, Sse},
|
||||
IntoResponse,
|
||||
},
|
||||
Json,
|
||||
};
|
||||
use futures::Stream;
|
||||
use tokio_stream::wrappers::BroadcastStream;
|
||||
use futures::StreamExt;
|
||||
|
||||
use crate::{
|
||||
models::{ModelEvent, ModelStatusResponse},
|
||||
worker::WorkerCmd,
|
||||
AppState, Result,
|
||||
};
|
||||
|
||||
type SseStream = Pin<Box<dyn Stream<Item = std::result::Result<Event, std::convert::Infallible>> + Send>>;
|
||||
|
||||
// ── GET /model/status ────────────────────────────────────────────────────────
|
||||
|
||||
/// Return the current model state and VRAM statistics.
|
||||
#[utoipa::path(
|
||||
get,
|
||||
path = "/model/status",
|
||||
tag = "model",
|
||||
responses(
|
||||
(status = 200, description = "Model status", body = ModelStatusResponse),
|
||||
)
|
||||
)]
|
||||
pub async fn model_status(State(state): State<AppState>) -> Result<Json<ModelStatusResponse>> {
|
||||
let model_state = state.model_state.read().await.clone();
|
||||
let (vram_used_mb, vram_total_mb) = vram_stats(state.gpu_device);
|
||||
|
||||
Ok(Json(ModelStatusResponse {
|
||||
state: model_state,
|
||||
vram_used_mb,
|
||||
vram_total_mb,
|
||||
}))
|
||||
}
|
||||
|
||||
// ── POST /model/load ─────────────────────────────────────────────────────────
|
||||
|
||||
/// Request the model to be loaded into GPU memory.
|
||||
/// Idempotent: if the model is already loading or ready, this is a no-op.
|
||||
/// Returns 202 Accepted; poll `GET /model/status` or subscribe to
|
||||
/// `GET /model/events` to know when it is ready.
|
||||
#[utoipa::path(
|
||||
post,
|
||||
path = "/model/load",
|
||||
tag = "model",
|
||||
responses(
|
||||
(status = 202, description = "Load initiated or already in progress"),
|
||||
(status = 200, description = "Model already ready"),
|
||||
)
|
||||
)]
|
||||
pub async fn model_load(State(state): State<AppState>) -> impl IntoResponse {
|
||||
let is_ready = state.model_state.read().await.is_ready();
|
||||
if is_ready {
|
||||
return (StatusCode::OK, Json(serde_json::json!({"status": "already_ready"})));
|
||||
}
|
||||
// Ignore send errors (channel full = load already in progress).
|
||||
let _ = state.cmd_tx.try_send(WorkerCmd::Load);
|
||||
(StatusCode::ACCEPTED, Json(serde_json::json!({"status": "load_initiated"})))
|
||||
}
|
||||
|
||||
// ── POST /model/unload ───────────────────────────────────────────────────────
|
||||
|
||||
/// Unload the model from GPU memory immediately.
|
||||
/// Idempotent: if the model is already unloaded, returns 200 immediately.
|
||||
#[utoipa::path(
|
||||
post,
|
||||
path = "/model/unload",
|
||||
tag = "model",
|
||||
responses(
|
||||
(status = 200, description = "Model unloaded or was already unloaded"),
|
||||
)
|
||||
)]
|
||||
pub async fn model_unload(State(state): State<AppState>) -> impl IntoResponse {
|
||||
let _ = state.cmd_tx.try_send(WorkerCmd::Unload);
|
||||
(StatusCode::OK, Json(serde_json::json!({"status": "unload_requested"})))
|
||||
}
|
||||
|
||||
// ── GET /model/events ────────────────────────────────────────────────────────
|
||||
|
||||
/// Subscribe to model lifecycle events via Server-Sent Events.
|
||||
///
|
||||
/// Event types:
|
||||
/// - `model_loading` — load initiated
|
||||
/// - `model_ready` — model loaded and warmed up
|
||||
/// - `model_unloaded` — model freed from GPU memory
|
||||
/// - `model_waiting_for_gpu` — insufficient VRAM; retrying
|
||||
#[utoipa::path(
|
||||
get,
|
||||
path = "/model/events",
|
||||
tag = "model",
|
||||
responses(
|
||||
(status = 200, description = "SSE stream of model lifecycle events"),
|
||||
)
|
||||
)]
|
||||
pub async fn model_events(State(state): State<AppState>) -> Sse<SseStream> {
|
||||
let rx = state.model_event_tx.subscribe();
|
||||
|
||||
let stream: SseStream = Box::pin(
|
||||
BroadcastStream::new(rx).filter_map(|msg| async move {
|
||||
match msg {
|
||||
Ok(event) => {
|
||||
let event_type = match &event {
|
||||
ModelEvent::ModelReady { .. } => "model_ready",
|
||||
ModelEvent::ModelUnloaded => "model_unloaded",
|
||||
ModelEvent::ModelLoading => "model_loading",
|
||||
ModelEvent::ModelWaitingForGpu {..} => "model_waiting_for_gpu",
|
||||
};
|
||||
let data = serde_json::to_string(&event).ok()?;
|
||||
Some(Ok(Event::default().event(event_type).data(data)))
|
||||
}
|
||||
Err(_) => None,
|
||||
}
|
||||
})
|
||||
);
|
||||
|
||||
Sse::new(stream).keep_alive(KeepAlive::default())
|
||||
}
|
||||
|
||||
// ── Helpers ───────────────────────────────────────────────────────────────────
|
||||
|
||||
fn vram_stats(gpu_device: u32) -> (Option<u64>, Option<u64>) {
|
||||
fn inner(gpu_device: u32) -> Option<(u64, u64)> {
|
||||
let out = std::process::Command::new("nvidia-smi")
|
||||
.args([
|
||||
&format!("--id={gpu_device}"),
|
||||
"--query-gpu=memory.used,memory.total",
|
||||
"--format=csv,noheader,nounits",
|
||||
])
|
||||
.output()
|
||||
.ok()?;
|
||||
|
||||
if !out.status.success() {
|
||||
return None;
|
||||
}
|
||||
|
||||
let line = String::from_utf8_lossy(&out.stdout);
|
||||
let line = line.trim();
|
||||
let mut parts = line.splitn(2, ',');
|
||||
let used = parts.next().and_then(|s| s.trim().parse::<u64>().ok())?;
|
||||
let total = parts.next().and_then(|s| s.trim().parse::<u64>().ok())?;
|
||||
Some((used, total))
|
||||
}
|
||||
|
||||
match inner(gpu_device) {
|
||||
Some((u, t)) => (Some(u), Some(t)),
|
||||
None => (None, None),
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user