- Domain: add ModelState, ModelStateEvent, ModelNotReady, ManageModelLifecycle (in-port), ModelLoader and ModelStateEventBus (out-ports) - Application: InMemoryModelStateEventBus; ModelLifecycleService — state machine (ReentrantLock), lazy load on first request, idle-timeout auto-unload (configurable via trueref.embedding.idle-timeout-seconds, default 300 s), job-guard (skips unload while ingestion running), platform-thread CUDA executor - Adapters: OnnxModelLoader wires embedder + reranker start/stop; remove @PostConstruct/@PreDestroy from OnnxEmbeddingService and OnnxRerankerService; requireStarted() now throws ModelNotReady instead of IllegalStateException - REST: GET /api/model/status, POST /api/model/unload (409 when jobs running, force=true to override), GET /api/model/status/stream (SSE) - GlobalExceptionHandler: ModelNotReady -> 503 + Retry-After header - HybridSearchService: calls lifecycle.ensureReady() before every search so both REST and MCP paths get ModelNotReady (-> 503 / MCP error) when unloaded - TrueRefMcpTools: catches ModelNotReady, returns retry hint in MCP error text - Tests: InMemoryModelStateEventBusTest, ModelLifecycleServiceTest (10 cases), OnnxModelLoaderTest, GlobalExceptionHandlerTest — all 41 tests green Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
246 lines
10 KiB
Java
246 lines
10 KiB
Java
package com.trueref.application.search;
|
||
|
||
import com.trueref.domain.error.InvalidSearchRequest;
|
||
import com.trueref.domain.model.ChunkId;
|
||
import com.trueref.domain.model.Repository;
|
||
import com.trueref.domain.model.SearchHit;
|
||
import com.trueref.domain.model.SearchScope;
|
||
import com.trueref.domain.model.Version;
|
||
import com.trueref.domain.port.in.ManageModelLifecycle;
|
||
import com.trueref.domain.port.in.SearchLibraryDocs;
|
||
import com.trueref.domain.port.out.ChunkStore;
|
||
import com.trueref.domain.port.out.EmbeddingService;
|
||
import com.trueref.domain.port.out.RepositoryStore;
|
||
import com.trueref.domain.port.out.RerankerService;
|
||
import java.util.ArrayList;
|
||
import java.util.Comparator;
|
||
import java.util.HashMap;
|
||
import java.util.List;
|
||
import java.util.Map;
|
||
import java.util.Objects;
|
||
import java.util.regex.Matcher;
|
||
import java.util.regex.Pattern;
|
||
import org.slf4j.Logger;
|
||
import org.slf4j.LoggerFactory;
|
||
|
||
/**
|
||
* Hybrid search: BM25 + dense kNN fused by Reciprocal Rank Fusion (RRF), then reranked by a
|
||
* cross-encoder, then packed to a token budget.
|
||
*/
|
||
public final class HybridSearchService implements SearchLibraryDocs {
|
||
|
||
private static final Logger log = LoggerFactory.getLogger(HybridSearchService.class);
|
||
|
||
/**
|
||
* Matches camelCase identifiers that are likely to be Phaser API method/class names (≥6 chars,
|
||
* must contain at least one uppercase letter after the first char, not all-caps).
|
||
* Examples: setCollideWorldBounds, createBitmapMask, addOverlap.
|
||
*/
|
||
private static final Pattern CAMEL_IDENT = Pattern.compile(
|
||
"\\b([a-z][a-zA-Z0-9]{5,})(?=\\b)");
|
||
|
||
private final ChunkStore chunks;
|
||
private final EmbeddingService embedder;
|
||
private final RerankerService reranker;
|
||
private final RepositoryStore repos;
|
||
private final ManageModelLifecycle lifecycle;
|
||
private final int rrfK;
|
||
private final int rerankTopK;
|
||
private final int finalTopK;
|
||
|
||
public HybridSearchService(
|
||
ChunkStore chunks,
|
||
EmbeddingService embedder,
|
||
RerankerService reranker,
|
||
RepositoryStore repos,
|
||
ManageModelLifecycle lifecycle,
|
||
int rrfK,
|
||
int rerankTopK,
|
||
int finalTopK) {
|
||
this.chunks = chunks;
|
||
this.embedder = embedder;
|
||
this.reranker = reranker;
|
||
this.repos = repos;
|
||
this.lifecycle = lifecycle;
|
||
this.rrfK = rrfK;
|
||
this.rerankTopK = rerankTopK;
|
||
this.finalTopK = finalTopK;
|
||
}
|
||
|
||
@Override
|
||
public Result search(Query q) {
|
||
// Ensure models are loaded; throws ModelNotReady (→ HTTP 503) if not.
|
||
lifecycle.ensureReady();
|
||
|
||
if (q.text() == null || q.text().isBlank()) {
|
||
throw new InvalidSearchRequest("query text must not be blank");
|
||
}
|
||
if (q.scope().refs().isEmpty()) {
|
||
throw new InvalidSearchRequest("search scope must not be empty");
|
||
}
|
||
|
||
String text = rewrite(q.text(), q.topic());
|
||
// Augment BM25 query with camelCase identifiers found in the text so that the exact
|
||
// method-name chunk scores higher in BM25 even when it competes with generic mentions.
|
||
String bm25Text = augmentWithCamelIdents(text);
|
||
|
||
List<SearchHit> bm25 = chunks.bm25Search(bm25Text, q.scope(), rerankTopK);
|
||
float[] vec = embedder.embed(List.of(text)).get(0);
|
||
List<SearchHit> dense = chunks.denseSearch(vec, q.scope(), rerankTopK);
|
||
|
||
List<SearchHit> fused = rrf(bm25, dense);
|
||
if (fused.size() > rerankTopK) fused = fused.subList(0, rerankTopK);
|
||
|
||
// Demote changelog / synthetic-skill / docs paths before the reranker sees them so that
|
||
// authoritative source-code chunks aren't squeezed out by historical migration notes.
|
||
List<SearchHit> biased = applyFilePathBias(fused);
|
||
|
||
// Enrich with repo name + tag (ChunkStore leaves these empty).
|
||
List<SearchHit> enriched = enrich(biased);
|
||
|
||
List<SearchHit> reranked = reranker.rerank(text, enriched);
|
||
|
||
List<SearchHit> packed = packByTokenBudget(reranked, q.tokensBudget(), q.maxHits() > 0 ? q.maxHits() : finalTopK);
|
||
int totalTokens = packed.stream().mapToInt(h -> estimateTokens(h.content())).sum();
|
||
return new Result(packed, totalTokens);
|
||
}
|
||
|
||
/* ------------------------------------------------------------------ */
|
||
|
||
private String rewrite(String text, String topic) {
|
||
String base = text.trim();
|
||
if (topic != null && !topic.isBlank()) {
|
||
return base + " " + topic.trim();
|
||
}
|
||
return base;
|
||
}
|
||
|
||
/**
|
||
* Returns a copy of {@code text} with each camelCase identifier repeated at the end (once).
|
||
* This lifts their BM25 term-frequency contribution without altering the semantic meaning
|
||
* used for the dense embedding query.
|
||
*
|
||
* <p>Example: "how to use setCollideWorldBounds" →
|
||
* "how to use setCollideWorldBounds setCollideWorldBounds"
|
||
*/
|
||
private static String augmentWithCamelIdents(String text) {
|
||
Matcher m = CAMEL_IDENT.matcher(text);
|
||
StringBuilder extra = new StringBuilder();
|
||
while (m.find()) {
|
||
String ident = m.group(1);
|
||
// Only repeat identifiers that contain at least one uppercase letter
|
||
// (filters out short common words like "should", "create").
|
||
if (!ident.equals(ident.toLowerCase())) {
|
||
extra.append(' ').append(ident);
|
||
}
|
||
}
|
||
return extra.isEmpty() ? text : text + extra;
|
||
}
|
||
|
||
/**
|
||
* Applies a path-based multiplier to RRF scores before handing candidates to the reranker.
|
||
* Changelogs and synthetic skill docs are semantically relevant but tend to outrank the
|
||
* authoritative source-code chunks when the query mentions API migration or breaking changes.
|
||
* Demoting them here keeps them retrievable while giving source files priority.
|
||
*
|
||
* <p>Multipliers (tuned against the phaser_rag_eval suite):
|
||
* <ul>
|
||
* <li>{@code changelog/} → ×0.50 — migration notes, not current API reference
|
||
* <li>{@code skills/} / {@code SKILL.md} → ×0.60 — synthetic summaries, not authoritative
|
||
* <li>{@code docs/} → ×0.75 — curated docs; useful but prefer source JSDoc
|
||
* <li>everything else (source, tests, configs) → ×1.0
|
||
* </ul>
|
||
*/
|
||
private static List<SearchHit> applyFilePathBias(List<SearchHit> hits) {
|
||
boolean anyChanged = false;
|
||
List<SearchHit> out = new ArrayList<>(hits.size());
|
||
for (SearchHit h : hits) {
|
||
double mult = filePathMultiplier(h.filePath());
|
||
if (mult == 1.0) {
|
||
out.add(h);
|
||
} else {
|
||
out.add(new SearchHit(
|
||
h.chunkId(), h.repoId(), h.versionId(), h.repoName(), h.tag(),
|
||
h.filePath(), h.startLine(), h.endLine(), h.language(), h.symbol(),
|
||
h.content(), h.score() * mult));
|
||
anyChanged = true;
|
||
}
|
||
}
|
||
if (!anyChanged) return hits;
|
||
out.sort(Comparator.comparingDouble(SearchHit::score).reversed());
|
||
return out;
|
||
}
|
||
|
||
private static double filePathMultiplier(String filePath) {
|
||
if (filePath == null || filePath.isEmpty()) return 1.0;
|
||
String lp = filePath.toLowerCase();
|
||
if (lp.startsWith("changelog/") || lp.contains("/changelog/")) return 0.50;
|
||
if (lp.contains("/skills/") || lp.endsWith("skill.md")) return 0.60;
|
||
if (lp.startsWith("docs/") || lp.contains("/docs/")) return 0.75;
|
||
return 1.0;
|
||
}
|
||
|
||
private List<SearchHit> rrf(List<SearchHit> a, List<SearchHit> b) {
|
||
Map<ChunkId, Double> scores = new HashMap<>();
|
||
Map<ChunkId, SearchHit> firstSeen = new HashMap<>();
|
||
addRankContribution(a, scores, firstSeen);
|
||
addRankContribution(b, scores, firstSeen);
|
||
return scores.entrySet().stream()
|
||
.sorted(Map.Entry.<ChunkId, Double>comparingByValue().reversed())
|
||
.map(e -> {
|
||
SearchHit h = firstSeen.get(e.getKey());
|
||
return new SearchHit(
|
||
h.chunkId(), h.repoId(), h.versionId(), h.repoName(), h.tag(),
|
||
h.filePath(), h.startLine(), h.endLine(), h.language(), h.symbol(),
|
||
h.content(), e.getValue());
|
||
})
|
||
.toList();
|
||
}
|
||
|
||
private void addRankContribution(List<SearchHit> hits, Map<ChunkId, Double> scores, Map<ChunkId, SearchHit> seen) {
|
||
for (int rank = 0; rank < hits.size(); rank++) {
|
||
SearchHit h = hits.get(rank);
|
||
scores.merge(h.chunkId(), 1.0 / (rrfK + rank + 1.0), Double::sum);
|
||
seen.putIfAbsent(h.chunkId(), h);
|
||
}
|
||
}
|
||
|
||
private List<SearchHit> enrich(List<SearchHit> hits) {
|
||
Map<String, String> repoNameByRepoId = new HashMap<>();
|
||
Map<String, String> tagByVersionId = new HashMap<>();
|
||
List<SearchHit> out = new ArrayList<>(hits.size());
|
||
for (SearchHit h : hits) {
|
||
String repoName = repoNameByRepoId.computeIfAbsent(
|
||
h.repoId().toString(),
|
||
k -> repos.findById(h.repoId()).map(Repository::name).orElse("?"));
|
||
String tag = tagByVersionId.computeIfAbsent(
|
||
h.versionId().toString(),
|
||
k -> repos.findVersion(h.versionId()).map(Version::tag).orElse("?"));
|
||
out.add(new SearchHit(
|
||
h.chunkId(), h.repoId(), h.versionId(),
|
||
repoName, tag,
|
||
h.filePath(), h.startLine(), h.endLine(), h.language(), h.symbol(),
|
||
h.content(), h.score()));
|
||
}
|
||
return out;
|
||
}
|
||
|
||
private List<SearchHit> packByTokenBudget(List<SearchHit> ranked, int tokenBudget, int maxHits) {
|
||
List<SearchHit> out = new ArrayList<>();
|
||
int used = 0;
|
||
for (SearchHit h : ranked) {
|
||
if (out.size() >= maxHits) break;
|
||
int t = estimateTokens(h.content());
|
||
if (used + t > tokenBudget && !out.isEmpty()) break;
|
||
out.add(h);
|
||
used += t;
|
||
}
|
||
return out;
|
||
}
|
||
|
||
/** 4 chars ≈ 1 token — same rule of thumb Context7 uses for packing. */
|
||
private static int estimateTokens(String s) {
|
||
return Math.max(1, s.length() / 4);
|
||
}
|
||
}
|