feat: GPU model lazy-load/unload lifecycle management
- Domain: add ModelState, ModelStateEvent, ModelNotReady, ManageModelLifecycle (in-port), ModelLoader and ModelStateEventBus (out-ports) - Application: InMemoryModelStateEventBus; ModelLifecycleService — state machine (ReentrantLock), lazy load on first request, idle-timeout auto-unload (configurable via trueref.embedding.idle-timeout-seconds, default 300 s), job-guard (skips unload while ingestion running), platform-thread CUDA executor - Adapters: OnnxModelLoader wires embedder + reranker start/stop; remove @PostConstruct/@PreDestroy from OnnxEmbeddingService and OnnxRerankerService; requireStarted() now throws ModelNotReady instead of IllegalStateException - REST: GET /api/model/status, POST /api/model/unload (409 when jobs running, force=true to override), GET /api/model/status/stream (SSE) - GlobalExceptionHandler: ModelNotReady -> 503 + Retry-After header - HybridSearchService: calls lifecycle.ensureReady() before every search so both REST and MCP paths get ModelNotReady (-> 503 / MCP error) when unloaded - TrueRefMcpTools: catches ModelNotReady, returns retry hint in MCP error text - Tests: InMemoryModelStateEventBusTest, ModelLifecycleServiceTest (10 cases), OnnxModelLoaderTest, GlobalExceptionHandlerTest — all 41 tests green Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
This commit is contained in:
@@ -6,6 +6,7 @@ import com.trueref.domain.model.SearchHit;
|
|||||||
import com.trueref.domain.model.SearchScope;
|
import com.trueref.domain.model.SearchScope;
|
||||||
import com.trueref.domain.model.Version;
|
import com.trueref.domain.model.Version;
|
||||||
import com.trueref.domain.model.VersionStatus;
|
import com.trueref.domain.model.VersionStatus;
|
||||||
|
import com.trueref.domain.error.ModelNotReady;
|
||||||
import com.trueref.domain.port.in.IndexVersion;
|
import com.trueref.domain.port.in.IndexVersion;
|
||||||
import com.trueref.domain.port.in.QueryCatalog;
|
import com.trueref.domain.port.in.QueryCatalog;
|
||||||
import com.trueref.domain.port.in.ResolveLibraryId;
|
import com.trueref.domain.port.in.ResolveLibraryId;
|
||||||
@@ -153,6 +154,9 @@ public class TrueRefMcpTools {
|
|||||||
SearchLibraryDocs.Result res;
|
SearchLibraryDocs.Result res;
|
||||||
try {
|
try {
|
||||||
res = search.search(q);
|
res = search.search(q);
|
||||||
|
} catch (ModelNotReady e) {
|
||||||
|
return "[model_not_ready] The inference model is loading. "
|
||||||
|
+ "Please retry in ~" + e.retryAfterSeconds() + " seconds.";
|
||||||
} catch (Exception e) {
|
} catch (Exception e) {
|
||||||
log.warn("MCP search failed for {}: {}", libraryId, e.toString());
|
log.warn("MCP search failed for {}: {}", libraryId, e.toString());
|
||||||
return "Search failed for " + libraryId + ": " + e.getMessage();
|
return "Search failed for " + libraryId + ": " + e.getMessage();
|
||||||
|
|||||||
@@ -2,6 +2,7 @@ package com.trueref.adapter.in.rest;
|
|||||||
|
|
||||||
import com.trueref.domain.error.IngestionFailed;
|
import com.trueref.domain.error.IngestionFailed;
|
||||||
import com.trueref.domain.error.InvalidSearchRequest;
|
import com.trueref.domain.error.InvalidSearchRequest;
|
||||||
|
import com.trueref.domain.error.ModelNotReady;
|
||||||
import com.trueref.domain.error.RepositoryAlreadyRegistered;
|
import com.trueref.domain.error.RepositoryAlreadyRegistered;
|
||||||
import com.trueref.domain.error.RepositoryNotFound;
|
import com.trueref.domain.error.RepositoryNotFound;
|
||||||
import com.trueref.domain.error.TagNotFound;
|
import com.trueref.domain.error.TagNotFound;
|
||||||
@@ -12,6 +13,7 @@ import jakarta.validation.ConstraintViolationException;
|
|||||||
import java.util.List;
|
import java.util.List;
|
||||||
import org.slf4j.Logger;
|
import org.slf4j.Logger;
|
||||||
import org.slf4j.LoggerFactory;
|
import org.slf4j.LoggerFactory;
|
||||||
|
import org.springframework.http.HttpHeaders;
|
||||||
import org.springframework.http.HttpStatus;
|
import org.springframework.http.HttpStatus;
|
||||||
import org.springframework.http.ResponseEntity;
|
import org.springframework.http.ResponseEntity;
|
||||||
import org.springframework.validation.FieldError;
|
import org.springframework.validation.FieldError;
|
||||||
@@ -28,6 +30,15 @@ public class GlobalExceptionHandler {
|
|||||||
|
|
||||||
private static final Logger log = LoggerFactory.getLogger(GlobalExceptionHandler.class);
|
private static final Logger log = LoggerFactory.getLogger(GlobalExceptionHandler.class);
|
||||||
|
|
||||||
|
@ExceptionHandler(ModelNotReady.class)
|
||||||
|
public ResponseEntity<ErrorResponse> handleModelNotReady(ModelNotReady ex) {
|
||||||
|
HttpHeaders headers = new HttpHeaders();
|
||||||
|
headers.set(HttpHeaders.RETRY_AFTER, String.valueOf(ex.retryAfterSeconds()));
|
||||||
|
return ResponseEntity.status(HttpStatus.SERVICE_UNAVAILABLE)
|
||||||
|
.headers(headers)
|
||||||
|
.body(ErrorResponse.of(ex.code(), ex.getMessage()));
|
||||||
|
}
|
||||||
|
|
||||||
@ExceptionHandler({RepositoryNotFound.class, VersionNotFound.class, TagNotFound.class})
|
@ExceptionHandler({RepositoryNotFound.class, VersionNotFound.class, TagNotFound.class})
|
||||||
public ResponseEntity<ErrorResponse> handleNotFound(TrueRefException ex) {
|
public ResponseEntity<ErrorResponse> handleNotFound(TrueRefException ex) {
|
||||||
return status(HttpStatus.NOT_FOUND, ex);
|
return status(HttpStatus.NOT_FOUND, ex);
|
||||||
|
|||||||
@@ -0,0 +1,106 @@
|
|||||||
|
package com.trueref.adapter.in.rest;
|
||||||
|
|
||||||
|
import com.trueref.adapter.in.rest.dto.ModelStatusDto;
|
||||||
|
import com.trueref.domain.model.ModelState;
|
||||||
|
import com.trueref.domain.port.in.ManageModelLifecycle;
|
||||||
|
import io.swagger.v3.oas.annotations.Operation;
|
||||||
|
import io.swagger.v3.oas.annotations.tags.Tag;
|
||||||
|
import java.io.IOException;
|
||||||
|
import org.slf4j.Logger;
|
||||||
|
import org.slf4j.LoggerFactory;
|
||||||
|
import org.springframework.http.HttpStatus;
|
||||||
|
import org.springframework.http.MediaType;
|
||||||
|
import org.springframework.web.bind.annotation.GetMapping;
|
||||||
|
import org.springframework.web.bind.annotation.PostMapping;
|
||||||
|
import org.springframework.web.bind.annotation.RequestMapping;
|
||||||
|
import org.springframework.web.bind.annotation.RequestParam;
|
||||||
|
import org.springframework.web.bind.annotation.RestController;
|
||||||
|
import org.springframework.web.server.ResponseStatusException;
|
||||||
|
import org.springframework.web.servlet.mvc.method.annotation.SseEmitter;
|
||||||
|
|
||||||
|
/** REST resource: {@code /api/model}. */
|
||||||
|
@RestController
|
||||||
|
@RequestMapping("/api/model")
|
||||||
|
@Tag(name = "model", description = "GPU model lifecycle management (load state, force unload, SSE status stream).")
|
||||||
|
public class ModelController {
|
||||||
|
|
||||||
|
private static final Logger log = LoggerFactory.getLogger(ModelController.class);
|
||||||
|
|
||||||
|
private final ManageModelLifecycle lifecycle;
|
||||||
|
|
||||||
|
public ModelController(ManageModelLifecycle lifecycle) {
|
||||||
|
this.lifecycle = lifecycle;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Operation(summary = "Current model lifecycle status.")
|
||||||
|
@GetMapping("/status")
|
||||||
|
public ModelStatusDto status() {
|
||||||
|
return ModelStatusDto.of(lifecycle.getStatus());
|
||||||
|
}
|
||||||
|
|
||||||
|
@Operation(summary = "Force-unload the GPU models from VRAM. "
|
||||||
|
+ "Returns 409 if ingestion jobs are running (use force=true to override).")
|
||||||
|
@PostMapping("/unload")
|
||||||
|
public ModelStatusDto unload(@RequestParam(value = "force", defaultValue = "false") boolean force) {
|
||||||
|
boolean initiated = lifecycle.forceUnload(force);
|
||||||
|
if (!initiated) {
|
||||||
|
throw new ResponseStatusException(
|
||||||
|
HttpStatus.CONFLICT,
|
||||||
|
"model_unload_blocked_by_running_jobs");
|
||||||
|
}
|
||||||
|
return ModelStatusDto.of(lifecycle.getStatus());
|
||||||
|
}
|
||||||
|
|
||||||
|
@Operation(summary = "Server-Sent Events stream of model state transitions "
|
||||||
|
+ "(LOADING, LOADED, UNLOADING, UNLOADED).")
|
||||||
|
@GetMapping(value = "/status/stream", produces = MediaType.TEXT_EVENT_STREAM_VALUE)
|
||||||
|
public SseEmitter statusStream() {
|
||||||
|
SseEmitter emitter = new SseEmitter(0L);
|
||||||
|
|
||||||
|
// Flush headers to the client immediately so EventSource fires 'open'.
|
||||||
|
try {
|
||||||
|
emitter.send(SseEmitter.event().name("ping").data(""));
|
||||||
|
// Send current state immediately so the client doesn't have to wait for the next transition.
|
||||||
|
emitter.send(SseEmitter.event()
|
||||||
|
.name("model-status")
|
||||||
|
.data(lifecycle.getStatus().state().name()));
|
||||||
|
} catch (IOException e) {
|
||||||
|
emitter.completeWithError(e);
|
||||||
|
return emitter;
|
||||||
|
}
|
||||||
|
|
||||||
|
AutoCloseable subscription = lifecycle.subscribeState(event -> {
|
||||||
|
try {
|
||||||
|
emitter.send(SseEmitter.event().name("model-status").data(event));
|
||||||
|
} catch (IOException ex) {
|
||||||
|
emitter.completeWithError(ex);
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
|
Thread keepalive = Thread.startVirtualThread(() -> {
|
||||||
|
try {
|
||||||
|
while (!Thread.currentThread().isInterrupted()) {
|
||||||
|
Thread.sleep(20_000);
|
||||||
|
emitter.send(SseEmitter.event().name("ping").data(""));
|
||||||
|
}
|
||||||
|
} catch (InterruptedException ignored) {
|
||||||
|
// normal shutdown
|
||||||
|
} catch (Exception ignored) {
|
||||||
|
// emitter already completed; exit
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
|
Runnable cleanup = () -> {
|
||||||
|
keepalive.interrupt();
|
||||||
|
try {
|
||||||
|
subscription.close();
|
||||||
|
} catch (Exception ex) {
|
||||||
|
log.debug("failed to close model status SSE subscription: {}", ex.toString());
|
||||||
|
}
|
||||||
|
};
|
||||||
|
emitter.onCompletion(cleanup);
|
||||||
|
emitter.onTimeout(cleanup);
|
||||||
|
emitter.onError(e -> cleanup.run());
|
||||||
|
return emitter;
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -0,0 +1,22 @@
|
|||||||
|
package com.trueref.adapter.in.rest.dto;
|
||||||
|
|
||||||
|
import com.trueref.domain.model.ModelState;
|
||||||
|
import com.trueref.domain.port.in.ManageModelLifecycle;
|
||||||
|
import java.time.Instant;
|
||||||
|
import org.jspecify.annotations.Nullable;
|
||||||
|
|
||||||
|
/** Snapshot of the GPU model lifecycle state, returned by {@code GET /api/model/status}. */
|
||||||
|
public record ModelStatusDto(
|
||||||
|
ModelState state,
|
||||||
|
@Nullable Instant loadedAt,
|
||||||
|
@Nullable Instant lastActivityAt,
|
||||||
|
long idleTimeoutSeconds) {
|
||||||
|
|
||||||
|
public static ModelStatusDto of(ManageModelLifecycle.Status status) {
|
||||||
|
return new ModelStatusDto(
|
||||||
|
status.state(),
|
||||||
|
status.loadedAt(),
|
||||||
|
status.lastActivityAt(),
|
||||||
|
status.idleTimeoutSeconds());
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -1,5 +1,11 @@
|
|||||||
package com.trueref.adapter.out.embedding.onnx;
|
package com.trueref.adapter.out.embedding.onnx;
|
||||||
|
|
||||||
|
import com.trueref.application.model.InMemoryModelStateEventBus;
|
||||||
|
import com.trueref.application.model.ModelLifecycleService;
|
||||||
|
import com.trueref.domain.port.in.ManageModelLifecycle;
|
||||||
|
import com.trueref.domain.port.in.ObserveJobs;
|
||||||
|
import com.trueref.domain.port.out.ModelLoader;
|
||||||
|
import com.trueref.domain.port.out.ModelStateEventBus;
|
||||||
import org.springframework.beans.factory.annotation.Value;
|
import org.springframework.beans.factory.annotation.Value;
|
||||||
import org.springframework.boot.context.properties.EnableConfigurationProperties;
|
import org.springframework.boot.context.properties.EnableConfigurationProperties;
|
||||||
import org.springframework.context.annotation.Bean;
|
import org.springframework.context.annotation.Bean;
|
||||||
@@ -9,7 +15,8 @@ import java.nio.file.Path;
|
|||||||
|
|
||||||
/**
|
/**
|
||||||
* Spring wiring for the ONNX embedding/reranker stack. Produces the shared {@link GpuSemaphore}
|
* Spring wiring for the ONNX embedding/reranker stack. Produces the shared {@link GpuSemaphore}
|
||||||
* (permits = {@code session-count}) and the resolved models-home {@link Path}.
|
* (permits = {@code session-count}), the resolved models-home {@link Path}, and the
|
||||||
|
* {@link ModelLifecycleService} that manages lazy loading and idle-timeout unloading.
|
||||||
*/
|
*/
|
||||||
@Configuration
|
@Configuration
|
||||||
@EnableConfigurationProperties(OnnxProperties.class)
|
@EnableConfigurationProperties(OnnxProperties.class)
|
||||||
@@ -29,4 +36,27 @@ public class OnnxEmbeddingConfig {
|
|||||||
Path home = properties.home();
|
Path home = properties.home();
|
||||||
return home != null ? home : Path.of(trueRefHome).resolve("models");
|
return home != null ? home : Path.of(trueRefHome).resolve("models");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Bean
|
||||||
|
ModelStateEventBus modelStateEventBus() {
|
||||||
|
return new InMemoryModelStateEventBus();
|
||||||
|
}
|
||||||
|
|
||||||
|
@Bean
|
||||||
|
ModelLoader onnxModelLoader(OnnxEmbeddingService embeddingService, OnnxRerankerService rerankerService) {
|
||||||
|
return new OnnxModelLoader(embeddingService, rerankerService);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Bean(initMethod = "init", destroyMethod = "shutdown")
|
||||||
|
ManageModelLifecycle modelLifecycleService(
|
||||||
|
ModelLoader modelLoader,
|
||||||
|
ModelStateEventBus modelStateEventBus,
|
||||||
|
ObserveJobs observeJobs,
|
||||||
|
OnnxProperties properties) {
|
||||||
|
return new ModelLifecycleService(
|
||||||
|
modelLoader,
|
||||||
|
modelStateEventBus,
|
||||||
|
observeJobs,
|
||||||
|
properties.idleTimeoutSecondsOrDefault());
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -7,9 +7,8 @@ import ai.onnxruntime.OrtException;
|
|||||||
import ai.onnxruntime.OrtSession;
|
import ai.onnxruntime.OrtSession;
|
||||||
import com.trueref.adapter.out.embedding.onnx.HuggingFaceTokenizerWrapper.EncodedBatch;
|
import com.trueref.adapter.out.embedding.onnx.HuggingFaceTokenizerWrapper.EncodedBatch;
|
||||||
import com.trueref.domain.error.IngestionFailed;
|
import com.trueref.domain.error.IngestionFailed;
|
||||||
|
import com.trueref.domain.error.ModelNotReady;
|
||||||
import com.trueref.domain.port.out.EmbeddingService;
|
import com.trueref.domain.port.out.EmbeddingService;
|
||||||
import jakarta.annotation.PostConstruct;
|
|
||||||
import jakarta.annotation.PreDestroy;
|
|
||||||
import java.nio.LongBuffer;
|
import java.nio.LongBuffer;
|
||||||
import java.nio.file.Path;
|
import java.nio.file.Path;
|
||||||
import java.util.ArrayList;
|
import java.util.ArrayList;
|
||||||
@@ -68,8 +67,13 @@ public class OnnxEmbeddingService implements EmbeddingService {
|
|||||||
this.modelsHome = modelsHome;
|
this.modelsHome = modelsHome;
|
||||||
}
|
}
|
||||||
|
|
||||||
@PostConstruct
|
/** Returns {@code true} if the model is currently loaded and accepting requests. */
|
||||||
void start() {
|
public boolean isStarted() {
|
||||||
|
return pool != null && tokenizer != null;
|
||||||
|
}
|
||||||
|
|
||||||
|
/** Loads the ONNX model into memory. Called by {@link OnnxModelLoader} on demand. */
|
||||||
|
public void start() {
|
||||||
if (properties.sessionCountOrDefault() <= 0) {
|
if (properties.sessionCountOrDefault() <= 0) {
|
||||||
log.warn("OnnxEmbeddingService DISABLED (session-count=0); embed() calls will fail");
|
log.warn("OnnxEmbeddingService DISABLED (session-count=0); embed() calls will fail");
|
||||||
return;
|
return;
|
||||||
@@ -104,8 +108,8 @@ public class OnnxEmbeddingService implements EmbeddingService {
|
|||||||
usesTokenTypeIds);
|
usesTokenTypeIds);
|
||||||
}
|
}
|
||||||
|
|
||||||
@PreDestroy
|
/** Unloads the ONNX model from memory. Called by {@link OnnxModelLoader} on unload. */
|
||||||
void stop() {
|
public void stop() {
|
||||||
if (pool != null) {
|
if (pool != null) {
|
||||||
pool.close();
|
pool.close();
|
||||||
pool = null;
|
pool = null;
|
||||||
@@ -257,7 +261,7 @@ public class OnnxEmbeddingService implements EmbeddingService {
|
|||||||
|
|
||||||
private static <T> T requireStarted(@org.jspecify.annotations.Nullable T t, String what) {
|
private static <T> T requireStarted(@org.jspecify.annotations.Nullable T t, String what) {
|
||||||
if (t == null) {
|
if (t == null) {
|
||||||
throw new IllegalStateException("OnnxEmbeddingService not started: " + what + " is null");
|
throw new ModelNotReady(30);
|
||||||
}
|
}
|
||||||
return t;
|
return t;
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -0,0 +1,54 @@
|
|||||||
|
package com.trueref.adapter.out.embedding.onnx;
|
||||||
|
|
||||||
|
import com.trueref.domain.port.out.ModelLoader;
|
||||||
|
import org.slf4j.Logger;
|
||||||
|
import org.slf4j.LoggerFactory;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Implements {@link ModelLoader} by delegating to the ONNX embedding and reranker services.
|
||||||
|
* Both services are loaded/unloaded as a unit. This bean drives their lifecycle instead of
|
||||||
|
* {@code @PostConstruct} / {@code @PreDestroy} so that models are loaded lazily on first demand.
|
||||||
|
*/
|
||||||
|
public class OnnxModelLoader implements ModelLoader {
|
||||||
|
|
||||||
|
private static final Logger log = LoggerFactory.getLogger(OnnxModelLoader.class);
|
||||||
|
|
||||||
|
private final OnnxEmbeddingService embeddingService;
|
||||||
|
private final OnnxRerankerService rerankerService;
|
||||||
|
|
||||||
|
public OnnxModelLoader(OnnxEmbeddingService embeddingService, OnnxRerankerService rerankerService) {
|
||||||
|
this.embeddingService = embeddingService;
|
||||||
|
this.rerankerService = rerankerService;
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void load() {
|
||||||
|
log.info("OnnxModelLoader: loading embedder…");
|
||||||
|
embeddingService.start();
|
||||||
|
log.info("OnnxModelLoader: loading reranker…");
|
||||||
|
rerankerService.start();
|
||||||
|
log.info("OnnxModelLoader: both models loaded");
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void unload() {
|
||||||
|
log.info("OnnxModelLoader: unloading reranker…");
|
||||||
|
try {
|
||||||
|
rerankerService.stop();
|
||||||
|
} catch (Exception e) {
|
||||||
|
log.warn("Reranker stop error (continuing): {}", e.toString());
|
||||||
|
}
|
||||||
|
log.info("OnnxModelLoader: unloading embedder…");
|
||||||
|
try {
|
||||||
|
embeddingService.stop();
|
||||||
|
} catch (Exception e) {
|
||||||
|
log.warn("Embedder stop error (continuing): {}", e.toString());
|
||||||
|
}
|
||||||
|
log.info("OnnxModelLoader: both models unloaded");
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public boolean isLoaded() {
|
||||||
|
return embeddingService.isStarted() && rerankerService.isStarted();
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -21,6 +21,7 @@ public record OnnxProperties(
|
|||||||
@Nullable Integer maxSeqLen,
|
@Nullable Integer maxSeqLen,
|
||||||
@Nullable Integer gpuDeviceId,
|
@Nullable Integer gpuDeviceId,
|
||||||
@Nullable Long gpuMemLimitBytes,
|
@Nullable Long gpuMemLimitBytes,
|
||||||
|
@Nullable Long idleTimeoutSeconds,
|
||||||
@Nullable Path home,
|
@Nullable Path home,
|
||||||
@Nullable Map<String, Map<String, List<String>>> modelSources) {
|
@Nullable Map<String, Map<String, List<String>>> modelSources) {
|
||||||
|
|
||||||
@@ -37,6 +38,7 @@ public record OnnxProperties(
|
|||||||
if (maxSeqLen == null || maxSeqLen <= 0) maxSeqLen = 512;
|
if (maxSeqLen == null || maxSeqLen <= 0) maxSeqLen = 512;
|
||||||
if (gpuDeviceId == null || gpuDeviceId < 0) gpuDeviceId = 0;
|
if (gpuDeviceId == null || gpuDeviceId < 0) gpuDeviceId = 0;
|
||||||
if (gpuMemLimitBytes == null || gpuMemLimitBytes <= 0L) gpuMemLimitBytes = 0L; // 0 = no cap
|
if (gpuMemLimitBytes == null || gpuMemLimitBytes <= 0L) gpuMemLimitBytes = 0L; // 0 = no cap
|
||||||
|
if (idleTimeoutSeconds == null || idleTimeoutSeconds <= 0L) idleTimeoutSeconds = 300L; // 5 min
|
||||||
if (modelSources == null) modelSources = Map.of();
|
if (modelSources == null) modelSources = Map.of();
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -75,4 +77,8 @@ public record OnnxProperties(
|
|||||||
public long gpuMemLimitBytesOrDefault() {
|
public long gpuMemLimitBytesOrDefault() {
|
||||||
return gpuMemLimitBytes == null ? 0L : gpuMemLimitBytes;
|
return gpuMemLimitBytes == null ? 0L : gpuMemLimitBytes;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public long idleTimeoutSecondsOrDefault() {
|
||||||
|
return idleTimeoutSeconds == null ? 300L : idleTimeoutSeconds;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -7,10 +7,9 @@ import ai.onnxruntime.OrtException;
|
|||||||
import ai.onnxruntime.OrtSession;
|
import ai.onnxruntime.OrtSession;
|
||||||
import com.trueref.adapter.out.embedding.onnx.HuggingFaceTokenizerWrapper.EncodedBatch;
|
import com.trueref.adapter.out.embedding.onnx.HuggingFaceTokenizerWrapper.EncodedBatch;
|
||||||
import com.trueref.domain.error.IngestionFailed;
|
import com.trueref.domain.error.IngestionFailed;
|
||||||
|
import com.trueref.domain.error.ModelNotReady;
|
||||||
import com.trueref.domain.model.SearchHit;
|
import com.trueref.domain.model.SearchHit;
|
||||||
import com.trueref.domain.port.out.RerankerService;
|
import com.trueref.domain.port.out.RerankerService;
|
||||||
import jakarta.annotation.PostConstruct;
|
|
||||||
import jakarta.annotation.PreDestroy;
|
|
||||||
import java.nio.LongBuffer;
|
import java.nio.LongBuffer;
|
||||||
import java.nio.file.Path;
|
import java.nio.file.Path;
|
||||||
import java.util.ArrayList;
|
import java.util.ArrayList;
|
||||||
@@ -62,8 +61,13 @@ public class OnnxRerankerService implements RerankerService {
|
|||||||
this.modelsHome = modelsHome;
|
this.modelsHome = modelsHome;
|
||||||
}
|
}
|
||||||
|
|
||||||
@PostConstruct
|
/** Returns {@code true} if the reranker model is currently loaded and accepting requests. */
|
||||||
void start() {
|
public boolean isStarted() {
|
||||||
|
return pool != null && tokenizer != null;
|
||||||
|
}
|
||||||
|
|
||||||
|
/** Loads the reranker model into memory. Called by {@link OnnxModelLoader} on demand. */
|
||||||
|
public void start() {
|
||||||
if (properties.rerankerSessionCountOrDefault() <= 0) {
|
if (properties.rerankerSessionCountOrDefault() <= 0) {
|
||||||
log.warn("OnnxRerankerService DISABLED (reranker-session-count=0); rerank() passes through input order");
|
log.warn("OnnxRerankerService DISABLED (reranker-session-count=0); rerank() passes through input order");
|
||||||
return;
|
return;
|
||||||
@@ -96,8 +100,8 @@ public class OnnxRerankerService implements RerankerService {
|
|||||||
usesTokenTypeIds);
|
usesTokenTypeIds);
|
||||||
}
|
}
|
||||||
|
|
||||||
@PreDestroy
|
/** Unloads the reranker model from memory. Called by {@link OnnxModelLoader} on unload. */
|
||||||
void stop() {
|
public void stop() {
|
||||||
if (pool != null) {
|
if (pool != null) {
|
||||||
pool.close();
|
pool.close();
|
||||||
pool = null;
|
pool = null;
|
||||||
@@ -252,7 +256,7 @@ public class OnnxRerankerService implements RerankerService {
|
|||||||
|
|
||||||
private static <T> T requireStarted(@org.jspecify.annotations.Nullable T t, String what) {
|
private static <T> T requireStarted(@org.jspecify.annotations.Nullable T t, String what) {
|
||||||
if (t == null) {
|
if (t == null) {
|
||||||
throw new IllegalStateException("OnnxRerankerService not started: " + what + " is null");
|
throw new ModelNotReady(30);
|
||||||
}
|
}
|
||||||
return t;
|
return t;
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -0,0 +1,44 @@
|
|||||||
|
package com.trueref.adapter.in.rest;
|
||||||
|
|
||||||
|
import static org.assertj.core.api.Assertions.assertThat;
|
||||||
|
|
||||||
|
import com.trueref.domain.error.ModelNotReady;
|
||||||
|
import org.junit.jupiter.api.Test;
|
||||||
|
import org.springframework.http.HttpHeaders;
|
||||||
|
import org.springframework.http.HttpStatus;
|
||||||
|
import org.springframework.http.ResponseEntity;
|
||||||
|
|
||||||
|
class GlobalExceptionHandlerTest {
|
||||||
|
|
||||||
|
private final GlobalExceptionHandler handler = new GlobalExceptionHandler();
|
||||||
|
|
||||||
|
@Test
|
||||||
|
void modelNotReadyMapsto503WithRetryAfterHeader() {
|
||||||
|
ModelNotReady ex = new ModelNotReady(30);
|
||||||
|
|
||||||
|
ResponseEntity<ErrorResponse> response = handler.handleModelNotReady(ex);
|
||||||
|
|
||||||
|
assertThat(response.getStatusCode()).isEqualTo(HttpStatus.SERVICE_UNAVAILABLE);
|
||||||
|
assertThat(response.getHeaders().getFirst(HttpHeaders.RETRY_AFTER)).isEqualTo("30");
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
void modelNotReadyResponseBodyHasModelNotReadyCode() {
|
||||||
|
ModelNotReady ex = new ModelNotReady(15);
|
||||||
|
|
||||||
|
ResponseEntity<ErrorResponse> response = handler.handleModelNotReady(ex);
|
||||||
|
|
||||||
|
assertThat(response.getBody()).isNotNull();
|
||||||
|
assertThat(response.getBody().code()).isEqualTo("model_not_ready");
|
||||||
|
assertThat(response.getBody().message()).contains("15");
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
void retryAfterReflectsExceptionValue() {
|
||||||
|
ModelNotReady ex = new ModelNotReady(60);
|
||||||
|
|
||||||
|
ResponseEntity<ErrorResponse> response = handler.handleModelNotReady(ex);
|
||||||
|
|
||||||
|
assertThat(response.getHeaders().getFirst(HttpHeaders.RETRY_AFTER)).isEqualTo("60");
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -0,0 +1,71 @@
|
|||||||
|
package com.trueref.adapter.out.embedding.onnx;
|
||||||
|
|
||||||
|
import static org.assertj.core.api.Assertions.assertThat;
|
||||||
|
import static org.mockito.Mockito.doThrow;
|
||||||
|
import static org.mockito.Mockito.mock;
|
||||||
|
import static org.mockito.Mockito.never;
|
||||||
|
import static org.mockito.Mockito.verify;
|
||||||
|
import static org.mockito.Mockito.when;
|
||||||
|
|
||||||
|
import org.junit.jupiter.api.Test;
|
||||||
|
|
||||||
|
class OnnxModelLoaderTest {
|
||||||
|
|
||||||
|
@Test
|
||||||
|
void loadCallsStartOnBothServices() {
|
||||||
|
OnnxEmbeddingService embedder = mock(OnnxEmbeddingService.class);
|
||||||
|
OnnxRerankerService reranker = mock(OnnxRerankerService.class);
|
||||||
|
OnnxModelLoader loader = new OnnxModelLoader(embedder, reranker);
|
||||||
|
|
||||||
|
loader.load();
|
||||||
|
|
||||||
|
verify(embedder).start();
|
||||||
|
verify(reranker).start();
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
void unloadCallsStopOnBothServices() {
|
||||||
|
OnnxEmbeddingService embedder = mock(OnnxEmbeddingService.class);
|
||||||
|
OnnxRerankerService reranker = mock(OnnxRerankerService.class);
|
||||||
|
OnnxModelLoader loader = new OnnxModelLoader(embedder, reranker);
|
||||||
|
|
||||||
|
loader.unload();
|
||||||
|
|
||||||
|
verify(embedder).stop();
|
||||||
|
verify(reranker).stop();
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
void unloadContinuesEvenIfRerankerStopThrows() {
|
||||||
|
OnnxEmbeddingService embedder = mock(OnnxEmbeddingService.class);
|
||||||
|
OnnxRerankerService reranker = mock(OnnxRerankerService.class);
|
||||||
|
doThrow(new RuntimeException("VRAM error")).when(reranker).stop();
|
||||||
|
OnnxModelLoader loader = new OnnxModelLoader(embedder, reranker);
|
||||||
|
|
||||||
|
loader.unload(); // must not throw
|
||||||
|
|
||||||
|
verify(embedder).stop();
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
void isLoadedReturnsTrueWhenBothServicesStarted() {
|
||||||
|
OnnxEmbeddingService embedder = mock(OnnxEmbeddingService.class);
|
||||||
|
OnnxRerankerService reranker = mock(OnnxRerankerService.class);
|
||||||
|
when(embedder.isStarted()).thenReturn(true);
|
||||||
|
when(reranker.isStarted()).thenReturn(true);
|
||||||
|
OnnxModelLoader loader = new OnnxModelLoader(embedder, reranker);
|
||||||
|
|
||||||
|
assertThat(loader.isLoaded()).isTrue();
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
void isLoadedReturnsFalseWhenEitherServiceNotStarted() {
|
||||||
|
OnnxEmbeddingService embedder = mock(OnnxEmbeddingService.class);
|
||||||
|
OnnxRerankerService reranker = mock(OnnxRerankerService.class);
|
||||||
|
when(embedder.isStarted()).thenReturn(true);
|
||||||
|
when(reranker.isStarted()).thenReturn(false);
|
||||||
|
OnnxModelLoader loader = new OnnxModelLoader(embedder, reranker);
|
||||||
|
|
||||||
|
assertThat(loader.isLoaded()).isFalse();
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -0,0 +1,75 @@
|
|||||||
|
package com.trueref.application.model;
|
||||||
|
|
||||||
|
import static org.assertj.core.api.Assertions.assertThat;
|
||||||
|
|
||||||
|
import com.trueref.domain.model.ModelState;
|
||||||
|
import com.trueref.domain.model.ModelStateEvent;
|
||||||
|
import java.util.ArrayList;
|
||||||
|
import java.util.List;
|
||||||
|
import java.util.concurrent.atomic.AtomicInteger;
|
||||||
|
import org.junit.jupiter.api.Test;
|
||||||
|
|
||||||
|
class InMemoryModelStateEventBusTest {
|
||||||
|
|
||||||
|
@Test
|
||||||
|
void singleSubscriberReceivesPublishedEvent() {
|
||||||
|
InMemoryModelStateEventBus bus = new InMemoryModelStateEventBus();
|
||||||
|
List<ModelStateEvent> received = new ArrayList<>();
|
||||||
|
bus.subscribe(received::add);
|
||||||
|
|
||||||
|
bus.publish(ModelStateEvent.of(ModelState.LOADING));
|
||||||
|
|
||||||
|
assertThat(received).hasSize(1);
|
||||||
|
assertThat(received.get(0).state()).isEqualTo(ModelState.LOADING);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
void multipleSubscribersAllReceiveEvent() {
|
||||||
|
InMemoryModelStateEventBus bus = new InMemoryModelStateEventBus();
|
||||||
|
List<ModelStateEvent> a = new ArrayList<>();
|
||||||
|
List<ModelStateEvent> b = new ArrayList<>();
|
||||||
|
bus.subscribe(a::add);
|
||||||
|
bus.subscribe(b::add);
|
||||||
|
|
||||||
|
bus.publish(ModelStateEvent.of(ModelState.LOADED));
|
||||||
|
|
||||||
|
assertThat(a).hasSize(1);
|
||||||
|
assertThat(b).hasSize(1);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
void closingSubscriptionRemovesIt() throws Exception {
|
||||||
|
InMemoryModelStateEventBus bus = new InMemoryModelStateEventBus();
|
||||||
|
List<ModelStateEvent> received = new ArrayList<>();
|
||||||
|
AutoCloseable handle = bus.subscribe(received::add);
|
||||||
|
|
||||||
|
handle.close();
|
||||||
|
bus.publish(ModelStateEvent.of(ModelState.UNLOADED));
|
||||||
|
|
||||||
|
assertThat(received).isEmpty();
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
void exceptionInSubscriberDoesNotBreakOthers() {
|
||||||
|
InMemoryModelStateEventBus bus = new InMemoryModelStateEventBus();
|
||||||
|
AtomicInteger badCalls = new AtomicInteger();
|
||||||
|
List<ModelStateEvent> good = new ArrayList<>();
|
||||||
|
|
||||||
|
bus.subscribe(e -> {
|
||||||
|
badCalls.incrementAndGet();
|
||||||
|
throw new RuntimeException("intentional test failure");
|
||||||
|
});
|
||||||
|
bus.subscribe(good::add);
|
||||||
|
|
||||||
|
bus.publish(ModelStateEvent.of(ModelState.LOADING));
|
||||||
|
|
||||||
|
assertThat(badCalls.get()).isEqualTo(1);
|
||||||
|
assertThat(good).hasSize(1);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
void noSubscribersMeansPublishIsNoop() {
|
||||||
|
InMemoryModelStateEventBus bus = new InMemoryModelStateEventBus();
|
||||||
|
bus.publish(ModelStateEvent.of(ModelState.UNLOADED)); // must not throw
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -0,0 +1,227 @@
|
|||||||
|
package com.trueref.application.model;
|
||||||
|
|
||||||
|
import static org.assertj.core.api.Assertions.assertThat;
|
||||||
|
import static org.assertj.core.api.Assertions.assertThatThrownBy;
|
||||||
|
|
||||||
|
import com.trueref.domain.error.ModelNotReady;
|
||||||
|
import com.trueref.domain.model.IngestionJob;
|
||||||
|
import com.trueref.domain.model.JobId;
|
||||||
|
import com.trueref.domain.model.JobStatus;
|
||||||
|
import com.trueref.domain.model.JobType;
|
||||||
|
import com.trueref.domain.model.ModelState;
|
||||||
|
import com.trueref.domain.model.ModelStateEvent;
|
||||||
|
import com.trueref.domain.model.RepositoryId;
|
||||||
|
import com.trueref.domain.port.in.ManageModelLifecycle;
|
||||||
|
import com.trueref.domain.port.in.ObserveJobs;
|
||||||
|
import com.trueref.domain.port.out.ModelLoader;
|
||||||
|
import com.trueref.domain.port.out.ModelStateEventBus;
|
||||||
|
import java.time.Instant;
|
||||||
|
import java.util.ArrayList;
|
||||||
|
import java.util.List;
|
||||||
|
import java.util.Optional;
|
||||||
|
import java.util.concurrent.CountDownLatch;
|
||||||
|
import java.util.concurrent.TimeUnit;
|
||||||
|
import java.util.concurrent.atomic.AtomicBoolean;
|
||||||
|
import java.util.function.Consumer;
|
||||||
|
import org.junit.jupiter.api.AfterEach;
|
||||||
|
import org.junit.jupiter.api.BeforeEach;
|
||||||
|
import org.junit.jupiter.api.Test;
|
||||||
|
|
||||||
|
class ModelLifecycleServiceTest {
|
||||||
|
|
||||||
|
// ---- Test doubles ----
|
||||||
|
|
||||||
|
private final AtomicBoolean loaderLoaded = new AtomicBoolean(false);
|
||||||
|
private volatile boolean slowLoad = false;
|
||||||
|
|
||||||
|
private final ModelLoader loader = new ModelLoader() {
|
||||||
|
@Override
|
||||||
|
public void load() {
|
||||||
|
if (slowLoad) {
|
||||||
|
try { Thread.sleep(200); } catch (InterruptedException e) { Thread.currentThread().interrupt(); }
|
||||||
|
}
|
||||||
|
loaderLoaded.set(true);
|
||||||
|
}
|
||||||
|
@Override
|
||||||
|
public void unload() { loaderLoaded.set(false); }
|
||||||
|
@Override
|
||||||
|
public boolean isLoaded() { return loaderLoaded.get(); }
|
||||||
|
};
|
||||||
|
|
||||||
|
private final InMemoryModelStateEventBus eventBus = new InMemoryModelStateEventBus();
|
||||||
|
|
||||||
|
private volatile List<IngestionJob> runningJobs = List.of();
|
||||||
|
|
||||||
|
private final ObserveJobs observeJobs = new ObserveJobs() {
|
||||||
|
@Override
|
||||||
|
public Optional<IngestionJob> findJob(JobId id) { return Optional.empty(); }
|
||||||
|
@Override
|
||||||
|
public List<IngestionJob> listJobs(com.trueref.domain.model.RepositoryId r,
|
||||||
|
com.trueref.domain.model.VersionId v, JobStatus status, int limit) {
|
||||||
|
if (status == JobStatus.RUNNING) return runningJobs;
|
||||||
|
return List.of();
|
||||||
|
}
|
||||||
|
@Override
|
||||||
|
public AutoCloseable subscribeJobs(Consumer<IngestionJob> listener) { return () -> {}; }
|
||||||
|
@Override
|
||||||
|
public AutoCloseable subscribeLogs(JobId jobId, Consumer<com.trueref.domain.model.JobLogEvent> listener) { return () -> {}; }
|
||||||
|
};
|
||||||
|
|
||||||
|
private ModelLifecycleService service;
|
||||||
|
|
||||||
|
@BeforeEach
|
||||||
|
void setUp() {
|
||||||
|
loaderLoaded.set(false);
|
||||||
|
slowLoad = false;
|
||||||
|
runningJobs = List.of();
|
||||||
|
service = new ModelLifecycleService(loader, eventBus, observeJobs, 300L);
|
||||||
|
service.init();
|
||||||
|
}
|
||||||
|
|
||||||
|
@AfterEach
|
||||||
|
void tearDown() {
|
||||||
|
service.shutdown();
|
||||||
|
}
|
||||||
|
|
||||||
|
// ---- Tests ----
|
||||||
|
|
||||||
|
@Test
|
||||||
|
void startsInUnloadedState() {
|
||||||
|
assertThat(service.getStatus().state()).isEqualTo(ModelState.UNLOADED);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
void ensureReadyOnUnloadedTriggersLoadAndThrows() throws InterruptedException {
|
||||||
|
List<ModelStateEvent> events = new ArrayList<>();
|
||||||
|
eventBus.subscribe(events::add);
|
||||||
|
|
||||||
|
assertThatThrownBy(() -> service.ensureReady())
|
||||||
|
.isInstanceOf(ModelNotReady.class);
|
||||||
|
|
||||||
|
// Wait for async load to complete (max 2 s in test)
|
||||||
|
for (int i = 0; i < 20 && service.getStatus().state() != ModelState.LOADED; i++) {
|
||||||
|
Thread.sleep(100);
|
||||||
|
}
|
||||||
|
|
||||||
|
assertThat(service.getStatus().state()).isEqualTo(ModelState.LOADED);
|
||||||
|
assertThat(loaderLoaded.get()).isTrue();
|
||||||
|
assertThat(events).extracting(ModelStateEvent::state)
|
||||||
|
.containsSequence(ModelState.LOADING, ModelState.LOADED);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
void ensureReadyOnLoadedIsNoop() throws InterruptedException {
|
||||||
|
triggerLoad();
|
||||||
|
|
||||||
|
// Should not throw
|
||||||
|
service.ensureReady();
|
||||||
|
service.ensureReady();
|
||||||
|
|
||||||
|
assertThat(service.getStatus().state()).isEqualTo(ModelState.LOADED);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
void ensureReadyOnLoadingThrowsImmediately() {
|
||||||
|
slowLoad = true;
|
||||||
|
assertThatThrownBy(() -> service.ensureReady())
|
||||||
|
.isInstanceOf(ModelNotReady.class);
|
||||||
|
// state is LOADING, second call also throws
|
||||||
|
assertThatThrownBy(() -> service.ensureReady())
|
||||||
|
.isInstanceOf(ModelNotReady.class);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
void forceUnloadFromLoadedState() throws InterruptedException {
|
||||||
|
triggerLoad();
|
||||||
|
|
||||||
|
boolean initiated = service.forceUnload(false);
|
||||||
|
|
||||||
|
assertThat(initiated).isTrue();
|
||||||
|
// Wait for async unload
|
||||||
|
for (int i = 0; i < 20 && service.getStatus().state() != ModelState.UNLOADED; i++) {
|
||||||
|
Thread.sleep(100);
|
||||||
|
}
|
||||||
|
assertThat(service.getStatus().state()).isEqualTo(ModelState.UNLOADED);
|
||||||
|
assertThat(loaderLoaded.get()).isFalse();
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
void forceUnloadBlockedByRunningJobsWithoutForce() throws InterruptedException {
|
||||||
|
triggerLoad();
|
||||||
|
runningJobs = List.of(runningJob());
|
||||||
|
|
||||||
|
boolean initiated = service.forceUnload(false);
|
||||||
|
|
||||||
|
assertThat(initiated).isFalse();
|
||||||
|
assertThat(service.getStatus().state()).isEqualTo(ModelState.LOADED);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
void forceUnloadSucceedsWithForceEvenWhenJobsRunning() throws InterruptedException {
|
||||||
|
triggerLoad();
|
||||||
|
runningJobs = List.of(runningJob());
|
||||||
|
|
||||||
|
boolean initiated = service.forceUnload(true);
|
||||||
|
|
||||||
|
assertThat(initiated).isTrue();
|
||||||
|
for (int i = 0; i < 20 && service.getStatus().state() != ModelState.UNLOADED; i++) {
|
||||||
|
Thread.sleep(100);
|
||||||
|
}
|
||||||
|
assertThat(service.getStatus().state()).isEqualTo(ModelState.UNLOADED);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
void forceUnloadWhenAlreadyUnloadedIsNoop() {
|
||||||
|
boolean initiated = service.forceUnload(false);
|
||||||
|
assertThat(initiated).isTrue();
|
||||||
|
assertThat(service.getStatus().state()).isEqualTo(ModelState.UNLOADED);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
void concurrentEnsureReadyDeduplicatesLoad() throws InterruptedException {
|
||||||
|
int callCount = 0;
|
||||||
|
// Both calls throw ModelNotReady but only one load should be triggered
|
||||||
|
List<ModelStateEvent> events = new ArrayList<>();
|
||||||
|
eventBus.subscribe(events::add);
|
||||||
|
|
||||||
|
try { service.ensureReady(); } catch (ModelNotReady ignored) { callCount++; }
|
||||||
|
try { service.ensureReady(); } catch (ModelNotReady ignored) { callCount++; }
|
||||||
|
|
||||||
|
for (int i = 0; i < 20 && service.getStatus().state() != ModelState.LOADED; i++) {
|
||||||
|
Thread.sleep(100);
|
||||||
|
}
|
||||||
|
|
||||||
|
assertThat(callCount).isEqualTo(2);
|
||||||
|
// Only one LOADING event should be emitted (deduplicated)
|
||||||
|
long loadingCount = events.stream().filter(e -> e.state() == ModelState.LOADING).count();
|
||||||
|
assertThat(loadingCount).isEqualTo(1);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Test
|
||||||
|
void statusIncludesLoadedAtAfterLoad() throws InterruptedException {
|
||||||
|
assertThat(service.getStatus().loadedAt()).isNull();
|
||||||
|
triggerLoad();
|
||||||
|
assertThat(service.getStatus().loadedAt()).isNotNull();
|
||||||
|
}
|
||||||
|
|
||||||
|
// ---- helpers ----
|
||||||
|
|
||||||
|
private void triggerLoad() throws InterruptedException {
|
||||||
|
try { service.ensureReady(); } catch (ModelNotReady ignored) {}
|
||||||
|
for (int i = 0; i < 20 && service.getStatus().state() != ModelState.LOADED; i++) {
|
||||||
|
Thread.sleep(100);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private IngestionJob runningJob() {
|
||||||
|
return new IngestionJob(
|
||||||
|
JobId.of(java.util.UUID.randomUUID().toString()),
|
||||||
|
RepositoryId.of(java.util.UUID.randomUUID().toString()),
|
||||||
|
null,
|
||||||
|
JobType.INDEX_VERSION,
|
||||||
|
JobStatus.RUNNING,
|
||||||
|
Instant.now(),
|
||||||
|
null,
|
||||||
|
List.of());
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -0,0 +1,35 @@
|
|||||||
|
package com.trueref.application.model;
|
||||||
|
|
||||||
|
import com.trueref.domain.model.ModelStateEvent;
|
||||||
|
import com.trueref.domain.port.out.ModelStateEventBus;
|
||||||
|
import java.util.concurrent.CopyOnWriteArrayList;
|
||||||
|
import java.util.function.Consumer;
|
||||||
|
import org.slf4j.Logger;
|
||||||
|
import org.slf4j.LoggerFactory;
|
||||||
|
|
||||||
|
/** In-memory, thread-safe implementation of {@link ModelStateEventBus}. */
|
||||||
|
public class InMemoryModelStateEventBus implements ModelStateEventBus {
|
||||||
|
|
||||||
|
private static final Logger log = LoggerFactory.getLogger(InMemoryModelStateEventBus.class);
|
||||||
|
|
||||||
|
private final CopyOnWriteArrayList<Consumer<ModelStateEvent>> subscribers =
|
||||||
|
new CopyOnWriteArrayList<>();
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void publish(ModelStateEvent event) {
|
||||||
|
for (Consumer<ModelStateEvent> subscriber : subscribers) {
|
||||||
|
try {
|
||||||
|
subscriber.accept(event);
|
||||||
|
} catch (Exception e) {
|
||||||
|
log.warn("Model state subscriber threw during publish; removing: {}", e.toString());
|
||||||
|
subscribers.remove(subscriber);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public AutoCloseable subscribe(Consumer<ModelStateEvent> subscriber) {
|
||||||
|
subscribers.add(subscriber);
|
||||||
|
return () -> subscribers.remove(subscriber);
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -0,0 +1,287 @@
|
|||||||
|
package com.trueref.application.model;
|
||||||
|
|
||||||
|
import com.trueref.domain.error.ModelNotReady;
|
||||||
|
import com.trueref.domain.model.IngestionJob;
|
||||||
|
import com.trueref.domain.model.JobStatus;
|
||||||
|
import com.trueref.domain.model.ModelState;
|
||||||
|
import com.trueref.domain.model.ModelStateEvent;
|
||||||
|
import com.trueref.domain.port.in.ManageModelLifecycle;
|
||||||
|
import com.trueref.domain.port.in.ObserveJobs;
|
||||||
|
import com.trueref.domain.port.out.ModelLoader;
|
||||||
|
import com.trueref.domain.port.out.ModelStateEventBus;
|
||||||
|
import java.time.Instant;
|
||||||
|
import java.util.concurrent.Executors;
|
||||||
|
import java.util.concurrent.Future;
|
||||||
|
import java.util.concurrent.ScheduledExecutorService;
|
||||||
|
import java.util.concurrent.ThreadFactory;
|
||||||
|
import java.util.concurrent.TimeUnit;
|
||||||
|
import java.util.concurrent.atomic.AtomicLong;
|
||||||
|
import java.util.concurrent.atomic.AtomicReference;
|
||||||
|
import java.util.concurrent.locks.ReentrantLock;
|
||||||
|
import java.util.function.Consumer;
|
||||||
|
import org.jspecify.annotations.Nullable;
|
||||||
|
import org.slf4j.Logger;
|
||||||
|
import org.slf4j.LoggerFactory;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Orchestrates the GPU model lifecycle: lazy loading, auto-unload after idle timeout, and
|
||||||
|
* force-unload. State transitions are protected by a {@link ReentrantLock}. Heavy load/unload
|
||||||
|
* operations are executed on a dedicated single-thread platform-thread executor to satisfy CUDA
|
||||||
|
* OS-thread affinity constraints.
|
||||||
|
*/
|
||||||
|
public class ModelLifecycleService implements ManageModelLifecycle {
|
||||||
|
|
||||||
|
private static final Logger log = LoggerFactory.getLogger(ModelLifecycleService.class);
|
||||||
|
|
||||||
|
/** Default retry hint in seconds when the model is LOADING. Load typically takes 5–30 s. */
|
||||||
|
private static final int LOADING_RETRY_AFTER_SECONDS = 30;
|
||||||
|
|
||||||
|
/** Idle-check interval. One check per minute is sufficient resolution for a 5-min timeout. */
|
||||||
|
private static final long IDLE_CHECK_INTERVAL_SECONDS = 60L;
|
||||||
|
|
||||||
|
private final ModelLoader modelLoader;
|
||||||
|
private final ModelStateEventBus eventBus;
|
||||||
|
private final ObserveJobs observeJobs;
|
||||||
|
private final long idleTimeoutSeconds;
|
||||||
|
|
||||||
|
// ---- state ----
|
||||||
|
private final ReentrantLock lock = new ReentrantLock();
|
||||||
|
private final AtomicReference<ModelState> state = new AtomicReference<>(ModelState.UNLOADED);
|
||||||
|
private volatile @Nullable Instant loadedAt = null;
|
||||||
|
private final AtomicLong lastActivityEpochMs = new AtomicLong(0L);
|
||||||
|
|
||||||
|
// ---- executors ----
|
||||||
|
/** Single platform OS thread for all CUDA load/unload calls. */
|
||||||
|
private final java.util.concurrent.ExecutorService cudaExecutor;
|
||||||
|
|
||||||
|
/** Scheduler that periodically checks for idle timeout. */
|
||||||
|
private final ScheduledExecutorService idleScheduler;
|
||||||
|
|
||||||
|
/** Handle to the in-flight async load Future (guarded by lock). */
|
||||||
|
private @Nullable Future<?> pendingLoad = null;
|
||||||
|
|
||||||
|
public ModelLifecycleService(
|
||||||
|
ModelLoader modelLoader,
|
||||||
|
ModelStateEventBus eventBus,
|
||||||
|
ObserveJobs observeJobs,
|
||||||
|
long idleTimeoutSeconds) {
|
||||||
|
this.modelLoader = modelLoader;
|
||||||
|
this.eventBus = eventBus;
|
||||||
|
this.observeJobs = observeJobs;
|
||||||
|
this.idleTimeoutSeconds = idleTimeoutSeconds;
|
||||||
|
|
||||||
|
ThreadFactory platformFactory = Thread.ofPlatform()
|
||||||
|
.name("model-lifecycle-cuda")
|
||||||
|
.factory();
|
||||||
|
this.cudaExecutor = Executors.newSingleThreadExecutor(platformFactory);
|
||||||
|
this.idleScheduler = Executors.newSingleThreadScheduledExecutor(
|
||||||
|
Thread.ofPlatform().name("model-idle-check").factory());
|
||||||
|
}
|
||||||
|
|
||||||
|
/** Initializes the idle timer and job-event subscription. Called by the Spring bean lifecycle. */
|
||||||
|
public void init() {
|
||||||
|
// Subscribe to job events so that finishing a job resets the idle clock.
|
||||||
|
observeJobs.subscribeJobs(this::onJobEvent);
|
||||||
|
// Start the periodic idle check.
|
||||||
|
idleScheduler.scheduleAtFixedRate(
|
||||||
|
this::checkIdleTimeout,
|
||||||
|
IDLE_CHECK_INTERVAL_SECONDS,
|
||||||
|
IDLE_CHECK_INTERVAL_SECONDS,
|
||||||
|
TimeUnit.SECONDS);
|
||||||
|
log.info("ModelLifecycleService started; idleTimeoutSeconds={} state=UNLOADED", idleTimeoutSeconds);
|
||||||
|
}
|
||||||
|
|
||||||
|
/** Shuts down schedulers and unloads models if loaded. Called by the Spring bean lifecycle. */
|
||||||
|
public void shutdown() {
|
||||||
|
idleScheduler.shutdownNow();
|
||||||
|
// Unload synchronously on shutdown so VRAM is released cleanly.
|
||||||
|
if (state.get() == ModelState.LOADED || state.get() == ModelState.LOADING) {
|
||||||
|
doUnload("JVM shutdown");
|
||||||
|
}
|
||||||
|
cudaExecutor.shutdownNow();
|
||||||
|
}
|
||||||
|
|
||||||
|
// ---- ManageModelLifecycle ----
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void ensureReady() {
|
||||||
|
// Fast path: already loaded.
|
||||||
|
if (state.get() == ModelState.LOADED) {
|
||||||
|
touchActivity();
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
lock.lock();
|
||||||
|
try {
|
||||||
|
ModelState current = state.get();
|
||||||
|
switch (current) {
|
||||||
|
case LOADED -> {
|
||||||
|
touchActivity();
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
case LOADING -> throw new ModelNotReady(LOADING_RETRY_AFTER_SECONDS);
|
||||||
|
case UNLOADING -> throw new ModelNotReady(LOADING_RETRY_AFTER_SECONDS);
|
||||||
|
case UNLOADED -> {
|
||||||
|
triggerAsyncLoad("triggered by incoming request");
|
||||||
|
throw new ModelNotReady(LOADING_RETRY_AFTER_SECONDS);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} finally {
|
||||||
|
lock.unlock();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public boolean forceUnload(boolean force) {
|
||||||
|
lock.lock();
|
||||||
|
try {
|
||||||
|
ModelState current = state.get();
|
||||||
|
if (current == ModelState.UNLOADED || current == ModelState.UNLOADING) {
|
||||||
|
log.info("forceUnload: already {} — no-op", current);
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
if (!force && hasRunningJobs()) {
|
||||||
|
log.info("forceUnload: blocked by running ingestion jobs (use force=true to override)");
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
if (current == ModelState.LOADING && pendingLoad != null) {
|
||||||
|
pendingLoad.cancel(false); // try to cancel; unload will follow after transition
|
||||||
|
}
|
||||||
|
doUnload("force-unload via API");
|
||||||
|
return true;
|
||||||
|
} finally {
|
||||||
|
lock.unlock();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public Status getStatus() {
|
||||||
|
return new Status(state.get(), loadedAt, epochMsToInstant(lastActivityEpochMs.get()), idleTimeoutSeconds);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public AutoCloseable subscribeState(Consumer<ModelStateEvent> subscriber) {
|
||||||
|
return eventBus.subscribe(subscriber);
|
||||||
|
}
|
||||||
|
|
||||||
|
// ---- idle timer ----
|
||||||
|
|
||||||
|
private void checkIdleTimeout() {
|
||||||
|
if (state.get() != ModelState.LOADED) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
long lastMs = lastActivityEpochMs.get();
|
||||||
|
if (lastMs == 0L) {
|
||||||
|
return; // no activity yet
|
||||||
|
}
|
||||||
|
long idleMs = System.currentTimeMillis() - lastMs;
|
||||||
|
if (idleMs < idleTimeoutSeconds * 1000L) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
lock.lock();
|
||||||
|
try {
|
||||||
|
if (state.get() != ModelState.LOADED) {
|
||||||
|
return; // state changed while we waited for the lock
|
||||||
|
}
|
||||||
|
if (hasRunningJobs()) {
|
||||||
|
log.debug("idle timeout reached but ingestion jobs are running; skipping unload");
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
long nowIdleMs = System.currentTimeMillis() - lastActivityEpochMs.get();
|
||||||
|
if (nowIdleMs < idleTimeoutSeconds * 1000L) {
|
||||||
|
return; // activity happened while we waited for the lock
|
||||||
|
}
|
||||||
|
doUnload("idle timeout after " + (nowIdleMs / 1000) + " s");
|
||||||
|
} finally {
|
||||||
|
lock.unlock();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ---- internal helpers (always called while holding lock) ----
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Submits an async load to the CUDA executor. Must be called while holding {@code lock}.
|
||||||
|
* Transitions state UNLOADED → LOADING.
|
||||||
|
*/
|
||||||
|
private void triggerAsyncLoad(String reason) {
|
||||||
|
if (pendingLoad != null && !pendingLoad.isDone()) {
|
||||||
|
return; // already loading — deduplicate
|
||||||
|
}
|
||||||
|
transition(ModelState.LOADING, "load requested: " + reason);
|
||||||
|
pendingLoad = cudaExecutor.submit(() -> {
|
||||||
|
try {
|
||||||
|
log.info("Loading GPU models…");
|
||||||
|
modelLoader.load();
|
||||||
|
lock.lock();
|
||||||
|
try {
|
||||||
|
loadedAt = Instant.now();
|
||||||
|
touchActivity();
|
||||||
|
transition(ModelState.LOADED, null);
|
||||||
|
log.info("GPU models loaded successfully");
|
||||||
|
} finally {
|
||||||
|
lock.unlock();
|
||||||
|
}
|
||||||
|
} catch (Exception e) {
|
||||||
|
log.error("GPU model load failed", e);
|
||||||
|
lock.lock();
|
||||||
|
try {
|
||||||
|
transition(ModelState.UNLOADED, "load failed: " + e.getMessage());
|
||||||
|
} finally {
|
||||||
|
lock.unlock();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Runs an unload synchronously on the calling thread (must already be the CUDA executor thread
|
||||||
|
* or called during shutdown). For runtime calls it is submitted to the cuda executor.
|
||||||
|
* Must be called while holding {@code lock}.
|
||||||
|
*/
|
||||||
|
private void doUnload(String reason) {
|
||||||
|
transition(ModelState.UNLOADING, reason);
|
||||||
|
// Run unload on the CUDA executor so it stays on the same OS thread as the sessions.
|
||||||
|
cudaExecutor.execute(() -> {
|
||||||
|
try {
|
||||||
|
log.info("Unloading GPU models…");
|
||||||
|
modelLoader.unload();
|
||||||
|
log.info("GPU models unloaded");
|
||||||
|
} catch (Exception e) {
|
||||||
|
log.error("GPU model unload error (models may still be in VRAM)", e);
|
||||||
|
} finally {
|
||||||
|
lock.lock();
|
||||||
|
try {
|
||||||
|
loadedAt = null;
|
||||||
|
transition(ModelState.UNLOADED, null);
|
||||||
|
} finally {
|
||||||
|
lock.unlock();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
private void transition(ModelState next, @Nullable String message) {
|
||||||
|
ModelState prev = state.getAndSet(next);
|
||||||
|
if (prev != next) {
|
||||||
|
log.debug("Model state: {} → {}{}", prev, next, message != null ? " [" + message + "]" : "");
|
||||||
|
eventBus.publish(ModelStateEvent.of(next, message != null ? message : ""));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private void touchActivity() {
|
||||||
|
lastActivityEpochMs.set(System.currentTimeMillis());
|
||||||
|
}
|
||||||
|
|
||||||
|
private boolean hasRunningJobs() {
|
||||||
|
return !observeJobs.listJobs(null, null, JobStatus.RUNNING, 1).isEmpty();
|
||||||
|
}
|
||||||
|
|
||||||
|
private void onJobEvent(IngestionJob job) {
|
||||||
|
if (job.status() == JobStatus.SUCCEEDED || job.status() == JobStatus.FAILED) {
|
||||||
|
touchActivity(); // reset idle clock when a job finishes
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private static @Nullable Instant epochMsToInstant(long ms) {
|
||||||
|
return ms == 0L ? null : Instant.ofEpochMilli(ms);
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -6,6 +6,7 @@ import com.trueref.domain.model.Repository;
|
|||||||
import com.trueref.domain.model.SearchHit;
|
import com.trueref.domain.model.SearchHit;
|
||||||
import com.trueref.domain.model.SearchScope;
|
import com.trueref.domain.model.SearchScope;
|
||||||
import com.trueref.domain.model.Version;
|
import com.trueref.domain.model.Version;
|
||||||
|
import com.trueref.domain.port.in.ManageModelLifecycle;
|
||||||
import com.trueref.domain.port.in.SearchLibraryDocs;
|
import com.trueref.domain.port.in.SearchLibraryDocs;
|
||||||
import com.trueref.domain.port.out.ChunkStore;
|
import com.trueref.domain.port.out.ChunkStore;
|
||||||
import com.trueref.domain.port.out.EmbeddingService;
|
import com.trueref.domain.port.out.EmbeddingService;
|
||||||
@@ -42,6 +43,7 @@ public final class HybridSearchService implements SearchLibraryDocs {
|
|||||||
private final EmbeddingService embedder;
|
private final EmbeddingService embedder;
|
||||||
private final RerankerService reranker;
|
private final RerankerService reranker;
|
||||||
private final RepositoryStore repos;
|
private final RepositoryStore repos;
|
||||||
|
private final ManageModelLifecycle lifecycle;
|
||||||
private final int rrfK;
|
private final int rrfK;
|
||||||
private final int rerankTopK;
|
private final int rerankTopK;
|
||||||
private final int finalTopK;
|
private final int finalTopK;
|
||||||
@@ -51,6 +53,7 @@ public final class HybridSearchService implements SearchLibraryDocs {
|
|||||||
EmbeddingService embedder,
|
EmbeddingService embedder,
|
||||||
RerankerService reranker,
|
RerankerService reranker,
|
||||||
RepositoryStore repos,
|
RepositoryStore repos,
|
||||||
|
ManageModelLifecycle lifecycle,
|
||||||
int rrfK,
|
int rrfK,
|
||||||
int rerankTopK,
|
int rerankTopK,
|
||||||
int finalTopK) {
|
int finalTopK) {
|
||||||
@@ -58,6 +61,7 @@ public final class HybridSearchService implements SearchLibraryDocs {
|
|||||||
this.embedder = embedder;
|
this.embedder = embedder;
|
||||||
this.reranker = reranker;
|
this.reranker = reranker;
|
||||||
this.repos = repos;
|
this.repos = repos;
|
||||||
|
this.lifecycle = lifecycle;
|
||||||
this.rrfK = rrfK;
|
this.rrfK = rrfK;
|
||||||
this.rerankTopK = rerankTopK;
|
this.rerankTopK = rerankTopK;
|
||||||
this.finalTopK = finalTopK;
|
this.finalTopK = finalTopK;
|
||||||
@@ -65,6 +69,9 @@ public final class HybridSearchService implements SearchLibraryDocs {
|
|||||||
|
|
||||||
@Override
|
@Override
|
||||||
public Result search(Query q) {
|
public Result search(Query q) {
|
||||||
|
// Ensure models are loaded; throws ModelNotReady (→ HTTP 503) if not.
|
||||||
|
lifecycle.ensureReady();
|
||||||
|
|
||||||
if (q.text() == null || q.text().isBlank()) {
|
if (q.text() == null || q.text().isBlank()) {
|
||||||
throw new InvalidSearchRequest("query text must not be blank");
|
throw new InvalidSearchRequest("query text must not be blank");
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -7,6 +7,7 @@ import com.trueref.application.observability.InMemoryJobEventBus;
|
|||||||
import com.trueref.application.observability.JobObservationService;
|
import com.trueref.application.observability.JobObservationService;
|
||||||
import com.trueref.application.resolve.LibraryResolver;
|
import com.trueref.application.resolve.LibraryResolver;
|
||||||
import com.trueref.application.search.HybridSearchService;
|
import com.trueref.application.search.HybridSearchService;
|
||||||
|
import com.trueref.domain.port.in.ManageModelLifecycle;
|
||||||
import com.trueref.domain.port.out.ChunkStore;
|
import com.trueref.domain.port.out.ChunkStore;
|
||||||
import com.trueref.domain.port.out.CodeParser;
|
import com.trueref.domain.port.out.CodeParser;
|
||||||
import com.trueref.domain.port.out.EmbeddingCache;
|
import com.trueref.domain.port.out.EmbeddingCache;
|
||||||
@@ -76,10 +77,11 @@ public class ApplicationBeans {
|
|||||||
EmbeddingService embedder,
|
EmbeddingService embedder,
|
||||||
RerankerService reranker,
|
RerankerService reranker,
|
||||||
RepositoryStore repos,
|
RepositoryStore repos,
|
||||||
|
ManageModelLifecycle lifecycle,
|
||||||
@Value("${trueref.search.rrf-k:60}") int rrfK,
|
@Value("${trueref.search.rrf-k:60}") int rrfK,
|
||||||
@Value("${trueref.reranker.top-k:50}") int rerankTopK,
|
@Value("${trueref.reranker.top-k:50}") int rerankTopK,
|
||||||
@Value("${trueref.search.final-top-k:20}") int finalTopK) {
|
@Value("${trueref.search.final-top-k:20}") int finalTopK) {
|
||||||
return new HybridSearchService(chunks, embedder, reranker, repos, rrfK, rerankTopK, finalTopK);
|
return new HybridSearchService(chunks, embedder, reranker, repos, lifecycle, rrfK, rerankTopK, finalTopK);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Bean
|
@Bean
|
||||||
|
|||||||
@@ -0,0 +1,23 @@
|
|||||||
|
package com.trueref.domain.error;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Thrown when an inference request arrives while the GPU models are not yet loaded (or are being
|
||||||
|
* unloaded). Callers should surface this as HTTP 503 with a {@code Retry-After} header, or as an
|
||||||
|
* MCP error inviting retry.
|
||||||
|
*/
|
||||||
|
public final class ModelNotReady extends TrueRefException {
|
||||||
|
|
||||||
|
private final int retryAfterSeconds;
|
||||||
|
|
||||||
|
public ModelNotReady(int retryAfterSeconds) {
|
||||||
|
super("model_not_ready",
|
||||||
|
"Model is not ready, retry in ~" + retryAfterSeconds + " seconds",
|
||||||
|
null);
|
||||||
|
this.retryAfterSeconds = retryAfterSeconds;
|
||||||
|
}
|
||||||
|
|
||||||
|
/** Suggested number of seconds the caller should wait before retrying. */
|
||||||
|
public int retryAfterSeconds() {
|
||||||
|
return retryAfterSeconds;
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -10,7 +10,8 @@ public abstract sealed class TrueRefException extends RuntimeException
|
|||||||
VersionNotIndexed,
|
VersionNotIndexed,
|
||||||
TagNotFound,
|
TagNotFound,
|
||||||
IngestionFailed,
|
IngestionFailed,
|
||||||
InvalidSearchRequest {
|
InvalidSearchRequest,
|
||||||
|
ModelNotReady {
|
||||||
|
|
||||||
private final String code;
|
private final String code;
|
||||||
|
|
||||||
|
|||||||
@@ -0,0 +1,13 @@
|
|||||||
|
package com.trueref.domain.model;
|
||||||
|
|
||||||
|
/** Life-cycle state of the GPU inference models (embedder + reranker). */
|
||||||
|
public enum ModelState {
|
||||||
|
/** Models are not loaded; no VRAM is consumed. */
|
||||||
|
UNLOADED,
|
||||||
|
/** Models are being loaded into GPU memory (load in progress). */
|
||||||
|
LOADING,
|
||||||
|
/** Models are loaded and ready to serve inference requests. */
|
||||||
|
LOADED,
|
||||||
|
/** Models are being unloaded from GPU memory. */
|
||||||
|
UNLOADING
|
||||||
|
}
|
||||||
@@ -0,0 +1,18 @@
|
|||||||
|
package com.trueref.domain.model;
|
||||||
|
|
||||||
|
import java.time.Instant;
|
||||||
|
import org.jspecify.annotations.Nullable;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Immutable event emitted whenever the GPU model lifecycle transitions to a new {@link ModelState}.
|
||||||
|
*/
|
||||||
|
public record ModelStateEvent(ModelState state, Instant ts, @Nullable String message) {
|
||||||
|
|
||||||
|
public static ModelStateEvent of(ModelState state) {
|
||||||
|
return new ModelStateEvent(state, Instant.now(), null);
|
||||||
|
}
|
||||||
|
|
||||||
|
public static ModelStateEvent of(ModelState state, String message) {
|
||||||
|
return new ModelStateEvent(state, Instant.now(), message);
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -0,0 +1,47 @@
|
|||||||
|
package com.trueref.domain.port.in;
|
||||||
|
|
||||||
|
import com.trueref.domain.model.ModelState;
|
||||||
|
import com.trueref.domain.model.ModelStateEvent;
|
||||||
|
import java.time.Instant;
|
||||||
|
import java.util.function.Consumer;
|
||||||
|
import org.jspecify.annotations.Nullable;
|
||||||
|
|
||||||
|
/** Use-case port: manage the lifecycle of the GPU inference models. */
|
||||||
|
public interface ManageModelLifecycle {
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Ensures the model is loaded and ready. If the model is already LOADED, returns immediately.
|
||||||
|
* If the model is UNLOADED, triggers an async load and throws
|
||||||
|
* {@link com.trueref.domain.error.ModelNotReady}. If the model is LOADING or UNLOADING,
|
||||||
|
* throws {@link com.trueref.domain.error.ModelNotReady}.
|
||||||
|
*
|
||||||
|
* <p>Also records the current time as the last-activity timestamp used by the idle timer.
|
||||||
|
*/
|
||||||
|
void ensureReady();
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Forces an unload of the GPU models. Blocked if ingestion jobs are running, unless
|
||||||
|
* {@code force=true}.
|
||||||
|
*
|
||||||
|
* @param force when {@code true}, unloads even while jobs are running
|
||||||
|
* @return {@code true} if unload was initiated; {@code false} if blocked by running jobs
|
||||||
|
*/
|
||||||
|
boolean forceUnload(boolean force);
|
||||||
|
|
||||||
|
/** Returns a snapshot of the current model lifecycle status. */
|
||||||
|
Status getStatus();
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Registers a subscriber to receive model state-change events.
|
||||||
|
*
|
||||||
|
* @return an {@link AutoCloseable} that removes the subscription when closed
|
||||||
|
*/
|
||||||
|
AutoCloseable subscribeState(Consumer<ModelStateEvent> subscriber);
|
||||||
|
|
||||||
|
/** Snapshot of the current model lifecycle status. */
|
||||||
|
record Status(
|
||||||
|
ModelState state,
|
||||||
|
@Nullable Instant loadedAt,
|
||||||
|
@Nullable Instant lastActivityAt,
|
||||||
|
long idleTimeoutSeconds) {}
|
||||||
|
}
|
||||||
@@ -0,0 +1,18 @@
|
|||||||
|
package com.trueref.domain.port.out;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Loads and unloads the GPU inference models (embedder + reranker) as a shared lifecycle unit.
|
||||||
|
* Implementations are expected to run load/unload on a platform OS thread to satisfy CUDA
|
||||||
|
* context affinity constraints.
|
||||||
|
*/
|
||||||
|
public interface ModelLoader {
|
||||||
|
|
||||||
|
/** Loads both models into GPU memory. Blocks until ready or throws on error. */
|
||||||
|
void load();
|
||||||
|
|
||||||
|
/** Releases both models from GPU memory. Idempotent. */
|
||||||
|
void unload();
|
||||||
|
|
||||||
|
/** Returns {@code true} if both models are currently loaded and accepting inference. */
|
||||||
|
boolean isLoaded();
|
||||||
|
}
|
||||||
@@ -0,0 +1,18 @@
|
|||||||
|
package com.trueref.domain.port.out;
|
||||||
|
|
||||||
|
import com.trueref.domain.model.ModelStateEvent;
|
||||||
|
import java.util.function.Consumer;
|
||||||
|
|
||||||
|
/** Event bus for broadcasting {@link ModelStateEvent}s to subscribers (e.g. SSE connections). */
|
||||||
|
public interface ModelStateEventBus {
|
||||||
|
|
||||||
|
/** Publishes a state-change event to all current subscribers. */
|
||||||
|
void publish(ModelStateEvent event);
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Registers a subscriber to receive future events.
|
||||||
|
*
|
||||||
|
* @return an {@link AutoCloseable} that removes the subscription when closed
|
||||||
|
*/
|
||||||
|
AutoCloseable subscribe(Consumer<ModelStateEvent> subscriber);
|
||||||
|
}
|
||||||
Reference in New Issue
Block a user