feat: GPU model lazy-load/unload lifecycle management
- Domain: add ModelState, ModelStateEvent, ModelNotReady, ManageModelLifecycle (in-port), ModelLoader and ModelStateEventBus (out-ports) - Application: InMemoryModelStateEventBus; ModelLifecycleService — state machine (ReentrantLock), lazy load on first request, idle-timeout auto-unload (configurable via trueref.embedding.idle-timeout-seconds, default 300 s), job-guard (skips unload while ingestion running), platform-thread CUDA executor - Adapters: OnnxModelLoader wires embedder + reranker start/stop; remove @PostConstruct/@PreDestroy from OnnxEmbeddingService and OnnxRerankerService; requireStarted() now throws ModelNotReady instead of IllegalStateException - REST: GET /api/model/status, POST /api/model/unload (409 when jobs running, force=true to override), GET /api/model/status/stream (SSE) - GlobalExceptionHandler: ModelNotReady -> 503 + Retry-After header - HybridSearchService: calls lifecycle.ensureReady() before every search so both REST and MCP paths get ModelNotReady (-> 503 / MCP error) when unloaded - TrueRefMcpTools: catches ModelNotReady, returns retry hint in MCP error text - Tests: InMemoryModelStateEventBusTest, ModelLifecycleServiceTest (10 cases), OnnxModelLoaderTest, GlobalExceptionHandlerTest — all 41 tests green Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
This commit is contained in:
@@ -6,6 +6,7 @@ import com.trueref.domain.model.SearchHit;
|
||||
import com.trueref.domain.model.SearchScope;
|
||||
import com.trueref.domain.model.Version;
|
||||
import com.trueref.domain.model.VersionStatus;
|
||||
import com.trueref.domain.error.ModelNotReady;
|
||||
import com.trueref.domain.port.in.IndexVersion;
|
||||
import com.trueref.domain.port.in.QueryCatalog;
|
||||
import com.trueref.domain.port.in.ResolveLibraryId;
|
||||
@@ -153,6 +154,9 @@ public class TrueRefMcpTools {
|
||||
SearchLibraryDocs.Result res;
|
||||
try {
|
||||
res = search.search(q);
|
||||
} catch (ModelNotReady e) {
|
||||
return "[model_not_ready] The inference model is loading. "
|
||||
+ "Please retry in ~" + e.retryAfterSeconds() + " seconds.";
|
||||
} catch (Exception e) {
|
||||
log.warn("MCP search failed for {}: {}", libraryId, e.toString());
|
||||
return "Search failed for " + libraryId + ": " + e.getMessage();
|
||||
|
||||
@@ -2,6 +2,7 @@ package com.trueref.adapter.in.rest;
|
||||
|
||||
import com.trueref.domain.error.IngestionFailed;
|
||||
import com.trueref.domain.error.InvalidSearchRequest;
|
||||
import com.trueref.domain.error.ModelNotReady;
|
||||
import com.trueref.domain.error.RepositoryAlreadyRegistered;
|
||||
import com.trueref.domain.error.RepositoryNotFound;
|
||||
import com.trueref.domain.error.TagNotFound;
|
||||
@@ -12,6 +13,7 @@ import jakarta.validation.ConstraintViolationException;
|
||||
import java.util.List;
|
||||
import org.slf4j.Logger;
|
||||
import org.slf4j.LoggerFactory;
|
||||
import org.springframework.http.HttpHeaders;
|
||||
import org.springframework.http.HttpStatus;
|
||||
import org.springframework.http.ResponseEntity;
|
||||
import org.springframework.validation.FieldError;
|
||||
@@ -28,6 +30,15 @@ public class GlobalExceptionHandler {
|
||||
|
||||
private static final Logger log = LoggerFactory.getLogger(GlobalExceptionHandler.class);
|
||||
|
||||
@ExceptionHandler(ModelNotReady.class)
|
||||
public ResponseEntity<ErrorResponse> handleModelNotReady(ModelNotReady ex) {
|
||||
HttpHeaders headers = new HttpHeaders();
|
||||
headers.set(HttpHeaders.RETRY_AFTER, String.valueOf(ex.retryAfterSeconds()));
|
||||
return ResponseEntity.status(HttpStatus.SERVICE_UNAVAILABLE)
|
||||
.headers(headers)
|
||||
.body(ErrorResponse.of(ex.code(), ex.getMessage()));
|
||||
}
|
||||
|
||||
@ExceptionHandler({RepositoryNotFound.class, VersionNotFound.class, TagNotFound.class})
|
||||
public ResponseEntity<ErrorResponse> handleNotFound(TrueRefException ex) {
|
||||
return status(HttpStatus.NOT_FOUND, ex);
|
||||
|
||||
@@ -0,0 +1,106 @@
|
||||
package com.trueref.adapter.in.rest;
|
||||
|
||||
import com.trueref.adapter.in.rest.dto.ModelStatusDto;
|
||||
import com.trueref.domain.model.ModelState;
|
||||
import com.trueref.domain.port.in.ManageModelLifecycle;
|
||||
import io.swagger.v3.oas.annotations.Operation;
|
||||
import io.swagger.v3.oas.annotations.tags.Tag;
|
||||
import java.io.IOException;
|
||||
import org.slf4j.Logger;
|
||||
import org.slf4j.LoggerFactory;
|
||||
import org.springframework.http.HttpStatus;
|
||||
import org.springframework.http.MediaType;
|
||||
import org.springframework.web.bind.annotation.GetMapping;
|
||||
import org.springframework.web.bind.annotation.PostMapping;
|
||||
import org.springframework.web.bind.annotation.RequestMapping;
|
||||
import org.springframework.web.bind.annotation.RequestParam;
|
||||
import org.springframework.web.bind.annotation.RestController;
|
||||
import org.springframework.web.server.ResponseStatusException;
|
||||
import org.springframework.web.servlet.mvc.method.annotation.SseEmitter;
|
||||
|
||||
/** REST resource: {@code /api/model}. */
|
||||
@RestController
|
||||
@RequestMapping("/api/model")
|
||||
@Tag(name = "model", description = "GPU model lifecycle management (load state, force unload, SSE status stream).")
|
||||
public class ModelController {
|
||||
|
||||
private static final Logger log = LoggerFactory.getLogger(ModelController.class);
|
||||
|
||||
private final ManageModelLifecycle lifecycle;
|
||||
|
||||
public ModelController(ManageModelLifecycle lifecycle) {
|
||||
this.lifecycle = lifecycle;
|
||||
}
|
||||
|
||||
@Operation(summary = "Current model lifecycle status.")
|
||||
@GetMapping("/status")
|
||||
public ModelStatusDto status() {
|
||||
return ModelStatusDto.of(lifecycle.getStatus());
|
||||
}
|
||||
|
||||
@Operation(summary = "Force-unload the GPU models from VRAM. "
|
||||
+ "Returns 409 if ingestion jobs are running (use force=true to override).")
|
||||
@PostMapping("/unload")
|
||||
public ModelStatusDto unload(@RequestParam(value = "force", defaultValue = "false") boolean force) {
|
||||
boolean initiated = lifecycle.forceUnload(force);
|
||||
if (!initiated) {
|
||||
throw new ResponseStatusException(
|
||||
HttpStatus.CONFLICT,
|
||||
"model_unload_blocked_by_running_jobs");
|
||||
}
|
||||
return ModelStatusDto.of(lifecycle.getStatus());
|
||||
}
|
||||
|
||||
@Operation(summary = "Server-Sent Events stream of model state transitions "
|
||||
+ "(LOADING, LOADED, UNLOADING, UNLOADED).")
|
||||
@GetMapping(value = "/status/stream", produces = MediaType.TEXT_EVENT_STREAM_VALUE)
|
||||
public SseEmitter statusStream() {
|
||||
SseEmitter emitter = new SseEmitter(0L);
|
||||
|
||||
// Flush headers to the client immediately so EventSource fires 'open'.
|
||||
try {
|
||||
emitter.send(SseEmitter.event().name("ping").data(""));
|
||||
// Send current state immediately so the client doesn't have to wait for the next transition.
|
||||
emitter.send(SseEmitter.event()
|
||||
.name("model-status")
|
||||
.data(lifecycle.getStatus().state().name()));
|
||||
} catch (IOException e) {
|
||||
emitter.completeWithError(e);
|
||||
return emitter;
|
||||
}
|
||||
|
||||
AutoCloseable subscription = lifecycle.subscribeState(event -> {
|
||||
try {
|
||||
emitter.send(SseEmitter.event().name("model-status").data(event));
|
||||
} catch (IOException ex) {
|
||||
emitter.completeWithError(ex);
|
||||
}
|
||||
});
|
||||
|
||||
Thread keepalive = Thread.startVirtualThread(() -> {
|
||||
try {
|
||||
while (!Thread.currentThread().isInterrupted()) {
|
||||
Thread.sleep(20_000);
|
||||
emitter.send(SseEmitter.event().name("ping").data(""));
|
||||
}
|
||||
} catch (InterruptedException ignored) {
|
||||
// normal shutdown
|
||||
} catch (Exception ignored) {
|
||||
// emitter already completed; exit
|
||||
}
|
||||
});
|
||||
|
||||
Runnable cleanup = () -> {
|
||||
keepalive.interrupt();
|
||||
try {
|
||||
subscription.close();
|
||||
} catch (Exception ex) {
|
||||
log.debug("failed to close model status SSE subscription: {}", ex.toString());
|
||||
}
|
||||
};
|
||||
emitter.onCompletion(cleanup);
|
||||
emitter.onTimeout(cleanup);
|
||||
emitter.onError(e -> cleanup.run());
|
||||
return emitter;
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,22 @@
|
||||
package com.trueref.adapter.in.rest.dto;
|
||||
|
||||
import com.trueref.domain.model.ModelState;
|
||||
import com.trueref.domain.port.in.ManageModelLifecycle;
|
||||
import java.time.Instant;
|
||||
import org.jspecify.annotations.Nullable;
|
||||
|
||||
/** Snapshot of the GPU model lifecycle state, returned by {@code GET /api/model/status}. */
|
||||
public record ModelStatusDto(
|
||||
ModelState state,
|
||||
@Nullable Instant loadedAt,
|
||||
@Nullable Instant lastActivityAt,
|
||||
long idleTimeoutSeconds) {
|
||||
|
||||
public static ModelStatusDto of(ManageModelLifecycle.Status status) {
|
||||
return new ModelStatusDto(
|
||||
status.state(),
|
||||
status.loadedAt(),
|
||||
status.lastActivityAt(),
|
||||
status.idleTimeoutSeconds());
|
||||
}
|
||||
}
|
||||
@@ -1,5 +1,11 @@
|
||||
package com.trueref.adapter.out.embedding.onnx;
|
||||
|
||||
import com.trueref.application.model.InMemoryModelStateEventBus;
|
||||
import com.trueref.application.model.ModelLifecycleService;
|
||||
import com.trueref.domain.port.in.ManageModelLifecycle;
|
||||
import com.trueref.domain.port.in.ObserveJobs;
|
||||
import com.trueref.domain.port.out.ModelLoader;
|
||||
import com.trueref.domain.port.out.ModelStateEventBus;
|
||||
import org.springframework.beans.factory.annotation.Value;
|
||||
import org.springframework.boot.context.properties.EnableConfigurationProperties;
|
||||
import org.springframework.context.annotation.Bean;
|
||||
@@ -9,7 +15,8 @@ import java.nio.file.Path;
|
||||
|
||||
/**
|
||||
* Spring wiring for the ONNX embedding/reranker stack. Produces the shared {@link GpuSemaphore}
|
||||
* (permits = {@code session-count}) and the resolved models-home {@link Path}.
|
||||
* (permits = {@code session-count}), the resolved models-home {@link Path}, and the
|
||||
* {@link ModelLifecycleService} that manages lazy loading and idle-timeout unloading.
|
||||
*/
|
||||
@Configuration
|
||||
@EnableConfigurationProperties(OnnxProperties.class)
|
||||
@@ -29,4 +36,27 @@ public class OnnxEmbeddingConfig {
|
||||
Path home = properties.home();
|
||||
return home != null ? home : Path.of(trueRefHome).resolve("models");
|
||||
}
|
||||
|
||||
@Bean
|
||||
ModelStateEventBus modelStateEventBus() {
|
||||
return new InMemoryModelStateEventBus();
|
||||
}
|
||||
|
||||
@Bean
|
||||
ModelLoader onnxModelLoader(OnnxEmbeddingService embeddingService, OnnxRerankerService rerankerService) {
|
||||
return new OnnxModelLoader(embeddingService, rerankerService);
|
||||
}
|
||||
|
||||
@Bean(initMethod = "init", destroyMethod = "shutdown")
|
||||
ManageModelLifecycle modelLifecycleService(
|
||||
ModelLoader modelLoader,
|
||||
ModelStateEventBus modelStateEventBus,
|
||||
ObserveJobs observeJobs,
|
||||
OnnxProperties properties) {
|
||||
return new ModelLifecycleService(
|
||||
modelLoader,
|
||||
modelStateEventBus,
|
||||
observeJobs,
|
||||
properties.idleTimeoutSecondsOrDefault());
|
||||
}
|
||||
}
|
||||
|
||||
@@ -7,9 +7,8 @@ import ai.onnxruntime.OrtException;
|
||||
import ai.onnxruntime.OrtSession;
|
||||
import com.trueref.adapter.out.embedding.onnx.HuggingFaceTokenizerWrapper.EncodedBatch;
|
||||
import com.trueref.domain.error.IngestionFailed;
|
||||
import com.trueref.domain.error.ModelNotReady;
|
||||
import com.trueref.domain.port.out.EmbeddingService;
|
||||
import jakarta.annotation.PostConstruct;
|
||||
import jakarta.annotation.PreDestroy;
|
||||
import java.nio.LongBuffer;
|
||||
import java.nio.file.Path;
|
||||
import java.util.ArrayList;
|
||||
@@ -68,8 +67,13 @@ public class OnnxEmbeddingService implements EmbeddingService {
|
||||
this.modelsHome = modelsHome;
|
||||
}
|
||||
|
||||
@PostConstruct
|
||||
void start() {
|
||||
/** Returns {@code true} if the model is currently loaded and accepting requests. */
|
||||
public boolean isStarted() {
|
||||
return pool != null && tokenizer != null;
|
||||
}
|
||||
|
||||
/** Loads the ONNX model into memory. Called by {@link OnnxModelLoader} on demand. */
|
||||
public void start() {
|
||||
if (properties.sessionCountOrDefault() <= 0) {
|
||||
log.warn("OnnxEmbeddingService DISABLED (session-count=0); embed() calls will fail");
|
||||
return;
|
||||
@@ -104,8 +108,8 @@ public class OnnxEmbeddingService implements EmbeddingService {
|
||||
usesTokenTypeIds);
|
||||
}
|
||||
|
||||
@PreDestroy
|
||||
void stop() {
|
||||
/** Unloads the ONNX model from memory. Called by {@link OnnxModelLoader} on unload. */
|
||||
public void stop() {
|
||||
if (pool != null) {
|
||||
pool.close();
|
||||
pool = null;
|
||||
@@ -257,7 +261,7 @@ public class OnnxEmbeddingService implements EmbeddingService {
|
||||
|
||||
private static <T> T requireStarted(@org.jspecify.annotations.Nullable T t, String what) {
|
||||
if (t == null) {
|
||||
throw new IllegalStateException("OnnxEmbeddingService not started: " + what + " is null");
|
||||
throw new ModelNotReady(30);
|
||||
}
|
||||
return t;
|
||||
}
|
||||
|
||||
@@ -0,0 +1,54 @@
|
||||
package com.trueref.adapter.out.embedding.onnx;
|
||||
|
||||
import com.trueref.domain.port.out.ModelLoader;
|
||||
import org.slf4j.Logger;
|
||||
import org.slf4j.LoggerFactory;
|
||||
|
||||
/**
|
||||
* Implements {@link ModelLoader} by delegating to the ONNX embedding and reranker services.
|
||||
* Both services are loaded/unloaded as a unit. This bean drives their lifecycle instead of
|
||||
* {@code @PostConstruct} / {@code @PreDestroy} so that models are loaded lazily on first demand.
|
||||
*/
|
||||
public class OnnxModelLoader implements ModelLoader {
|
||||
|
||||
private static final Logger log = LoggerFactory.getLogger(OnnxModelLoader.class);
|
||||
|
||||
private final OnnxEmbeddingService embeddingService;
|
||||
private final OnnxRerankerService rerankerService;
|
||||
|
||||
public OnnxModelLoader(OnnxEmbeddingService embeddingService, OnnxRerankerService rerankerService) {
|
||||
this.embeddingService = embeddingService;
|
||||
this.rerankerService = rerankerService;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void load() {
|
||||
log.info("OnnxModelLoader: loading embedder…");
|
||||
embeddingService.start();
|
||||
log.info("OnnxModelLoader: loading reranker…");
|
||||
rerankerService.start();
|
||||
log.info("OnnxModelLoader: both models loaded");
|
||||
}
|
||||
|
||||
@Override
|
||||
public void unload() {
|
||||
log.info("OnnxModelLoader: unloading reranker…");
|
||||
try {
|
||||
rerankerService.stop();
|
||||
} catch (Exception e) {
|
||||
log.warn("Reranker stop error (continuing): {}", e.toString());
|
||||
}
|
||||
log.info("OnnxModelLoader: unloading embedder…");
|
||||
try {
|
||||
embeddingService.stop();
|
||||
} catch (Exception e) {
|
||||
log.warn("Embedder stop error (continuing): {}", e.toString());
|
||||
}
|
||||
log.info("OnnxModelLoader: both models unloaded");
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean isLoaded() {
|
||||
return embeddingService.isStarted() && rerankerService.isStarted();
|
||||
}
|
||||
}
|
||||
@@ -21,6 +21,7 @@ public record OnnxProperties(
|
||||
@Nullable Integer maxSeqLen,
|
||||
@Nullable Integer gpuDeviceId,
|
||||
@Nullable Long gpuMemLimitBytes,
|
||||
@Nullable Long idleTimeoutSeconds,
|
||||
@Nullable Path home,
|
||||
@Nullable Map<String, Map<String, List<String>>> modelSources) {
|
||||
|
||||
@@ -37,6 +38,7 @@ public record OnnxProperties(
|
||||
if (maxSeqLen == null || maxSeqLen <= 0) maxSeqLen = 512;
|
||||
if (gpuDeviceId == null || gpuDeviceId < 0) gpuDeviceId = 0;
|
||||
if (gpuMemLimitBytes == null || gpuMemLimitBytes <= 0L) gpuMemLimitBytes = 0L; // 0 = no cap
|
||||
if (idleTimeoutSeconds == null || idleTimeoutSeconds <= 0L) idleTimeoutSeconds = 300L; // 5 min
|
||||
if (modelSources == null) modelSources = Map.of();
|
||||
}
|
||||
|
||||
@@ -75,4 +77,8 @@ public record OnnxProperties(
|
||||
public long gpuMemLimitBytesOrDefault() {
|
||||
return gpuMemLimitBytes == null ? 0L : gpuMemLimitBytes;
|
||||
}
|
||||
|
||||
public long idleTimeoutSecondsOrDefault() {
|
||||
return idleTimeoutSeconds == null ? 300L : idleTimeoutSeconds;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -7,10 +7,9 @@ import ai.onnxruntime.OrtException;
|
||||
import ai.onnxruntime.OrtSession;
|
||||
import com.trueref.adapter.out.embedding.onnx.HuggingFaceTokenizerWrapper.EncodedBatch;
|
||||
import com.trueref.domain.error.IngestionFailed;
|
||||
import com.trueref.domain.error.ModelNotReady;
|
||||
import com.trueref.domain.model.SearchHit;
|
||||
import com.trueref.domain.port.out.RerankerService;
|
||||
import jakarta.annotation.PostConstruct;
|
||||
import jakarta.annotation.PreDestroy;
|
||||
import java.nio.LongBuffer;
|
||||
import java.nio.file.Path;
|
||||
import java.util.ArrayList;
|
||||
@@ -62,8 +61,13 @@ public class OnnxRerankerService implements RerankerService {
|
||||
this.modelsHome = modelsHome;
|
||||
}
|
||||
|
||||
@PostConstruct
|
||||
void start() {
|
||||
/** Returns {@code true} if the reranker model is currently loaded and accepting requests. */
|
||||
public boolean isStarted() {
|
||||
return pool != null && tokenizer != null;
|
||||
}
|
||||
|
||||
/** Loads the reranker model into memory. Called by {@link OnnxModelLoader} on demand. */
|
||||
public void start() {
|
||||
if (properties.rerankerSessionCountOrDefault() <= 0) {
|
||||
log.warn("OnnxRerankerService DISABLED (reranker-session-count=0); rerank() passes through input order");
|
||||
return;
|
||||
@@ -96,8 +100,8 @@ public class OnnxRerankerService implements RerankerService {
|
||||
usesTokenTypeIds);
|
||||
}
|
||||
|
||||
@PreDestroy
|
||||
void stop() {
|
||||
/** Unloads the reranker model from memory. Called by {@link OnnxModelLoader} on unload. */
|
||||
public void stop() {
|
||||
if (pool != null) {
|
||||
pool.close();
|
||||
pool = null;
|
||||
@@ -252,7 +256,7 @@ public class OnnxRerankerService implements RerankerService {
|
||||
|
||||
private static <T> T requireStarted(@org.jspecify.annotations.Nullable T t, String what) {
|
||||
if (t == null) {
|
||||
throw new IllegalStateException("OnnxRerankerService not started: " + what + " is null");
|
||||
throw new ModelNotReady(30);
|
||||
}
|
||||
return t;
|
||||
}
|
||||
|
||||
@@ -0,0 +1,44 @@
|
||||
package com.trueref.adapter.in.rest;
|
||||
|
||||
import static org.assertj.core.api.Assertions.assertThat;
|
||||
|
||||
import com.trueref.domain.error.ModelNotReady;
|
||||
import org.junit.jupiter.api.Test;
|
||||
import org.springframework.http.HttpHeaders;
|
||||
import org.springframework.http.HttpStatus;
|
||||
import org.springframework.http.ResponseEntity;
|
||||
|
||||
class GlobalExceptionHandlerTest {
|
||||
|
||||
private final GlobalExceptionHandler handler = new GlobalExceptionHandler();
|
||||
|
||||
@Test
|
||||
void modelNotReadyMapsto503WithRetryAfterHeader() {
|
||||
ModelNotReady ex = new ModelNotReady(30);
|
||||
|
||||
ResponseEntity<ErrorResponse> response = handler.handleModelNotReady(ex);
|
||||
|
||||
assertThat(response.getStatusCode()).isEqualTo(HttpStatus.SERVICE_UNAVAILABLE);
|
||||
assertThat(response.getHeaders().getFirst(HttpHeaders.RETRY_AFTER)).isEqualTo("30");
|
||||
}
|
||||
|
||||
@Test
|
||||
void modelNotReadyResponseBodyHasModelNotReadyCode() {
|
||||
ModelNotReady ex = new ModelNotReady(15);
|
||||
|
||||
ResponseEntity<ErrorResponse> response = handler.handleModelNotReady(ex);
|
||||
|
||||
assertThat(response.getBody()).isNotNull();
|
||||
assertThat(response.getBody().code()).isEqualTo("model_not_ready");
|
||||
assertThat(response.getBody().message()).contains("15");
|
||||
}
|
||||
|
||||
@Test
|
||||
void retryAfterReflectsExceptionValue() {
|
||||
ModelNotReady ex = new ModelNotReady(60);
|
||||
|
||||
ResponseEntity<ErrorResponse> response = handler.handleModelNotReady(ex);
|
||||
|
||||
assertThat(response.getHeaders().getFirst(HttpHeaders.RETRY_AFTER)).isEqualTo("60");
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,71 @@
|
||||
package com.trueref.adapter.out.embedding.onnx;
|
||||
|
||||
import static org.assertj.core.api.Assertions.assertThat;
|
||||
import static org.mockito.Mockito.doThrow;
|
||||
import static org.mockito.Mockito.mock;
|
||||
import static org.mockito.Mockito.never;
|
||||
import static org.mockito.Mockito.verify;
|
||||
import static org.mockito.Mockito.when;
|
||||
|
||||
import org.junit.jupiter.api.Test;
|
||||
|
||||
class OnnxModelLoaderTest {
|
||||
|
||||
@Test
|
||||
void loadCallsStartOnBothServices() {
|
||||
OnnxEmbeddingService embedder = mock(OnnxEmbeddingService.class);
|
||||
OnnxRerankerService reranker = mock(OnnxRerankerService.class);
|
||||
OnnxModelLoader loader = new OnnxModelLoader(embedder, reranker);
|
||||
|
||||
loader.load();
|
||||
|
||||
verify(embedder).start();
|
||||
verify(reranker).start();
|
||||
}
|
||||
|
||||
@Test
|
||||
void unloadCallsStopOnBothServices() {
|
||||
OnnxEmbeddingService embedder = mock(OnnxEmbeddingService.class);
|
||||
OnnxRerankerService reranker = mock(OnnxRerankerService.class);
|
||||
OnnxModelLoader loader = new OnnxModelLoader(embedder, reranker);
|
||||
|
||||
loader.unload();
|
||||
|
||||
verify(embedder).stop();
|
||||
verify(reranker).stop();
|
||||
}
|
||||
|
||||
@Test
|
||||
void unloadContinuesEvenIfRerankerStopThrows() {
|
||||
OnnxEmbeddingService embedder = mock(OnnxEmbeddingService.class);
|
||||
OnnxRerankerService reranker = mock(OnnxRerankerService.class);
|
||||
doThrow(new RuntimeException("VRAM error")).when(reranker).stop();
|
||||
OnnxModelLoader loader = new OnnxModelLoader(embedder, reranker);
|
||||
|
||||
loader.unload(); // must not throw
|
||||
|
||||
verify(embedder).stop();
|
||||
}
|
||||
|
||||
@Test
|
||||
void isLoadedReturnsTrueWhenBothServicesStarted() {
|
||||
OnnxEmbeddingService embedder = mock(OnnxEmbeddingService.class);
|
||||
OnnxRerankerService reranker = mock(OnnxRerankerService.class);
|
||||
when(embedder.isStarted()).thenReturn(true);
|
||||
when(reranker.isStarted()).thenReturn(true);
|
||||
OnnxModelLoader loader = new OnnxModelLoader(embedder, reranker);
|
||||
|
||||
assertThat(loader.isLoaded()).isTrue();
|
||||
}
|
||||
|
||||
@Test
|
||||
void isLoadedReturnsFalseWhenEitherServiceNotStarted() {
|
||||
OnnxEmbeddingService embedder = mock(OnnxEmbeddingService.class);
|
||||
OnnxRerankerService reranker = mock(OnnxRerankerService.class);
|
||||
when(embedder.isStarted()).thenReturn(true);
|
||||
when(reranker.isStarted()).thenReturn(false);
|
||||
OnnxModelLoader loader = new OnnxModelLoader(embedder, reranker);
|
||||
|
||||
assertThat(loader.isLoaded()).isFalse();
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,75 @@
|
||||
package com.trueref.application.model;
|
||||
|
||||
import static org.assertj.core.api.Assertions.assertThat;
|
||||
|
||||
import com.trueref.domain.model.ModelState;
|
||||
import com.trueref.domain.model.ModelStateEvent;
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
import java.util.concurrent.atomic.AtomicInteger;
|
||||
import org.junit.jupiter.api.Test;
|
||||
|
||||
class InMemoryModelStateEventBusTest {
|
||||
|
||||
@Test
|
||||
void singleSubscriberReceivesPublishedEvent() {
|
||||
InMemoryModelStateEventBus bus = new InMemoryModelStateEventBus();
|
||||
List<ModelStateEvent> received = new ArrayList<>();
|
||||
bus.subscribe(received::add);
|
||||
|
||||
bus.publish(ModelStateEvent.of(ModelState.LOADING));
|
||||
|
||||
assertThat(received).hasSize(1);
|
||||
assertThat(received.get(0).state()).isEqualTo(ModelState.LOADING);
|
||||
}
|
||||
|
||||
@Test
|
||||
void multipleSubscribersAllReceiveEvent() {
|
||||
InMemoryModelStateEventBus bus = new InMemoryModelStateEventBus();
|
||||
List<ModelStateEvent> a = new ArrayList<>();
|
||||
List<ModelStateEvent> b = new ArrayList<>();
|
||||
bus.subscribe(a::add);
|
||||
bus.subscribe(b::add);
|
||||
|
||||
bus.publish(ModelStateEvent.of(ModelState.LOADED));
|
||||
|
||||
assertThat(a).hasSize(1);
|
||||
assertThat(b).hasSize(1);
|
||||
}
|
||||
|
||||
@Test
|
||||
void closingSubscriptionRemovesIt() throws Exception {
|
||||
InMemoryModelStateEventBus bus = new InMemoryModelStateEventBus();
|
||||
List<ModelStateEvent> received = new ArrayList<>();
|
||||
AutoCloseable handle = bus.subscribe(received::add);
|
||||
|
||||
handle.close();
|
||||
bus.publish(ModelStateEvent.of(ModelState.UNLOADED));
|
||||
|
||||
assertThat(received).isEmpty();
|
||||
}
|
||||
|
||||
@Test
|
||||
void exceptionInSubscriberDoesNotBreakOthers() {
|
||||
InMemoryModelStateEventBus bus = new InMemoryModelStateEventBus();
|
||||
AtomicInteger badCalls = new AtomicInteger();
|
||||
List<ModelStateEvent> good = new ArrayList<>();
|
||||
|
||||
bus.subscribe(e -> {
|
||||
badCalls.incrementAndGet();
|
||||
throw new RuntimeException("intentional test failure");
|
||||
});
|
||||
bus.subscribe(good::add);
|
||||
|
||||
bus.publish(ModelStateEvent.of(ModelState.LOADING));
|
||||
|
||||
assertThat(badCalls.get()).isEqualTo(1);
|
||||
assertThat(good).hasSize(1);
|
||||
}
|
||||
|
||||
@Test
|
||||
void noSubscribersMeansPublishIsNoop() {
|
||||
InMemoryModelStateEventBus bus = new InMemoryModelStateEventBus();
|
||||
bus.publish(ModelStateEvent.of(ModelState.UNLOADED)); // must not throw
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,227 @@
|
||||
package com.trueref.application.model;
|
||||
|
||||
import static org.assertj.core.api.Assertions.assertThat;
|
||||
import static org.assertj.core.api.Assertions.assertThatThrownBy;
|
||||
|
||||
import com.trueref.domain.error.ModelNotReady;
|
||||
import com.trueref.domain.model.IngestionJob;
|
||||
import com.trueref.domain.model.JobId;
|
||||
import com.trueref.domain.model.JobStatus;
|
||||
import com.trueref.domain.model.JobType;
|
||||
import com.trueref.domain.model.ModelState;
|
||||
import com.trueref.domain.model.ModelStateEvent;
|
||||
import com.trueref.domain.model.RepositoryId;
|
||||
import com.trueref.domain.port.in.ManageModelLifecycle;
|
||||
import com.trueref.domain.port.in.ObserveJobs;
|
||||
import com.trueref.domain.port.out.ModelLoader;
|
||||
import com.trueref.domain.port.out.ModelStateEventBus;
|
||||
import java.time.Instant;
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
import java.util.Optional;
|
||||
import java.util.concurrent.CountDownLatch;
|
||||
import java.util.concurrent.TimeUnit;
|
||||
import java.util.concurrent.atomic.AtomicBoolean;
|
||||
import java.util.function.Consumer;
|
||||
import org.junit.jupiter.api.AfterEach;
|
||||
import org.junit.jupiter.api.BeforeEach;
|
||||
import org.junit.jupiter.api.Test;
|
||||
|
||||
class ModelLifecycleServiceTest {
|
||||
|
||||
// ---- Test doubles ----
|
||||
|
||||
private final AtomicBoolean loaderLoaded = new AtomicBoolean(false);
|
||||
private volatile boolean slowLoad = false;
|
||||
|
||||
private final ModelLoader loader = new ModelLoader() {
|
||||
@Override
|
||||
public void load() {
|
||||
if (slowLoad) {
|
||||
try { Thread.sleep(200); } catch (InterruptedException e) { Thread.currentThread().interrupt(); }
|
||||
}
|
||||
loaderLoaded.set(true);
|
||||
}
|
||||
@Override
|
||||
public void unload() { loaderLoaded.set(false); }
|
||||
@Override
|
||||
public boolean isLoaded() { return loaderLoaded.get(); }
|
||||
};
|
||||
|
||||
private final InMemoryModelStateEventBus eventBus = new InMemoryModelStateEventBus();
|
||||
|
||||
private volatile List<IngestionJob> runningJobs = List.of();
|
||||
|
||||
private final ObserveJobs observeJobs = new ObserveJobs() {
|
||||
@Override
|
||||
public Optional<IngestionJob> findJob(JobId id) { return Optional.empty(); }
|
||||
@Override
|
||||
public List<IngestionJob> listJobs(com.trueref.domain.model.RepositoryId r,
|
||||
com.trueref.domain.model.VersionId v, JobStatus status, int limit) {
|
||||
if (status == JobStatus.RUNNING) return runningJobs;
|
||||
return List.of();
|
||||
}
|
||||
@Override
|
||||
public AutoCloseable subscribeJobs(Consumer<IngestionJob> listener) { return () -> {}; }
|
||||
@Override
|
||||
public AutoCloseable subscribeLogs(JobId jobId, Consumer<com.trueref.domain.model.JobLogEvent> listener) { return () -> {}; }
|
||||
};
|
||||
|
||||
private ModelLifecycleService service;
|
||||
|
||||
@BeforeEach
|
||||
void setUp() {
|
||||
loaderLoaded.set(false);
|
||||
slowLoad = false;
|
||||
runningJobs = List.of();
|
||||
service = new ModelLifecycleService(loader, eventBus, observeJobs, 300L);
|
||||
service.init();
|
||||
}
|
||||
|
||||
@AfterEach
|
||||
void tearDown() {
|
||||
service.shutdown();
|
||||
}
|
||||
|
||||
// ---- Tests ----
|
||||
|
||||
@Test
|
||||
void startsInUnloadedState() {
|
||||
assertThat(service.getStatus().state()).isEqualTo(ModelState.UNLOADED);
|
||||
}
|
||||
|
||||
@Test
|
||||
void ensureReadyOnUnloadedTriggersLoadAndThrows() throws InterruptedException {
|
||||
List<ModelStateEvent> events = new ArrayList<>();
|
||||
eventBus.subscribe(events::add);
|
||||
|
||||
assertThatThrownBy(() -> service.ensureReady())
|
||||
.isInstanceOf(ModelNotReady.class);
|
||||
|
||||
// Wait for async load to complete (max 2 s in test)
|
||||
for (int i = 0; i < 20 && service.getStatus().state() != ModelState.LOADED; i++) {
|
||||
Thread.sleep(100);
|
||||
}
|
||||
|
||||
assertThat(service.getStatus().state()).isEqualTo(ModelState.LOADED);
|
||||
assertThat(loaderLoaded.get()).isTrue();
|
||||
assertThat(events).extracting(ModelStateEvent::state)
|
||||
.containsSequence(ModelState.LOADING, ModelState.LOADED);
|
||||
}
|
||||
|
||||
@Test
|
||||
void ensureReadyOnLoadedIsNoop() throws InterruptedException {
|
||||
triggerLoad();
|
||||
|
||||
// Should not throw
|
||||
service.ensureReady();
|
||||
service.ensureReady();
|
||||
|
||||
assertThat(service.getStatus().state()).isEqualTo(ModelState.LOADED);
|
||||
}
|
||||
|
||||
@Test
|
||||
void ensureReadyOnLoadingThrowsImmediately() {
|
||||
slowLoad = true;
|
||||
assertThatThrownBy(() -> service.ensureReady())
|
||||
.isInstanceOf(ModelNotReady.class);
|
||||
// state is LOADING, second call also throws
|
||||
assertThatThrownBy(() -> service.ensureReady())
|
||||
.isInstanceOf(ModelNotReady.class);
|
||||
}
|
||||
|
||||
@Test
|
||||
void forceUnloadFromLoadedState() throws InterruptedException {
|
||||
triggerLoad();
|
||||
|
||||
boolean initiated = service.forceUnload(false);
|
||||
|
||||
assertThat(initiated).isTrue();
|
||||
// Wait for async unload
|
||||
for (int i = 0; i < 20 && service.getStatus().state() != ModelState.UNLOADED; i++) {
|
||||
Thread.sleep(100);
|
||||
}
|
||||
assertThat(service.getStatus().state()).isEqualTo(ModelState.UNLOADED);
|
||||
assertThat(loaderLoaded.get()).isFalse();
|
||||
}
|
||||
|
||||
@Test
|
||||
void forceUnloadBlockedByRunningJobsWithoutForce() throws InterruptedException {
|
||||
triggerLoad();
|
||||
runningJobs = List.of(runningJob());
|
||||
|
||||
boolean initiated = service.forceUnload(false);
|
||||
|
||||
assertThat(initiated).isFalse();
|
||||
assertThat(service.getStatus().state()).isEqualTo(ModelState.LOADED);
|
||||
}
|
||||
|
||||
@Test
|
||||
void forceUnloadSucceedsWithForceEvenWhenJobsRunning() throws InterruptedException {
|
||||
triggerLoad();
|
||||
runningJobs = List.of(runningJob());
|
||||
|
||||
boolean initiated = service.forceUnload(true);
|
||||
|
||||
assertThat(initiated).isTrue();
|
||||
for (int i = 0; i < 20 && service.getStatus().state() != ModelState.UNLOADED; i++) {
|
||||
Thread.sleep(100);
|
||||
}
|
||||
assertThat(service.getStatus().state()).isEqualTo(ModelState.UNLOADED);
|
||||
}
|
||||
|
||||
@Test
|
||||
void forceUnloadWhenAlreadyUnloadedIsNoop() {
|
||||
boolean initiated = service.forceUnload(false);
|
||||
assertThat(initiated).isTrue();
|
||||
assertThat(service.getStatus().state()).isEqualTo(ModelState.UNLOADED);
|
||||
}
|
||||
|
||||
@Test
|
||||
void concurrentEnsureReadyDeduplicatesLoad() throws InterruptedException {
|
||||
int callCount = 0;
|
||||
// Both calls throw ModelNotReady but only one load should be triggered
|
||||
List<ModelStateEvent> events = new ArrayList<>();
|
||||
eventBus.subscribe(events::add);
|
||||
|
||||
try { service.ensureReady(); } catch (ModelNotReady ignored) { callCount++; }
|
||||
try { service.ensureReady(); } catch (ModelNotReady ignored) { callCount++; }
|
||||
|
||||
for (int i = 0; i < 20 && service.getStatus().state() != ModelState.LOADED; i++) {
|
||||
Thread.sleep(100);
|
||||
}
|
||||
|
||||
assertThat(callCount).isEqualTo(2);
|
||||
// Only one LOADING event should be emitted (deduplicated)
|
||||
long loadingCount = events.stream().filter(e -> e.state() == ModelState.LOADING).count();
|
||||
assertThat(loadingCount).isEqualTo(1);
|
||||
}
|
||||
|
||||
@Test
|
||||
void statusIncludesLoadedAtAfterLoad() throws InterruptedException {
|
||||
assertThat(service.getStatus().loadedAt()).isNull();
|
||||
triggerLoad();
|
||||
assertThat(service.getStatus().loadedAt()).isNotNull();
|
||||
}
|
||||
|
||||
// ---- helpers ----
|
||||
|
||||
private void triggerLoad() throws InterruptedException {
|
||||
try { service.ensureReady(); } catch (ModelNotReady ignored) {}
|
||||
for (int i = 0; i < 20 && service.getStatus().state() != ModelState.LOADED; i++) {
|
||||
Thread.sleep(100);
|
||||
}
|
||||
}
|
||||
|
||||
private IngestionJob runningJob() {
|
||||
return new IngestionJob(
|
||||
JobId.of(java.util.UUID.randomUUID().toString()),
|
||||
RepositoryId.of(java.util.UUID.randomUUID().toString()),
|
||||
null,
|
||||
JobType.INDEX_VERSION,
|
||||
JobStatus.RUNNING,
|
||||
Instant.now(),
|
||||
null,
|
||||
List.of());
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user