fix(model): skip stale unload command
Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
This commit is contained in:
@@ -88,7 +88,12 @@ pub async fn model_load(State(state): State<AppState>) -> impl IntoResponse {
|
|||||||
)
|
)
|
||||||
)]
|
)]
|
||||||
pub async fn model_unload(State(state): State<AppState>) -> impl IntoResponse {
|
pub async fn model_unload(State(state): State<AppState>) -> impl IntoResponse {
|
||||||
let _ = state.cmd_tx.try_send(WorkerCmd::Unload);
|
if !matches!(
|
||||||
|
*state.model_state.read().await,
|
||||||
|
crate::models::ModelState::Unloaded
|
||||||
|
) {
|
||||||
|
let _ = state.cmd_tx.try_send(WorkerCmd::Unload);
|
||||||
|
}
|
||||||
(
|
(
|
||||||
StatusCode::OK,
|
StatusCode::OK,
|
||||||
Json(serde_json::json!({"status": "unload_requested"})),
|
Json(serde_json::json!({"status": "unload_requested"})),
|
||||||
@@ -164,3 +169,76 @@ fn vram_stats(gpu_device: u32) -> (Option<u64>, Option<u64>) {
|
|||||||
None => (None, None),
|
None => (None, None),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
use super::*;
|
||||||
|
use std::sync::{atomic::AtomicUsize, Arc, Mutex};
|
||||||
|
|
||||||
|
use axum::response::IntoResponse;
|
||||||
|
use chrono::Utc;
|
||||||
|
use tempfile::tempdir;
|
||||||
|
use tokio::sync::{broadcast, mpsc, RwLock};
|
||||||
|
|
||||||
|
use crate::{models::ModelState, storage::Storage, worker::ProgressRegistry, AppState};
|
||||||
|
|
||||||
|
async fn test_state(
|
||||||
|
model_state: ModelState,
|
||||||
|
) -> (
|
||||||
|
AppState,
|
||||||
|
tempfile::TempDir,
|
||||||
|
std::sync::mpsc::Receiver<WorkerCmd>,
|
||||||
|
) {
|
||||||
|
let tmp = tempdir().expect("tempdir");
|
||||||
|
let storage = Arc::new(Storage::new(tmp.path()).await.expect("storage"));
|
||||||
|
let (job_tx, _job_rx) = mpsc::unbounded_channel();
|
||||||
|
let (cmd_tx, cmd_rx) = std::sync::mpsc::sync_channel(8);
|
||||||
|
let progress: ProgressRegistry = Arc::new(dashmap::DashMap::new());
|
||||||
|
let (model_event_tx, _) = broadcast::channel(8);
|
||||||
|
|
||||||
|
(
|
||||||
|
AppState {
|
||||||
|
job_tx,
|
||||||
|
cmd_tx,
|
||||||
|
storage,
|
||||||
|
progress,
|
||||||
|
model_name: "test".into(),
|
||||||
|
queue_depth: Arc::new(AtomicUsize::new(0)),
|
||||||
|
gpu_device: 0,
|
||||||
|
model_state: Arc::new(RwLock::new(model_state)),
|
||||||
|
model_event_tx,
|
||||||
|
webhook_registry: Arc::new(Mutex::new(Default::default())),
|
||||||
|
idle_timeout: std::time::Duration::from_secs(300),
|
||||||
|
gpu_poll_interval: std::time::Duration::from_secs(30),
|
||||||
|
},
|
||||||
|
tmp,
|
||||||
|
cmd_rx,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_model_unload_skips_command_when_already_unloaded() {
|
||||||
|
let (state, _tmp, cmd_rx) = test_state(ModelState::Unloaded).await;
|
||||||
|
|
||||||
|
let response = model_unload(State(state)).await.into_response();
|
||||||
|
|
||||||
|
assert_eq!(response.status(), StatusCode::OK);
|
||||||
|
assert!(
|
||||||
|
cmd_rx.try_recv().is_err(),
|
||||||
|
"unexpected unload command queued"
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_model_unload_queues_command_when_ready() {
|
||||||
|
let (state, _tmp, cmd_rx) = test_state(ModelState::Ready {
|
||||||
|
loaded_at: Utc::now(),
|
||||||
|
})
|
||||||
|
.await;
|
||||||
|
|
||||||
|
let response = model_unload(State(state)).await.into_response();
|
||||||
|
|
||||||
|
assert_eq!(response.status(), StatusCode::OK);
|
||||||
|
assert!(matches!(cmd_rx.try_recv(), Ok(WorkerCmd::Unload)));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user