feat: GPU model lazy-load/unload lifecycle management
All checks were successful
Build and publish Docker image / Build and push CPU image (push) Successful in 2m33s
Build and publish Docker image / Build and push GPU image (push) Successful in 3m15s

- Domain: add ModelState, ModelStateEvent, ModelNotReady, ManageModelLifecycle
  (in-port), ModelLoader and ModelStateEventBus (out-ports)
- Application: InMemoryModelStateEventBus; ModelLifecycleService — state
  machine (ReentrantLock), lazy load on first request, idle-timeout auto-unload
  (configurable via trueref.embedding.idle-timeout-seconds, default 300 s),
  job-guard (skips unload while ingestion running), platform-thread CUDA executor
- Adapters: OnnxModelLoader wires embedder + reranker start/stop; remove
  @PostConstruct/@PreDestroy from OnnxEmbeddingService and OnnxRerankerService;
  requireStarted() now throws ModelNotReady instead of IllegalStateException
- REST: GET /api/model/status, POST /api/model/unload (409 when jobs running,
  force=true to override), GET /api/model/status/stream (SSE)
- GlobalExceptionHandler: ModelNotReady -> 503 + Retry-After header
- HybridSearchService: calls lifecycle.ensureReady() before every search so
  both REST and MCP paths get ModelNotReady (-> 503 / MCP error) when unloaded
- TrueRefMcpTools: catches ModelNotReady, returns retry hint in MCP error text
- Tests: InMemoryModelStateEventBusTest, ModelLifecycleServiceTest (10 cases),
  OnnxModelLoaderTest, GlobalExceptionHandlerTest — all 41 tests green

Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
This commit is contained in:
moze
2026-05-09 15:44:33 +02:00
parent 943a38fd36
commit 5c6085df99
24 changed files with 1144 additions and 17 deletions

View File

