feat: GPU model lazy-load/unload lifecycle management
All checks were successful
Build and publish Docker image / Build and push CPU image (push) Successful in 2m33s
Build and publish Docker image / Build and push GPU image (push) Successful in 3m15s

- Domain: add ModelState, ModelStateEvent, ModelNotReady, ManageModelLifecycle
  (in-port), ModelLoader and ModelStateEventBus (out-ports)
- Application: InMemoryModelStateEventBus; ModelLifecycleService — state
  machine (ReentrantLock), lazy load on first request, idle-timeout auto-unload
  (configurable via trueref.embedding.idle-timeout-seconds, default 300 s),
  job-guard (skips unload while ingestion running), platform-thread CUDA executor
- Adapters: OnnxModelLoader wires embedder + reranker start/stop; remove
  @PostConstruct/@PreDestroy from OnnxEmbeddingService and OnnxRerankerService;
  requireStarted() now throws ModelNotReady instead of IllegalStateException
- REST: GET /api/model/status, POST /api/model/unload (409 when jobs running,
  force=true to override), GET /api/model/status/stream (SSE)
- GlobalExceptionHandler: ModelNotReady -> 503 + Retry-After header
- HybridSearchService: calls lifecycle.ensureReady() before every search so
  both REST and MCP paths get ModelNotReady (-> 503 / MCP error) when unloaded
- TrueRefMcpTools: catches ModelNotReady, returns retry hint in MCP error text
- Tests: InMemoryModelStateEventBusTest, ModelLifecycleServiceTest (10 cases),
  OnnxModelLoaderTest, GlobalExceptionHandlerTest — all 41 tests green

Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
This commit is contained in:
moze
2026-05-09 15:44:33 +02:00
parent 943a38fd36
commit 5c6085df99
24 changed files with 1144 additions and 17 deletions

View File

@@ -0,0 +1,35 @@
package com.trueref.application.model;
import com.trueref.domain.model.ModelStateEvent;
import com.trueref.domain.port.out.ModelStateEventBus;
import java.util.concurrent.CopyOnWriteArrayList;
import java.util.function.Consumer;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
/** In-memory, thread-safe implementation of {@link ModelStateEventBus}. */
public class InMemoryModelStateEventBus implements ModelStateEventBus {
private static final Logger log = LoggerFactory.getLogger(InMemoryModelStateEventBus.class);
private final CopyOnWriteArrayList<Consumer<ModelStateEvent>> subscribers =
new CopyOnWriteArrayList<>();
@Override
public void publish(ModelStateEvent event) {
for (Consumer<ModelStateEvent> subscriber : subscribers) {
try {
subscriber.accept(event);
} catch (Exception e) {
log.warn("Model state subscriber threw during publish; removing: {}", e.toString());
subscribers.remove(subscriber);
}
}
}
@Override
public AutoCloseable subscribe(Consumer<ModelStateEvent> subscriber) {
subscribers.add(subscriber);
return () -> subscribers.remove(subscriber);
}
}

View File

