feat: GPU model lazy-load/unload lifecycle management
All checks were successful
Build and publish Docker image / Build and push CPU image (push) Successful in 2m33s
Build and publish Docker image / Build and push GPU image (push) Successful in 3m15s

- Domain: add ModelState, ModelStateEvent, ModelNotReady, ManageModelLifecycle
  (in-port), ModelLoader and ModelStateEventBus (out-ports)
- Application: InMemoryModelStateEventBus; ModelLifecycleService — state
  machine (ReentrantLock), lazy load on first request, idle-timeout auto-unload
  (configurable via trueref.embedding.idle-timeout-seconds, default 300 s),
  job-guard (skips unload while ingestion running), platform-thread CUDA executor
- Adapters: OnnxModelLoader wires embedder + reranker start/stop; remove
  @PostConstruct/@PreDestroy from OnnxEmbeddingService and OnnxRerankerService;
  requireStarted() now throws ModelNotReady instead of IllegalStateException
- REST: GET /api/model/status, POST /api/model/unload (409 when jobs running,
  force=true to override), GET /api/model/status/stream (SSE)
- GlobalExceptionHandler: ModelNotReady -> 503 + Retry-After header
- HybridSearchService: calls lifecycle.ensureReady() before every search so
  both REST and MCP paths get ModelNotReady (-> 503 / MCP error) when unloaded
- TrueRefMcpTools: catches ModelNotReady, returns retry hint in MCP error text
- Tests: InMemoryModelStateEventBusTest, ModelLifecycleServiceTest (10 cases),
  OnnxModelLoaderTest, GlobalExceptionHandlerTest — all 41 tests green

Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
This commit is contained in:
moze
2026-05-09 15:44:33 +02:00
parent 943a38fd36
commit 5c6085df99
24 changed files with 1144 additions and 17 deletions

View File

