Files
trueref/trueref-application/src/main/java/com/trueref/application/search/HybridSearchService.java
moze 5c6085df99
All checks were successful
Build and publish Docker image / Build and push CPU image (push) Successful in 2m33s
Build and publish Docker image / Build and push GPU image (push) Successful in 3m15s
feat: GPU model lazy-load/unload lifecycle management
- Domain: add ModelState, ModelStateEvent, ModelNotReady, ManageModelLifecycle
  (in-port), ModelLoader and ModelStateEventBus (out-ports)
- Application: InMemoryModelStateEventBus; ModelLifecycleService — state
  machine (ReentrantLock), lazy load on first request, idle-timeout auto-unload
  (configurable via trueref.embedding.idle-timeout-seconds, default 300 s),
  job-guard (skips unload while ingestion running), platform-thread CUDA executor
- Adapters: OnnxModelLoader wires embedder + reranker start/stop; remove
  @PostConstruct/@PreDestroy from OnnxEmbeddingService and OnnxRerankerService;
  requireStarted() now throws ModelNotReady instead of IllegalStateException
- REST: GET /api/model/status, POST /api/model/unload (409 when jobs running,
  force=true to override), GET /api/model/status/stream (SSE)
- GlobalExceptionHandler: ModelNotReady -> 503 + Retry-After header
- HybridSearchService: calls lifecycle.ensureReady() before every search so
  both REST and MCP paths get ModelNotReady (-> 503 / MCP error) when unloaded
- TrueRefMcpTools: catches ModelNotReady, returns retry hint in MCP error text
- Tests: InMemoryModelStateEventBusTest, ModelLifecycleServiceTest (10 cases),
  OnnxModelLoaderTest, GlobalExceptionHandlerTest — all 41 tests green

Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
2026-05-09 15:44:33 +02:00

