fix(model): skip stale unload command
Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
This commit is contained in:
@@ -88,7 +88,12 @@ pub async fn model_load(State(state): State<AppState>) -> impl IntoResponse {
|
||||
)
|
||||
)]
|
||||
pub async fn model_unload(State(state): State<AppState>) -> impl IntoResponse {
|
||||
let _ = state.cmd_tx.try_send(WorkerCmd::Unload);
|
||||
if !matches!(
|
||||
*state.model_state.read().await,
|
||||
crate::models::ModelState::Unloaded
|
||||
) {
|
||||
let _ = state.cmd_tx.try_send(WorkerCmd::Unload);
|
||||
}
|
||||
(
|
||||
StatusCode::OK,
|
||||
Json(serde_json::json!({"status": "unload_requested"})),
|
||||
@@ -164,3 +169,76 @@ fn vram_stats(gpu_device: u32) -> (Option<u64>, Option<u64>) {
|
||||
None => (None, None),
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use std::sync::{atomic::AtomicUsize, Arc, Mutex};
|
||||
|
||||
use axum::response::IntoResponse;
|
||||
use chrono::Utc;
|
||||
use tempfile::tempdir;
|
||||
use tokio::sync::{broadcast, mpsc, RwLock};
|
||||
|
||||
use crate::{models::ModelState, storage::Storage, worker::ProgressRegistry, AppState};
|
||||
|
||||
async fn test_state(
|
||||
model_state: ModelState,
|
||||
) -> (
|
||||
AppState,
|
||||
tempfile::TempDir,
|
||||
std::sync::mpsc::Receiver<WorkerCmd>,
|
||||
) {
|
||||
let tmp = tempdir().expect("tempdir");
|
||||
let storage = Arc::new(Storage::new(tmp.path()).await.expect("storage"));
|
||||
let (job_tx, _job_rx) = mpsc::unbounded_channel();
|
||||
let (cmd_tx, cmd_rx) = std::sync::mpsc::sync_channel(8);
|
||||
let progress: ProgressRegistry = Arc::new(dashmap::DashMap::new());
|
||||
let (model_event_tx, _) = broadcast::channel(8);
|
||||
|
||||
(
|
||||
AppState {
|
||||
job_tx,
|
||||
cmd_tx,
|
||||
storage,
|
||||
progress,
|
||||
model_name: "test".into(),
|
||||
queue_depth: Arc::new(AtomicUsize::new(0)),
|
||||
gpu_device: 0,
|
||||
model_state: Arc::new(RwLock::new(model_state)),
|
||||
model_event_tx,
|
||||
webhook_registry: Arc::new(Mutex::new(Default::default())),
|
||||
idle_timeout: std::time::Duration::from_secs(300),
|
||||
gpu_poll_interval: std::time::Duration::from_secs(30),
|
||||
},
|
||||
tmp,
|
||||
cmd_rx,
|
||||
)
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_model_unload_skips_command_when_already_unloaded() {
|
||||
let (state, _tmp, cmd_rx) = test_state(ModelState::Unloaded).await;
|
||||
|
||||
let response = model_unload(State(state)).await.into_response();
|
||||
|
||||
assert_eq!(response.status(), StatusCode::OK);
|
||||
assert!(
|
||||
cmd_rx.try_recv().is_err(),
|
||||
"unexpected unload command queued"
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_model_unload_queues_command_when_ready() {
|
||||
let (state, _tmp, cmd_rx) = test_state(ModelState::Ready {
|
||||
loaded_at: Utc::now(),
|
||||
})
|
||||
.await;
|
||||
|
||||
let response = model_unload(State(state)).await.into_response();
|
||||
|
||||
assert_eq!(response.status(), StatusCode::OK);
|
||||
assert!(matches!(cmd_rx.try_recv(), Ok(WorkerCmd::Unload)));
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user