@@ -6,6 +6,7 @@ import com.trueref.domain.model.SearchHit;
import com.trueref.domain.model.SearchScope;
import com.trueref.domain.model.Version;
import com.trueref.domain.model.VersionStatus;
import com.trueref.domain.error.ModelNotReady;
import com.trueref.domain.port.in.IndexVersion;
import com.trueref.domain.port.in.QueryCatalog;
import com.trueref.domain.port.in.ResolveLibraryId;
@@ -153,6 +154,9 @@ public class TrueRefMcpTools {
SearchLibraryDocs.Result res;
try {
res = search.search(q);
} catch (ModelNotReady e) {
return "[model_not_ready] The inference model is loading. "
+ "Please retry in ~" + e.retryAfterSeconds() + " seconds.";
} catch (Exception e) {
log.warn("MCP search failed for {}: {}", libraryId, e.toString());
return "Search failed for " + libraryId + ": " + e.getMessage();

View File

@@ -2,6 +2,7 @@ package com.trueref.adapter.in.rest;
import com.trueref.domain.error.IngestionFailed;
import com.trueref.domain.error.InvalidSearchRequest;
import com.trueref.domain.error.ModelNotReady;
import com.trueref.domain.error.RepositoryAlreadyRegistered;
import com.trueref.domain.error.RepositoryNotFound;
import com.trueref.domain.error.TagNotFound;
@@ -12,6 +13,7 @@ import jakarta.validation.ConstraintViolationException;
import java.util.List;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.http.HttpHeaders;
import org.springframework.http.HttpStatus;
import org.springframework.http.ResponseEntity;
import org.springframework.validation.FieldError;
@@ -28,6 +30,15 @@ public class GlobalExceptionHandler {
private static final Logger log = LoggerFactory.getLogger(GlobalExceptionHandler.class);
@ExceptionHandler(ModelNotReady.class)
public ResponseEntity<ErrorResponse> handleModelNotReady(ModelNotReady ex) {
HttpHeaders headers = new HttpHeaders();
headers.set(HttpHeaders.RETRY_AFTER, String.valueOf(ex.retryAfterSeconds()));
return ResponseEntity.status(HttpStatus.SERVICE_UNAVAILABLE)
.headers(headers)
.body(ErrorResponse.of(ex.code(), ex.getMessage()));
}
@ExceptionHandler({RepositoryNotFound.class, VersionNotFound.class, TagNotFound.class})
public ResponseEntity<ErrorResponse> handleNotFound(TrueRefException ex) {
return status(HttpStatus.NOT_FOUND, ex);

View File

@@ -0,0 +1,106 @@
package com.trueref.adapter.in.rest;
import com.trueref.adapter.in.rest.dto.ModelStatusDto;
import com.trueref.domain.model.ModelState;
import com.trueref.domain.port.in.ManageModelLifecycle;
import io.swagger.v3.oas.annotations.Operation;
import io.swagger.v3.oas.annotations.tags.Tag;
import java.io.IOException;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.http.HttpStatus;
import org.springframework.http.MediaType;
import org.springframework.web.bind.annotation.GetMapping;
import org.springframework.web.bind.annotation.PostMapping;
import org.springframework.web.bind.annotation.RequestMapping;
import org.springframework.web.bind.annotation.RequestParam;
import org.springframework.web.bind.annotation.RestController;
import org.springframework.web.server.ResponseStatusException;
import org.springframework.web.servlet.mvc.method.annotation.SseEmitter;
/** REST resource: {@code /api/model}. */
@RestController
@RequestMapping("/api/model")
@Tag(name = "model", description = "GPU model lifecycle management (load state, force unload, SSE status stream).")
public class ModelController {
private static final Logger log = LoggerFactory.getLogger(ModelController.class);
private final ManageModelLifecycle lifecycle;
public ModelController(ManageModelLifecycle lifecycle) {
this.lifecycle = lifecycle;
}
@Operation(summary = "Current model lifecycle status.")
@GetMapping("/status")
public ModelStatusDto status() {
return ModelStatusDto.of(lifecycle.getStatus());
}
@Operation(summary = "Force-unload the GPU models from VRAM. "
+ "Returns 409 if ingestion jobs are running (use force=true to override).")
@PostMapping("/unload")
public ModelStatusDto unload(@RequestParam(value = "force", defaultValue = "false") boolean force) {
boolean initiated = lifecycle.forceUnload(force);
if (!initiated) {
throw new ResponseStatusException(
HttpStatus.CONFLICT,
"model_unload_blocked_by_running_jobs");
}
return ModelStatusDto.of(lifecycle.getStatus());
}
@Operation(summary = "Server-Sent Events stream of model state transitions "
+ "(LOADING, LOADED, UNLOADING, UNLOADED).")
@GetMapping(value = "/status/stream", produces = MediaType.TEXT_EVENT_STREAM_VALUE)
public SseEmitter statusStream() {
SseEmitter emitter = new SseEmitter(0L);
// Flush headers to the client immediately so EventSource fires 'open'.
try {
emitter.send(SseEmitter.event().name("ping").data(""));
// Send current state immediately so the client doesn't have to wait for the next transition.
emitter.send(SseEmitter.event()
.name("model-status")
.data(lifecycle.getStatus().state().name()));
} catch (IOException e) {
emitter.completeWithError(e);
return emitter;
}
AutoCloseable subscription = lifecycle.subscribeState(event -> {
try {
emitter.send(SseEmitter.event().name("model-status").data(event));
} catch (IOException ex) {
emitter.completeWithError(ex);
}
});
Thread keepalive = Thread.startVirtualThread(() -> {
try {
while (!Thread.currentThread().isInterrupted()) {
Thread.sleep(20_000);
emitter.send(SseEmitter.event().name("ping").data(""));
}
} catch (InterruptedException ignored) {
// normal shutdown
} catch (Exception ignored) {
// emitter already completed; exit
}
});
Runnable cleanup = () -> {
keepalive.interrupt();
try {
subscription.close();
} catch (Exception ex) {
log.debug("failed to close model status SSE subscription: {}", ex.toString());
}
};
emitter.onCompletion(cleanup);
emitter.onTimeout(cleanup);
emitter.onError(e -> cleanup.run());
return emitter;
}
}

View File

@@ -0,0 +1,22 @@
package com.trueref.adapter.in.rest.dto;
import com.trueref.domain.model.ModelState;
import com.trueref.domain.port.in.ManageModelLifecycle;
import java.time.Instant;
import org.jspecify.annotations.Nullable;
/** Snapshot of the GPU model lifecycle state, returned by {@code GET /api/model/status}. */
public record ModelStatusDto(
ModelState state,
@Nullable Instant loadedAt,
@Nullable Instant lastActivityAt,
long idleTimeoutSeconds) {
public static ModelStatusDto of(ManageModelLifecycle.Status status) {
return new ModelStatusDto(
status.state(),
status.loadedAt(),
status.lastActivityAt(),
status.idleTimeoutSeconds());
}
}

View File

@@ -1,5 +1,11 @@
package com.trueref.adapter.out.embedding.onnx;
import com.trueref.application.model.InMemoryModelStateEventBus;
import com.trueref.application.model.ModelLifecycleService;
import com.trueref.domain.port.in.ManageModelLifecycle;
import com.trueref.domain.port.in.ObserveJobs;
import com.trueref.domain.port.out.ModelLoader;
import com.trueref.domain.port.out.ModelStateEventBus;
import org.springframework.beans.factory.annotation.Value;
import org.springframework.boot.context.properties.EnableConfigurationProperties;
import org.springframework.context.annotation.Bean;
@@ -9,7 +15,8 @@ import java.nio.file.Path;
/**
* Spring wiring for the ONNX embedding/reranker stack. Produces the shared {@link GpuSemaphore}
* (permits = {@code session-count}) and the resolved models-home {@link Path}.
* (permits = {@code session-count}), the resolved models-home {@link Path}, and the
* {@link ModelLifecycleService} that manages lazy loading and idle-timeout unloading.
*/
@Configuration
@EnableConfigurationProperties(OnnxProperties.class)
@@ -29,4 +36,27 @@ public class OnnxEmbeddingConfig {
Path home = properties.home();
return home != null ? home : Path.of(trueRefHome).resolve("models");
}
@Bean
ModelStateEventBus modelStateEventBus() {
return new InMemoryModelStateEventBus();
}
@Bean
ModelLoader onnxModelLoader(OnnxEmbeddingService embeddingService, OnnxRerankerService rerankerService) {
return new OnnxModelLoader(embeddingService, rerankerService);
}
@Bean(initMethod = "init", destroyMethod = "shutdown")
ManageModelLifecycle modelLifecycleService(
ModelLoader modelLoader,
ModelStateEventBus modelStateEventBus,
ObserveJobs observeJobs,
OnnxProperties properties) {
return new ModelLifecycleService(
modelLoader,
modelStateEventBus,
observeJobs,
properties.idleTimeoutSecondsOrDefault());
}
}

View File

@@ -7,9 +7,8 @@ import ai.onnxruntime.OrtException;
import ai.onnxruntime.OrtSession;
import com.trueref.adapter.out.embedding.onnx.HuggingFaceTokenizerWrapper.EncodedBatch;
import com.trueref.domain.error.IngestionFailed;
import com.trueref.domain.error.ModelNotReady;
import com.trueref.domain.port.out.EmbeddingService;
import jakarta.annotation.PostConstruct;
import jakarta.annotation.PreDestroy;
import java.nio.LongBuffer;
import java.nio.file.Path;
import java.util.ArrayList;
@@ -68,8 +67,13 @@ public class OnnxEmbeddingService implements EmbeddingService {
this.modelsHome = modelsHome;
}
@PostConstruct
void start() {
/** Returns {@code true} if the model is currently loaded and accepting requests. */
public boolean isStarted() {
return pool != null && tokenizer != null;
}
/** Loads the ONNX model into memory. Called by {@link OnnxModelLoader} on demand. */
public void start() {
if (properties.sessionCountOrDefault() <= 0) {
log.warn("OnnxEmbeddingService DISABLED (session-count=0); embed() calls will fail");
return;
@@ -104,8 +108,8 @@ public class OnnxEmbeddingService implements EmbeddingService {
usesTokenTypeIds);
}
@PreDestroy
void stop() {
/** Unloads the ONNX model from memory. Called by {@link OnnxModelLoader} on unload. */
public void stop() {
if (pool != null) {
pool.close();
pool = null;
@@ -257,7 +261,7 @@ public class OnnxEmbeddingService implements EmbeddingService {
private static <T> T requireStarted(@org.jspecify.annotations.Nullable T t, String what) {
if (t == null) {
throw new IllegalStateException("OnnxEmbeddingService not started: " + what + " is null");
throw new ModelNotReady(30);
}
return t;
}

View File

@@ -0,0 +1,54 @@
package com.trueref.adapter.out.embedding.onnx;
import com.trueref.domain.port.out.ModelLoader;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
/**
* Implements {@link ModelLoader} by delegating to the ONNX embedding and reranker services.
* Both services are loaded/unloaded as a unit. This bean drives their lifecycle instead of
* {@code @PostConstruct} / {@code @PreDestroy} so that models are loaded lazily on first demand.
*/
public class OnnxModelLoader implements ModelLoader {
private static final Logger log = LoggerFactory.getLogger(OnnxModelLoader.class);
private final OnnxEmbeddingService embeddingService;
private final OnnxRerankerService rerankerService;
public OnnxModelLoader(OnnxEmbeddingService embeddingService, OnnxRerankerService rerankerService) {
this.embeddingService = embeddingService;
this.rerankerService = rerankerService;
}
@Override
public void load() {
log.info("OnnxModelLoader: loading embedder…");
embeddingService.start();
log.info("OnnxModelLoader: loading reranker…");
rerankerService.start();
log.info("OnnxModelLoader: both models loaded");
}
@Override
public void unload() {
log.info("OnnxModelLoader: unloading reranker…");
try {
rerankerService.stop();
} catch (Exception e) {
log.warn("Reranker stop error (continuing): {}", e.toString());
}
log.info("OnnxModelLoader: unloading embedder…");
try {
embeddingService.stop();
} catch (Exception e) {
log.warn("Embedder stop error (continuing): {}", e.toString());
}
log.info("OnnxModelLoader: both models unloaded");
}
@Override
public boolean isLoaded() {
return embeddingService.isStarted() && rerankerService.isStarted();
}
}

View File

@@ -21,6 +21,7 @@ public record OnnxProperties(
@Nullable Integer maxSeqLen,
@Nullable Integer gpuDeviceId,
@Nullable Long gpuMemLimitBytes,
@Nullable Long idleTimeoutSeconds,
@Nullable Path home,
@Nullable Map<String, Map<String, List<String>>> modelSources) {
@@ -37,6 +38,7 @@ public record OnnxProperties(
if (maxSeqLen == null || maxSeqLen <= 0) maxSeqLen = 512;
if (gpuDeviceId == null || gpuDeviceId < 0) gpuDeviceId = 0;
if (gpuMemLimitBytes == null || gpuMemLimitBytes <= 0L) gpuMemLimitBytes = 0L; // 0 = no cap
if (idleTimeoutSeconds == null || idleTimeoutSeconds <= 0L) idleTimeoutSeconds = 300L; // 5 min
if (modelSources == null) modelSources = Map.of();
}
@@ -75,4 +77,8 @@ public record OnnxProperties(
public long gpuMemLimitBytesOrDefault() {
return gpuMemLimitBytes == null ? 0L : gpuMemLimitBytes;
}
public long idleTimeoutSecondsOrDefault() {
return idleTimeoutSeconds == null ? 300L : idleTimeoutSeconds;
}
}

View File

@@ -7,10 +7,9 @@ import ai.onnxruntime.OrtException;
import ai.onnxruntime.OrtSession;
import com.trueref.adapter.out.embedding.onnx.HuggingFaceTokenizerWrapper.EncodedBatch;
import com.trueref.domain.error.IngestionFailed;
import com.trueref.domain.error.ModelNotReady;
import com.trueref.domain.model.SearchHit;
import com.trueref.domain.port.out.RerankerService;
import jakarta.annotation.PostConstruct;
import jakarta.annotation.PreDestroy;
import java.nio.LongBuffer;
import java.nio.file.Path;
import java.util.ArrayList;
@@ -62,8 +61,13 @@ public class OnnxRerankerService implements RerankerService {
this.modelsHome = modelsHome;
}
@PostConstruct
void start() {
/** Returns {@code true} if the reranker model is currently loaded and accepting requests. */
public boolean isStarted() {
return pool != null && tokenizer != null;
}
/** Loads the reranker model into memory. Called by {@link OnnxModelLoader} on demand. */
public void start() {
if (properties.rerankerSessionCountOrDefault() <= 0) {
log.warn("OnnxRerankerService DISABLED (reranker-session-count=0); rerank() passes through input order");
return;
@@ -96,8 +100,8 @@ public class OnnxRerankerService implements RerankerService {
usesTokenTypeIds);
}
@PreDestroy
void stop() {
/** Unloads the reranker model from memory. Called by {@link OnnxModelLoader} on unload. */
public void stop() {
if (pool != null) {
pool.close();
pool = null;
@@ -252,7 +256,7 @@ public class OnnxRerankerService implements RerankerService {
private static <T> T requireStarted(@org.jspecify.annotations.Nullable T t, String what) {
if (t == null) {
throw new IllegalStateException("OnnxRerankerService not started: " + what + " is null");
throw new ModelNotReady(30);
}
return t;
}

View File

@@ -0,0 +1,44 @@
package com.trueref.adapter.in.rest;
import static org.assertj.core.api.Assertions.assertThat;
import com.trueref.domain.error.ModelNotReady;
import org.junit.jupiter.api.Test;
import org.springframework.http.HttpHeaders;
import org.springframework.http.HttpStatus;
import org.springframework.http.ResponseEntity;
class GlobalExceptionHandlerTest {
private final GlobalExceptionHandler handler = new GlobalExceptionHandler();
@Test
void modelNotReadyMapsto503WithRetryAfterHeader() {
ModelNotReady ex = new ModelNotReady(30);
ResponseEntity<ErrorResponse> response = handler.handleModelNotReady(ex);
assertThat(response.getStatusCode()).isEqualTo(HttpStatus.SERVICE_UNAVAILABLE);
assertThat(response.getHeaders().getFirst(HttpHeaders.RETRY_AFTER)).isEqualTo("30");
}
@Test
void modelNotReadyResponseBodyHasModelNotReadyCode() {
ModelNotReady ex = new ModelNotReady(15);
ResponseEntity<ErrorResponse> response = handler.handleModelNotReady(ex);
assertThat(response.getBody()).isNotNull();
assertThat(response.getBody().code()).isEqualTo("model_not_ready");
assertThat(response.getBody().message()).contains("15");
}
@Test
void retryAfterReflectsExceptionValue() {
ModelNotReady ex = new ModelNotReady(60);
ResponseEntity<ErrorResponse> response = handler.handleModelNotReady(ex);
assertThat(response.getHeaders().getFirst(HttpHeaders.RETRY_AFTER)).isEqualTo("60");
}
}

View File

@@ -0,0 +1,71 @@
package com.trueref.adapter.out.embedding.onnx;
import static org.assertj.core.api.Assertions.assertThat;
import static org.mockito.Mockito.doThrow;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.never;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;
import org.junit.jupiter.api.Test;
class OnnxModelLoaderTest {
@Test
void loadCallsStartOnBothServices() {
OnnxEmbeddingService embedder = mock(OnnxEmbeddingService.class);
OnnxRerankerService reranker = mock(OnnxRerankerService.class);
OnnxModelLoader loader = new OnnxModelLoader(embedder, reranker);
loader.load();
verify(embedder).start();
verify(reranker).start();
}
@Test
void unloadCallsStopOnBothServices() {
OnnxEmbeddingService embedder = mock(OnnxEmbeddingService.class);
OnnxRerankerService reranker = mock(OnnxRerankerService.class);
OnnxModelLoader loader = new OnnxModelLoader(embedder, reranker);
loader.unload();
verify(embedder).stop();
verify(reranker).stop();
}
@Test
void unloadContinuesEvenIfRerankerStopThrows() {
OnnxEmbeddingService embedder = mock(OnnxEmbeddingService.class);
OnnxRerankerService reranker = mock(OnnxRerankerService.class);
doThrow(new RuntimeException("VRAM error")).when(reranker).stop();
OnnxModelLoader loader = new OnnxModelLoader(embedder, reranker);
loader.unload(); // must not throw
verify(embedder).stop();
}
@Test
void isLoadedReturnsTrueWhenBothServicesStarted() {
OnnxEmbeddingService embedder = mock(OnnxEmbeddingService.class);
OnnxRerankerService reranker = mock(OnnxRerankerService.class);
when(embedder.isStarted()).thenReturn(true);
when(reranker.isStarted()).thenReturn(true);
OnnxModelLoader loader = new OnnxModelLoader(embedder, reranker);
assertThat(loader.isLoaded()).isTrue();
}
@Test
void isLoadedReturnsFalseWhenEitherServiceNotStarted() {
OnnxEmbeddingService embedder = mock(OnnxEmbeddingService.class);
OnnxRerankerService reranker = mock(OnnxRerankerService.class);
when(embedder.isStarted()).thenReturn(true);
when(reranker.isStarted()).thenReturn(false);
OnnxModelLoader loader = new OnnxModelLoader(embedder, reranker);
assertThat(loader.isLoaded()).isFalse();
}
}

View File

@@ -0,0 +1,75 @@
package com.trueref.application.model;
import static org.assertj.core.api.Assertions.assertThat;
import com.trueref.domain.model.ModelState;
import com.trueref.domain.model.ModelStateEvent;
import java.util.ArrayList;
import java.util.List;
import java.util.concurrent.atomic.AtomicInteger;
import org.junit.jupiter.api.Test;
class InMemoryModelStateEventBusTest {
@Test
void singleSubscriberReceivesPublishedEvent() {
InMemoryModelStateEventBus bus = new InMemoryModelStateEventBus();
List<ModelStateEvent> received = new ArrayList<>();
bus.subscribe(received::add);
bus.publish(ModelStateEvent.of(ModelState.LOADING));
assertThat(received).hasSize(1);
assertThat(received.get(0).state()).isEqualTo(ModelState.LOADING);
}
@Test
void multipleSubscribersAllReceiveEvent() {
InMemoryModelStateEventBus bus = new InMemoryModelStateEventBus();
List<ModelStateEvent> a = new ArrayList<>();
List<ModelStateEvent> b = new ArrayList<>();
bus.subscribe(a::add);
bus.subscribe(b::add);
bus.publish(ModelStateEvent.of(ModelState.LOADED));
assertThat(a).hasSize(1);
assertThat(b).hasSize(1);
}
@Test
void closingSubscriptionRemovesIt() throws Exception {
InMemoryModelStateEventBus bus = new InMemoryModelStateEventBus();
List<ModelStateEvent> received = new ArrayList<>();
AutoCloseable handle = bus.subscribe(received::add);
handle.close();
bus.publish(ModelStateEvent.of(ModelState.UNLOADED));
assertThat(received).isEmpty();
}
@Test
void exceptionInSubscriberDoesNotBreakOthers() {
InMemoryModelStateEventBus bus = new InMemoryModelStateEventBus();
AtomicInteger badCalls = new AtomicInteger();
List<ModelStateEvent> good = new ArrayList<>();
bus.subscribe(e -> {
badCalls.incrementAndGet();
throw new RuntimeException("intentional test failure");
});
bus.subscribe(good::add);
bus.publish(ModelStateEvent.of(ModelState.LOADING));
assertThat(badCalls.get()).isEqualTo(1);
assertThat(good).hasSize(1);
}
@Test
void noSubscribersMeansPublishIsNoop() {
InMemoryModelStateEventBus bus = new InMemoryModelStateEventBus();
bus.publish(ModelStateEvent.of(ModelState.UNLOADED)); // must not throw
}
}

View File

@@ -0,0 +1,227 @@
package com.trueref.application.model;
import static org.assertj.core.api.Assertions.assertThat;
import static org.assertj.core.api.Assertions.assertThatThrownBy;
import com.trueref.domain.error.ModelNotReady;
import com.trueref.domain.model.IngestionJob;
import com.trueref.domain.model.JobId;
import com.trueref.domain.model.JobStatus;
import com.trueref.domain.model.JobType;
import com.trueref.domain.model.ModelState;
import com.trueref.domain.model.ModelStateEvent;
import com.trueref.domain.model.RepositoryId;
import com.trueref.domain.port.in.ManageModelLifecycle;
import com.trueref.domain.port.in.ObserveJobs;
import com.trueref.domain.port.out.ModelLoader;
import com.trueref.domain.port.out.ModelStateEventBus;
import java.time.Instant;
import java.util.ArrayList;
import java.util.List;
import java.util.Optional;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.function.Consumer;
import org.junit.jupiter.api.AfterEach;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
class ModelLifecycleServiceTest {
// ---- Test doubles ----
private final AtomicBoolean loaderLoaded = new AtomicBoolean(false);
private volatile boolean slowLoad = false;
private final ModelLoader loader = new ModelLoader() {
@Override
public void load() {
if (slowLoad) {
try { Thread.sleep(200); } catch (InterruptedException e) { Thread.currentThread().interrupt(); }
}
loaderLoaded.set(true);
}
@Override
public void unload() { loaderLoaded.set(false); }
@Override
public boolean isLoaded() { return loaderLoaded.get(); }
};
private final InMemoryModelStateEventBus eventBus = new InMemoryModelStateEventBus();
private volatile List<IngestionJob> runningJobs = List.of();
private final ObserveJobs observeJobs = new ObserveJobs() {
@Override
public Optional<IngestionJob> findJob(JobId id) { return Optional.empty(); }
@Override
public List<IngestionJob> listJobs(com.trueref.domain.model.RepositoryId r,
com.trueref.domain.model.VersionId v, JobStatus status, int limit) {
if (status == JobStatus.RUNNING) return runningJobs;
return List.of();
}
@Override
public AutoCloseable subscribeJobs(Consumer<IngestionJob> listener) { return () -> {}; }
@Override
public AutoCloseable subscribeLogs(JobId jobId, Consumer<com.trueref.domain.model.JobLogEvent> listener) { return () -> {}; }
};
private ModelLifecycleService service;
@BeforeEach
void setUp() {
loaderLoaded.set(false);
slowLoad = false;
runningJobs = List.of();
service = new ModelLifecycleService(loader, eventBus, observeJobs, 300L);
service.init();
}
@AfterEach
void tearDown() {
service.shutdown();
}
// ---- Tests ----
@Test
void startsInUnloadedState() {
assertThat(service.getStatus().state()).isEqualTo(ModelState.UNLOADED);
}
@Test
void ensureReadyOnUnloadedTriggersLoadAndThrows() throws InterruptedException {
List<ModelStateEvent> events = new ArrayList<>();
eventBus.subscribe(events::add);
assertThatThrownBy(() -> service.ensureReady())
.isInstanceOf(ModelNotReady.class);
// Wait for async load to complete (max 2 s in test)
for (int i = 0; i < 20 && service.getStatus().state() != ModelState.LOADED; i++) {
Thread.sleep(100);
}
assertThat(service.getStatus().state()).isEqualTo(ModelState.LOADED);
assertThat(loaderLoaded.get()).isTrue();
assertThat(events).extracting(ModelStateEvent::state)
.containsSequence(ModelState.LOADING, ModelState.LOADED);
}
@Test
void ensureReadyOnLoadedIsNoop() throws InterruptedException {
triggerLoad();
// Should not throw
service.ensureReady();
service.ensureReady();
assertThat(service.getStatus().state()).isEqualTo(ModelState.LOADED);
}
@Test
void ensureReadyOnLoadingThrowsImmediately() {
slowLoad = true;
assertThatThrownBy(() -> service.ensureReady())
.isInstanceOf(ModelNotReady.class);
// state is LOADING, second call also throws
assertThatThrownBy(() -> service.ensureReady())
.isInstanceOf(ModelNotReady.class);
}
@Test
void forceUnloadFromLoadedState() throws InterruptedException {
triggerLoad();
boolean initiated = service.forceUnload(false);
assertThat(initiated).isTrue();
// Wait for async unload
for (int i = 0; i < 20 && service.getStatus().state() != ModelState.UNLOADED; i++) {
Thread.sleep(100);
}
assertThat(service.getStatus().state()).isEqualTo(ModelState.UNLOADED);
assertThat(loaderLoaded.get()).isFalse();
}
@Test
void forceUnloadBlockedByRunningJobsWithoutForce() throws InterruptedException {
triggerLoad();
runningJobs = List.of(runningJob());
boolean initiated = service.forceUnload(false);
assertThat(initiated).isFalse();
assertThat(service.getStatus().state()).isEqualTo(ModelState.LOADED);
}
@Test
void forceUnloadSucceedsWithForceEvenWhenJobsRunning() throws InterruptedException {
triggerLoad();
runningJobs = List.of(runningJob());
boolean initiated = service.forceUnload(true);
assertThat(initiated).isTrue();
for (int i = 0; i < 20 && service.getStatus().state() != ModelState.UNLOADED; i++) {
Thread.sleep(100);
}
assertThat(service.getStatus().state()).isEqualTo(ModelState.UNLOADED);
}
@Test
void forceUnloadWhenAlreadyUnloadedIsNoop() {
boolean initiated = service.forceUnload(false);
assertThat(initiated).isTrue();
assertThat(service.getStatus().state()).isEqualTo(ModelState.UNLOADED);
}
@Test
void concurrentEnsureReadyDeduplicatesLoad() throws InterruptedException {
int callCount = 0;
// Both calls throw ModelNotReady but only one load should be triggered
List<ModelStateEvent> events = new ArrayList<>();
eventBus.subscribe(events::add);
try { service.ensureReady(); } catch (ModelNotReady ignored) { callCount++; }
try { service.ensureReady(); } catch (ModelNotReady ignored) { callCount++; }
for (int i = 0; i < 20 && service.getStatus().state() != ModelState.LOADED; i++) {
Thread.sleep(100);
}
assertThat(callCount).isEqualTo(2);
// Only one LOADING event should be emitted (deduplicated)
long loadingCount = events.stream().filter(e -> e.state() == ModelState.LOADING).count();
assertThat(loadingCount).isEqualTo(1);
}
@Test
void statusIncludesLoadedAtAfterLoad() throws InterruptedException {
assertThat(service.getStatus().loadedAt()).isNull();
triggerLoad();
assertThat(service.getStatus().loadedAt()).isNotNull();
}
// ---- helpers ----
private void triggerLoad() throws InterruptedException {
try { service.ensureReady(); } catch (ModelNotReady ignored) {}
for (int i = 0; i < 20 && service.getStatus().state() != ModelState.LOADED; i++) {
Thread.sleep(100);
}
}
private IngestionJob runningJob() {
return new IngestionJob(
JobId.of(java.util.UUID.randomUUID().toString()),
RepositoryId.of(java.util.UUID.randomUUID().toString()),
null,
JobType.INDEX_VERSION,
JobStatus.RUNNING,
Instant.now(),
null,
List.of());
}
}