@@ -0,0 +1,287 @@
package com.trueref.application.model;
import com.trueref.domain.error.ModelNotReady;
import com.trueref.domain.model.IngestionJob;
import com.trueref.domain.model.JobStatus;
import com.trueref.domain.model.ModelState;
import com.trueref.domain.model.ModelStateEvent;
import com.trueref.domain.port.in.ManageModelLifecycle;
import com.trueref.domain.port.in.ObserveJobs;
import com.trueref.domain.port.out.ModelLoader;
import com.trueref.domain.port.out.ModelStateEventBus;
import java.time.Instant;
import java.util.concurrent.Executors;
import java.util.concurrent.Future;
import java.util.concurrent.ScheduledExecutorService;
import java.util.concurrent.ThreadFactory;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicLong;
import java.util.concurrent.atomic.AtomicReference;
import java.util.concurrent.locks.ReentrantLock;
import java.util.function.Consumer;
import org.jspecify.annotations.Nullable;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
/**
* Orchestrates the GPU model lifecycle: lazy loading, auto-unload after idle timeout, and
* force-unload. State transitions are protected by a {@link ReentrantLock}. Heavy load/unload
* operations are executed on a dedicated single-thread platform-thread executor to satisfy CUDA
* OS-thread affinity constraints.
*/
public class ModelLifecycleService implements ManageModelLifecycle {
private static final Logger log = LoggerFactory.getLogger(ModelLifecycleService.class);
/** Default retry hint in seconds when the model is LOADING. Load typically takes 530 s. */
private static final int LOADING_RETRY_AFTER_SECONDS = 30;
/** Idle-check interval. One check per minute is sufficient resolution for a 5-min timeout. */
private static final long IDLE_CHECK_INTERVAL_SECONDS = 60L;
private final ModelLoader modelLoader;
private final ModelStateEventBus eventBus;
private final ObserveJobs observeJobs;
private final long idleTimeoutSeconds;
// ---- state ----
private final ReentrantLock lock = new ReentrantLock();
private final AtomicReference<ModelState> state = new AtomicReference<>(ModelState.UNLOADED);
private volatile @Nullable Instant loadedAt = null;
private final AtomicLong lastActivityEpochMs = new AtomicLong(0L);
// ---- executors ----
/** Single platform OS thread for all CUDA load/unload calls. */
private final java.util.concurrent.ExecutorService cudaExecutor;
/** Scheduler that periodically checks for idle timeout. */
private final ScheduledExecutorService idleScheduler;
/** Handle to the in-flight async load Future (guarded by lock). */
private @Nullable Future<?> pendingLoad = null;
public ModelLifecycleService(
ModelLoader modelLoader,
ModelStateEventBus eventBus,
ObserveJobs observeJobs,
long idleTimeoutSeconds) {
this.modelLoader = modelLoader;
this.eventBus = eventBus;
this.observeJobs = observeJobs;
this.idleTimeoutSeconds = idleTimeoutSeconds;
ThreadFactory platformFactory = Thread.ofPlatform()
.name("model-lifecycle-cuda")
.factory();
this.cudaExecutor = Executors.newSingleThreadExecutor(platformFactory);
this.idleScheduler = Executors.newSingleThreadScheduledExecutor(
Thread.ofPlatform().name("model-idle-check").factory());
}
/** Initializes the idle timer and job-event subscription. Called by the Spring bean lifecycle. */
public void init() {
// Subscribe to job events so that finishing a job resets the idle clock.
observeJobs.subscribeJobs(this::onJobEvent);
// Start the periodic idle check.
idleScheduler.scheduleAtFixedRate(
this::checkIdleTimeout,
IDLE_CHECK_INTERVAL_SECONDS,
IDLE_CHECK_INTERVAL_SECONDS,
TimeUnit.SECONDS);
log.info("ModelLifecycleService started; idleTimeoutSeconds={} state=UNLOADED", idleTimeoutSeconds);
}
/** Shuts down schedulers and unloads models if loaded. Called by the Spring bean lifecycle. */
public void shutdown() {
idleScheduler.shutdownNow();
// Unload synchronously on shutdown so VRAM is released cleanly.
if (state.get() == ModelState.LOADED || state.get() == ModelState.LOADING) {
doUnload("JVM shutdown");
}
cudaExecutor.shutdownNow();
}
// ---- ManageModelLifecycle ----
@Override
public void ensureReady() {
// Fast path: already loaded.
if (state.get() == ModelState.LOADED) {
touchActivity();
return;
}
lock.lock();
try {
ModelState current = state.get();
switch (current) {
case LOADED -> {
touchActivity();
return;
}
case LOADING -> throw new ModelNotReady(LOADING_RETRY_AFTER_SECONDS);
case UNLOADING -> throw new ModelNotReady(LOADING_RETRY_AFTER_SECONDS);
case UNLOADED -> {
triggerAsyncLoad("triggered by incoming request");
throw new ModelNotReady(LOADING_RETRY_AFTER_SECONDS);
}
}
} finally {
lock.unlock();
}
}
@Override
public boolean forceUnload(boolean force) {
lock.lock();
try {
ModelState current = state.get();
if (current == ModelState.UNLOADED || current == ModelState.UNLOADING) {
log.info("forceUnload: already {} — no-op", current);
return true;
}
if (!force && hasRunningJobs()) {
log.info("forceUnload: blocked by running ingestion jobs (use force=true to override)");
return false;
}
if (current == ModelState.LOADING && pendingLoad != null) {
pendingLoad.cancel(false); // try to cancel; unload will follow after transition
}
doUnload("force-unload via API");
return true;
} finally {
lock.unlock();
}
}
@Override
public Status getStatus() {
return new Status(state.get(), loadedAt, epochMsToInstant(lastActivityEpochMs.get()), idleTimeoutSeconds);
}
@Override
public AutoCloseable subscribeState(Consumer<ModelStateEvent> subscriber) {
return eventBus.subscribe(subscriber);
}
// ---- idle timer ----
private void checkIdleTimeout() {
if (state.get() != ModelState.LOADED) {
return;
}
long lastMs = lastActivityEpochMs.get();
if (lastMs == 0L) {
return; // no activity yet
}
long idleMs = System.currentTimeMillis() - lastMs;
if (idleMs < idleTimeoutSeconds * 1000L) {
return;
}
lock.lock();
try {
if (state.get() != ModelState.LOADED) {
return; // state changed while we waited for the lock
}
if (hasRunningJobs()) {
log.debug("idle timeout reached but ingestion jobs are running; skipping unload");
return;
}
long nowIdleMs = System.currentTimeMillis() - lastActivityEpochMs.get();
if (nowIdleMs < idleTimeoutSeconds * 1000L) {
return; // activity happened while we waited for the lock
}
doUnload("idle timeout after " + (nowIdleMs / 1000) + " s");
} finally {
lock.unlock();
}
}
// ---- internal helpers (always called while holding lock) ----
/**
* Submits an async load to the CUDA executor. Must be called while holding {@code lock}.
* Transitions state UNLOADED → LOADING.
*/
private void triggerAsyncLoad(String reason) {
if (pendingLoad != null && !pendingLoad.isDone()) {
return; // already loading — deduplicate
}
transition(ModelState.LOADING, "load requested: " + reason);
pendingLoad = cudaExecutor.submit(() -> {
try {
log.info("Loading GPU models…");
modelLoader.load();
lock.lock();
try {
loadedAt = Instant.now();
touchActivity();
transition(ModelState.LOADED, null);
log.info("GPU models loaded successfully");
} finally {
lock.unlock();
}
} catch (Exception e) {
log.error("GPU model load failed", e);
lock.lock();
try {
transition(ModelState.UNLOADED, "load failed: " + e.getMessage());
} finally {
lock.unlock();
}
}
});
}
/**
* Runs an unload synchronously on the calling thread (must already be the CUDA executor thread
* or called during shutdown). For runtime calls it is submitted to the cuda executor.
* Must be called while holding {@code lock}.
*/
private void doUnload(String reason) {
transition(ModelState.UNLOADING, reason);
// Run unload on the CUDA executor so it stays on the same OS thread as the sessions.
cudaExecutor.execute(() -> {
try {
log.info("Unloading GPU models…");
modelLoader.unload();
log.info("GPU models unloaded");
} catch (Exception e) {
log.error("GPU model unload error (models may still be in VRAM)", e);
} finally {
lock.lock();
try {
loadedAt = null;
transition(ModelState.UNLOADED, null);
} finally {
lock.unlock();
}
}
});
}
private void transition(ModelState next, @Nullable String message) {
ModelState prev = state.getAndSet(next);
if (prev != next) {
log.debug("Model state: {} → {}{}", prev, next, message != null ? " [" + message + "]" : "");
eventBus.publish(ModelStateEvent.of(next, message != null ? message : ""));
}
}
private void touchActivity() {
lastActivityEpochMs.set(System.currentTimeMillis());
}
private boolean hasRunningJobs() {
return !observeJobs.listJobs(null, null, JobStatus.RUNNING, 1).isEmpty();
}
private void onJobEvent(IngestionJob job) {
if (job.status() == JobStatus.SUCCEEDED || job.status() == JobStatus.FAILED) {
touchActivity(); // reset idle clock when a job finishes
}
}
private static @Nullable Instant epochMsToInstant(long ms) {
return ms == 0L ? null : Instant.ofEpochMilli(ms);
}
}