246 lines
10 KiB
Java
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
package com.trueref.application.search;
import com.trueref.domain.error.InvalidSearchRequest;
import com.trueref.domain.model.ChunkId;
import com.trueref.domain.model.Repository;
import com.trueref.domain.model.SearchHit;
import com.trueref.domain.model.SearchScope;
import com.trueref.domain.model.Version;
import com.trueref.domain.port.in.ManageModelLifecycle;
import com.trueref.domain.port.in.SearchLibraryDocs;
import com.trueref.domain.port.out.ChunkStore;
import com.trueref.domain.port.out.EmbeddingService;
import com.trueref.domain.port.out.RepositoryStore;
import com.trueref.domain.port.out.RerankerService;
import java.util.ArrayList;
import java.util.Comparator;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.regex.Matcher;
import java.util.regex.Pattern;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
/**
* Hybrid search: BM25 + dense kNN fused by Reciprocal Rank Fusion (RRF), then reranked by a
* cross-encoder, then packed to a token budget.
*/
public final class HybridSearchService implements SearchLibraryDocs {
private static final Logger log = LoggerFactory.getLogger(HybridSearchService.class);
/**
* Matches camelCase identifiers that are likely to be Phaser API method/class names (≥6 chars,
* must contain at least one uppercase letter after the first char, not all-caps).
* Examples: setCollideWorldBounds, createBitmapMask, addOverlap.
*/
private static final Pattern CAMEL_IDENT = Pattern.compile(
"\\b([a-z][a-zA-Z0-9]{5,})(?=\\b)");
private final ChunkStore chunks;
private final EmbeddingService embedder;
private final RerankerService reranker;
private final RepositoryStore repos;
private final ManageModelLifecycle lifecycle;
private final int rrfK;
private final int rerankTopK;
private final int finalTopK;
public HybridSearchService(
ChunkStore chunks,
EmbeddingService embedder,
RerankerService reranker,
RepositoryStore repos,
ManageModelLifecycle lifecycle,
int rrfK,
int rerankTopK,
int finalTopK) {
this.chunks = chunks;
this.embedder = embedder;
this.reranker = reranker;
this.repos = repos;
this.lifecycle = lifecycle;
this.rrfK = rrfK;
this.rerankTopK = rerankTopK;
this.finalTopK = finalTopK;
}
@Override
public Result search(Query q) {
// Ensure models are loaded; throws ModelNotReady (→ HTTP 503) if not.
lifecycle.ensureReady();
if (q.text() == null || q.text().isBlank()) {
throw new InvalidSearchRequest("query text must not be blank");
}
if (q.scope().refs().isEmpty()) {
throw new InvalidSearchRequest("search scope must not be empty");
}
String text = rewrite(q.text(), q.topic());
// Augment BM25 query with camelCase identifiers found in the text so that the exact
// method-name chunk scores higher in BM25 even when it competes with generic mentions.
String bm25Text = augmentWithCamelIdents(text);
List<SearchHit> bm25 = chunks.bm25Search(bm25Text, q.scope(), rerankTopK);
float[] vec = embedder.embed(List.of(text)).get(0);
List<SearchHit> dense = chunks.denseSearch(vec, q.scope(), rerankTopK);
List<SearchHit> fused = rrf(bm25, dense);
if (fused.size() > rerankTopK) fused = fused.subList(0, rerankTopK);
// Demote changelog / synthetic-skill / docs paths before the reranker sees them so that
// authoritative source-code chunks aren't squeezed out by historical migration notes.
List<SearchHit> biased = applyFilePathBias(fused);
// Enrich with repo name + tag (ChunkStore leaves these empty).
List<SearchHit> enriched = enrich(biased);
List<SearchHit> reranked = reranker.rerank(text, enriched);
List<SearchHit> packed = packByTokenBudget(reranked, q.tokensBudget(), q.maxHits() > 0 ? q.maxHits() : finalTopK);
int totalTokens = packed.stream().mapToInt(h -> estimateTokens(h.content())).sum();
return new Result(packed, totalTokens);
}
/* ------------------------------------------------------------------ */
private String rewrite(String text, String topic) {
String base = text.trim();
if (topic != null && !topic.isBlank()) {
return base + " " + topic.trim();
}
return base;
}
/**
* Returns a copy of {@code text} with each camelCase identifier repeated at the end (once).
* This lifts their BM25 term-frequency contribution without altering the semantic meaning
* used for the dense embedding query.
*
* <p>Example: "how to use setCollideWorldBounds" →
* "how to use setCollideWorldBounds setCollideWorldBounds"
*/
private static String augmentWithCamelIdents(String text) {
Matcher m = CAMEL_IDENT.matcher(text);
StringBuilder extra = new StringBuilder();
while (m.find()) {
String ident = m.group(1);
// Only repeat identifiers that contain at least one uppercase letter
// (filters out short common words like "should", "create").
if (!ident.equals(ident.toLowerCase())) {
extra.append(' ').append(ident);
}
}
return extra.isEmpty() ? text : text + extra;
}
/**
* Applies a path-based multiplier to RRF scores before handing candidates to the reranker.
* Changelogs and synthetic skill docs are semantically relevant but tend to outrank the
* authoritative source-code chunks when the query mentions API migration or breaking changes.
* Demoting them here keeps them retrievable while giving source files priority.
*
* <p>Multipliers (tuned against the phaser_rag_eval suite):
* <ul>
* <li>{@code changelog/} → ×0.50 — migration notes, not current API reference
* <li>{@code skills/} / {@code SKILL.md} → ×0.60 — synthetic summaries, not authoritative
* <li>{@code docs/} → ×0.75 — curated docs; useful but prefer source JSDoc
* <li>everything else (source, tests, configs) → ×1.0
* </ul>
*/
private static List<SearchHit> applyFilePathBias(List<SearchHit> hits) {
boolean anyChanged = false;
List<SearchHit> out = new ArrayList<>(hits.size());
for (SearchHit h : hits) {
double mult = filePathMultiplier(h.filePath());
if (mult == 1.0) {
out.add(h);
} else {
out.add(new SearchHit(
h.chunkId(), h.repoId(), h.versionId(), h.repoName(), h.tag(),
h.filePath(), h.startLine(), h.endLine(), h.language(), h.symbol(),
h.content(), h.score() * mult));
anyChanged = true;
}
}
if (!anyChanged) return hits;
out.sort(Comparator.comparingDouble(SearchHit::score).reversed());
return out;
}
private static double filePathMultiplier(String filePath) {
if (filePath == null || filePath.isEmpty()) return 1.0;
String lp = filePath.toLowerCase();
if (lp.startsWith("changelog/") || lp.contains("/changelog/")) return 0.50;
if (lp.contains("/skills/") || lp.endsWith("skill.md")) return 0.60;
if (lp.startsWith("docs/") || lp.contains("/docs/")) return 0.75;
return 1.0;
}
private List<SearchHit> rrf(List<SearchHit> a, List<SearchHit> b) {
Map<ChunkId, Double> scores = new HashMap<>();
Map<ChunkId, SearchHit> firstSeen = new HashMap<>();
addRankContribution(a, scores, firstSeen);
addRankContribution(b, scores, firstSeen);
return scores.entrySet().stream()
.sorted(Map.Entry.<ChunkId, Double>comparingByValue().reversed())
.map(e -> {
SearchHit h = firstSeen.get(e.getKey());
return new SearchHit(
h.chunkId(), h.repoId(), h.versionId(), h.repoName(), h.tag(),
h.filePath(), h.startLine(), h.endLine(), h.language(), h.symbol(),
h.content(), e.getValue());
})
.toList();
}
private void addRankContribution(List<SearchHit> hits, Map<ChunkId, Double> scores, Map<ChunkId, SearchHit> seen) {
for (int rank = 0; rank < hits.size(); rank++) {
SearchHit h = hits.get(rank);
scores.merge(h.chunkId(), 1.0 / (rrfK + rank + 1.0), Double::sum);
seen.putIfAbsent(h.chunkId(), h);
}
}
private List<SearchHit> enrich(List<SearchHit> hits) {
Map<String, String> repoNameByRepoId = new HashMap<>();
Map<String, String> tagByVersionId = new HashMap<>();
List<SearchHit> out = new ArrayList<>(hits.size());
for (SearchHit h : hits) {
String repoName = repoNameByRepoId.computeIfAbsent(
h.repoId().toString(),
k -> repos.findById(h.repoId()).map(Repository::name).orElse("?"));
String tag = tagByVersionId.computeIfAbsent(
h.versionId().toString(),
k -> repos.findVersion(h.versionId()).map(Version::tag).orElse("?"));
out.add(new SearchHit(
h.chunkId(), h.repoId(), h.versionId(),
repoName, tag,
h.filePath(), h.startLine(), h.endLine(), h.language(), h.symbol(),
h.content(), h.score()));
}
return out;
}
private List<SearchHit> packByTokenBudget(List<SearchHit> ranked, int tokenBudget, int maxHits) {
List<SearchHit> out = new ArrayList<>();
int used = 0;
for (SearchHit h : ranked) {
if (out.size() >= maxHits) break;
int t = estimateTokens(h.content());
if (used + t > tokenBudget && !out.isEmpty()) break;
out.add(h);
used += t;
}
return out;
}
/** 4 chars ≈ 1 token — same rule of thumb Context7 uses for packing. */
private static int estimateTokens(String s) {
return Math.max(1, s.length() / 4);
}
}