feat: GPU model lazy-load/unload lifecycle management
- Domain: add ModelState, ModelStateEvent, ModelNotReady, ManageModelLifecycle (in-port), ModelLoader and ModelStateEventBus (out-ports) - Application: InMemoryModelStateEventBus; ModelLifecycleService — state machine (ReentrantLock), lazy load on first request, idle-timeout auto-unload (configurable via trueref.embedding.idle-timeout-seconds, default 300 s), job-guard (skips unload while ingestion running), platform-thread CUDA executor - Adapters: OnnxModelLoader wires embedder + reranker start/stop; remove @PostConstruct/@PreDestroy from OnnxEmbeddingService and OnnxRerankerService; requireStarted() now throws ModelNotReady instead of IllegalStateException - REST: GET /api/model/status, POST /api/model/unload (409 when jobs running, force=true to override), GET /api/model/status/stream (SSE) - GlobalExceptionHandler: ModelNotReady -> 503 + Retry-After header - HybridSearchService: calls lifecycle.ensureReady() before every search so both REST and MCP paths get ModelNotReady (-> 503 / MCP error) when unloaded - TrueRefMcpTools: catches ModelNotReady, returns retry hint in MCP error text - Tests: InMemoryModelStateEventBusTest, ModelLifecycleServiceTest (10 cases), OnnxModelLoaderTest, GlobalExceptionHandlerTest — all 41 tests green Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
This commit is contained in:
@@ -0,0 +1,35 @@
|
||||
package com.trueref.application.model;
|
||||
|
||||
import com.trueref.domain.model.ModelStateEvent;
|
||||
import com.trueref.domain.port.out.ModelStateEventBus;
|
||||
import java.util.concurrent.CopyOnWriteArrayList;
|
||||
import java.util.function.Consumer;
|
||||
import org.slf4j.Logger;
|
||||
import org.slf4j.LoggerFactory;
|
||||
|
||||
/** In-memory, thread-safe implementation of {@link ModelStateEventBus}. */
|
||||
public class InMemoryModelStateEventBus implements ModelStateEventBus {
|
||||
|
||||
private static final Logger log = LoggerFactory.getLogger(InMemoryModelStateEventBus.class);
|
||||
|
||||
private final CopyOnWriteArrayList<Consumer<ModelStateEvent>> subscribers =
|
||||
new CopyOnWriteArrayList<>();
|
||||
|
||||
@Override
|
||||
public void publish(ModelStateEvent event) {
|
||||
for (Consumer<ModelStateEvent> subscriber : subscribers) {
|
||||
try {
|
||||
subscriber.accept(event);
|
||||
} catch (Exception e) {
|
||||
log.warn("Model state subscriber threw during publish; removing: {}", e.toString());
|
||||
subscribers.remove(subscriber);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public AutoCloseable subscribe(Consumer<ModelStateEvent> subscriber) {
|
||||
subscribers.add(subscriber);
|
||||
return () -> subscribers.remove(subscriber);
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,287 @@
|
||||
package com.trueref.application.model;
|
||||
|
||||
import com.trueref.domain.error.ModelNotReady;
|
||||
import com.trueref.domain.model.IngestionJob;
|
||||
import com.trueref.domain.model.JobStatus;
|
||||
import com.trueref.domain.model.ModelState;
|
||||
import com.trueref.domain.model.ModelStateEvent;
|
||||
import com.trueref.domain.port.in.ManageModelLifecycle;
|
||||
import com.trueref.domain.port.in.ObserveJobs;
|
||||
import com.trueref.domain.port.out.ModelLoader;
|
||||
import com.trueref.domain.port.out.ModelStateEventBus;
|
||||
import java.time.Instant;
|
||||
import java.util.concurrent.Executors;
|
||||
import java.util.concurrent.Future;
|
||||
import java.util.concurrent.ScheduledExecutorService;
|
||||
import java.util.concurrent.ThreadFactory;
|
||||
import java.util.concurrent.TimeUnit;
|
||||
import java.util.concurrent.atomic.AtomicLong;
|
||||
import java.util.concurrent.atomic.AtomicReference;
|
||||
import java.util.concurrent.locks.ReentrantLock;
|
||||
import java.util.function.Consumer;
|
||||
import org.jspecify.annotations.Nullable;
|
||||
import org.slf4j.Logger;
|
||||
import org.slf4j.LoggerFactory;
|
||||
|
||||
/**
|
||||
* Orchestrates the GPU model lifecycle: lazy loading, auto-unload after idle timeout, and
|
||||
* force-unload. State transitions are protected by a {@link ReentrantLock}. Heavy load/unload
|
||||
* operations are executed on a dedicated single-thread platform-thread executor to satisfy CUDA
|
||||
* OS-thread affinity constraints.
|
||||
*/
|
||||
public class ModelLifecycleService implements ManageModelLifecycle {
|
||||
|
||||
private static final Logger log = LoggerFactory.getLogger(ModelLifecycleService.class);
|
||||
|
||||
/** Default retry hint in seconds when the model is LOADING. Load typically takes 5–30 s. */
|
||||
private static final int LOADING_RETRY_AFTER_SECONDS = 30;
|
||||
|
||||
/** Idle-check interval. One check per minute is sufficient resolution for a 5-min timeout. */
|
||||
private static final long IDLE_CHECK_INTERVAL_SECONDS = 60L;
|
||||
|
||||
private final ModelLoader modelLoader;
|
||||
private final ModelStateEventBus eventBus;
|
||||
private final ObserveJobs observeJobs;
|
||||
private final long idleTimeoutSeconds;
|
||||
|
||||
// ---- state ----
|
||||
private final ReentrantLock lock = new ReentrantLock();
|
||||
private final AtomicReference<ModelState> state = new AtomicReference<>(ModelState.UNLOADED);
|
||||
private volatile @Nullable Instant loadedAt = null;
|
||||
private final AtomicLong lastActivityEpochMs = new AtomicLong(0L);
|
||||
|
||||
// ---- executors ----
|
||||
/** Single platform OS thread for all CUDA load/unload calls. */
|
||||
private final java.util.concurrent.ExecutorService cudaExecutor;
|
||||
|
||||
/** Scheduler that periodically checks for idle timeout. */
|
||||
private final ScheduledExecutorService idleScheduler;
|
||||
|
||||
/** Handle to the in-flight async load Future (guarded by lock). */
|
||||
private @Nullable Future<?> pendingLoad = null;
|
||||
|
||||
public ModelLifecycleService(
|
||||
ModelLoader modelLoader,
|
||||
ModelStateEventBus eventBus,
|
||||
ObserveJobs observeJobs,
|
||||
long idleTimeoutSeconds) {
|
||||
this.modelLoader = modelLoader;
|
||||
this.eventBus = eventBus;
|
||||
this.observeJobs = observeJobs;
|
||||
this.idleTimeoutSeconds = idleTimeoutSeconds;
|
||||
|
||||
ThreadFactory platformFactory = Thread.ofPlatform()
|
||||
.name("model-lifecycle-cuda")
|
||||
.factory();
|
||||
this.cudaExecutor = Executors.newSingleThreadExecutor(platformFactory);
|
||||
this.idleScheduler = Executors.newSingleThreadScheduledExecutor(
|
||||
Thread.ofPlatform().name("model-idle-check").factory());
|
||||
}
|
||||
|
||||
/** Initializes the idle timer and job-event subscription. Called by the Spring bean lifecycle. */
|
||||
public void init() {
|
||||
// Subscribe to job events so that finishing a job resets the idle clock.
|
||||
observeJobs.subscribeJobs(this::onJobEvent);
|
||||
// Start the periodic idle check.
|
||||
idleScheduler.scheduleAtFixedRate(
|
||||
this::checkIdleTimeout,
|
||||
IDLE_CHECK_INTERVAL_SECONDS,
|
||||
IDLE_CHECK_INTERVAL_SECONDS,
|
||||
TimeUnit.SECONDS);
|
||||
log.info("ModelLifecycleService started; idleTimeoutSeconds={} state=UNLOADED", idleTimeoutSeconds);
|
||||
}
|
||||
|
||||
/** Shuts down schedulers and unloads models if loaded. Called by the Spring bean lifecycle. */
|
||||
public void shutdown() {
|
||||
idleScheduler.shutdownNow();
|
||||
// Unload synchronously on shutdown so VRAM is released cleanly.
|
||||
if (state.get() == ModelState.LOADED || state.get() == ModelState.LOADING) {
|
||||
doUnload("JVM shutdown");
|
||||
}
|
||||
cudaExecutor.shutdownNow();
|
||||
}
|
||||
|
||||
// ---- ManageModelLifecycle ----
|
||||
|
||||
@Override
|
||||
public void ensureReady() {
|
||||
// Fast path: already loaded.
|
||||
if (state.get() == ModelState.LOADED) {
|
||||
touchActivity();
|
||||
return;
|
||||
}
|
||||
lock.lock();
|
||||
try {
|
||||
ModelState current = state.get();
|
||||
switch (current) {
|
||||
case LOADED -> {
|
||||
touchActivity();
|
||||
return;
|
||||
}
|
||||
case LOADING -> throw new ModelNotReady(LOADING_RETRY_AFTER_SECONDS);
|
||||
case UNLOADING -> throw new ModelNotReady(LOADING_RETRY_AFTER_SECONDS);
|
||||
case UNLOADED -> {
|
||||
triggerAsyncLoad("triggered by incoming request");
|
||||
throw new ModelNotReady(LOADING_RETRY_AFTER_SECONDS);
|
||||
}
|
||||
}
|
||||
} finally {
|
||||
lock.unlock();
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean forceUnload(boolean force) {
|
||||
lock.lock();
|
||||
try {
|
||||
ModelState current = state.get();
|
||||
if (current == ModelState.UNLOADED || current == ModelState.UNLOADING) {
|
||||
log.info("forceUnload: already {} — no-op", current);
|
||||
return true;
|
||||
}
|
||||
if (!force && hasRunningJobs()) {
|
||||
log.info("forceUnload: blocked by running ingestion jobs (use force=true to override)");
|
||||
return false;
|
||||
}
|
||||
if (current == ModelState.LOADING && pendingLoad != null) {
|
||||
pendingLoad.cancel(false); // try to cancel; unload will follow after transition
|
||||
}
|
||||
doUnload("force-unload via API");
|
||||
return true;
|
||||
} finally {
|
||||
lock.unlock();
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public Status getStatus() {
|
||||
return new Status(state.get(), loadedAt, epochMsToInstant(lastActivityEpochMs.get()), idleTimeoutSeconds);
|
||||
}
|
||||
|
||||
@Override
|
||||
public AutoCloseable subscribeState(Consumer<ModelStateEvent> subscriber) {
|
||||
return eventBus.subscribe(subscriber);
|
||||
}
|
||||
|
||||
// ---- idle timer ----
|
||||
|
||||
private void checkIdleTimeout() {
|
||||
if (state.get() != ModelState.LOADED) {
|
||||
return;
|
||||
}
|
||||
long lastMs = lastActivityEpochMs.get();
|
||||
if (lastMs == 0L) {
|
||||
return; // no activity yet
|
||||
}
|
||||
long idleMs = System.currentTimeMillis() - lastMs;
|
||||
if (idleMs < idleTimeoutSeconds * 1000L) {
|
||||
return;
|
||||
}
|
||||
lock.lock();
|
||||
try {
|
||||
if (state.get() != ModelState.LOADED) {
|
||||
return; // state changed while we waited for the lock
|
||||
}
|
||||
if (hasRunningJobs()) {
|
||||
log.debug("idle timeout reached but ingestion jobs are running; skipping unload");
|
||||
return;
|
||||
}
|
||||
long nowIdleMs = System.currentTimeMillis() - lastActivityEpochMs.get();
|
||||
if (nowIdleMs < idleTimeoutSeconds * 1000L) {
|
||||
return; // activity happened while we waited for the lock
|
||||
}
|
||||
doUnload("idle timeout after " + (nowIdleMs / 1000) + " s");
|
||||
} finally {
|
||||
lock.unlock();
|
||||
}
|
||||
}
|
||||
|
||||
// ---- internal helpers (always called while holding lock) ----
|
||||
|
||||
/**
|
||||
* Submits an async load to the CUDA executor. Must be called while holding {@code lock}.
|
||||
* Transitions state UNLOADED → LOADING.
|
||||
*/
|
||||
private void triggerAsyncLoad(String reason) {
|
||||
if (pendingLoad != null && !pendingLoad.isDone()) {
|
||||
return; // already loading — deduplicate
|
||||
}
|
||||
transition(ModelState.LOADING, "load requested: " + reason);
|
||||
pendingLoad = cudaExecutor.submit(() -> {
|
||||
try {
|
||||
log.info("Loading GPU models…");
|
||||
modelLoader.load();
|
||||
lock.lock();
|
||||
try {
|
||||
loadedAt = Instant.now();
|
||||
touchActivity();
|
||||
transition(ModelState.LOADED, null);
|
||||
log.info("GPU models loaded successfully");
|
||||
} finally {
|
||||
lock.unlock();
|
||||
}
|
||||
} catch (Exception e) {
|
||||
log.error("GPU model load failed", e);
|
||||
lock.lock();
|
||||
try {
|
||||
transition(ModelState.UNLOADED, "load failed: " + e.getMessage());
|
||||
} finally {
|
||||
lock.unlock();
|
||||
}
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
/**
|
||||
* Runs an unload synchronously on the calling thread (must already be the CUDA executor thread
|
||||
* or called during shutdown). For runtime calls it is submitted to the cuda executor.
|
||||
* Must be called while holding {@code lock}.
|
||||
*/
|
||||
private void doUnload(String reason) {
|
||||
transition(ModelState.UNLOADING, reason);
|
||||
// Run unload on the CUDA executor so it stays on the same OS thread as the sessions.
|
||||
cudaExecutor.execute(() -> {
|
||||
try {
|
||||
log.info("Unloading GPU models…");
|
||||
modelLoader.unload();
|
||||
log.info("GPU models unloaded");
|
||||
} catch (Exception e) {
|
||||
log.error("GPU model unload error (models may still be in VRAM)", e);
|
||||
} finally {
|
||||
lock.lock();
|
||||
try {
|
||||
loadedAt = null;
|
||||
transition(ModelState.UNLOADED, null);
|
||||
} finally {
|
||||
lock.unlock();
|
||||
}
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
private void transition(ModelState next, @Nullable String message) {
|
||||
ModelState prev = state.getAndSet(next);
|
||||
if (prev != next) {
|
||||
log.debug("Model state: {} → {}{}", prev, next, message != null ? " [" + message + "]" : "");
|
||||
eventBus.publish(ModelStateEvent.of(next, message != null ? message : ""));
|
||||
}
|
||||
}
|
||||
|
||||
private void touchActivity() {
|
||||
lastActivityEpochMs.set(System.currentTimeMillis());
|
||||
}
|
||||
|
||||
private boolean hasRunningJobs() {
|
||||
return !observeJobs.listJobs(null, null, JobStatus.RUNNING, 1).isEmpty();
|
||||
}
|
||||
|
||||
private void onJobEvent(IngestionJob job) {
|
||||
if (job.status() == JobStatus.SUCCEEDED || job.status() == JobStatus.FAILED) {
|
||||
touchActivity(); // reset idle clock when a job finishes
|
||||
}
|
||||
}
|
||||
|
||||
private static @Nullable Instant epochMsToInstant(long ms) {
|
||||
return ms == 0L ? null : Instant.ofEpochMilli(ms);
|
||||
}
|
||||
}
|
||||
@@ -6,6 +6,7 @@ import com.trueref.domain.model.Repository;
|
||||
import com.trueref.domain.model.SearchHit;
|
||||
import com.trueref.domain.model.SearchScope;
|
||||
import com.trueref.domain.model.Version;
|
||||
import com.trueref.domain.port.in.ManageModelLifecycle;
|
||||
import com.trueref.domain.port.in.SearchLibraryDocs;
|
||||
import com.trueref.domain.port.out.ChunkStore;
|
||||
import com.trueref.domain.port.out.EmbeddingService;
|
||||
@@ -42,6 +43,7 @@ public final class HybridSearchService implements SearchLibraryDocs {
|
||||
private final EmbeddingService embedder;
|
||||
private final RerankerService reranker;
|
||||
private final RepositoryStore repos;
|
||||
private final ManageModelLifecycle lifecycle;
|
||||
private final int rrfK;
|
||||
private final int rerankTopK;
|
||||
private final int finalTopK;
|
||||
@@ -51,6 +53,7 @@ public final class HybridSearchService implements SearchLibraryDocs {
|
||||
EmbeddingService embedder,
|
||||
RerankerService reranker,
|
||||
RepositoryStore repos,
|
||||
ManageModelLifecycle lifecycle,
|
||||
int rrfK,
|
||||
int rerankTopK,
|
||||
int finalTopK) {
|
||||
@@ -58,6 +61,7 @@ public final class HybridSearchService implements SearchLibraryDocs {
|
||||
this.embedder = embedder;
|
||||
this.reranker = reranker;
|
||||
this.repos = repos;
|
||||
this.lifecycle = lifecycle;
|
||||
this.rrfK = rrfK;
|
||||
this.rerankTopK = rerankTopK;
|
||||
this.finalTopK = finalTopK;
|
||||
@@ -65,6 +69,9 @@ public final class HybridSearchService implements SearchLibraryDocs {
|
||||
|
||||
@Override
|
||||
public Result search(Query q) {
|
||||
// Ensure models are loaded; throws ModelNotReady (→ HTTP 503) if not.
|
||||
lifecycle.ensureReady();
|
||||
|
||||
if (q.text() == null || q.text().isBlank()) {
|
||||
throw new InvalidSearchRequest("query text must not be blank");
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user