From 5c6085df99eee83c989c1f0fb5a19f09adcee66e Mon Sep 17 00:00:00 2001 From: moze Date: Sat, 9 May 2026 15:44:33 +0200 Subject: [PATCH] feat: GPU model lazy-load/unload lifecycle management MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Domain: add ModelState, ModelStateEvent, ModelNotReady, ManageModelLifecycle (in-port), ModelLoader and ModelStateEventBus (out-ports) - Application: InMemoryModelStateEventBus; ModelLifecycleService — state machine (ReentrantLock), lazy load on first request, idle-timeout auto-unload (configurable via trueref.embedding.idle-timeout-seconds, default 300 s), job-guard (skips unload while ingestion running), platform-thread CUDA executor - Adapters: OnnxModelLoader wires embedder + reranker start/stop; remove @PostConstruct/@PreDestroy from OnnxEmbeddingService and OnnxRerankerService; requireStarted() now throws ModelNotReady instead of IllegalStateException - REST: GET /api/model/status, POST /api/model/unload (409 when jobs running, force=true to override), GET /api/model/status/stream (SSE) - GlobalExceptionHandler: ModelNotReady -> 503 + Retry-After header - HybridSearchService: calls lifecycle.ensureReady() before every search so both REST and MCP paths get ModelNotReady (-> 503 / MCP error) when unloaded - TrueRefMcpTools: catches ModelNotReady, returns retry hint in MCP error text - Tests: InMemoryModelStateEventBusTest, ModelLifecycleServiceTest (10 cases), OnnxModelLoaderTest, GlobalExceptionHandlerTest — all 41 tests green Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- .../adapter/in/mcp/TrueRefMcpTools.java | 4 + .../in/rest/GlobalExceptionHandler.java | 11 + .../adapter/in/rest/ModelController.java | 106 +++++++ .../adapter/in/rest/dto/ModelStatusDto.java | 22 ++ .../embedding/onnx/OnnxEmbeddingConfig.java | 32 +- .../embedding/onnx/OnnxEmbeddingService.java | 18 +- .../out/embedding/onnx/OnnxModelLoader.java | 54 ++++ .../out/embedding/onnx/OnnxProperties.java | 6 + .../embedding/onnx/OnnxRerankerService.java | 18 +- .../in/rest/GlobalExceptionHandlerTest.java | 44 +++ .../embedding/onnx/OnnxModelLoaderTest.java | 71 +++++ .../model/InMemoryModelStateEventBusTest.java | 75 +++++ .../model/ModelLifecycleServiceTest.java | 227 ++++++++++++++ .../model/InMemoryModelStateEventBus.java | 35 +++ .../model/ModelLifecycleService.java | 287 ++++++++++++++++++ .../search/HybridSearchService.java | 7 + .../trueref/bootstrap/ApplicationBeans.java | 4 +- .../trueref/domain/error/ModelNotReady.java | 23 ++ .../domain/error/TrueRefException.java | 3 +- .../com/trueref/domain/model/ModelState.java | 13 + .../trueref/domain/model/ModelStateEvent.java | 18 ++ .../domain/port/in/ManageModelLifecycle.java | 47 +++ .../trueref/domain/port/out/ModelLoader.java | 18 ++ .../domain/port/out/ModelStateEventBus.java | 18 ++ 24 files changed, 1144 insertions(+), 17 deletions(-) create mode 100644 trueref-adapters/src/main/java/com/trueref/adapter/in/rest/ModelController.java create mode 100644 trueref-adapters/src/main/java/com/trueref/adapter/in/rest/dto/ModelStatusDto.java create mode 100644 trueref-adapters/src/main/java/com/trueref/adapter/out/embedding/onnx/OnnxModelLoader.java create mode 100644 trueref-adapters/src/test/java/com/trueref/adapter/in/rest/GlobalExceptionHandlerTest.java create mode 100644 trueref-adapters/src/test/java/com/trueref/adapter/out/embedding/onnx/OnnxModelLoaderTest.java create mode 100644 trueref-adapters/src/test/java/com/trueref/application/model/InMemoryModelStateEventBusTest.java create mode 100644 trueref-adapters/src/test/java/com/trueref/application/model/ModelLifecycleServiceTest.java create mode 100644 trueref-application/src/main/java/com/trueref/application/model/InMemoryModelStateEventBus.java create mode 100644 trueref-application/src/main/java/com/trueref/application/model/ModelLifecycleService.java create mode 100644 trueref-domain/src/main/java/com/trueref/domain/error/ModelNotReady.java create mode 100644 trueref-domain/src/main/java/com/trueref/domain/model/ModelState.java create mode 100644 trueref-domain/src/main/java/com/trueref/domain/model/ModelStateEvent.java create mode 100644 trueref-domain/src/main/java/com/trueref/domain/port/in/ManageModelLifecycle.java create mode 100644 trueref-domain/src/main/java/com/trueref/domain/port/out/ModelLoader.java create mode 100644 trueref-domain/src/main/java/com/trueref/domain/port/out/ModelStateEventBus.java diff --git a/trueref-adapters/src/main/java/com/trueref/adapter/in/mcp/TrueRefMcpTools.java b/trueref-adapters/src/main/java/com/trueref/adapter/in/mcp/TrueRefMcpTools.java index c133abc..26c2d84 100644 --- a/trueref-adapters/src/main/java/com/trueref/adapter/in/mcp/TrueRefMcpTools.java +++ b/trueref-adapters/src/main/java/com/trueref/adapter/in/mcp/TrueRefMcpTools.java @@ -6,6 +6,7 @@ import com.trueref.domain.model.SearchHit; import com.trueref.domain.model.SearchScope; import com.trueref.domain.model.Version; import com.trueref.domain.model.VersionStatus; +import com.trueref.domain.error.ModelNotReady; import com.trueref.domain.port.in.IndexVersion; import com.trueref.domain.port.in.QueryCatalog; import com.trueref.domain.port.in.ResolveLibraryId; @@ -153,6 +154,9 @@ public class TrueRefMcpTools { SearchLibraryDocs.Result res; try { res = search.search(q); + } catch (ModelNotReady e) { + return "[model_not_ready] The inference model is loading. " + + "Please retry in ~" + e.retryAfterSeconds() + " seconds."; } catch (Exception e) { log.warn("MCP search failed for {}: {}", libraryId, e.toString()); return "Search failed for " + libraryId + ": " + e.getMessage(); diff --git a/trueref-adapters/src/main/java/com/trueref/adapter/in/rest/GlobalExceptionHandler.java b/trueref-adapters/src/main/java/com/trueref/adapter/in/rest/GlobalExceptionHandler.java index c66cd36..93c1c1b 100644 --- a/trueref-adapters/src/main/java/com/trueref/adapter/in/rest/GlobalExceptionHandler.java +++ b/trueref-adapters/src/main/java/com/trueref/adapter/in/rest/GlobalExceptionHandler.java @@ -2,6 +2,7 @@ package com.trueref.adapter.in.rest; import com.trueref.domain.error.IngestionFailed; import com.trueref.domain.error.InvalidSearchRequest; +import com.trueref.domain.error.ModelNotReady; import com.trueref.domain.error.RepositoryAlreadyRegistered; import com.trueref.domain.error.RepositoryNotFound; import com.trueref.domain.error.TagNotFound; @@ -12,6 +13,7 @@ import jakarta.validation.ConstraintViolationException; import java.util.List; import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import org.springframework.http.HttpHeaders; import org.springframework.http.HttpStatus; import org.springframework.http.ResponseEntity; import org.springframework.validation.FieldError; @@ -28,6 +30,15 @@ public class GlobalExceptionHandler { private static final Logger log = LoggerFactory.getLogger(GlobalExceptionHandler.class); + @ExceptionHandler(ModelNotReady.class) + public ResponseEntity handleModelNotReady(ModelNotReady ex) { + HttpHeaders headers = new HttpHeaders(); + headers.set(HttpHeaders.RETRY_AFTER, String.valueOf(ex.retryAfterSeconds())); + return ResponseEntity.status(HttpStatus.SERVICE_UNAVAILABLE) + .headers(headers) + .body(ErrorResponse.of(ex.code(), ex.getMessage())); + } + @ExceptionHandler({RepositoryNotFound.class, VersionNotFound.class, TagNotFound.class}) public ResponseEntity handleNotFound(TrueRefException ex) { return status(HttpStatus.NOT_FOUND, ex); diff --git a/trueref-adapters/src/main/java/com/trueref/adapter/in/rest/ModelController.java b/trueref-adapters/src/main/java/com/trueref/adapter/in/rest/ModelController.java new file mode 100644 index 0000000..b0773aa --- /dev/null +++ b/trueref-adapters/src/main/java/com/trueref/adapter/in/rest/ModelController.java @@ -0,0 +1,106 @@ +package com.trueref.adapter.in.rest; + +import com.trueref.adapter.in.rest.dto.ModelStatusDto; +import com.trueref.domain.model.ModelState; +import com.trueref.domain.port.in.ManageModelLifecycle; +import io.swagger.v3.oas.annotations.Operation; +import io.swagger.v3.oas.annotations.tags.Tag; +import java.io.IOException; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import org.springframework.http.HttpStatus; +import org.springframework.http.MediaType; +import org.springframework.web.bind.annotation.GetMapping; +import org.springframework.web.bind.annotation.PostMapping; +import org.springframework.web.bind.annotation.RequestMapping; +import org.springframework.web.bind.annotation.RequestParam; +import org.springframework.web.bind.annotation.RestController; +import org.springframework.web.server.ResponseStatusException; +import org.springframework.web.servlet.mvc.method.annotation.SseEmitter; + +/** REST resource: {@code /api/model}. */ +@RestController +@RequestMapping("/api/model") +@Tag(name = "model", description = "GPU model lifecycle management (load state, force unload, SSE status stream).") +public class ModelController { + + private static final Logger log = LoggerFactory.getLogger(ModelController.class); + + private final ManageModelLifecycle lifecycle; + + public ModelController(ManageModelLifecycle lifecycle) { + this.lifecycle = lifecycle; + } + + @Operation(summary = "Current model lifecycle status.") + @GetMapping("/status") + public ModelStatusDto status() { + return ModelStatusDto.of(lifecycle.getStatus()); + } + + @Operation(summary = "Force-unload the GPU models from VRAM. " + + "Returns 409 if ingestion jobs are running (use force=true to override).") + @PostMapping("/unload") + public ModelStatusDto unload(@RequestParam(value = "force", defaultValue = "false") boolean force) { + boolean initiated = lifecycle.forceUnload(force); + if (!initiated) { + throw new ResponseStatusException( + HttpStatus.CONFLICT, + "model_unload_blocked_by_running_jobs"); + } + return ModelStatusDto.of(lifecycle.getStatus()); + } + + @Operation(summary = "Server-Sent Events stream of model state transitions " + + "(LOADING, LOADED, UNLOADING, UNLOADED).") + @GetMapping(value = "/status/stream", produces = MediaType.TEXT_EVENT_STREAM_VALUE) + public SseEmitter statusStream() { + SseEmitter emitter = new SseEmitter(0L); + + // Flush headers to the client immediately so EventSource fires 'open'. + try { + emitter.send(SseEmitter.event().name("ping").data("")); + // Send current state immediately so the client doesn't have to wait for the next transition. + emitter.send(SseEmitter.event() + .name("model-status") + .data(lifecycle.getStatus().state().name())); + } catch (IOException e) { + emitter.completeWithError(e); + return emitter; + } + + AutoCloseable subscription = lifecycle.subscribeState(event -> { + try { + emitter.send(SseEmitter.event().name("model-status").data(event)); + } catch (IOException ex) { + emitter.completeWithError(ex); + } + }); + + Thread keepalive = Thread.startVirtualThread(() -> { + try { + while (!Thread.currentThread().isInterrupted()) { + Thread.sleep(20_000); + emitter.send(SseEmitter.event().name("ping").data("")); + } + } catch (InterruptedException ignored) { + // normal shutdown + } catch (Exception ignored) { + // emitter already completed; exit + } + }); + + Runnable cleanup = () -> { + keepalive.interrupt(); + try { + subscription.close(); + } catch (Exception ex) { + log.debug("failed to close model status SSE subscription: {}", ex.toString()); + } + }; + emitter.onCompletion(cleanup); + emitter.onTimeout(cleanup); + emitter.onError(e -> cleanup.run()); + return emitter; + } +} diff --git a/trueref-adapters/src/main/java/com/trueref/adapter/in/rest/dto/ModelStatusDto.java b/trueref-adapters/src/main/java/com/trueref/adapter/in/rest/dto/ModelStatusDto.java new file mode 100644 index 0000000..be08b2c --- /dev/null +++ b/trueref-adapters/src/main/java/com/trueref/adapter/in/rest/dto/ModelStatusDto.java @@ -0,0 +1,22 @@ +package com.trueref.adapter.in.rest.dto; + +import com.trueref.domain.model.ModelState; +import com.trueref.domain.port.in.ManageModelLifecycle; +import java.time.Instant; +import org.jspecify.annotations.Nullable; + +/** Snapshot of the GPU model lifecycle state, returned by {@code GET /api/model/status}. */ +public record ModelStatusDto( + ModelState state, + @Nullable Instant loadedAt, + @Nullable Instant lastActivityAt, + long idleTimeoutSeconds) { + + public static ModelStatusDto of(ManageModelLifecycle.Status status) { + return new ModelStatusDto( + status.state(), + status.loadedAt(), + status.lastActivityAt(), + status.idleTimeoutSeconds()); + } +} diff --git a/trueref-adapters/src/main/java/com/trueref/adapter/out/embedding/onnx/OnnxEmbeddingConfig.java b/trueref-adapters/src/main/java/com/trueref/adapter/out/embedding/onnx/OnnxEmbeddingConfig.java index 95816b6..fdca7e2 100644 --- a/trueref-adapters/src/main/java/com/trueref/adapter/out/embedding/onnx/OnnxEmbeddingConfig.java +++ b/trueref-adapters/src/main/java/com/trueref/adapter/out/embedding/onnx/OnnxEmbeddingConfig.java @@ -1,5 +1,11 @@ package com.trueref.adapter.out.embedding.onnx; +import com.trueref.application.model.InMemoryModelStateEventBus; +import com.trueref.application.model.ModelLifecycleService; +import com.trueref.domain.port.in.ManageModelLifecycle; +import com.trueref.domain.port.in.ObserveJobs; +import com.trueref.domain.port.out.ModelLoader; +import com.trueref.domain.port.out.ModelStateEventBus; import org.springframework.beans.factory.annotation.Value; import org.springframework.boot.context.properties.EnableConfigurationProperties; import org.springframework.context.annotation.Bean; @@ -9,7 +15,8 @@ import java.nio.file.Path; /** * Spring wiring for the ONNX embedding/reranker stack. Produces the shared {@link GpuSemaphore} - * (permits = {@code session-count}) and the resolved models-home {@link Path}. + * (permits = {@code session-count}), the resolved models-home {@link Path}, and the + * {@link ModelLifecycleService} that manages lazy loading and idle-timeout unloading. */ @Configuration @EnableConfigurationProperties(OnnxProperties.class) @@ -29,4 +36,27 @@ public class OnnxEmbeddingConfig { Path home = properties.home(); return home != null ? home : Path.of(trueRefHome).resolve("models"); } + + @Bean + ModelStateEventBus modelStateEventBus() { + return new InMemoryModelStateEventBus(); + } + + @Bean + ModelLoader onnxModelLoader(OnnxEmbeddingService embeddingService, OnnxRerankerService rerankerService) { + return new OnnxModelLoader(embeddingService, rerankerService); + } + + @Bean(initMethod = "init", destroyMethod = "shutdown") + ManageModelLifecycle modelLifecycleService( + ModelLoader modelLoader, + ModelStateEventBus modelStateEventBus, + ObserveJobs observeJobs, + OnnxProperties properties) { + return new ModelLifecycleService( + modelLoader, + modelStateEventBus, + observeJobs, + properties.idleTimeoutSecondsOrDefault()); + } } diff --git a/trueref-adapters/src/main/java/com/trueref/adapter/out/embedding/onnx/OnnxEmbeddingService.java b/trueref-adapters/src/main/java/com/trueref/adapter/out/embedding/onnx/OnnxEmbeddingService.java index 242aa24..6a5caf3 100644 --- a/trueref-adapters/src/main/java/com/trueref/adapter/out/embedding/onnx/OnnxEmbeddingService.java +++ b/trueref-adapters/src/main/java/com/trueref/adapter/out/embedding/onnx/OnnxEmbeddingService.java @@ -7,9 +7,8 @@ import ai.onnxruntime.OrtException; import ai.onnxruntime.OrtSession; import com.trueref.adapter.out.embedding.onnx.HuggingFaceTokenizerWrapper.EncodedBatch; import com.trueref.domain.error.IngestionFailed; +import com.trueref.domain.error.ModelNotReady; import com.trueref.domain.port.out.EmbeddingService; -import jakarta.annotation.PostConstruct; -import jakarta.annotation.PreDestroy; import java.nio.LongBuffer; import java.nio.file.Path; import java.util.ArrayList; @@ -68,8 +67,13 @@ public class OnnxEmbeddingService implements EmbeddingService { this.modelsHome = modelsHome; } - @PostConstruct - void start() { + /** Returns {@code true} if the model is currently loaded and accepting requests. */ + public boolean isStarted() { + return pool != null && tokenizer != null; + } + + /** Loads the ONNX model into memory. Called by {@link OnnxModelLoader} on demand. */ + public void start() { if (properties.sessionCountOrDefault() <= 0) { log.warn("OnnxEmbeddingService DISABLED (session-count=0); embed() calls will fail"); return; @@ -104,8 +108,8 @@ public class OnnxEmbeddingService implements EmbeddingService { usesTokenTypeIds); } - @PreDestroy - void stop() { + /** Unloads the ONNX model from memory. Called by {@link OnnxModelLoader} on unload. */ + public void stop() { if (pool != null) { pool.close(); pool = null; @@ -257,7 +261,7 @@ public class OnnxEmbeddingService implements EmbeddingService { private static T requireStarted(@org.jspecify.annotations.Nullable T t, String what) { if (t == null) { - throw new IllegalStateException("OnnxEmbeddingService not started: " + what + " is null"); + throw new ModelNotReady(30); } return t; } diff --git a/trueref-adapters/src/main/java/com/trueref/adapter/out/embedding/onnx/OnnxModelLoader.java b/trueref-adapters/src/main/java/com/trueref/adapter/out/embedding/onnx/OnnxModelLoader.java new file mode 100644 index 0000000..4b8197e --- /dev/null +++ b/trueref-adapters/src/main/java/com/trueref/adapter/out/embedding/onnx/OnnxModelLoader.java @@ -0,0 +1,54 @@ +package com.trueref.adapter.out.embedding.onnx; + +import com.trueref.domain.port.out.ModelLoader; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +/** + * Implements {@link ModelLoader} by delegating to the ONNX embedding and reranker services. + * Both services are loaded/unloaded as a unit. This bean drives their lifecycle instead of + * {@code @PostConstruct} / {@code @PreDestroy} so that models are loaded lazily on first demand. + */ +public class OnnxModelLoader implements ModelLoader { + + private static final Logger log = LoggerFactory.getLogger(OnnxModelLoader.class); + + private final OnnxEmbeddingService embeddingService; + private final OnnxRerankerService rerankerService; + + public OnnxModelLoader(OnnxEmbeddingService embeddingService, OnnxRerankerService rerankerService) { + this.embeddingService = embeddingService; + this.rerankerService = rerankerService; + } + + @Override + public void load() { + log.info("OnnxModelLoader: loading embedder…"); + embeddingService.start(); + log.info("OnnxModelLoader: loading reranker…"); + rerankerService.start(); + log.info("OnnxModelLoader: both models loaded"); + } + + @Override + public void unload() { + log.info("OnnxModelLoader: unloading reranker…"); + try { + rerankerService.stop(); + } catch (Exception e) { + log.warn("Reranker stop error (continuing): {}", e.toString()); + } + log.info("OnnxModelLoader: unloading embedder…"); + try { + embeddingService.stop(); + } catch (Exception e) { + log.warn("Embedder stop error (continuing): {}", e.toString()); + } + log.info("OnnxModelLoader: both models unloaded"); + } + + @Override + public boolean isLoaded() { + return embeddingService.isStarted() && rerankerService.isStarted(); + } +} diff --git a/trueref-adapters/src/main/java/com/trueref/adapter/out/embedding/onnx/OnnxProperties.java b/trueref-adapters/src/main/java/com/trueref/adapter/out/embedding/onnx/OnnxProperties.java index 4642042..95e58b0 100644 --- a/trueref-adapters/src/main/java/com/trueref/adapter/out/embedding/onnx/OnnxProperties.java +++ b/trueref-adapters/src/main/java/com/trueref/adapter/out/embedding/onnx/OnnxProperties.java @@ -21,6 +21,7 @@ public record OnnxProperties( @Nullable Integer maxSeqLen, @Nullable Integer gpuDeviceId, @Nullable Long gpuMemLimitBytes, + @Nullable Long idleTimeoutSeconds, @Nullable Path home, @Nullable Map>> modelSources) { @@ -37,6 +38,7 @@ public record OnnxProperties( if (maxSeqLen == null || maxSeqLen <= 0) maxSeqLen = 512; if (gpuDeviceId == null || gpuDeviceId < 0) gpuDeviceId = 0; if (gpuMemLimitBytes == null || gpuMemLimitBytes <= 0L) gpuMemLimitBytes = 0L; // 0 = no cap + if (idleTimeoutSeconds == null || idleTimeoutSeconds <= 0L) idleTimeoutSeconds = 300L; // 5 min if (modelSources == null) modelSources = Map.of(); } @@ -75,4 +77,8 @@ public record OnnxProperties( public long gpuMemLimitBytesOrDefault() { return gpuMemLimitBytes == null ? 0L : gpuMemLimitBytes; } + + public long idleTimeoutSecondsOrDefault() { + return idleTimeoutSeconds == null ? 300L : idleTimeoutSeconds; + } } diff --git a/trueref-adapters/src/main/java/com/trueref/adapter/out/embedding/onnx/OnnxRerankerService.java b/trueref-adapters/src/main/java/com/trueref/adapter/out/embedding/onnx/OnnxRerankerService.java index 75ad4b2..7cbd61a 100644 --- a/trueref-adapters/src/main/java/com/trueref/adapter/out/embedding/onnx/OnnxRerankerService.java +++ b/trueref-adapters/src/main/java/com/trueref/adapter/out/embedding/onnx/OnnxRerankerService.java @@ -7,10 +7,9 @@ import ai.onnxruntime.OrtException; import ai.onnxruntime.OrtSession; import com.trueref.adapter.out.embedding.onnx.HuggingFaceTokenizerWrapper.EncodedBatch; import com.trueref.domain.error.IngestionFailed; +import com.trueref.domain.error.ModelNotReady; import com.trueref.domain.model.SearchHit; import com.trueref.domain.port.out.RerankerService; -import jakarta.annotation.PostConstruct; -import jakarta.annotation.PreDestroy; import java.nio.LongBuffer; import java.nio.file.Path; import java.util.ArrayList; @@ -62,8 +61,13 @@ public class OnnxRerankerService implements RerankerService { this.modelsHome = modelsHome; } - @PostConstruct - void start() { + /** Returns {@code true} if the reranker model is currently loaded and accepting requests. */ + public boolean isStarted() { + return pool != null && tokenizer != null; + } + + /** Loads the reranker model into memory. Called by {@link OnnxModelLoader} on demand. */ + public void start() { if (properties.rerankerSessionCountOrDefault() <= 0) { log.warn("OnnxRerankerService DISABLED (reranker-session-count=0); rerank() passes through input order"); return; @@ -96,8 +100,8 @@ public class OnnxRerankerService implements RerankerService { usesTokenTypeIds); } - @PreDestroy - void stop() { + /** Unloads the reranker model from memory. Called by {@link OnnxModelLoader} on unload. */ + public void stop() { if (pool != null) { pool.close(); pool = null; @@ -252,7 +256,7 @@ public class OnnxRerankerService implements RerankerService { private static T requireStarted(@org.jspecify.annotations.Nullable T t, String what) { if (t == null) { - throw new IllegalStateException("OnnxRerankerService not started: " + what + " is null"); + throw new ModelNotReady(30); } return t; } diff --git a/trueref-adapters/src/test/java/com/trueref/adapter/in/rest/GlobalExceptionHandlerTest.java b/trueref-adapters/src/test/java/com/trueref/adapter/in/rest/GlobalExceptionHandlerTest.java new file mode 100644 index 0000000..a5798cb --- /dev/null +++ b/trueref-adapters/src/test/java/com/trueref/adapter/in/rest/GlobalExceptionHandlerTest.java @@ -0,0 +1,44 @@ +package com.trueref.adapter.in.rest; + +import static org.assertj.core.api.Assertions.assertThat; + +import com.trueref.domain.error.ModelNotReady; +import org.junit.jupiter.api.Test; +import org.springframework.http.HttpHeaders; +import org.springframework.http.HttpStatus; +import org.springframework.http.ResponseEntity; + +class GlobalExceptionHandlerTest { + + private final GlobalExceptionHandler handler = new GlobalExceptionHandler(); + + @Test + void modelNotReadyMapsto503WithRetryAfterHeader() { + ModelNotReady ex = new ModelNotReady(30); + + ResponseEntity response = handler.handleModelNotReady(ex); + + assertThat(response.getStatusCode()).isEqualTo(HttpStatus.SERVICE_UNAVAILABLE); + assertThat(response.getHeaders().getFirst(HttpHeaders.RETRY_AFTER)).isEqualTo("30"); + } + + @Test + void modelNotReadyResponseBodyHasModelNotReadyCode() { + ModelNotReady ex = new ModelNotReady(15); + + ResponseEntity response = handler.handleModelNotReady(ex); + + assertThat(response.getBody()).isNotNull(); + assertThat(response.getBody().code()).isEqualTo("model_not_ready"); + assertThat(response.getBody().message()).contains("15"); + } + + @Test + void retryAfterReflectsExceptionValue() { + ModelNotReady ex = new ModelNotReady(60); + + ResponseEntity response = handler.handleModelNotReady(ex); + + assertThat(response.getHeaders().getFirst(HttpHeaders.RETRY_AFTER)).isEqualTo("60"); + } +} diff --git a/trueref-adapters/src/test/java/com/trueref/adapter/out/embedding/onnx/OnnxModelLoaderTest.java b/trueref-adapters/src/test/java/com/trueref/adapter/out/embedding/onnx/OnnxModelLoaderTest.java new file mode 100644 index 0000000..404cf46 --- /dev/null +++ b/trueref-adapters/src/test/java/com/trueref/adapter/out/embedding/onnx/OnnxModelLoaderTest.java @@ -0,0 +1,71 @@ +package com.trueref.adapter.out.embedding.onnx; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.mockito.Mockito.doThrow; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.never; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; + +import org.junit.jupiter.api.Test; + +class OnnxModelLoaderTest { + + @Test + void loadCallsStartOnBothServices() { + OnnxEmbeddingService embedder = mock(OnnxEmbeddingService.class); + OnnxRerankerService reranker = mock(OnnxRerankerService.class); + OnnxModelLoader loader = new OnnxModelLoader(embedder, reranker); + + loader.load(); + + verify(embedder).start(); + verify(reranker).start(); + } + + @Test + void unloadCallsStopOnBothServices() { + OnnxEmbeddingService embedder = mock(OnnxEmbeddingService.class); + OnnxRerankerService reranker = mock(OnnxRerankerService.class); + OnnxModelLoader loader = new OnnxModelLoader(embedder, reranker); + + loader.unload(); + + verify(embedder).stop(); + verify(reranker).stop(); + } + + @Test + void unloadContinuesEvenIfRerankerStopThrows() { + OnnxEmbeddingService embedder = mock(OnnxEmbeddingService.class); + OnnxRerankerService reranker = mock(OnnxRerankerService.class); + doThrow(new RuntimeException("VRAM error")).when(reranker).stop(); + OnnxModelLoader loader = new OnnxModelLoader(embedder, reranker); + + loader.unload(); // must not throw + + verify(embedder).stop(); + } + + @Test + void isLoadedReturnsTrueWhenBothServicesStarted() { + OnnxEmbeddingService embedder = mock(OnnxEmbeddingService.class); + OnnxRerankerService reranker = mock(OnnxRerankerService.class); + when(embedder.isStarted()).thenReturn(true); + when(reranker.isStarted()).thenReturn(true); + OnnxModelLoader loader = new OnnxModelLoader(embedder, reranker); + + assertThat(loader.isLoaded()).isTrue(); + } + + @Test + void isLoadedReturnsFalseWhenEitherServiceNotStarted() { + OnnxEmbeddingService embedder = mock(OnnxEmbeddingService.class); + OnnxRerankerService reranker = mock(OnnxRerankerService.class); + when(embedder.isStarted()).thenReturn(true); + when(reranker.isStarted()).thenReturn(false); + OnnxModelLoader loader = new OnnxModelLoader(embedder, reranker); + + assertThat(loader.isLoaded()).isFalse(); + } +} diff --git a/trueref-adapters/src/test/java/com/trueref/application/model/InMemoryModelStateEventBusTest.java b/trueref-adapters/src/test/java/com/trueref/application/model/InMemoryModelStateEventBusTest.java new file mode 100644 index 0000000..502ed20 --- /dev/null +++ b/trueref-adapters/src/test/java/com/trueref/application/model/InMemoryModelStateEventBusTest.java @@ -0,0 +1,75 @@ +package com.trueref.application.model; + +import static org.assertj.core.api.Assertions.assertThat; + +import com.trueref.domain.model.ModelState; +import com.trueref.domain.model.ModelStateEvent; +import java.util.ArrayList; +import java.util.List; +import java.util.concurrent.atomic.AtomicInteger; +import org.junit.jupiter.api.Test; + +class InMemoryModelStateEventBusTest { + + @Test + void singleSubscriberReceivesPublishedEvent() { + InMemoryModelStateEventBus bus = new InMemoryModelStateEventBus(); + List received = new ArrayList<>(); + bus.subscribe(received::add); + + bus.publish(ModelStateEvent.of(ModelState.LOADING)); + + assertThat(received).hasSize(1); + assertThat(received.get(0).state()).isEqualTo(ModelState.LOADING); + } + + @Test + void multipleSubscribersAllReceiveEvent() { + InMemoryModelStateEventBus bus = new InMemoryModelStateEventBus(); + List a = new ArrayList<>(); + List b = new ArrayList<>(); + bus.subscribe(a::add); + bus.subscribe(b::add); + + bus.publish(ModelStateEvent.of(ModelState.LOADED)); + + assertThat(a).hasSize(1); + assertThat(b).hasSize(1); + } + + @Test + void closingSubscriptionRemovesIt() throws Exception { + InMemoryModelStateEventBus bus = new InMemoryModelStateEventBus(); + List received = new ArrayList<>(); + AutoCloseable handle = bus.subscribe(received::add); + + handle.close(); + bus.publish(ModelStateEvent.of(ModelState.UNLOADED)); + + assertThat(received).isEmpty(); + } + + @Test + void exceptionInSubscriberDoesNotBreakOthers() { + InMemoryModelStateEventBus bus = new InMemoryModelStateEventBus(); + AtomicInteger badCalls = new AtomicInteger(); + List good = new ArrayList<>(); + + bus.subscribe(e -> { + badCalls.incrementAndGet(); + throw new RuntimeException("intentional test failure"); + }); + bus.subscribe(good::add); + + bus.publish(ModelStateEvent.of(ModelState.LOADING)); + + assertThat(badCalls.get()).isEqualTo(1); + assertThat(good).hasSize(1); + } + + @Test + void noSubscribersMeansPublishIsNoop() { + InMemoryModelStateEventBus bus = new InMemoryModelStateEventBus(); + bus.publish(ModelStateEvent.of(ModelState.UNLOADED)); // must not throw + } +} diff --git a/trueref-adapters/src/test/java/com/trueref/application/model/ModelLifecycleServiceTest.java b/trueref-adapters/src/test/java/com/trueref/application/model/ModelLifecycleServiceTest.java new file mode 100644 index 0000000..cfbf16e --- /dev/null +++ b/trueref-adapters/src/test/java/com/trueref/application/model/ModelLifecycleServiceTest.java @@ -0,0 +1,227 @@ +package com.trueref.application.model; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatThrownBy; + +import com.trueref.domain.error.ModelNotReady; +import com.trueref.domain.model.IngestionJob; +import com.trueref.domain.model.JobId; +import com.trueref.domain.model.JobStatus; +import com.trueref.domain.model.JobType; +import com.trueref.domain.model.ModelState; +import com.trueref.domain.model.ModelStateEvent; +import com.trueref.domain.model.RepositoryId; +import com.trueref.domain.port.in.ManageModelLifecycle; +import com.trueref.domain.port.in.ObserveJobs; +import com.trueref.domain.port.out.ModelLoader; +import com.trueref.domain.port.out.ModelStateEventBus; +import java.time.Instant; +import java.util.ArrayList; +import java.util.List; +import java.util.Optional; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicBoolean; +import java.util.function.Consumer; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; + +class ModelLifecycleServiceTest { + + // ---- Test doubles ---- + + private final AtomicBoolean loaderLoaded = new AtomicBoolean(false); + private volatile boolean slowLoad = false; + + private final ModelLoader loader = new ModelLoader() { + @Override + public void load() { + if (slowLoad) { + try { Thread.sleep(200); } catch (InterruptedException e) { Thread.currentThread().interrupt(); } + } + loaderLoaded.set(true); + } + @Override + public void unload() { loaderLoaded.set(false); } + @Override + public boolean isLoaded() { return loaderLoaded.get(); } + }; + + private final InMemoryModelStateEventBus eventBus = new InMemoryModelStateEventBus(); + + private volatile List runningJobs = List.of(); + + private final ObserveJobs observeJobs = new ObserveJobs() { + @Override + public Optional findJob(JobId id) { return Optional.empty(); } + @Override + public List listJobs(com.trueref.domain.model.RepositoryId r, + com.trueref.domain.model.VersionId v, JobStatus status, int limit) { + if (status == JobStatus.RUNNING) return runningJobs; + return List.of(); + } + @Override + public AutoCloseable subscribeJobs(Consumer listener) { return () -> {}; } + @Override + public AutoCloseable subscribeLogs(JobId jobId, Consumer listener) { return () -> {}; } + }; + + private ModelLifecycleService service; + + @BeforeEach + void setUp() { + loaderLoaded.set(false); + slowLoad = false; + runningJobs = List.of(); + service = new ModelLifecycleService(loader, eventBus, observeJobs, 300L); + service.init(); + } + + @AfterEach + void tearDown() { + service.shutdown(); + } + + // ---- Tests ---- + + @Test + void startsInUnloadedState() { + assertThat(service.getStatus().state()).isEqualTo(ModelState.UNLOADED); + } + + @Test + void ensureReadyOnUnloadedTriggersLoadAndThrows() throws InterruptedException { + List events = new ArrayList<>(); + eventBus.subscribe(events::add); + + assertThatThrownBy(() -> service.ensureReady()) + .isInstanceOf(ModelNotReady.class); + + // Wait for async load to complete (max 2 s in test) + for (int i = 0; i < 20 && service.getStatus().state() != ModelState.LOADED; i++) { + Thread.sleep(100); + } + + assertThat(service.getStatus().state()).isEqualTo(ModelState.LOADED); + assertThat(loaderLoaded.get()).isTrue(); + assertThat(events).extracting(ModelStateEvent::state) + .containsSequence(ModelState.LOADING, ModelState.LOADED); + } + + @Test + void ensureReadyOnLoadedIsNoop() throws InterruptedException { + triggerLoad(); + + // Should not throw + service.ensureReady(); + service.ensureReady(); + + assertThat(service.getStatus().state()).isEqualTo(ModelState.LOADED); + } + + @Test + void ensureReadyOnLoadingThrowsImmediately() { + slowLoad = true; + assertThatThrownBy(() -> service.ensureReady()) + .isInstanceOf(ModelNotReady.class); + // state is LOADING, second call also throws + assertThatThrownBy(() -> service.ensureReady()) + .isInstanceOf(ModelNotReady.class); + } + + @Test + void forceUnloadFromLoadedState() throws InterruptedException { + triggerLoad(); + + boolean initiated = service.forceUnload(false); + + assertThat(initiated).isTrue(); + // Wait for async unload + for (int i = 0; i < 20 && service.getStatus().state() != ModelState.UNLOADED; i++) { + Thread.sleep(100); + } + assertThat(service.getStatus().state()).isEqualTo(ModelState.UNLOADED); + assertThat(loaderLoaded.get()).isFalse(); + } + + @Test + void forceUnloadBlockedByRunningJobsWithoutForce() throws InterruptedException { + triggerLoad(); + runningJobs = List.of(runningJob()); + + boolean initiated = service.forceUnload(false); + + assertThat(initiated).isFalse(); + assertThat(service.getStatus().state()).isEqualTo(ModelState.LOADED); + } + + @Test + void forceUnloadSucceedsWithForceEvenWhenJobsRunning() throws InterruptedException { + triggerLoad(); + runningJobs = List.of(runningJob()); + + boolean initiated = service.forceUnload(true); + + assertThat(initiated).isTrue(); + for (int i = 0; i < 20 && service.getStatus().state() != ModelState.UNLOADED; i++) { + Thread.sleep(100); + } + assertThat(service.getStatus().state()).isEqualTo(ModelState.UNLOADED); + } + + @Test + void forceUnloadWhenAlreadyUnloadedIsNoop() { + boolean initiated = service.forceUnload(false); + assertThat(initiated).isTrue(); + assertThat(service.getStatus().state()).isEqualTo(ModelState.UNLOADED); + } + + @Test + void concurrentEnsureReadyDeduplicatesLoad() throws InterruptedException { + int callCount = 0; + // Both calls throw ModelNotReady but only one load should be triggered + List events = new ArrayList<>(); + eventBus.subscribe(events::add); + + try { service.ensureReady(); } catch (ModelNotReady ignored) { callCount++; } + try { service.ensureReady(); } catch (ModelNotReady ignored) { callCount++; } + + for (int i = 0; i < 20 && service.getStatus().state() != ModelState.LOADED; i++) { + Thread.sleep(100); + } + + assertThat(callCount).isEqualTo(2); + // Only one LOADING event should be emitted (deduplicated) + long loadingCount = events.stream().filter(e -> e.state() == ModelState.LOADING).count(); + assertThat(loadingCount).isEqualTo(1); + } + + @Test + void statusIncludesLoadedAtAfterLoad() throws InterruptedException { + assertThat(service.getStatus().loadedAt()).isNull(); + triggerLoad(); + assertThat(service.getStatus().loadedAt()).isNotNull(); + } + + // ---- helpers ---- + + private void triggerLoad() throws InterruptedException { + try { service.ensureReady(); } catch (ModelNotReady ignored) {} + for (int i = 0; i < 20 && service.getStatus().state() != ModelState.LOADED; i++) { + Thread.sleep(100); + } + } + + private IngestionJob runningJob() { + return new IngestionJob( + JobId.of(java.util.UUID.randomUUID().toString()), + RepositoryId.of(java.util.UUID.randomUUID().toString()), + null, + JobType.INDEX_VERSION, + JobStatus.RUNNING, + Instant.now(), + null, + List.of()); + } +} diff --git a/trueref-application/src/main/java/com/trueref/application/model/InMemoryModelStateEventBus.java b/trueref-application/src/main/java/com/trueref/application/model/InMemoryModelStateEventBus.java new file mode 100644 index 0000000..8f0ec3a --- /dev/null +++ b/trueref-application/src/main/java/com/trueref/application/model/InMemoryModelStateEventBus.java @@ -0,0 +1,35 @@ +package com.trueref.application.model; + +import com.trueref.domain.model.ModelStateEvent; +import com.trueref.domain.port.out.ModelStateEventBus; +import java.util.concurrent.CopyOnWriteArrayList; +import java.util.function.Consumer; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +/** In-memory, thread-safe implementation of {@link ModelStateEventBus}. */ +public class InMemoryModelStateEventBus implements ModelStateEventBus { + + private static final Logger log = LoggerFactory.getLogger(InMemoryModelStateEventBus.class); + + private final CopyOnWriteArrayList> subscribers = + new CopyOnWriteArrayList<>(); + + @Override + public void publish(ModelStateEvent event) { + for (Consumer subscriber : subscribers) { + try { + subscriber.accept(event); + } catch (Exception e) { + log.warn("Model state subscriber threw during publish; removing: {}", e.toString()); + subscribers.remove(subscriber); + } + } + } + + @Override + public AutoCloseable subscribe(Consumer subscriber) { + subscribers.add(subscriber); + return () -> subscribers.remove(subscriber); + } +} diff --git a/trueref-application/src/main/java/com/trueref/application/model/ModelLifecycleService.java b/trueref-application/src/main/java/com/trueref/application/model/ModelLifecycleService.java new file mode 100644 index 0000000..ee4cfc7 --- /dev/null +++ b/trueref-application/src/main/java/com/trueref/application/model/ModelLifecycleService.java @@ -0,0 +1,287 @@ +package com.trueref.application.model; + +import com.trueref.domain.error.ModelNotReady; +import com.trueref.domain.model.IngestionJob; +import com.trueref.domain.model.JobStatus; +import com.trueref.domain.model.ModelState; +import com.trueref.domain.model.ModelStateEvent; +import com.trueref.domain.port.in.ManageModelLifecycle; +import com.trueref.domain.port.in.ObserveJobs; +import com.trueref.domain.port.out.ModelLoader; +import com.trueref.domain.port.out.ModelStateEventBus; +import java.time.Instant; +import java.util.concurrent.Executors; +import java.util.concurrent.Future; +import java.util.concurrent.ScheduledExecutorService; +import java.util.concurrent.ThreadFactory; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicLong; +import java.util.concurrent.atomic.AtomicReference; +import java.util.concurrent.locks.ReentrantLock; +import java.util.function.Consumer; +import org.jspecify.annotations.Nullable; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +/** + * Orchestrates the GPU model lifecycle: lazy loading, auto-unload after idle timeout, and + * force-unload. State transitions are protected by a {@link ReentrantLock}. Heavy load/unload + * operations are executed on a dedicated single-thread platform-thread executor to satisfy CUDA + * OS-thread affinity constraints. + */ +public class ModelLifecycleService implements ManageModelLifecycle { + + private static final Logger log = LoggerFactory.getLogger(ModelLifecycleService.class); + + /** Default retry hint in seconds when the model is LOADING. Load typically takes 5–30 s. */ + private static final int LOADING_RETRY_AFTER_SECONDS = 30; + + /** Idle-check interval. One check per minute is sufficient resolution for a 5-min timeout. */ + private static final long IDLE_CHECK_INTERVAL_SECONDS = 60L; + + private final ModelLoader modelLoader; + private final ModelStateEventBus eventBus; + private final ObserveJobs observeJobs; + private final long idleTimeoutSeconds; + + // ---- state ---- + private final ReentrantLock lock = new ReentrantLock(); + private final AtomicReference state = new AtomicReference<>(ModelState.UNLOADED); + private volatile @Nullable Instant loadedAt = null; + private final AtomicLong lastActivityEpochMs = new AtomicLong(0L); + + // ---- executors ---- + /** Single platform OS thread for all CUDA load/unload calls. */ + private final java.util.concurrent.ExecutorService cudaExecutor; + + /** Scheduler that periodically checks for idle timeout. */ + private final ScheduledExecutorService idleScheduler; + + /** Handle to the in-flight async load Future (guarded by lock). */ + private @Nullable Future pendingLoad = null; + + public ModelLifecycleService( + ModelLoader modelLoader, + ModelStateEventBus eventBus, + ObserveJobs observeJobs, + long idleTimeoutSeconds) { + this.modelLoader = modelLoader; + this.eventBus = eventBus; + this.observeJobs = observeJobs; + this.idleTimeoutSeconds = idleTimeoutSeconds; + + ThreadFactory platformFactory = Thread.ofPlatform() + .name("model-lifecycle-cuda") + .factory(); + this.cudaExecutor = Executors.newSingleThreadExecutor(platformFactory); + this.idleScheduler = Executors.newSingleThreadScheduledExecutor( + Thread.ofPlatform().name("model-idle-check").factory()); + } + + /** Initializes the idle timer and job-event subscription. Called by the Spring bean lifecycle. */ + public void init() { + // Subscribe to job events so that finishing a job resets the idle clock. + observeJobs.subscribeJobs(this::onJobEvent); + // Start the periodic idle check. + idleScheduler.scheduleAtFixedRate( + this::checkIdleTimeout, + IDLE_CHECK_INTERVAL_SECONDS, + IDLE_CHECK_INTERVAL_SECONDS, + TimeUnit.SECONDS); + log.info("ModelLifecycleService started; idleTimeoutSeconds={} state=UNLOADED", idleTimeoutSeconds); + } + + /** Shuts down schedulers and unloads models if loaded. Called by the Spring bean lifecycle. */ + public void shutdown() { + idleScheduler.shutdownNow(); + // Unload synchronously on shutdown so VRAM is released cleanly. + if (state.get() == ModelState.LOADED || state.get() == ModelState.LOADING) { + doUnload("JVM shutdown"); + } + cudaExecutor.shutdownNow(); + } + + // ---- ManageModelLifecycle ---- + + @Override + public void ensureReady() { + // Fast path: already loaded. + if (state.get() == ModelState.LOADED) { + touchActivity(); + return; + } + lock.lock(); + try { + ModelState current = state.get(); + switch (current) { + case LOADED -> { + touchActivity(); + return; + } + case LOADING -> throw new ModelNotReady(LOADING_RETRY_AFTER_SECONDS); + case UNLOADING -> throw new ModelNotReady(LOADING_RETRY_AFTER_SECONDS); + case UNLOADED -> { + triggerAsyncLoad("triggered by incoming request"); + throw new ModelNotReady(LOADING_RETRY_AFTER_SECONDS); + } + } + } finally { + lock.unlock(); + } + } + + @Override + public boolean forceUnload(boolean force) { + lock.lock(); + try { + ModelState current = state.get(); + if (current == ModelState.UNLOADED || current == ModelState.UNLOADING) { + log.info("forceUnload: already {} — no-op", current); + return true; + } + if (!force && hasRunningJobs()) { + log.info("forceUnload: blocked by running ingestion jobs (use force=true to override)"); + return false; + } + if (current == ModelState.LOADING && pendingLoad != null) { + pendingLoad.cancel(false); // try to cancel; unload will follow after transition + } + doUnload("force-unload via API"); + return true; + } finally { + lock.unlock(); + } + } + + @Override + public Status getStatus() { + return new Status(state.get(), loadedAt, epochMsToInstant(lastActivityEpochMs.get()), idleTimeoutSeconds); + } + + @Override + public AutoCloseable subscribeState(Consumer subscriber) { + return eventBus.subscribe(subscriber); + } + + // ---- idle timer ---- + + private void checkIdleTimeout() { + if (state.get() != ModelState.LOADED) { + return; + } + long lastMs = lastActivityEpochMs.get(); + if (lastMs == 0L) { + return; // no activity yet + } + long idleMs = System.currentTimeMillis() - lastMs; + if (idleMs < idleTimeoutSeconds * 1000L) { + return; + } + lock.lock(); + try { + if (state.get() != ModelState.LOADED) { + return; // state changed while we waited for the lock + } + if (hasRunningJobs()) { + log.debug("idle timeout reached but ingestion jobs are running; skipping unload"); + return; + } + long nowIdleMs = System.currentTimeMillis() - lastActivityEpochMs.get(); + if (nowIdleMs < idleTimeoutSeconds * 1000L) { + return; // activity happened while we waited for the lock + } + doUnload("idle timeout after " + (nowIdleMs / 1000) + " s"); + } finally { + lock.unlock(); + } + } + + // ---- internal helpers (always called while holding lock) ---- + + /** + * Submits an async load to the CUDA executor. Must be called while holding {@code lock}. + * Transitions state UNLOADED → LOADING. + */ + private void triggerAsyncLoad(String reason) { + if (pendingLoad != null && !pendingLoad.isDone()) { + return; // already loading — deduplicate + } + transition(ModelState.LOADING, "load requested: " + reason); + pendingLoad = cudaExecutor.submit(() -> { + try { + log.info("Loading GPU models…"); + modelLoader.load(); + lock.lock(); + try { + loadedAt = Instant.now(); + touchActivity(); + transition(ModelState.LOADED, null); + log.info("GPU models loaded successfully"); + } finally { + lock.unlock(); + } + } catch (Exception e) { + log.error("GPU model load failed", e); + lock.lock(); + try { + transition(ModelState.UNLOADED, "load failed: " + e.getMessage()); + } finally { + lock.unlock(); + } + } + }); + } + + /** + * Runs an unload synchronously on the calling thread (must already be the CUDA executor thread + * or called during shutdown). For runtime calls it is submitted to the cuda executor. + * Must be called while holding {@code lock}. + */ + private void doUnload(String reason) { + transition(ModelState.UNLOADING, reason); + // Run unload on the CUDA executor so it stays on the same OS thread as the sessions. + cudaExecutor.execute(() -> { + try { + log.info("Unloading GPU models…"); + modelLoader.unload(); + log.info("GPU models unloaded"); + } catch (Exception e) { + log.error("GPU model unload error (models may still be in VRAM)", e); + } finally { + lock.lock(); + try { + loadedAt = null; + transition(ModelState.UNLOADED, null); + } finally { + lock.unlock(); + } + } + }); + } + + private void transition(ModelState next, @Nullable String message) { + ModelState prev = state.getAndSet(next); + if (prev != next) { + log.debug("Model state: {} → {}{}", prev, next, message != null ? " [" + message + "]" : ""); + eventBus.publish(ModelStateEvent.of(next, message != null ? message : "")); + } + } + + private void touchActivity() { + lastActivityEpochMs.set(System.currentTimeMillis()); + } + + private boolean hasRunningJobs() { + return !observeJobs.listJobs(null, null, JobStatus.RUNNING, 1).isEmpty(); + } + + private void onJobEvent(IngestionJob job) { + if (job.status() == JobStatus.SUCCEEDED || job.status() == JobStatus.FAILED) { + touchActivity(); // reset idle clock when a job finishes + } + } + + private static @Nullable Instant epochMsToInstant(long ms) { + return ms == 0L ? null : Instant.ofEpochMilli(ms); + } +} diff --git a/trueref-application/src/main/java/com/trueref/application/search/HybridSearchService.java b/trueref-application/src/main/java/com/trueref/application/search/HybridSearchService.java index d9a8b38..6023e5a 100644 --- a/trueref-application/src/main/java/com/trueref/application/search/HybridSearchService.java +++ b/trueref-application/src/main/java/com/trueref/application/search/HybridSearchService.java @@ -6,6 +6,7 @@ import com.trueref.domain.model.Repository; import com.trueref.domain.model.SearchHit; import com.trueref.domain.model.SearchScope; import com.trueref.domain.model.Version; +import com.trueref.domain.port.in.ManageModelLifecycle; import com.trueref.domain.port.in.SearchLibraryDocs; import com.trueref.domain.port.out.ChunkStore; import com.trueref.domain.port.out.EmbeddingService; @@ -42,6 +43,7 @@ public final class HybridSearchService implements SearchLibraryDocs { private final EmbeddingService embedder; private final RerankerService reranker; private final RepositoryStore repos; + private final ManageModelLifecycle lifecycle; private final int rrfK; private final int rerankTopK; private final int finalTopK; @@ -51,6 +53,7 @@ public final class HybridSearchService implements SearchLibraryDocs { EmbeddingService embedder, RerankerService reranker, RepositoryStore repos, + ManageModelLifecycle lifecycle, int rrfK, int rerankTopK, int finalTopK) { @@ -58,6 +61,7 @@ public final class HybridSearchService implements SearchLibraryDocs { this.embedder = embedder; this.reranker = reranker; this.repos = repos; + this.lifecycle = lifecycle; this.rrfK = rrfK; this.rerankTopK = rerankTopK; this.finalTopK = finalTopK; @@ -65,6 +69,9 @@ public final class HybridSearchService implements SearchLibraryDocs { @Override public Result search(Query q) { + // Ensure models are loaded; throws ModelNotReady (→ HTTP 503) if not. + lifecycle.ensureReady(); + if (q.text() == null || q.text().isBlank()) { throw new InvalidSearchRequest("query text must not be blank"); } diff --git a/trueref-bootstrap/src/main/java/com/trueref/bootstrap/ApplicationBeans.java b/trueref-bootstrap/src/main/java/com/trueref/bootstrap/ApplicationBeans.java index c3c2b76..fa529bd 100644 --- a/trueref-bootstrap/src/main/java/com/trueref/bootstrap/ApplicationBeans.java +++ b/trueref-bootstrap/src/main/java/com/trueref/bootstrap/ApplicationBeans.java @@ -7,6 +7,7 @@ import com.trueref.application.observability.InMemoryJobEventBus; import com.trueref.application.observability.JobObservationService; import com.trueref.application.resolve.LibraryResolver; import com.trueref.application.search.HybridSearchService; +import com.trueref.domain.port.in.ManageModelLifecycle; import com.trueref.domain.port.out.ChunkStore; import com.trueref.domain.port.out.CodeParser; import com.trueref.domain.port.out.EmbeddingCache; @@ -76,10 +77,11 @@ public class ApplicationBeans { EmbeddingService embedder, RerankerService reranker, RepositoryStore repos, + ManageModelLifecycle lifecycle, @Value("${trueref.search.rrf-k:60}") int rrfK, @Value("${trueref.reranker.top-k:50}") int rerankTopK, @Value("${trueref.search.final-top-k:20}") int finalTopK) { - return new HybridSearchService(chunks, embedder, reranker, repos, rrfK, rerankTopK, finalTopK); + return new HybridSearchService(chunks, embedder, reranker, repos, lifecycle, rrfK, rerankTopK, finalTopK); } @Bean diff --git a/trueref-domain/src/main/java/com/trueref/domain/error/ModelNotReady.java b/trueref-domain/src/main/java/com/trueref/domain/error/ModelNotReady.java new file mode 100644 index 0000000..e27e179 --- /dev/null +++ b/trueref-domain/src/main/java/com/trueref/domain/error/ModelNotReady.java @@ -0,0 +1,23 @@ +package com.trueref.domain.error; + +/** + * Thrown when an inference request arrives while the GPU models are not yet loaded (or are being + * unloaded). Callers should surface this as HTTP 503 with a {@code Retry-After} header, or as an + * MCP error inviting retry. + */ +public final class ModelNotReady extends TrueRefException { + + private final int retryAfterSeconds; + + public ModelNotReady(int retryAfterSeconds) { + super("model_not_ready", + "Model is not ready, retry in ~" + retryAfterSeconds + " seconds", + null); + this.retryAfterSeconds = retryAfterSeconds; + } + + /** Suggested number of seconds the caller should wait before retrying. */ + public int retryAfterSeconds() { + return retryAfterSeconds; + } +} diff --git a/trueref-domain/src/main/java/com/trueref/domain/error/TrueRefException.java b/trueref-domain/src/main/java/com/trueref/domain/error/TrueRefException.java index 4c53b5d..76515c4 100644 --- a/trueref-domain/src/main/java/com/trueref/domain/error/TrueRefException.java +++ b/trueref-domain/src/main/java/com/trueref/domain/error/TrueRefException.java @@ -10,7 +10,8 @@ public abstract sealed class TrueRefException extends RuntimeException VersionNotIndexed, TagNotFound, IngestionFailed, - InvalidSearchRequest { + InvalidSearchRequest, + ModelNotReady { private final String code; diff --git a/trueref-domain/src/main/java/com/trueref/domain/model/ModelState.java b/trueref-domain/src/main/java/com/trueref/domain/model/ModelState.java new file mode 100644 index 0000000..33dc1c6 --- /dev/null +++ b/trueref-domain/src/main/java/com/trueref/domain/model/ModelState.java @@ -0,0 +1,13 @@ +package com.trueref.domain.model; + +/** Life-cycle state of the GPU inference models (embedder + reranker). */ +public enum ModelState { + /** Models are not loaded; no VRAM is consumed. */ + UNLOADED, + /** Models are being loaded into GPU memory (load in progress). */ + LOADING, + /** Models are loaded and ready to serve inference requests. */ + LOADED, + /** Models are being unloaded from GPU memory. */ + UNLOADING +} diff --git a/trueref-domain/src/main/java/com/trueref/domain/model/ModelStateEvent.java b/trueref-domain/src/main/java/com/trueref/domain/model/ModelStateEvent.java new file mode 100644 index 0000000..f579198 --- /dev/null +++ b/trueref-domain/src/main/java/com/trueref/domain/model/ModelStateEvent.java @@ -0,0 +1,18 @@ +package com.trueref.domain.model; + +import java.time.Instant; +import org.jspecify.annotations.Nullable; + +/** + * Immutable event emitted whenever the GPU model lifecycle transitions to a new {@link ModelState}. + */ +public record ModelStateEvent(ModelState state, Instant ts, @Nullable String message) { + + public static ModelStateEvent of(ModelState state) { + return new ModelStateEvent(state, Instant.now(), null); + } + + public static ModelStateEvent of(ModelState state, String message) { + return new ModelStateEvent(state, Instant.now(), message); + } +} diff --git a/trueref-domain/src/main/java/com/trueref/domain/port/in/ManageModelLifecycle.java b/trueref-domain/src/main/java/com/trueref/domain/port/in/ManageModelLifecycle.java new file mode 100644 index 0000000..b317d77 --- /dev/null +++ b/trueref-domain/src/main/java/com/trueref/domain/port/in/ManageModelLifecycle.java @@ -0,0 +1,47 @@ +package com.trueref.domain.port.in; + +import com.trueref.domain.model.ModelState; +import com.trueref.domain.model.ModelStateEvent; +import java.time.Instant; +import java.util.function.Consumer; +import org.jspecify.annotations.Nullable; + +/** Use-case port: manage the lifecycle of the GPU inference models. */ +public interface ManageModelLifecycle { + + /** + * Ensures the model is loaded and ready. If the model is already LOADED, returns immediately. + * If the model is UNLOADED, triggers an async load and throws + * {@link com.trueref.domain.error.ModelNotReady}. If the model is LOADING or UNLOADING, + * throws {@link com.trueref.domain.error.ModelNotReady}. + * + *

Also records the current time as the last-activity timestamp used by the idle timer. + */ + void ensureReady(); + + /** + * Forces an unload of the GPU models. Blocked if ingestion jobs are running, unless + * {@code force=true}. + * + * @param force when {@code true}, unloads even while jobs are running + * @return {@code true} if unload was initiated; {@code false} if blocked by running jobs + */ + boolean forceUnload(boolean force); + + /** Returns a snapshot of the current model lifecycle status. */ + Status getStatus(); + + /** + * Registers a subscriber to receive model state-change events. + * + * @return an {@link AutoCloseable} that removes the subscription when closed + */ + AutoCloseable subscribeState(Consumer subscriber); + + /** Snapshot of the current model lifecycle status. */ + record Status( + ModelState state, + @Nullable Instant loadedAt, + @Nullable Instant lastActivityAt, + long idleTimeoutSeconds) {} +} diff --git a/trueref-domain/src/main/java/com/trueref/domain/port/out/ModelLoader.java b/trueref-domain/src/main/java/com/trueref/domain/port/out/ModelLoader.java new file mode 100644 index 0000000..e9c6e98 --- /dev/null +++ b/trueref-domain/src/main/java/com/trueref/domain/port/out/ModelLoader.java @@ -0,0 +1,18 @@ +package com.trueref.domain.port.out; + +/** + * Loads and unloads the GPU inference models (embedder + reranker) as a shared lifecycle unit. + * Implementations are expected to run load/unload on a platform OS thread to satisfy CUDA + * context affinity constraints. + */ +public interface ModelLoader { + + /** Loads both models into GPU memory. Blocks until ready or throws on error. */ + void load(); + + /** Releases both models from GPU memory. Idempotent. */ + void unload(); + + /** Returns {@code true} if both models are currently loaded and accepting inference. */ + boolean isLoaded(); +} diff --git a/trueref-domain/src/main/java/com/trueref/domain/port/out/ModelStateEventBus.java b/trueref-domain/src/main/java/com/trueref/domain/port/out/ModelStateEventBus.java new file mode 100644 index 0000000..7a6a464 --- /dev/null +++ b/trueref-domain/src/main/java/com/trueref/domain/port/out/ModelStateEventBus.java @@ -0,0 +1,18 @@ +package com.trueref.domain.port.out; + +import com.trueref.domain.model.ModelStateEvent; +import java.util.function.Consumer; + +/** Event bus for broadcasting {@link ModelStateEvent}s to subscribers (e.g. SSE connections). */ +public interface ModelStateEventBus { + + /** Publishes a state-change event to all current subscribers. */ + void publish(ModelStateEvent event); + + /** + * Registers a subscriber to receive future events. + * + * @return an {@link AutoCloseable} that removes the subscription when closed + */ + AutoCloseable subscribe(Consumer subscriber); +}