fix(model): skip stale unload command
All checks were successful
Build & Push Docker Image / test (push) Successful in 6m9s
Build & Push Docker Image / build-and-push (push) Successful in 6m33s

Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
This commit is contained in:
2026-05-12 01:08:35 +02:00
parent d8a73e150a
commit 7c01d7f77f

View File

@@ -88,7 +88,12 @@ pub async fn model_load(State(state): State<AppState>) -> impl IntoResponse {
) )
)] )]
pub async fn model_unload(State(state): State<AppState>) -> impl IntoResponse { pub async fn model_unload(State(state): State<AppState>) -> impl IntoResponse {
if !matches!(
*state.model_state.read().await,
crate::models::ModelState::Unloaded
) {
let _ = state.cmd_tx.try_send(WorkerCmd::Unload); let _ = state.cmd_tx.try_send(WorkerCmd::Unload);
}
( (
StatusCode::OK, StatusCode::OK,
Json(serde_json::json!({"status": "unload_requested"})), Json(serde_json::json!({"status": "unload_requested"})),
@@ -164,3 +169,76 @@ fn vram_stats(gpu_device: u32) -> (Option<u64>, Option<u64>) {
None => (None, None), None => (None, None),
} }
} }
#[cfg(test)]
mod tests {
use super::*;
use std::sync::{atomic::AtomicUsize, Arc, Mutex};
use axum::response::IntoResponse;
use chrono::Utc;
use tempfile::tempdir;
use tokio::sync::{broadcast, mpsc, RwLock};
use crate::{models::ModelState, storage::Storage, worker::ProgressRegistry, AppState};
async fn test_state(
model_state: ModelState,
) -> (
AppState,
tempfile::TempDir,
std::sync::mpsc::Receiver<WorkerCmd>,
) {
let tmp = tempdir().expect("tempdir");
let storage = Arc::new(Storage::new(tmp.path()).await.expect("storage"));
let (job_tx, _job_rx) = mpsc::unbounded_channel();
let (cmd_tx, cmd_rx) = std::sync::mpsc::sync_channel(8);
let progress: ProgressRegistry = Arc::new(dashmap::DashMap::new());
let (model_event_tx, _) = broadcast::channel(8);
(
AppState {
job_tx,
cmd_tx,
storage,
progress,
model_name: "test".into(),
queue_depth: Arc::new(AtomicUsize::new(0)),
gpu_device: 0,
model_state: Arc::new(RwLock::new(model_state)),
model_event_tx,
webhook_registry: Arc::new(Mutex::new(Default::default())),
idle_timeout: std::time::Duration::from_secs(300),
gpu_poll_interval: std::time::Duration::from_secs(30),
},
tmp,
cmd_rx,
)
}
#[tokio::test]
async fn test_model_unload_skips_command_when_already_unloaded() {
let (state, _tmp, cmd_rx) = test_state(ModelState::Unloaded).await;
let response = model_unload(State(state)).await.into_response();
assert_eq!(response.status(), StatusCode::OK);
assert!(
cmd_rx.try_recv().is_err(),
"unexpected unload command queued"
);
}
#[tokio::test]
async fn test_model_unload_queues_command_when_ready() {
let (state, _tmp, cmd_rx) = test_state(ModelState::Ready {
loaded_at: Utc::now(),
})
.await;
let response = model_unload(State(state)).await.into_response();
assert_eq!(response.status(), StatusCode::OK);
assert!(matches!(cmd_rx.try_recv(), Ok(WorkerCmd::Unload)));
}
}