use std::pin::Pin; use axum::{ extract::State, http::StatusCode, response::{ sse::{Event, KeepAlive, Sse}, IntoResponse, }, Json, }; use futures::Stream; use tokio_stream::wrappers::BroadcastStream; use futures::StreamExt; use crate::{ models::{ModelEvent, ModelStatusResponse}, worker::WorkerCmd, AppState, Result, }; type SseStream = Pin> + Send>>; // ── GET /model/status ──────────────────────────────────────────────────────── /// Return the current model state and VRAM statistics. #[utoipa::path( get, path = "/model/status", tag = "model", responses( (status = 200, description = "Model status", body = ModelStatusResponse), ) )] pub async fn model_status(State(state): State) -> Result> { let model_state = state.model_state.read().await.clone(); let (vram_used_mb, vram_total_mb) = vram_stats(state.gpu_device); Ok(Json(ModelStatusResponse { state: model_state, vram_used_mb, vram_total_mb, })) } // ── POST /model/load ───────────────────────────────────────────────────────── /// Request the model to be loaded into GPU memory. /// Idempotent: if the model is already loading or ready, this is a no-op. /// Returns 202 Accepted; poll `GET /model/status` or subscribe to /// `GET /model/events` to know when it is ready. #[utoipa::path( post, path = "/model/load", tag = "model", responses( (status = 202, description = "Load initiated or already in progress"), (status = 200, description = "Model already ready"), ) )] pub async fn model_load(State(state): State) -> impl IntoResponse { let is_ready = state.model_state.read().await.is_ready(); if is_ready { return (StatusCode::OK, Json(serde_json::json!({"status": "already_ready"}))); } // Ignore send errors (channel full = load already in progress). let _ = state.cmd_tx.try_send(WorkerCmd::Load); (StatusCode::ACCEPTED, Json(serde_json::json!({"status": "load_initiated"}))) } // ── POST /model/unload ─────────────────────────────────────────────────────── /// Unload the model from GPU memory immediately. /// Idempotent: if the model is already unloaded, returns 200 immediately. #[utoipa::path( post, path = "/model/unload", tag = "model", responses( (status = 200, description = "Model unloaded or was already unloaded"), ) )] pub async fn model_unload(State(state): State) -> impl IntoResponse { let _ = state.cmd_tx.try_send(WorkerCmd::Unload); (StatusCode::OK, Json(serde_json::json!({"status": "unload_requested"}))) } // ── GET /model/events ──────────────────────────────────────────────────────── /// Subscribe to model lifecycle events via Server-Sent Events. /// /// Event types: /// - `model_loading` — load initiated /// - `model_ready` — model loaded and warmed up /// - `model_unloaded` — model freed from GPU memory /// - `model_waiting_for_gpu` — insufficient VRAM; retrying #[utoipa::path( get, path = "/model/events", tag = "model", responses( (status = 200, description = "SSE stream of model lifecycle events"), ) )] pub async fn model_events(State(state): State) -> Sse { let rx = state.model_event_tx.subscribe(); let stream: SseStream = Box::pin( BroadcastStream::new(rx).filter_map(|msg| async move { match msg { Ok(event) => { let event_type = match &event { ModelEvent::ModelReady { .. } => "model_ready", ModelEvent::ModelUnloaded => "model_unloaded", ModelEvent::ModelLoading => "model_loading", ModelEvent::ModelWaitingForGpu {..} => "model_waiting_for_gpu", }; let data = serde_json::to_string(&event).ok()?; Some(Ok(Event::default().event(event_type).data(data))) } Err(_) => None, } }) ); Sse::new(stream).keep_alive(KeepAlive::default()) } // ── Helpers ─────────────────────────────────────────────────────────────────── fn vram_stats(gpu_device: u32) -> (Option, Option) { fn inner(gpu_device: u32) -> Option<(u64, u64)> { let out = std::process::Command::new("nvidia-smi") .args([ &format!("--id={gpu_device}"), "--query-gpu=memory.used,memory.total", "--format=csv,noheader,nounits", ]) .output() .ok()?; if !out.status.success() { return None; } let line = String::from_utf8_lossy(&out.stdout); let line = line.trim(); let mut parts = line.splitn(2, ','); let used = parts.next().and_then(|s| s.trim().parse::().ok())?; let total = parts.next().and_then(|s| s.trim().parse::().ok())?; Some((used, total)) } match inner(gpu_device) { Some((u, t)) => (Some(u), Some(t)), None => (None, None), } }