View File

@@ -6,6 +6,7 @@ import com.trueref.domain.model.Repository;
import com.trueref.domain.model.SearchHit;
import com.trueref.domain.model.SearchScope;
import com.trueref.domain.model.Version;
import com.trueref.domain.port.in.ManageModelLifecycle;
import com.trueref.domain.port.in.SearchLibraryDocs;
import com.trueref.domain.port.out.ChunkStore;
import com.trueref.domain.port.out.EmbeddingService;
@@ -42,6 +43,7 @@ public final class HybridSearchService implements SearchLibraryDocs {
private final EmbeddingService embedder;
private final RerankerService reranker;
private final RepositoryStore repos;
private final ManageModelLifecycle lifecycle;
private final int rrfK;
private final int rerankTopK;
private final int finalTopK;
@@ -51,6 +53,7 @@ public final class HybridSearchService implements SearchLibraryDocs {
EmbeddingService embedder,
RerankerService reranker,
RepositoryStore repos,
ManageModelLifecycle lifecycle,
int rrfK,
int rerankTopK,
int finalTopK) {
@@ -58,6 +61,7 @@ public final class HybridSearchService implements SearchLibraryDocs {
this.embedder = embedder;
this.reranker = reranker;
this.repos = repos;
this.lifecycle = lifecycle;
this.rrfK = rrfK;
this.rerankTopK = rerankTopK;
this.finalTopK = finalTopK;
@@ -65,6 +69,9 @@ public final class HybridSearchService implements SearchLibraryDocs {
@Override
public Result search(Query q) {
// Ensure models are loaded; throws ModelNotReady (→ HTTP 503) if not.
lifecycle.ensureReady();
if (q.text() == null || q.text().isBlank()) {
throw new InvalidSearchRequest("query text must not be blank");
}