diff --git a/src/routes/model.rs b/src/routes/model.rs index 426171c..14a48c4 100644 --- a/src/routes/model.rs +++ b/src/routes/model.rs @@ -88,7 +88,12 @@ pub async fn model_load(State(state): State) -> impl IntoResponse { ) )] pub async fn model_unload(State(state): State) -> impl IntoResponse { - let _ = state.cmd_tx.try_send(WorkerCmd::Unload); + if !matches!( + *state.model_state.read().await, + crate::models::ModelState::Unloaded + ) { + let _ = state.cmd_tx.try_send(WorkerCmd::Unload); + } ( StatusCode::OK, Json(serde_json::json!({"status": "unload_requested"})), @@ -164,3 +169,76 @@ fn vram_stats(gpu_device: u32) -> (Option, Option) { None => (None, None), } } + +#[cfg(test)] +mod tests { + use super::*; + use std::sync::{atomic::AtomicUsize, Arc, Mutex}; + + use axum::response::IntoResponse; + use chrono::Utc; + use tempfile::tempdir; + use tokio::sync::{broadcast, mpsc, RwLock}; + + use crate::{models::ModelState, storage::Storage, worker::ProgressRegistry, AppState}; + + async fn test_state( + model_state: ModelState, + ) -> ( + AppState, + tempfile::TempDir, + std::sync::mpsc::Receiver, + ) { + let tmp = tempdir().expect("tempdir"); + let storage = Arc::new(Storage::new(tmp.path()).await.expect("storage")); + let (job_tx, _job_rx) = mpsc::unbounded_channel(); + let (cmd_tx, cmd_rx) = std::sync::mpsc::sync_channel(8); + let progress: ProgressRegistry = Arc::new(dashmap::DashMap::new()); + let (model_event_tx, _) = broadcast::channel(8); + + ( + AppState { + job_tx, + cmd_tx, + storage, + progress, + model_name: "test".into(), + queue_depth: Arc::new(AtomicUsize::new(0)), + gpu_device: 0, + model_state: Arc::new(RwLock::new(model_state)), + model_event_tx, + webhook_registry: Arc::new(Mutex::new(Default::default())), + idle_timeout: std::time::Duration::from_secs(300), + gpu_poll_interval: std::time::Duration::from_secs(30), + }, + tmp, + cmd_rx, + ) + } + + #[tokio::test] + async fn test_model_unload_skips_command_when_already_unloaded() { + let (state, _tmp, cmd_rx) = test_state(ModelState::Unloaded).await; + + let response = model_unload(State(state)).await.into_response(); + + assert_eq!(response.status(), StatusCode::OK); + assert!( + cmd_rx.try_recv().is_err(), + "unexpected unload command queued" + ); + } + + #[tokio::test] + async fn test_model_unload_queues_command_when_ready() { + let (state, _tmp, cmd_rx) = test_state(ModelState::Ready { + loaded_at: Utc::now(), + }) + .await; + + let response = model_unload(State(state)).await.into_response(); + + assert_eq!(response.status(), StatusCode::OK); + assert!(matches!(cmd_rx.try_recv(), Ok(WorkerCmd::Unload))); + } +}