@@ -6,6 +6,7 @@ import com.trueref.domain.model.SearchHit;
import com.trueref.domain.model.SearchScope;
import com.trueref.domain.model.Version;
import com.trueref.domain.model.VersionStatus;
import com.trueref.domain.error.ModelNotReady;
import com.trueref.domain.port.in.IndexVersion;
import com.trueref.domain.port.in.QueryCatalog;
import com.trueref.domain.port.in.ResolveLibraryId;
@@ -153,6 +154,9 @@ public class TrueRefMcpTools {
SearchLibraryDocs.Result res;
try {
res = search.search(q);
} catch (ModelNotReady e) {
return "[model_not_ready] The inference model is loading. "
+ "Please retry in ~" + e.retryAfterSeconds() + " seconds.";
} catch (Exception e) {
log.warn("MCP search failed for {}: {}", libraryId, e.toString());
return "Search failed for " + libraryId + ": " + e.getMessage();

View File

@@ -2,6 +2,7 @@ package com.trueref.adapter.in.rest;
import com.trueref.domain.error.IngestionFailed;
import com.trueref.domain.error.InvalidSearchRequest;
import com.trueref.domain.error.ModelNotReady;
import com.trueref.domain.error.RepositoryAlreadyRegistered;
import com.trueref.domain.error.RepositoryNotFound;
import com.trueref.domain.error.TagNotFound;
@@ -12,6 +13,7 @@ import jakarta.validation.ConstraintViolationException;
import java.util.List;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.http.HttpHeaders;
import org.springframework.http.HttpStatus;
import org.springframework.http.ResponseEntity;
import org.springframework.validation.FieldError;
@@ -28,6 +30,15 @@ public class GlobalExceptionHandler {
private static final Logger log = LoggerFactory.getLogger(GlobalExceptionHandler.class);
@ExceptionHandler(ModelNotReady.class)
public ResponseEntity<ErrorResponse> handleModelNotReady(ModelNotReady ex) {
HttpHeaders headers = new HttpHeaders();
headers.set(HttpHeaders.RETRY_AFTER, String.valueOf(ex.retryAfterSeconds()));
return ResponseEntity.status(HttpStatus.SERVICE_UNAVAILABLE)
.headers(headers)
.body(ErrorResponse.of(ex.code(), ex.getMessage()));
}
@ExceptionHandler({RepositoryNotFound.class, VersionNotFound.class, TagNotFound.class})
public ResponseEntity<ErrorResponse> handleNotFound(TrueRefException ex) {
return status(HttpStatus.NOT_FOUND, ex);

View File

@@ -0,0 +1,106 @@
package com.trueref.adapter.in.rest;
import com.trueref.adapter.in.rest.dto.ModelStatusDto;
import com.trueref.domain.model.ModelState;
import com.trueref.domain.port.in.ManageModelLifecycle;
import io.swagger.v3.oas.annotations.Operation;
import io.swagger.v3.oas.annotations.tags.Tag;
import java.io.IOException;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.http.HttpStatus;
import org.springframework.http.MediaType;
import org.springframework.web.bind.annotation.GetMapping;
import org.springframework.web.bind.annotation.PostMapping;
import org.springframework.web.bind.annotation.RequestMapping;
import org.springframework.web.bind.annotation.RequestParam;
import org.springframework.web.bind.annotation.RestController;
import org.springframework.web.server.ResponseStatusException;
import org.springframework.web.servlet.mvc.method.annotation.SseEmitter;
/** REST resource: {@code /api/model}. */
@RestController
@RequestMapping("/api/model")
@Tag(name = "model", description = "GPU model lifecycle management (load state, force unload, SSE status stream).")
public class ModelController {
private static final Logger log = LoggerFactory.getLogger(ModelController.class);
private final ManageModelLifecycle lifecycle;
public ModelController(ManageModelLifecycle lifecycle) {
this.lifecycle = lifecycle;
}
@Operation(summary = "Current model lifecycle status.")
@GetMapping("/status")
public ModelStatusDto status() {
return ModelStatusDto.of(lifecycle.getStatus());
}
@Operation(summary = "Force-unload the GPU models from VRAM. "
+ "Returns 409 if ingestion jobs are running (use force=true to override).")
@PostMapping("/unload")
public ModelStatusDto unload(@RequestParam(value = "force", defaultValue = "false") boolean force) {
boolean initiated = lifecycle.forceUnload(force);
if (!initiated) {
throw new ResponseStatusException(
HttpStatus.CONFLICT,
"model_unload_blocked_by_running_jobs");
}
return ModelStatusDto.of(lifecycle.getStatus());
}
@Operation(summary = "Server-Sent Events stream of model state transitions "
+ "(LOADING, LOADED, UNLOADING, UNLOADED).")
@GetMapping(value = "/status/stream", produces = MediaType.TEXT_EVENT_STREAM_VALUE)
public SseEmitter statusStream() {
SseEmitter emitter = new SseEmitter(0L);
// Flush headers to the client immediately so EventSource fires 'open'.
try {
emitter.send(SseEmitter.event().name("ping").data(""));
// Send current state immediately so the client doesn't have to wait for the next transition.
emitter.send(SseEmitter.event()
.name("model-status")
.data(lifecycle.getStatus().state().name()));
} catch (IOException e) {
emitter.completeWithError(e);
return emitter;
}
AutoCloseable subscription = lifecycle.subscribeState(event -> {
try {
emitter.send(SseEmitter.event().name("model-status").data(event));
} catch (IOException ex) {
emitter.completeWithError(ex);
}
});
Thread keepalive = Thread.startVirtualThread(() -> {
try {
while (!Thread.currentThread().isInterrupted()) {
Thread.sleep(20_000);
emitter.send(SseEmitter.event().name("ping").data(""));
}
} catch (InterruptedException ignored) {
// normal shutdown
} catch (Exception ignored) {
// emitter already completed; exit
}
});
Runnable cleanup = () -> {
keepalive.interrupt();
try {
subscription.close();
} catch (Exception ex) {
log.debug("failed to close model status SSE subscription: {}", ex.toString());
}
};
emitter.onCompletion(cleanup);
emitter.onTimeout(cleanup);
emitter.onError(e -> cleanup.run());
return emitter;
}
}

View File

@@ -0,0 +1,22 @@
package com.trueref.adapter.in.rest.dto;
import com.trueref.domain.model.ModelState;
import com.trueref.domain.port.in.ManageModelLifecycle;
import java.time.Instant;
import org.jspecify.annotations.Nullable;
/** Snapshot of the GPU model lifecycle state, returned by {@code GET /api/model/status}. */
public record ModelStatusDto(
ModelState state,
@Nullable Instant loadedAt,
@Nullable Instant lastActivityAt,
long idleTimeoutSeconds) {
public static ModelStatusDto of(ManageModelLifecycle.Status status) {
return new ModelStatusDto(
status.state(),
status.loadedAt(),
status.lastActivityAt(),
status.idleTimeoutSeconds());
}
}

View File

@@ -1,5 +1,11 @@
package com.trueref.adapter.out.embedding.onnx;
import com.trueref.application.model.InMemoryModelStateEventBus;
import com.trueref.application.model.ModelLifecycleService;
import com.trueref.domain.port.in.ManageModelLifecycle;
import com.trueref.domain.port.in.ObserveJobs;
import com.trueref.domain.port.out.ModelLoader;
import com.trueref.domain.port.out.ModelStateEventBus;
import org.springframework.beans.factory.annotation.Value;
import org.springframework.boot.context.properties.EnableConfigurationProperties;
import org.springframework.context.annotation.Bean;
@@ -9,7 +15,8 @@ import java.nio.file.Path;
/**
* Spring wiring for the ONNX embedding/reranker stack. Produces the shared {@link GpuSemaphore}
* (permits = {@code session-count}) and the resolved models-home {@link Path}.
* (permits = {@code session-count}), the resolved models-home {@link Path}, and the
* {@link ModelLifecycleService} that manages lazy loading and idle-timeout unloading.
*/
@Configuration
@EnableConfigurationProperties(OnnxProperties.class)
@@ -29,4 +36,27 @@ public class OnnxEmbeddingConfig {
Path home = properties.home();
return home != null ? home : Path.of(trueRefHome).resolve("models");
}
@Bean
ModelStateEventBus modelStateEventBus() {
return new InMemoryModelStateEventBus();
}
@Bean
ModelLoader onnxModelLoader(OnnxEmbeddingService embeddingService, OnnxRerankerService rerankerService) {
return new OnnxModelLoader(embeddingService, rerankerService);
}
@Bean(initMethod = "init", destroyMethod = "shutdown")
ManageModelLifecycle modelLifecycleService(
ModelLoader modelLoader,
ModelStateEventBus modelStateEventBus,
ObserveJobs observeJobs,
OnnxProperties properties) {
return new ModelLifecycleService(
modelLoader,
modelStateEventBus,
observeJobs,
properties.idleTimeoutSecondsOrDefault());
}
}

View File

@@ -7,9 +7,8 @@ import ai.onnxruntime.OrtException;
import ai.onnxruntime.OrtSession;
import com.trueref.adapter.out.embedding.onnx.HuggingFaceTokenizerWrapper.EncodedBatch;
import com.trueref.domain.error.IngestionFailed;
import com.trueref.domain.error.ModelNotReady;
import com.trueref.domain.port.out.EmbeddingService;
import jakarta.annotation.PostConstruct;
import jakarta.annotation.PreDestroy;
import java.nio.LongBuffer;
import java.nio.file.Path;
import java.util.ArrayList;
@@ -68,8 +67,13 @@ public class OnnxEmbeddingService implements EmbeddingService {
this.modelsHome = modelsHome;
}
@PostConstruct
void start() {
/** Returns {@code true} if the model is currently loaded and accepting requests. */
public boolean isStarted() {
return pool != null && tokenizer != null;
}
/** Loads the ONNX model into memory. Called by {@link OnnxModelLoader} on demand. */
public void start() {
if (properties.sessionCountOrDefault() <= 0) {
log.warn("OnnxEmbeddingService DISABLED (session-count=0); embed() calls will fail");
return;
@@ -104,8 +108,8 @@ public class OnnxEmbeddingService implements EmbeddingService {
usesTokenTypeIds);
}
@PreDestroy
void stop() {
/** Unloads the ONNX model from memory. Called by {@link OnnxModelLoader} on unload. */
public void stop() {
if (pool != null) {
pool.close();
pool = null;
@@ -257,7 +261,7 @@ public class OnnxEmbeddingService implements EmbeddingService {
private static <T> T requireStarted(@org.jspecify.annotations.Nullable T t, String what) {
if (t == null) {
throw new IllegalStateException("OnnxEmbeddingService not started: " + what + " is null");
throw new ModelNotReady(30);
}
return t;
}

View File

@@ -0,0 +1,54 @@
package com.trueref.adapter.out.embedding.onnx;
import com.trueref.domain.port.out.ModelLoader;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
/**
* Implements {@link ModelLoader} by delegating to the ONNX embedding and reranker services.
* Both services are loaded/unloaded as a unit. This bean drives their lifecycle instead of
* {@code @PostConstruct} / {@code @PreDestroy} so that models are loaded lazily on first demand.
*/
public class OnnxModelLoader implements ModelLoader {
private static final Logger log = LoggerFactory.getLogger(OnnxModelLoader.class);
private final OnnxEmbeddingService embeddingService;
private final OnnxRerankerService rerankerService;
public OnnxModelLoader(OnnxEmbeddingService embeddingService, OnnxRerankerService rerankerService) {
this.embeddingService = embeddingService;
this.rerankerService = rerankerService;
}
@Override
public void load() {
log.info("OnnxModelLoader: loading embedder…");
embeddingService.start();
log.info("OnnxModelLoader: loading reranker…");
rerankerService.start();
log.info("OnnxModelLoader: both models loaded");
}
@Override
public void unload() {
log.info("OnnxModelLoader: unloading reranker…");
try {
rerankerService.stop();
} catch (Exception e) {
log.warn("Reranker stop error (continuing): {}", e.toString());
}
log.info("OnnxModelLoader: unloading embedder…");
try {
embeddingService.stop();
} catch (Exception e) {
log.warn("Embedder stop error (continuing): {}", e.toString());
}
log.info("OnnxModelLoader: both models unloaded");
}
@Override
public boolean isLoaded() {
return embeddingService.isStarted() && rerankerService.isStarted();
}
}

View File

@@ -21,6 +21,7 @@ public record OnnxProperties(
@Nullable Integer maxSeqLen,
@Nullable Integer gpuDeviceId,
@Nullable Long gpuMemLimitBytes,
@Nullable Long idleTimeoutSeconds,
@Nullable Path home,
@Nullable Map<String, Map<String, List<String>>> modelSources) {
@@ -37,6 +38,7 @@ public record OnnxProperties(
if (maxSeqLen == null || maxSeqLen <= 0) maxSeqLen = 512;
if (gpuDeviceId == null || gpuDeviceId < 0) gpuDeviceId = 0;
if (gpuMemLimitBytes == null || gpuMemLimitBytes <= 0L) gpuMemLimitBytes = 0L; // 0 = no cap
if (idleTimeoutSeconds == null || idleTimeoutSeconds <= 0L) idleTimeoutSeconds = 300L; // 5 min
if (modelSources == null) modelSources = Map.of();
}
@@ -75,4 +77,8 @@ public record OnnxProperties(
public long gpuMemLimitBytesOrDefault() {
return gpuMemLimitBytes == null ? 0L : gpuMemLimitBytes;
}
public long idleTimeoutSecondsOrDefault() {
return idleTimeoutSeconds == null ? 300L : idleTimeoutSeconds;
}
}

View File

@@ -7,10 +7,9 @@ import ai.onnxruntime.OrtException;
import ai.onnxruntime.OrtSession;
import com.trueref.adapter.out.embedding.onnx.HuggingFaceTokenizerWrapper.EncodedBatch;
import com.trueref.domain.error.IngestionFailed;
import com.trueref.domain.error.ModelNotReady;
import com.trueref.domain.model.SearchHit;
import com.trueref.domain.port.out.RerankerService;
import jakarta.annotation.PostConstruct;
import jakarta.annotation.PreDestroy;
import java.nio.LongBuffer;
import java.nio.file.Path;
import java.util.ArrayList;
@@ -62,8 +61,13 @@ public class OnnxRerankerService implements RerankerService {
this.modelsHome = modelsHome;
}
@PostConstruct
void start() {
/** Returns {@code true} if the reranker model is currently loaded and accepting requests. */
public boolean isStarted() {
return pool != null && tokenizer != null;
}
/** Loads the reranker model into memory. Called by {@link OnnxModelLoader} on demand. */
public void start() {
if (properties.rerankerSessionCountOrDefault() <= 0) {
log.warn("OnnxRerankerService DISABLED (reranker-session-count=0); rerank() passes through input order");
return;
@@ -96,8 +100,8 @@ public class OnnxRerankerService implements RerankerService {
usesTokenTypeIds);
}
@PreDestroy
void stop() {
/** Unloads the reranker model from memory. Called by {@link OnnxModelLoader} on unload. */
public void stop() {
if (pool != null) {
pool.close();
pool = null;
@@ -252,7 +256,7 @@ public class OnnxRerankerService implements RerankerService {
private static <T> T requireStarted(@org.jspecify.annotations.Nullable T t, String what) {
if (t == null) {
throw new IllegalStateException("OnnxRerankerService not started: " + what + " is null");
throw new ModelNotReady(30);
}
return t;
}

View File

@@ -0,0 +1,44 @@
package com.trueref.adapter.in.rest;
import static org.assertj.core.api.Assertions.assertThat;
import com.trueref.domain.error.ModelNotReady;
import org.junit.jupiter.api.Test;
import org.springframework.http.HttpHeaders;
import org.springframework.http.HttpStatus;
import org.springframework.http.ResponseEntity;
class GlobalExceptionHandlerTest {
private final GlobalExceptionHandler handler = new GlobalExceptionHandler();
@Test
void modelNotReadyMapsto503WithRetryAfterHeader() {
ModelNotReady ex = new ModelNotReady(30);
ResponseEntity<ErrorResponse> response = handler.handleModelNotReady(ex);
assertThat(response.getStatusCode()).isEqualTo(HttpStatus.SERVICE_UNAVAILABLE);
assertThat(response.getHeaders().getFirst(HttpHeaders.RETRY_AFTER)).isEqualTo("30");
}
@Test
void modelNotReadyResponseBodyHasModelNotReadyCode() {
ModelNotReady ex = new ModelNotReady(15);
ResponseEntity<ErrorResponse> response = handler.handleModelNotReady(ex);
assertThat(response.getBody()).isNotNull();
assertThat(response.getBody().code()).isEqualTo("model_not_ready");
assertThat(response.getBody().message()).contains("15");
}
@Test
void retryAfterReflectsExceptionValue() {
ModelNotReady ex = new ModelNotReady(60);
ResponseEntity<ErrorResponse> response = handler.handleModelNotReady(ex);
assertThat(response.getHeaders().getFirst(HttpHeaders.RETRY_AFTER)).isEqualTo("60");
}
}

View File

@@ -0,0 +1,71 @@
package com.trueref.adapter.out.embedding.onnx;
import static org.assertj.core.api.Assertions.assertThat;
import static org.mockito.Mockito.doThrow;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.never;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;
import org.junit.jupiter.api.Test;
class OnnxModelLoaderTest {
@Test
void loadCallsStartOnBothServices() {
OnnxEmbeddingService embedder = mock(OnnxEmbeddingService.class);
OnnxRerankerService reranker = mock(OnnxRerankerService.class);
OnnxModelLoader loader = new OnnxModelLoader(embedder, reranker);
loader.load();
verify(embedder).start();
verify(reranker).start();
}
@Test
void unloadCallsStopOnBothServices() {
OnnxEmbeddingService embedder = mock(OnnxEmbeddingService.class);
OnnxRerankerService reranker = mock(OnnxRerankerService.class);
OnnxModelLoader loader = new OnnxModelLoader(embedder, reranker);
loader.unload();
verify(embedder).stop();
verify(reranker).stop();
}
@Test
void unloadContinuesEvenIfRerankerStopThrows() {
OnnxEmbeddingService embedder = mock(OnnxEmbeddingService.class);
OnnxRerankerService reranker = mock(OnnxRerankerService.class);
doThrow(new RuntimeException("VRAM error")).when(reranker).stop();
OnnxModelLoader loader = new OnnxModelLoader(embedder, reranker);
loader.unload(); // must not throw
verify(embedder).stop();
}
@Test
void isLoadedReturnsTrueWhenBothServicesStarted() {
OnnxEmbeddingService embedder = mock(OnnxEmbeddingService.class);
OnnxRerankerService reranker = mock(OnnxRerankerService.class);
when(embedder.isStarted()).thenReturn(true);
when(reranker.isStarted()).thenReturn(true);
OnnxModelLoader loader = new OnnxModelLoader(embedder, reranker);
assertThat(loader.isLoaded()).isTrue();
}
@Test
void isLoadedReturnsFalseWhenEitherServiceNotStarted() {
OnnxEmbeddingService embedder = mock(OnnxEmbeddingService.class);
OnnxRerankerService reranker = mock(OnnxRerankerService.class);
when(embedder.isStarted()).thenReturn(true);
when(reranker.isStarted()).thenReturn(false);
OnnxModelLoader loader = new OnnxModelLoader(embedder, reranker);
assertThat(loader.isLoaded()).isFalse();
}
}

View File

@@ -0,0 +1,75 @@
package com.trueref.application.model;
import static org.assertj.core.api.Assertions.assertThat;
import com.trueref.domain.model.ModelState;
import com.trueref.domain.model.ModelStateEvent;
import java.util.ArrayList;
import java.util.List;
import java.util.concurrent.atomic.AtomicInteger;
import org.junit.jupiter.api.Test;
class InMemoryModelStateEventBusTest {
@Test
void singleSubscriberReceivesPublishedEvent() {
InMemoryModelStateEventBus bus = new InMemoryModelStateEventBus();
List<ModelStateEvent> received = new ArrayList<>();
bus.subscribe(received::add);
bus.publish(ModelStateEvent.of(ModelState.LOADING));
assertThat(received).hasSize(1);
assertThat(received.get(0).state()).isEqualTo(ModelState.LOADING);
}
@Test
void multipleSubscribersAllReceiveEvent() {
InMemoryModelStateEventBus bus = new InMemoryModelStateEventBus();
List<ModelStateEvent> a = new ArrayList<>();
List<ModelStateEvent> b = new ArrayList<>();
bus.subscribe(a::add);
bus.subscribe(b::add);
bus.publish(ModelStateEvent.of(ModelState.LOADED));
assertThat(a).hasSize(1);
assertThat(b).hasSize(1);
}
@Test
void closingSubscriptionRemovesIt() throws Exception {
InMemoryModelStateEventBus bus = new InMemoryModelStateEventBus();
List<ModelStateEvent> received = new ArrayList<>();
AutoCloseable handle = bus.subscribe(received::add);
handle.close();
bus.publish(ModelStateEvent.of(ModelState.UNLOADED));
assertThat(received).isEmpty();
}
@Test
void exceptionInSubscriberDoesNotBreakOthers() {
InMemoryModelStateEventBus bus = new InMemoryModelStateEventBus();
AtomicInteger badCalls = new AtomicInteger();
List<ModelStateEvent> good = new ArrayList<>();
bus.subscribe(e -> {
badCalls.incrementAndGet();
throw new RuntimeException("intentional test failure");
});
bus.subscribe(good::add);
bus.publish(ModelStateEvent.of(ModelState.LOADING));
assertThat(badCalls.get()).isEqualTo(1);
assertThat(good).hasSize(1);
}
@Test
void noSubscribersMeansPublishIsNoop() {
InMemoryModelStateEventBus bus = new InMemoryModelStateEventBus();
bus.publish(ModelStateEvent.of(ModelState.UNLOADED)); // must not throw
}
}

View File

@@ -0,0 +1,227 @@
package com.trueref.application.model;
import static org.assertj.core.api.Assertions.assertThat;
import static org.assertj.core.api.Assertions.assertThatThrownBy;
import com.trueref.domain.error.ModelNotReady;
import com.trueref.domain.model.IngestionJob;
import com.trueref.domain.model.JobId;
import com.trueref.domain.model.JobStatus;
import com.trueref.domain.model.JobType;
import com.trueref.domain.model.ModelState;
import com.trueref.domain.model.ModelStateEvent;
import com.trueref.domain.model.RepositoryId;
import com.trueref.domain.port.in.ManageModelLifecycle;
import com.trueref.domain.port.in.ObserveJobs;
import com.trueref.domain.port.out.ModelLoader;
import com.trueref.domain.port.out.ModelStateEventBus;
import java.time.Instant;
import java.util.ArrayList;
import java.util.List;
import java.util.Optional;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.function.Consumer;
import org.junit.jupiter.api.AfterEach;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
class ModelLifecycleServiceTest {
// ---- Test doubles ----
private final AtomicBoolean loaderLoaded = new AtomicBoolean(false);
private volatile boolean slowLoad = false;
private final ModelLoader loader = new ModelLoader() {
@Override
public void load() {
if (slowLoad) {
try { Thread.sleep(200); } catch (InterruptedException e) { Thread.currentThread().interrupt(); }
}
loaderLoaded.set(true);
}
@Override
public void unload() { loaderLoaded.set(false); }
@Override
public boolean isLoaded() { return loaderLoaded.get(); }
};
private final InMemoryModelStateEventBus eventBus = new InMemoryModelStateEventBus();
private volatile List<IngestionJob> runningJobs = List.of();
private final ObserveJobs observeJobs = new ObserveJobs() {
@Override
public Optional<IngestionJob> findJob(JobId id) { return Optional.empty(); }
@Override
public List<IngestionJob> listJobs(com.trueref.domain.model.RepositoryId r,
com.trueref.domain.model.VersionId v, JobStatus status, int limit) {
if (status == JobStatus.RUNNING) return runningJobs;
return List.of();
}
@Override
public AutoCloseable subscribeJobs(Consumer<IngestionJob> listener) { return () -> {}; }
@Override
public AutoCloseable subscribeLogs(JobId jobId, Consumer<com.trueref.domain.model.JobLogEvent> listener) { return () -> {}; }
};
private ModelLifecycleService service;
@BeforeEach
void setUp() {
loaderLoaded.set(false);
slowLoad = false;
runningJobs = List.of();
service = new ModelLifecycleService(loader, eventBus, observeJobs, 300L);
service.init();
}
@AfterEach
void tearDown() {
service.shutdown();
}
// ---- Tests ----
@Test
void startsInUnloadedState() {
assertThat(service.getStatus().state()).isEqualTo(ModelState.UNLOADED);
}
@Test
void ensureReadyOnUnloadedTriggersLoadAndThrows() throws InterruptedException {
List<ModelStateEvent> events = new ArrayList<>();
eventBus.subscribe(events::add);
assertThatThrownBy(() -> service.ensureReady())
.isInstanceOf(ModelNotReady.class);
// Wait for async load to complete (max 2 s in test)
for (int i = 0; i < 20 && service.getStatus().state() != ModelState.LOADED; i++) {
Thread.sleep(100);
}
assertThat(service.getStatus().state()).isEqualTo(ModelState.LOADED);
assertThat(loaderLoaded.get()).isTrue();
assertThat(events).extracting(ModelStateEvent::state)
.containsSequence(ModelState.LOADING, ModelState.LOADED);
}
@Test
void ensureReadyOnLoadedIsNoop() throws InterruptedException {
triggerLoad();
// Should not throw
service.ensureReady();
service.ensureReady();
assertThat(service.getStatus().state()).isEqualTo(ModelState.LOADED);
}
@Test
void ensureReadyOnLoadingThrowsImmediately() {
slowLoad = true;
assertThatThrownBy(() -> service.ensureReady())
.isInstanceOf(ModelNotReady.class);
// state is LOADING, second call also throws
assertThatThrownBy(() -> service.ensureReady())
.isInstanceOf(ModelNotReady.class);
}
@Test
void forceUnloadFromLoadedState() throws InterruptedException {
triggerLoad();
boolean initiated = service.forceUnload(false);
assertThat(initiated).isTrue();
// Wait for async unload
for (int i = 0; i < 20 && service.getStatus().state() != ModelState.UNLOADED; i++) {
Thread.sleep(100);
}
assertThat(service.getStatus().state()).isEqualTo(ModelState.UNLOADED);
assertThat(loaderLoaded.get()).isFalse();
}
@Test
void forceUnloadBlockedByRunningJobsWithoutForce() throws InterruptedException {
triggerLoad();
runningJobs = List.of(runningJob());
boolean initiated = service.forceUnload(false);
assertThat(initiated).isFalse();
assertThat(service.getStatus().state()).isEqualTo(ModelState.LOADED);
}
@Test
void forceUnloadSucceedsWithForceEvenWhenJobsRunning() throws InterruptedException {
triggerLoad();
runningJobs = List.of(runningJob());
boolean initiated = service.forceUnload(true);
assertThat(initiated).isTrue();
for (int i = 0; i < 20 && service.getStatus().state() != ModelState.UNLOADED; i++) {
Thread.sleep(100);
}
assertThat(service.getStatus().state()).isEqualTo(ModelState.UNLOADED);
}
@Test
void forceUnloadWhenAlreadyUnloadedIsNoop() {
boolean initiated = service.forceUnload(false);
assertThat(initiated).isTrue();
assertThat(service.getStatus().state()).isEqualTo(ModelState.UNLOADED);
}
@Test
void concurrentEnsureReadyDeduplicatesLoad() throws InterruptedException {
int callCount = 0;
// Both calls throw ModelNotReady but only one load should be triggered
List<ModelStateEvent> events = new ArrayList<>();
eventBus.subscribe(events::add);
try { service.ensureReady(); } catch (ModelNotReady ignored) { callCount++; }
try { service.ensureReady(); } catch (ModelNotReady ignored) { callCount++; }
for (int i = 0; i < 20 && service.getStatus().state() != ModelState.LOADED; i++) {
Thread.sleep(100);
}
assertThat(callCount).isEqualTo(2);
// Only one LOADING event should be emitted (deduplicated)
long loadingCount = events.stream().filter(e -> e.state() == ModelState.LOADING).count();
assertThat(loadingCount).isEqualTo(1);
}
@Test
void statusIncludesLoadedAtAfterLoad() throws InterruptedException {
assertThat(service.getStatus().loadedAt()).isNull();
triggerLoad();
assertThat(service.getStatus().loadedAt()).isNotNull();
}
// ---- helpers ----
private void triggerLoad() throws InterruptedException {
try { service.ensureReady(); } catch (ModelNotReady ignored) {}
for (int i = 0; i < 20 && service.getStatus().state() != ModelState.LOADED; i++) {
Thread.sleep(100);
}
}
private IngestionJob runningJob() {
return new IngestionJob(
JobId.of(java.util.UUID.randomUUID().toString()),
RepositoryId.of(java.util.UUID.randomUUID().toString()),
null,
JobType.INDEX_VERSION,
JobStatus.RUNNING,
Instant.now(),
null,
List.of());
}
}

View File

@@ -0,0 +1,35 @@
package com.trueref.application.model;
import com.trueref.domain.model.ModelStateEvent;
import com.trueref.domain.port.out.ModelStateEventBus;
import java.util.concurrent.CopyOnWriteArrayList;
import java.util.function.Consumer;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
/** In-memory, thread-safe implementation of {@link ModelStateEventBus}. */
public class InMemoryModelStateEventBus implements ModelStateEventBus {
private static final Logger log = LoggerFactory.getLogger(InMemoryModelStateEventBus.class);
private final CopyOnWriteArrayList<Consumer<ModelStateEvent>> subscribers =
new CopyOnWriteArrayList<>();
@Override
public void publish(ModelStateEvent event) {
for (Consumer<ModelStateEvent> subscriber : subscribers) {
try {
subscriber.accept(event);
} catch (Exception e) {
log.warn("Model state subscriber threw during publish; removing: {}", e.toString());
subscribers.remove(subscriber);
}
}
}
@Override
public AutoCloseable subscribe(Consumer<ModelStateEvent> subscriber) {
subscribers.add(subscriber);
return () -> subscribers.remove(subscriber);
}
}

View File

@@ -0,0 +1,287 @@
package com.trueref.application.model;
import com.trueref.domain.error.ModelNotReady;
import com.trueref.domain.model.IngestionJob;
import com.trueref.domain.model.JobStatus;
import com.trueref.domain.model.ModelState;
import com.trueref.domain.model.ModelStateEvent;
import com.trueref.domain.port.in.ManageModelLifecycle;
import com.trueref.domain.port.in.ObserveJobs;
import com.trueref.domain.port.out.ModelLoader;
import com.trueref.domain.port.out.ModelStateEventBus;
import java.time.Instant;
import java.util.concurrent.Executors;
import java.util.concurrent.Future;
import java.util.concurrent.ScheduledExecutorService;
import java.util.concurrent.ThreadFactory;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicLong;
import java.util.concurrent.atomic.AtomicReference;
import java.util.concurrent.locks.ReentrantLock;
import java.util.function.Consumer;
import org.jspecify.annotations.Nullable;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
/**
* Orchestrates the GPU model lifecycle: lazy loading, auto-unload after idle timeout, and
* force-unload. State transitions are protected by a {@link ReentrantLock}. Heavy load/unload
* operations are executed on a dedicated single-thread platform-thread executor to satisfy CUDA
* OS-thread affinity constraints.
*/
public class ModelLifecycleService implements ManageModelLifecycle {
private static final Logger log = LoggerFactory.getLogger(ModelLifecycleService.class);
/** Default retry hint in seconds when the model is LOADING. Load typically takes 530 s. */
private static final int LOADING_RETRY_AFTER_SECONDS = 30;
/** Idle-check interval. One check per minute is sufficient resolution for a 5-min timeout. */
private static final long IDLE_CHECK_INTERVAL_SECONDS = 60L;
private final ModelLoader modelLoader;
private final ModelStateEventBus eventBus;
private final ObserveJobs observeJobs;
private final long idleTimeoutSeconds;
// ---- state ----
private final ReentrantLock lock = new ReentrantLock();
private final AtomicReference<ModelState> state = new AtomicReference<>(ModelState.UNLOADED);
private volatile @Nullable Instant loadedAt = null;
private final AtomicLong lastActivityEpochMs = new AtomicLong(0L);
// ---- executors ----
/** Single platform OS thread for all CUDA load/unload calls. */
private final java.util.concurrent.ExecutorService cudaExecutor;
/** Scheduler that periodically checks for idle timeout. */
private final ScheduledExecutorService idleScheduler;
/** Handle to the in-flight async load Future (guarded by lock). */
private @Nullable Future<?> pendingLoad = null;
public ModelLifecycleService(
ModelLoader modelLoader,
ModelStateEventBus eventBus,
ObserveJobs observeJobs,
long idleTimeoutSeconds) {
this.modelLoader = modelLoader;
this.eventBus = eventBus;
this.observeJobs = observeJobs;
this.idleTimeoutSeconds = idleTimeoutSeconds;
ThreadFactory platformFactory = Thread.ofPlatform()
.name("model-lifecycle-cuda")
.factory();
this.cudaExecutor = Executors.newSingleThreadExecutor(platformFactory);
this.idleScheduler = Executors.newSingleThreadScheduledExecutor(
Thread.ofPlatform().name("model-idle-check").factory());
}
/** Initializes the idle timer and job-event subscription. Called by the Spring bean lifecycle. */
public void init() {
// Subscribe to job events so that finishing a job resets the idle clock.
observeJobs.subscribeJobs(this::onJobEvent);
// Start the periodic idle check.
idleScheduler.scheduleAtFixedRate(
this::checkIdleTimeout,
IDLE_CHECK_INTERVAL_SECONDS,
IDLE_CHECK_INTERVAL_SECONDS,
TimeUnit.SECONDS);
log.info("ModelLifecycleService started; idleTimeoutSeconds={} state=UNLOADED", idleTimeoutSeconds);
}
/** Shuts down schedulers and unloads models if loaded. Called by the Spring bean lifecycle. */
public void shutdown() {
idleScheduler.shutdownNow();
// Unload synchronously on shutdown so VRAM is released cleanly.
if (state.get() == ModelState.LOADED || state.get() == ModelState.LOADING) {
doUnload("JVM shutdown");
}
cudaExecutor.shutdownNow();
}
// ---- ManageModelLifecycle ----
@Override
public void ensureReady() {
// Fast path: already loaded.
if (state.get() == ModelState.LOADED) {
touchActivity();
return;
}
lock.lock();
try {
ModelState current = state.get();
switch (current) {
case LOADED -> {
touchActivity();
return;
}
case LOADING -> throw new ModelNotReady(LOADING_RETRY_AFTER_SECONDS);
case UNLOADING -> throw new ModelNotReady(LOADING_RETRY_AFTER_SECONDS);
case UNLOADED -> {
triggerAsyncLoad("triggered by incoming request");
throw new ModelNotReady(LOADING_RETRY_AFTER_SECONDS);
}
}
} finally {
lock.unlock();
}
}
@Override
public boolean forceUnload(boolean force) {
lock.lock();
try {
ModelState current = state.get();
if (current == ModelState.UNLOADED || current == ModelState.UNLOADING) {
log.info("forceUnload: already {} — no-op", current);
return true;
}
if (!force && hasRunningJobs()) {
log.info("forceUnload: blocked by running ingestion jobs (use force=true to override)");
return false;
}
if (current == ModelState.LOADING && pendingLoad != null) {
pendingLoad.cancel(false); // try to cancel; unload will follow after transition
}
doUnload("force-unload via API");
return true;
} finally {
lock.unlock();
}
}
@Override
public Status getStatus() {
return new Status(state.get(), loadedAt, epochMsToInstant(lastActivityEpochMs.get()), idleTimeoutSeconds);
}
@Override
public AutoCloseable subscribeState(Consumer<ModelStateEvent> subscriber) {
return eventBus.subscribe(subscriber);
}
// ---- idle timer ----
private void checkIdleTimeout() {
if (state.get() != ModelState.LOADED) {
return;
}
long lastMs = lastActivityEpochMs.get();
if (lastMs == 0L) {
return; // no activity yet
}
long idleMs = System.currentTimeMillis() - lastMs;
if (idleMs < idleTimeoutSeconds * 1000L) {
return;
}
lock.lock();
try {
if (state.get() != ModelState.LOADED) {
return; // state changed while we waited for the lock
}
if (hasRunningJobs()) {
log.debug("idle timeout reached but ingestion jobs are running; skipping unload");
return;
}
long nowIdleMs = System.currentTimeMillis() - lastActivityEpochMs.get();
if (nowIdleMs < idleTimeoutSeconds * 1000L) {
return; // activity happened while we waited for the lock
}
doUnload("idle timeout after " + (nowIdleMs / 1000) + " s");
} finally {
lock.unlock();
}
}
// ---- internal helpers (always called while holding lock) ----
/**
* Submits an async load to the CUDA executor. Must be called while holding {@code lock}.
* Transitions state UNLOADED → LOADING.
*/
private void triggerAsyncLoad(String reason) {
if (pendingLoad != null && !pendingLoad.isDone()) {
return; // already loading — deduplicate
}
transition(ModelState.LOADING, "load requested: " + reason);
pendingLoad = cudaExecutor.submit(() -> {
try {
log.info("Loading GPU models…");
modelLoader.load();
lock.lock();
try {
loadedAt = Instant.now();
touchActivity();
transition(ModelState.LOADED, null);
log.info("GPU models loaded successfully");
} finally {
lock.unlock();
}
} catch (Exception e) {
log.error("GPU model load failed", e);
lock.lock();
try {
transition(ModelState.UNLOADED, "load failed: " + e.getMessage());
} finally {
lock.unlock();
}
}
});
}
/**
* Runs an unload synchronously on the calling thread (must already be the CUDA executor thread
* or called during shutdown). For runtime calls it is submitted to the cuda executor.
* Must be called while holding {@code lock}.
*/
private void doUnload(String reason) {
transition(ModelState.UNLOADING, reason);
// Run unload on the CUDA executor so it stays on the same OS thread as the sessions.
cudaExecutor.execute(() -> {
try {
log.info("Unloading GPU models…");
modelLoader.unload();
log.info("GPU models unloaded");
} catch (Exception e) {
log.error("GPU model unload error (models may still be in VRAM)", e);
} finally {
lock.lock();
try {
loadedAt = null;
transition(ModelState.UNLOADED, null);
} finally {
lock.unlock();
}
}
});
}
private void transition(ModelState next, @Nullable String message) {
ModelState prev = state.getAndSet(next);
if (prev != next) {
log.debug("Model state: {} → {}{}", prev, next, message != null ? " [" + message + "]" : "");
eventBus.publish(ModelStateEvent.of(next, message != null ? message : ""));
}
}
private void touchActivity() {
lastActivityEpochMs.set(System.currentTimeMillis());
}
private boolean hasRunningJobs() {
return !observeJobs.listJobs(null, null, JobStatus.RUNNING, 1).isEmpty();
}
private void onJobEvent(IngestionJob job) {
if (job.status() == JobStatus.SUCCEEDED || job.status() == JobStatus.FAILED) {
touchActivity(); // reset idle clock when a job finishes
}
}
private static @Nullable Instant epochMsToInstant(long ms) {
return ms == 0L ? null : Instant.ofEpochMilli(ms);
}
}

View File

@@ -6,6 +6,7 @@ import com.trueref.domain.model.Repository;
import com.trueref.domain.model.SearchHit;
import com.trueref.domain.model.SearchScope;
import com.trueref.domain.model.Version;
import com.trueref.domain.port.in.ManageModelLifecycle;
import com.trueref.domain.port.in.SearchLibraryDocs;
import com.trueref.domain.port.out.ChunkStore;
import com.trueref.domain.port.out.EmbeddingService;
@@ -42,6 +43,7 @@ public final class HybridSearchService implements SearchLibraryDocs {
private final EmbeddingService embedder;
private final RerankerService reranker;
private final RepositoryStore repos;
private final ManageModelLifecycle lifecycle;
private final int rrfK;
private final int rerankTopK;
private final int finalTopK;
@@ -51,6 +53,7 @@ public final class HybridSearchService implements SearchLibraryDocs {
EmbeddingService embedder,
RerankerService reranker,
RepositoryStore repos,
ManageModelLifecycle lifecycle,
int rrfK,
int rerankTopK,
int finalTopK) {
@@ -58,6 +61,7 @@ public final class HybridSearchService implements SearchLibraryDocs {
this.embedder = embedder;
this.reranker = reranker;
this.repos = repos;
this.lifecycle = lifecycle;
this.rrfK = rrfK;
this.rerankTopK = rerankTopK;
this.finalTopK = finalTopK;
@@ -65,6 +69,9 @@ public final class HybridSearchService implements SearchLibraryDocs {
@Override
public Result search(Query q) {
// Ensure models are loaded; throws ModelNotReady (→ HTTP 503) if not.
lifecycle.ensureReady();
if (q.text() == null || q.text().isBlank()) {
throw new InvalidSearchRequest("query text must not be blank");
}

View File

@@ -7,6 +7,7 @@ import com.trueref.application.observability.InMemoryJobEventBus;
import com.trueref.application.observability.JobObservationService;
import com.trueref.application.resolve.LibraryResolver;
import com.trueref.application.search.HybridSearchService;
import com.trueref.domain.port.in.ManageModelLifecycle;
import com.trueref.domain.port.out.ChunkStore;
import com.trueref.domain.port.out.CodeParser;
import com.trueref.domain.port.out.EmbeddingCache;
@@ -76,10 +77,11 @@ public class ApplicationBeans {
EmbeddingService embedder,
RerankerService reranker,
RepositoryStore repos,
ManageModelLifecycle lifecycle,
@Value("${trueref.search.rrf-k:60}") int rrfK,
@Value("${trueref.reranker.top-k:50}") int rerankTopK,
@Value("${trueref.search.final-top-k:20}") int finalTopK) {
return new HybridSearchService(chunks, embedder, reranker, repos, rrfK, rerankTopK, finalTopK);
return new HybridSearchService(chunks, embedder, reranker, repos, lifecycle, rrfK, rerankTopK, finalTopK);
}
@Bean

View File

@@ -0,0 +1,23 @@
package com.trueref.domain.error;
/**
* Thrown when an inference request arrives while the GPU models are not yet loaded (or are being
* unloaded). Callers should surface this as HTTP 503 with a {@code Retry-After} header, or as an
* MCP error inviting retry.
*/
public final class ModelNotReady extends TrueRefException {
private final int retryAfterSeconds;
public ModelNotReady(int retryAfterSeconds) {
super("model_not_ready",
"Model is not ready, retry in ~" + retryAfterSeconds + " seconds",
null);
this.retryAfterSeconds = retryAfterSeconds;
}
/** Suggested number of seconds the caller should wait before retrying. */
public int retryAfterSeconds() {
return retryAfterSeconds;
}
}

View File

@@ -10,7 +10,8 @@ public abstract sealed class TrueRefException extends RuntimeException
VersionNotIndexed,
TagNotFound,
IngestionFailed,
InvalidSearchRequest {
InvalidSearchRequest,
ModelNotReady {
private final String code;

View File

@@ -0,0 +1,13 @@
package com.trueref.domain.model;
/** Life-cycle state of the GPU inference models (embedder + reranker). */
public enum ModelState {
/** Models are not loaded; no VRAM is consumed. */
UNLOADED,
/** Models are being loaded into GPU memory (load in progress). */
LOADING,
/** Models are loaded and ready to serve inference requests. */
LOADED,
/** Models are being unloaded from GPU memory. */
UNLOADING
}

View File

@@ -0,0 +1,18 @@
package com.trueref.domain.model;
import java.time.Instant;
import org.jspecify.annotations.Nullable;
/**
* Immutable event emitted whenever the GPU model lifecycle transitions to a new {@link ModelState}.
*/
public record ModelStateEvent(ModelState state, Instant ts, @Nullable String message) {
public static ModelStateEvent of(ModelState state) {
return new ModelStateEvent(state, Instant.now(), null);
}
public static ModelStateEvent of(ModelState state, String message) {
return new ModelStateEvent(state, Instant.now(), message);
}
}

View File

@@ -0,0 +1,47 @@
package com.trueref.domain.port.in;
import com.trueref.domain.model.ModelState;
import com.trueref.domain.model.ModelStateEvent;
import java.time.Instant;
import java.util.function.Consumer;
import org.jspecify.annotations.Nullable;
/** Use-case port: manage the lifecycle of the GPU inference models. */
public interface ManageModelLifecycle {
/**
* Ensures the model is loaded and ready. If the model is already LOADED, returns immediately.
* If the model is UNLOADED, triggers an async load and throws
* {@link com.trueref.domain.error.ModelNotReady}. If the model is LOADING or UNLOADING,
* throws {@link com.trueref.domain.error.ModelNotReady}.
*
* <p>Also records the current time as the last-activity timestamp used by the idle timer.
*/
void ensureReady();
/**
* Forces an unload of the GPU models. Blocked if ingestion jobs are running, unless
* {@code force=true}.
*
* @param force when {@code true}, unloads even while jobs are running
* @return {@code true} if unload was initiated; {@code false} if blocked by running jobs
*/
boolean forceUnload(boolean force);
/** Returns a snapshot of the current model lifecycle status. */
Status getStatus();
/**
* Registers a subscriber to receive model state-change events.
*
* @return an {@link AutoCloseable} that removes the subscription when closed
*/
AutoCloseable subscribeState(Consumer<ModelStateEvent> subscriber);
/** Snapshot of the current model lifecycle status. */
record Status(
ModelState state,
@Nullable Instant loadedAt,
@Nullable Instant lastActivityAt,
long idleTimeoutSeconds) {}
}

View File

@@ -0,0 +1,18 @@
package com.trueref.domain.port.out;
/**
* Loads and unloads the GPU inference models (embedder + reranker) as a shared lifecycle unit.
* Implementations are expected to run load/unload on a platform OS thread to satisfy CUDA
* context affinity constraints.
*/
public interface ModelLoader {
/** Loads both models into GPU memory. Blocks until ready or throws on error. */
void load();
/** Releases both models from GPU memory. Idempotent. */
void unload();
/** Returns {@code true} if both models are currently loaded and accepting inference. */
boolean isLoaded();
}

View File

@@ -0,0 +1,18 @@
package com.trueref.domain.port.out;
import com.trueref.domain.model.ModelStateEvent;
import java.util.function.Consumer;
/** Event bus for broadcasting {@link ModelStateEvent}s to subscribers (e.g. SSE connections). */
public interface ModelStateEventBus {
/** Publishes a state-change event to all current subscribers. */
void publish(ModelStateEvent event);
/**
* Registers a subscriber to receive future events.
*
* @return an {@link AutoCloseable} that removes the subscription when closed
*/
AutoCloseable subscribe(Consumer<ModelStateEvent> subscriber);
}