diff --git a/.dockerignore b/.dockerignore new file mode 100644 index 0000000..591aa99 --- /dev/null +++ b/.dockerignore @@ -0,0 +1,33 @@ +# Runtime data — large and machine-specific, never needed in the image +data/ +data-onnx-smoke/ +logs/ + +# cuDNN and other native runtime libraries (823 MB) +runtime/ + +# Build outputs (the Dockerfile re-builds from source) +**/target/ +**/build/ +trueref-frontend/web/node_modules/ +trueref-frontend/web/.svelte-kit/ + +# Git and IDE metadata +.git/ +.gitignore +.gitea/ +.vscode/ +.idea/ + +# JVM crash dumps +hs_err_pid*.log +core.* + +# Tests +tests/ + +# Docs +ARCHITECTURE.md +CODE_STYLE.md +FINDINGS.md +README.md diff --git a/.gitignore b/.gitignore index f625a74..27fa581 100644 --- a/.gitignore +++ b/.gitignore @@ -1,6 +1,6 @@ target/ build/ -out/ +/out/ .idea/ .vscode/ *.iml diff --git a/trueref-adapters/src/main/java/com/trueref/adapter/out/cache/disk/DiskEmbeddingCache.java b/trueref-adapters/src/main/java/com/trueref/adapter/out/cache/disk/DiskEmbeddingCache.java new file mode 100644 index 0000000..70bb719 --- /dev/null +++ b/trueref-adapters/src/main/java/com/trueref/adapter/out/cache/disk/DiskEmbeddingCache.java @@ -0,0 +1,173 @@ +package com.trueref.adapter.out.cache.disk; + +import com.trueref.domain.port.out.EmbeddingCache; +import jakarta.annotation.PostConstruct; +import java.io.IOException; +import java.io.UncheckedIOException; +import java.nio.ByteBuffer; +import java.nio.ByteOrder; +import java.nio.file.FileVisitResult; +import java.nio.file.Files; +import java.nio.file.Path; +import java.nio.file.SimpleFileVisitor; +import java.nio.file.StandardCopyOption; +import java.nio.file.attribute.BasicFileAttributes; +import java.util.LinkedHashMap; +import java.util.Map; +import java.util.Objects; +import java.util.Optional; +import java.util.concurrent.locks.ReentrantLock; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import org.springframework.stereotype.Component; + +/** + * File-per-hash on-disk embedding cache with a fixed-capacity in-memory LRU front. Files are stored + * as raw little-endian {@code float32} arrays under a two-level fanout + * ({@code home///.f32}). Atomic writes via {@code .tmp + ATOMIC_MOVE}. + */ +@Component +class DiskEmbeddingCache implements EmbeddingCache { + + private static final Logger log = LoggerFactory.getLogger(DiskEmbeddingCache.class); + private static final String FILE_SUFFIX = ".f32"; + private static final String TMP_SUFFIX = ".tmp"; + + private final EmbeddingCacheProperties props; + private final Path home; + private final int dimension; + private final int expectedBytes; + private final ReentrantLock lock = new ReentrantLock(); + private final LinkedHashMap lru; + + DiskEmbeddingCache(EmbeddingCacheProperties props) { + this.props = props; + this.home = Objects.requireNonNull(props.home(), "props.home must be resolved by config"); + this.dimension = props.dimension(); + this.expectedBytes = dimension * Float.BYTES; + int cap = props.memoryMaxEntries(); + this.lru = new LinkedHashMap<>(Math.min(cap, 1024), 0.75f, true) { + @Override + protected boolean removeEldestEntry(Map.Entry eldest) { + return size() > cap; + } + }; + } + + @PostConstruct + void init() { + try { + Files.createDirectories(home); + // Sentinel file records the dimension used to build this cache directory. + // If the dimension changes (e.g. model swap), all stale .f32 files are deleted + // so the cache never silently serves wrong-size vectors. + Path sentinel = home.resolve(".dimension"); + if (Files.isRegularFile(sentinel)) { + int storedDim = -1; + try { storedDim = Integer.parseInt(Files.readString(sentinel).strip()); } + catch (NumberFormatException ignored) {} + if (storedDim != dimension) { + log.warn("embedding cache dimension changed {} \u2192 {} \u2014 wiping stale .f32 files under {}", + storedDim, dimension, home); + Files.walkFileTree(home, new SimpleFileVisitor<>() { + @Override + public FileVisitResult visitFile(Path file, BasicFileAttributes attrs) throws IOException { + if (file.getFileName().toString().endsWith(FILE_SUFFIX)) { + Files.deleteIfExists(file); + } + return FileVisitResult.CONTINUE; + } + }); + } + } + Files.writeString(sentinel, String.valueOf(dimension)); + } catch (IOException e) { + throw new UncheckedIOException("failed to initialize embedding cache at " + home, e); + } + log.info("embedding cache initialized at {} (dim={}, lru={})", home, dimension, props.memoryMaxEntries()); + } + + @Override + public Optional get(String contentHash) { + Objects.requireNonNull(contentHash, "contentHash"); + lock.lock(); + try { + float[] hot = lru.get(contentHash); + if (hot != null) { + return Optional.of(hot.clone()); + } + } finally { + lock.unlock(); + } + + Path file = pathFor(contentHash); + if (!Files.isRegularFile(file)) { + return Optional.empty(); + } + byte[] bytes; + try { + bytes = Files.readAllBytes(file); + } catch (IOException e) { + log.warn("failed to read embedding cache file {}: {}", file, e.toString()); + return Optional.empty(); + } + if (bytes.length != expectedBytes) { + log.warn( + "embedding cache file {} has size {} bytes, expected {} (dim={}); treating as miss", + file, + bytes.length, + expectedBytes, + dimension); + return Optional.empty(); + } + float[] vector = new float[dimension]; + ByteBuffer.wrap(bytes).order(ByteOrder.LITTLE_ENDIAN).asFloatBuffer().get(vector); + + lock.lock(); + try { + lru.put(contentHash, vector); + } finally { + lock.unlock(); + } + return Optional.of(vector.clone()); + } + + @Override + public void put(String contentHash, float[] vector) { + Objects.requireNonNull(contentHash, "contentHash"); + Objects.requireNonNull(vector, "vector"); + if (vector.length != dimension) { + throw new IllegalArgumentException( + "vector length " + vector.length + " does not match cache dimension " + dimension); + } + + Path target = pathFor(contentHash); + Path tmp = target.resolveSibling(target.getFileName() + TMP_SUFFIX); + try { + Files.createDirectories(target.getParent()); + ByteBuffer buf = ByteBuffer.allocate(expectedBytes).order(ByteOrder.LITTLE_ENDIAN); + buf.asFloatBuffer().put(vector); + Files.write(tmp, buf.array()); + Files.move(tmp, target, StandardCopyOption.ATOMIC_MOVE, StandardCopyOption.REPLACE_EXISTING); + } catch (IOException e) { + throw new UncheckedIOException("failed to write embedding cache file: " + target, e); + } + + float[] cached = vector.clone(); + lock.lock(); + try { + lru.put(contentHash, cached); + } finally { + lock.unlock(); + } + } + + private Path pathFor(String contentHash) { + if (contentHash.length() < 4) { + throw new IllegalArgumentException("contentHash too short (need >=4 chars): " + contentHash); + } + String first = contentHash.substring(0, 2); + String second = contentHash.substring(2, 4); + return home.resolve(first).resolve(second).resolve(contentHash + FILE_SUFFIX); + } +} diff --git a/trueref-adapters/src/main/java/com/trueref/adapter/out/cache/disk/DiskEmbeddingCacheConfig.java b/trueref-adapters/src/main/java/com/trueref/adapter/out/cache/disk/DiskEmbeddingCacheConfig.java new file mode 100644 index 0000000..0c3348a --- /dev/null +++ b/trueref-adapters/src/main/java/com/trueref/adapter/out/cache/disk/DiskEmbeddingCacheConfig.java @@ -0,0 +1,26 @@ +package com.trueref.adapter.out.cache.disk; + +import java.nio.file.Path; +import org.springframework.beans.factory.annotation.Value; +import org.springframework.boot.context.properties.bind.Binder; +import org.springframework.context.annotation.Bean; +import org.springframework.context.annotation.Configuration; +import org.springframework.core.env.Environment; + +/** + * Wiring for the disk embedding cache. Binds {@code trueref.embedding-cache.*} from the + * environment, defaulting {@code home} to {@code ${trueref.home}/embedding-cache} when absent. + */ +@Configuration +class DiskEmbeddingCacheConfig { + + @Bean + EmbeddingCacheProperties embeddingCacheProperties( + Environment env, @Value("${trueref.home:./data}") Path truerefHome) { + EmbeddingCacheProperties bound = Binder.get(env) + .bind("trueref.embedding-cache", EmbeddingCacheProperties.class) + .orElseGet(EmbeddingCacheProperties::new); + Path resolvedHome = bound.home() != null ? bound.home() : truerefHome.resolve("embedding-cache"); + return new EmbeddingCacheProperties(resolvedHome, bound.memoryMaxEntries(), bound.dimension()); + } +} diff --git a/trueref-adapters/src/main/java/com/trueref/adapter/out/cache/disk/EmbeddingCacheProperties.java b/trueref-adapters/src/main/java/com/trueref/adapter/out/cache/disk/EmbeddingCacheProperties.java new file mode 100644 index 0000000..a066145 --- /dev/null +++ b/trueref-adapters/src/main/java/com/trueref/adapter/out/cache/disk/EmbeddingCacheProperties.java @@ -0,0 +1,28 @@ +package com.trueref.adapter.out.cache.disk; + +import java.nio.file.Path; +import org.jspecify.annotations.Nullable; +import org.springframework.boot.context.properties.ConfigurationProperties; + +/** + * Typed configuration for the disk-backed embedding cache. + * + *

{@code home} is nullable here because the default is computed from {@code trueref.home} in + * {@link DiskEmbeddingCacheConfig}; only the resolved instance reaches {@link DiskEmbeddingCache}. + */ +@ConfigurationProperties("trueref.embedding-cache") +public record EmbeddingCacheProperties(@Nullable Path home, int memoryMaxEntries, int dimension) { + + public EmbeddingCacheProperties { + if (memoryMaxEntries <= 0) { + throw new IllegalArgumentException("memoryMaxEntries must be > 0, got " + memoryMaxEntries); + } + if (dimension <= 0) { + throw new IllegalArgumentException("dimension must be > 0, got " + dimension); + } + } + + public EmbeddingCacheProperties() { + this(null, 4096, 768); + } +} diff --git a/trueref-adapters/src/main/java/com/trueref/adapter/out/cache/disk/package-info.java b/trueref-adapters/src/main/java/com/trueref/adapter/out/cache/disk/package-info.java new file mode 100644 index 0000000..88b4097 --- /dev/null +++ b/trueref-adapters/src/main/java/com/trueref/adapter/out/cache/disk/package-info.java @@ -0,0 +1,7 @@ +/** + * Disk-backed implementation of {@link com.trueref.domain.port.out.EmbeddingCache}: file-per-hash + * storage under {@code $TRUEREF_HOME/embedding-cache} with a hot in-memory LRU. Pure JDK NIO; no + * third-party dependencies. + */ +@org.jspecify.annotations.NullMarked +package com.trueref.adapter.out.cache.disk; diff --git a/trueref-adapters/src/main/java/com/trueref/adapter/out/embedding/onnx/GpuSemaphore.java b/trueref-adapters/src/main/java/com/trueref/adapter/out/embedding/onnx/GpuSemaphore.java new file mode 100644 index 0000000..5894518 --- /dev/null +++ b/trueref-adapters/src/main/java/com/trueref/adapter/out/embedding/onnx/GpuSemaphore.java @@ -0,0 +1,47 @@ +package com.trueref.adapter.out.embedding.onnx; + +import java.util.concurrent.Semaphore; +import java.util.function.Supplier; + +/** + * Thin wrapper around {@link Semaphore} used to gate GPU-bound inference. Permits equal the number + * of {@code OrtSession}s per model, so a caller holding a permit is guaranteed to be able to + * borrow one. + */ +final class GpuSemaphore { + + private final Semaphore semaphore; + + GpuSemaphore(int permits) { + if (permits < 0) { + throw new IllegalArgumentException("permits must be >= 0, got " + permits); + } + this.semaphore = new Semaphore(Math.max(permits, 1), true); + } + + void acquire() { + try { + semaphore.acquire(); + } catch (InterruptedException e) { + Thread.currentThread().interrupt(); + throw new RuntimeException("Interrupted while acquiring GPU semaphore", e); + } + } + + void release() { + semaphore.release(); + } + + int availablePermits() { + return semaphore.availablePermits(); + } + + T withPermit(Supplier work) { + acquire(); + try { + return work.get(); + } finally { + release(); + } + } +} diff --git a/trueref-adapters/src/main/java/com/trueref/adapter/out/embedding/onnx/HuggingFaceTokenizerWrapper.java b/trueref-adapters/src/main/java/com/trueref/adapter/out/embedding/onnx/HuggingFaceTokenizerWrapper.java new file mode 100644 index 0000000..3091b30 --- /dev/null +++ b/trueref-adapters/src/main/java/com/trueref/adapter/out/embedding/onnx/HuggingFaceTokenizerWrapper.java @@ -0,0 +1,79 @@ +package com.trueref.adapter.out.embedding.onnx; + +import ai.djl.huggingface.tokenizers.Encoding; +import ai.djl.huggingface.tokenizers.HuggingFaceTokenizer; +import ai.djl.util.PairList; +import com.trueref.domain.error.IngestionFailed; +import java.io.IOException; +import java.nio.file.Path; +import java.util.List; + +/** + * Wraps a {@link HuggingFaceTokenizer} loaded from a local {@code tokenizer.json}. Produces + * batch-padded {@code long[][]} arrays for {@code input_ids}, {@code attention_mask} and {@code + * token_type_ids} suitable for direct feeding into ONNX inputs. Padding is to the longest sample + * in the batch, truncation is to {@code maxSeqLen}. + */ +final class HuggingFaceTokenizerWrapper implements AutoCloseable { + + private final HuggingFaceTokenizer tokenizer; + + HuggingFaceTokenizerWrapper(Path tokenizerJson) { + try { + this.tokenizer = HuggingFaceTokenizer.newInstance(tokenizerJson); + } catch (IOException e) { + throw new IngestionFailed("Failed to load tokenizer from " + tokenizerJson, e); + } + } + + /** Encodes single texts (no pair). */ + EncodedBatch encode(List texts, int maxSeqLen) { + Encoding[] encodings = tokenizer.batchEncode(texts); + return pad(encodings, maxSeqLen); + } + + /** Encodes sentence pairs (query, passage). */ + EncodedBatch encodePairs(List firsts, List seconds, int maxSeqLen) { + if (firsts.size() != seconds.size()) { + throw new IllegalArgumentException( + "Pair lists must match size: " + firsts.size() + " vs " + seconds.size()); + } + PairList pairs = new PairList<>(firsts, seconds); + Encoding[] encodings = tokenizer.batchEncode(pairs); + return pad(encodings, maxSeqLen); + } + + private EncodedBatch pad(Encoding[] encodings, int maxSeqLen) { + int batch = encodings.length; + int seq = 0; + for (Encoding e : encodings) { + int len = Math.min(e.getIds().length, maxSeqLen); + if (len > seq) seq = len; + } + long[][] ids = new long[batch][seq]; + long[][] mask = new long[batch][seq]; + long[][] typeIds = new long[batch][seq]; + for (int i = 0; i < batch; i++) { + long[] srcIds = encodings[i].getIds(); + long[] srcMask = encodings[i].getAttentionMask(); + long[] srcTypes = encodings[i].getTypeIds(); + int copy = Math.min(srcIds.length, seq); + System.arraycopy(srcIds, 0, ids[i], 0, copy); + System.arraycopy(srcMask, 0, mask[i], 0, copy); + if (srcTypes != null) { + int copyTypes = Math.min(srcTypes.length, seq); + System.arraycopy(srcTypes, 0, typeIds[i], 0, copyTypes); + } + // remainder already zero (pad id 0, mask 0, type 0) + } + return new EncodedBatch(ids, mask, typeIds, seq); + } + + @Override + public void close() { + tokenizer.close(); + } + + /** Padded batch of token ids and companion masks. Arrays are shape {@code [batch, seq]}. */ + record EncodedBatch(long[][] inputIds, long[][] attentionMask, long[][] tokenTypeIds, int seqLen) {} +} diff --git a/trueref-adapters/src/main/java/com/trueref/adapter/out/embedding/onnx/ModelDownloader.java b/trueref-adapters/src/main/java/com/trueref/adapter/out/embedding/onnx/ModelDownloader.java new file mode 100644 index 0000000..3d0a1c3 --- /dev/null +++ b/trueref-adapters/src/main/java/com/trueref/adapter/out/embedding/onnx/ModelDownloader.java @@ -0,0 +1,252 @@ +package com.trueref.adapter.out.embedding.onnx; + +import com.trueref.domain.error.IngestionFailed; +import java.io.IOException; +import java.net.URI; +import java.net.http.HttpClient; +import java.net.http.HttpRequest; +import java.net.http.HttpResponse; +import java.nio.file.Files; +import java.nio.file.Path; +import java.nio.file.StandardCopyOption; +import java.time.Duration; +import java.util.List; +import java.util.Map; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +/** + * Downloads {@code model.onnx} + {@code tokenizer.json} for a HuggingFace-hosted model into a + * local directory, idempotently. Each model maps to an ordered list of candidate URLs (per file) + * tried in sequence until one returns 200. + * + *

Defaults cover the bge-m3 family by pointing at community ONNX ports, since {@code + * BAAI/bge-m3} keeps {@code model.onnx} under an {@code onnx/} subdirectory and {@code + * BAAI/bge-reranker-v2-m3} doesn't ship an ONNX export. Override via {@code + * trueref.embedding.model-sources..=[urls...]} in YAML. + */ +final class ModelDownloader { + + private static final Logger log = LoggerFactory.getLogger(ModelDownloader.class); + + private static final List REQUIRED_FILES = List.of("model.onnx", "tokenizer.json"); + /** + * Best-effort sidecars (e.g. ONNX external-data files). Missing/404 is not an error — many + * exports fold the weights into model.onnx itself. + */ + private static final List OPTIONAL_FILES = List.of("model.onnx.data", "model.onnx_data"); + + /** Built-in fallbacks: {modelAlias → {fileName → [candidate URLs in priority order]}}. */ + static final Map>> BUILT_IN_SOURCES = Map.of( + "bge-m3", + Map.of( + "model.onnx", + List.of( + "https://huggingface.co/BAAI/bge-m3/resolve/main/onnx/model.onnx", + "https://huggingface.co/aapot/bge-m3-onnx/resolve/main/model.onnx"), + "model.onnx_data", + List.of( + "https://huggingface.co/BAAI/bge-m3/resolve/main/onnx/model.onnx_data", + "https://huggingface.co/aapot/bge-m3-onnx/resolve/main/model.onnx_data"), + "tokenizer.json", + List.of( + "https://huggingface.co/BAAI/bge-m3/resolve/main/tokenizer.json", + "https://huggingface.co/aapot/bge-m3-onnx/resolve/main/tokenizer.json")), + "bge-reranker-v2-m3", + Map.of( + "model.onnx", + List.of( + "https://huggingface.co/EmbeddedLLM/bge-reranker-v2-m3-onnx-o3-cpu/resolve/main/model.onnx", + "https://huggingface.co/celinehoang/bge-reranker-v2-m3-onnx/resolve/main/model.onnx"), + "model.onnx_data", + List.of( + "https://huggingface.co/celinehoang/bge-reranker-v2-m3-onnx/resolve/main/model.onnx_data"), + "model.onnx.data", + List.of( + "https://huggingface.co/EmbeddedLLM/bge-reranker-v2-m3-onnx-o3-cpu/resolve/main/model.onnx.data"), + "tokenizer.json", + List.of( + "https://huggingface.co/BAAI/bge-reranker-v2-m3/resolve/main/tokenizer.json", + "https://huggingface.co/EmbeddedLLM/bge-reranker-v2-m3-onnx-o3-cpu/resolve/main/tokenizer.json")), + // ── New default models (≤1 GB VRAM combined, 5-8x faster than bge-m3) ────────────── + // Embedder: BAAI/bge-base-en-v1.5 — 768-dim, 512 tok, ~436 MB ONNX, ~500 MB VRAM + "bge-base-en-v1.5", + Map.of( + "model.onnx", + List.of( + "https://huggingface.co/BAAI/bge-base-en-v1.5/resolve/main/onnx/model.onnx"), + "tokenizer.json", + List.of( + "https://huggingface.co/BAAI/bge-base-en-v1.5/resolve/main/tokenizer.json")), + // Reranker: cross-encoder/ms-marco-MiniLM-L6-v2 — 22M params, ~91 MB ONNX, ~100 MB VRAM + "ms-marco-MiniLM-L6-v2", + Map.of( + "model.onnx", + List.of( + "https://huggingface.co/cross-encoder/ms-marco-MiniLM-L6-v2/resolve/main/onnx/model.onnx", + "https://huggingface.co/Xenova/ms-marco-MiniLM-L-6-v2/resolve/main/onnx/model.onnx"), + "tokenizer.json", + List.of( + "https://huggingface.co/cross-encoder/ms-marco-MiniLM-L6-v2/resolve/main/tokenizer.json", + "https://huggingface.co/Xenova/ms-marco-MiniLM-L-6-v2/resolve/main/tokenizer.json"))); + + private final HttpClient http; + private final Map>> overrides; + + ModelDownloader() { + this(Map.of()); + } + + ModelDownloader(Map>> overrides) { + this.http = HttpClient.newBuilder() + .version(HttpClient.Version.HTTP_2) + .followRedirects(HttpClient.Redirect.NORMAL) + .connectTimeout(Duration.ofSeconds(30)) + .build(); + this.overrides = overrides; + } + + /** Ensures {@code model.onnx} and {@code tokenizer.json} exist under {@code modelDir}. */ + void ensureModel(String modelName, Path modelDir) { + try { + Files.createDirectories(modelDir); + } catch (IOException e) { + throw new IngestionFailed("Failed to create model dir " + modelDir, e); + } + for (String file : REQUIRED_FILES) { + Path target = modelDir.resolve(file); + if (Files.isRegularFile(target)) { + try { + if (Files.size(target) > 0L) { + log.debug("Model file present: {}", target); + continue; + } + } catch (IOException ignored) { + // fall through and re-download + } + } + download(modelName, file, target, /*required=*/ true); + } + for (String file : OPTIONAL_FILES) { + Path target = modelDir.resolve(file); + if (Files.isRegularFile(target)) { + try { + if (Files.size(target) > 0L) { + ensureDataAlias(target); + continue; + } + } catch (IOException ignored) { + // try again + } + } + try { + download(modelName, file, target, /*required=*/ false); + ensureDataAlias(target); + } catch (IngestionFailed e) { + log.info("Optional sidecar {} not available for {}: {}", file, modelName, e.getMessage()); + } + } + } + + /** + * ONNX external-data references inside {@code model.onnx} can spell the sidecar either + * {@code model.onnx_data} (BAAI exports) or {@code model.onnx.data} (community exports). + * Materialize both names so ORT resolves the file regardless of which spelling the graph uses. + */ + private static void ensureDataAlias(Path dataFile) { + String name = dataFile.getFileName().toString(); + Path alias; + if ("model.onnx_data".equals(name)) { + alias = dataFile.resolveSibling("model.onnx.data"); + } else if ("model.onnx.data".equals(name)) { + alias = dataFile.resolveSibling("model.onnx_data"); + } else { + return; + } + if (Files.exists(alias)) return; + try { + Files.createLink(alias, dataFile); + log.info("Created hardlink alias {} -> {}", alias.getFileName(), dataFile.getFileName()); + } catch (UnsupportedOperationException | IOException e) { + try { + Files.copy(dataFile, alias, StandardCopyOption.REPLACE_EXISTING); + log.info("Copied {} -> {} (hardlink unsupported: {})", dataFile.getFileName(), + alias.getFileName(), e.toString()); + } catch (IOException copyErr) { + log.warn("Could not create alias {}: {}", alias, copyErr.toString()); + } + } + } + + private void download(String modelName, String file, Path target, boolean required) { + List urls = candidateUrls(modelName, file); + if (urls.isEmpty()) { + if (!required) { + throw new IngestionFailed("no URLs configured for optional " + file, null); + } + throw new IngestionFailed( + "No download URLs configured for model=" + modelName + " file=" + file + + ". Add an entry under trueref.embedding.model-sources..", + null); + } + Path tmp; + try { + tmp = Files.createTempFile(target.getParent(), file + ".", ".part"); + } catch (IOException e) { + throw new IngestionFailed("Failed to create temp file for " + file, e); + } + IngestionFailed lastError = null; + for (String url : urls) { + URI uri = URI.create(url); + log.info("Downloading {} for model {} from {}", file, modelName, uri); + try { + HttpRequest req = HttpRequest.newBuilder(uri) + .timeout(Duration.ofMinutes(30)) + .GET() + .build(); + HttpResponse res = http.send(req, HttpResponse.BodyHandlers.ofFile(tmp)); + if (res.statusCode() / 100 != 2) { + log.warn("HTTP {} for {} — trying next candidate", res.statusCode(), uri); + lastError = new IngestionFailed("HTTP " + res.statusCode() + " from " + uri, null); + continue; + } + long size = Files.size(tmp); + if (size <= 0L) { + lastError = new IngestionFailed("Empty body from " + uri, null); + continue; + } + Files.move(tmp, target, StandardCopyOption.REPLACE_EXISTING, StandardCopyOption.ATOMIC_MOVE); + log.info("Downloaded {} ({} bytes) -> {}", file, size, target); + return; + } catch (IOException | InterruptedException e) { + if (e instanceof InterruptedException) { + Thread.currentThread().interrupt(); + } + log.warn("Download failed from {}: {}", uri, e.toString()); + lastError = new IngestionFailed("Failed to download " + uri, e); + } + } + try { + Files.deleteIfExists(tmp); + } catch (IOException ignored) { + // best-effort cleanup + } + throw lastError != null + ? lastError + : new IngestionFailed("All download attempts failed for " + modelName + "/" + file, null); + } + + private List candidateUrls(String modelName, String file) { + Map> per = overrides.get(modelName); + if (per != null) { + List u = per.get(file); + if (u != null && !u.isEmpty()) return u; + } + Map> builtIn = BUILT_IN_SOURCES.get(modelName); + if (builtIn != null) { + List u = builtIn.get(file); + if (u != null) return u; + } + return List.of(); + } +} diff --git a/trueref-adapters/src/main/java/com/trueref/adapter/out/embedding/onnx/OnnxEmbeddingConfig.java b/trueref-adapters/src/main/java/com/trueref/adapter/out/embedding/onnx/OnnxEmbeddingConfig.java new file mode 100644 index 0000000..95816b6 --- /dev/null +++ b/trueref-adapters/src/main/java/com/trueref/adapter/out/embedding/onnx/OnnxEmbeddingConfig.java @@ -0,0 +1,32 @@ +package com.trueref.adapter.out.embedding.onnx; + +import org.springframework.beans.factory.annotation.Value; +import org.springframework.boot.context.properties.EnableConfigurationProperties; +import org.springframework.context.annotation.Bean; +import org.springframework.context.annotation.Configuration; + +import java.nio.file.Path; + +/** + * Spring wiring for the ONNX embedding/reranker stack. Produces the shared {@link GpuSemaphore} + * (permits = {@code session-count}) and the resolved models-home {@link Path}. + */ +@Configuration +@EnableConfigurationProperties(OnnxProperties.class) +public class OnnxEmbeddingConfig { + + @Bean + GpuSemaphore gpuSemaphore(OnnxProperties properties) { + return new GpuSemaphore(properties.sessionCountOrDefault()); + } + + /** + * Resolves the directory containing model subfolders. Defaults to {@code ${trueref.home}/models} + * when {@link OnnxProperties#home()} is unset. + */ + @Bean("onnxModelsHome") + Path onnxModelsHome(OnnxProperties properties, @Value("${trueref.home:./data}") String trueRefHome) { + Path home = properties.home(); + return home != null ? home : Path.of(trueRefHome).resolve("models"); + } +} diff --git a/trueref-adapters/src/main/java/com/trueref/adapter/out/embedding/onnx/OnnxEmbeddingService.java b/trueref-adapters/src/main/java/com/trueref/adapter/out/embedding/onnx/OnnxEmbeddingService.java new file mode 100644 index 0000000..242aa24 --- /dev/null +++ b/trueref-adapters/src/main/java/com/trueref/adapter/out/embedding/onnx/OnnxEmbeddingService.java @@ -0,0 +1,264 @@ +package com.trueref.adapter.out.embedding.onnx; + +import ai.onnxruntime.OnnxTensor; +import ai.onnxruntime.OnnxValue; +import ai.onnxruntime.OrtEnvironment; +import ai.onnxruntime.OrtException; +import ai.onnxruntime.OrtSession; +import com.trueref.adapter.out.embedding.onnx.HuggingFaceTokenizerWrapper.EncodedBatch; +import com.trueref.domain.error.IngestionFailed; +import com.trueref.domain.port.out.EmbeddingService; +import jakarta.annotation.PostConstruct; +import jakarta.annotation.PreDestroy; +import java.nio.LongBuffer; +import java.nio.file.Path; +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Optional; +import java.util.Set; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import org.springframework.beans.factory.annotation.Qualifier; +import org.springframework.stereotype.Component; + +/** + * ONNX Runtime-backed {@link EmbeddingService} for BAAI/bge-base-en-v1.5 (768-dim dense). Loads + * the model from {@code ${trueref.home}/models//} on startup (auto-downloading if absent). + * Exposes a pool of {@link OrtSession}s gated by the shared {@link GpuSemaphore} so that + * GPU-bound batches never exceed the configured concurrency. + */ +@Component +@Qualifier("bge-m3") +public class OnnxEmbeddingService implements EmbeddingService { + + private static final Logger log = LoggerFactory.getLogger(OnnxEmbeddingService.class); + + /** + * Default hidden size. Matches bge-base-en-v1.5 (768) and other 768-dim models. A model swap + * to a different dimensionality just needs this constant updated (and the Lucene index rebuilt). + */ + private static final int DEFAULT_EMBEDDING_DIM = 768; + + private static final String IN_INPUT_IDS = "input_ids"; + private static final String IN_ATTENTION_MASK = "attention_mask"; + private static final String IN_TOKEN_TYPE_IDS = "token_type_ids"; + + // Output names we try in priority order for pre-pooled dense vectors. + private static final List POOLED_OUTPUT_NAMES = + List.of("sentence_embedding", "dense_vecs", "pooler_output"); + private static final String FALLBACK_HIDDEN_NAME = "last_hidden_state"; + + private final OnnxProperties properties; + private final GpuSemaphore gpuSemaphore; + private final Path modelsHome; + + private @org.jspecify.annotations.Nullable OnnxSessionPool pool; + private @org.jspecify.annotations.Nullable HuggingFaceTokenizerWrapper tokenizer; + private @org.jspecify.annotations.Nullable String pooledOutputName; + private boolean usesTokenTypeIds; + + public OnnxEmbeddingService( + OnnxProperties properties, + GpuSemaphore gpuSemaphore, + @Qualifier("onnxModelsHome") Path modelsHome) { + this.properties = properties; + this.gpuSemaphore = gpuSemaphore; + this.modelsHome = modelsHome; + } + + @PostConstruct + void start() { + if (properties.sessionCountOrDefault() <= 0) { + log.warn("OnnxEmbeddingService DISABLED (session-count=0); embed() calls will fail"); + return; + } + String model = properties.modelOrDefault(); + Path modelDir = modelsHome.resolve(model); + new ModelDownloader(properties.modelSources()).ensureModel(model, modelDir); + Path modelPath = modelDir.resolve("model.onnx"); + Path tokenizerPath = modelDir.resolve("tokenizer.json"); + + this.tokenizer = new HuggingFaceTokenizerWrapper(tokenizerPath); + this.pool = new OnnxSessionPool( + OrtEnvironment.getEnvironment(), + modelPath, + properties.sessionCountOrDefault(), + properties.providersOrDefault(), + properties.gpuDeviceIdOrDefault(), + properties.gpuMemLimitBytesOrDefault()); + + Set outputs = pool.outputNames(); + this.pooledOutputName = POOLED_OUTPUT_NAMES.stream() + .filter(outputs::contains) + .findFirst() + .orElse(null); + // token_type_ids is optional — some BERT-family exports include it, some don't. + this.usesTokenTypeIds = pool.inputNames().contains(IN_TOKEN_TYPE_IDS); + log.info( + "OnnxEmbeddingService ready: model={} sessions={} pooledOutput={} useTokenTypeIds={}", + model, + properties.sessionCountOrDefault(), + pooledOutputName == null ? "" : pooledOutputName, + usesTokenTypeIds); + } + + @PreDestroy + void stop() { + if (pool != null) { + pool.close(); + pool = null; + } + if (tokenizer != null) { + tokenizer.close(); + tokenizer = null; + } + } + + @Override + public int dimension() { + return DEFAULT_EMBEDDING_DIM; + } + + @Override + public List embed(List texts) { + if (texts.isEmpty()) { + return List.of(); + } + int batchSize = properties.batchSizeOrDefault(); + int maxSeqLen = properties.maxSeqLenOrDefault(); + List out = new ArrayList<>(texts.size()); + for (int start = 0; start < texts.size(); start += batchSize) { + int end = Math.min(start + batchSize, texts.size()); + List slice = texts.subList(start, end); + out.addAll(embedBatch(slice, maxSeqLen)); + } + return out; + } + + private List embedBatch(List batch, int maxSeqLen) { + HuggingFaceTokenizerWrapper tok = requireStarted(tokenizer, "tokenizer"); + OnnxSessionPool sessions = requireStarted(pool, "session pool"); + EncodedBatch enc = tok.encode(batch, maxSeqLen); + gpuSemaphore.acquire(); + OrtSession session = null; + try { + session = sessions.borrow(); + return runEmbed(session, enc); + } finally { + if (session != null) { + sessions.release(session); + } + gpuSemaphore.release(); + } + } + + /** + * Runs a single ONNX embedding inference batch against the supplied session. + * + *

Virtual-thread pinning — intentional anti-pattern.
+ * The entire body is wrapped in {@code synchronized (session)} to deliberately pin the + * calling virtual thread to its carrier OS thread for the duration of the CUDA operation. + * + *

This is normally an anti-pattern with Project Loom: {@code synchronized} prevents the + * virtual thread from unmounting, wasting a carrier-thread slot and limiting throughput. + * Here the trade-off is consciously accepted because: + *

    + *
  1. CUDA contexts are per-OS-thread. If a virtual thread is allowed to unmount between + * tensor allocation and {@link OrtSession#run} — or mid-execution — it may remount on + * a different carrier OS thread. ONNX Runtime then finds a mismatched or uninitialized + * CUDA context, resulting in a {@code SIGSEGV} inside {@code libonnxruntime.so} + * (observed in production: {@code hs_err_pid649935.log}).
  2. + *
  3. The pool is sized to 1 session ({@code session-count=1}), so at most one virtual + * thread blocks here at a time. The {@link GpuSemaphore} already serialises access; + * the carrier-thread occupancy cost is therefore bounded to exactly one carrier for + * the duration of one batch inference (~10-100 ms).
  4. + *
  5. The proper long-term fix is to run ONNX inference inside a + * {@link java.util.concurrent.ExecutorService} backed by platform (OS) threads, keeping + * virtual threads entirely outside the CUDA call. That refactor is tracked as a + * follow-up; until then, pinning is the safest and simplest workaround.
  6. + *
+ */ + private List runEmbed(OrtSession session, EncodedBatch enc) { + synchronized (session) { + OrtEnvironment env = OrtEnvironment.getEnvironment(); + int batch = enc.inputIds().length; + long[] shape = {batch, enc.seqLen()}; + + Map inputs = new HashMap<>(); + try { + OnnxTensor idsT = OnnxTensor.createTensor(env, LongBuffer.wrap(flatten(enc.inputIds())), shape); + inputs.put(IN_INPUT_IDS, idsT); + OnnxTensor maskT = OnnxTensor.createTensor(env, LongBuffer.wrap(flatten(enc.attentionMask())), shape); + inputs.put(IN_ATTENTION_MASK, maskT); + if (usesTokenTypeIds) { + OnnxTensor typeT = + OnnxTensor.createTensor(env, LongBuffer.wrap(flatten(enc.tokenTypeIds())), shape); + inputs.put(IN_TOKEN_TYPE_IDS, typeT); + } + + try (OrtSession.Result result = session.run(inputs)) { + return extractEmbeddings(result, enc.attentionMask()); + } + } catch (OrtException e) { + throw new IngestionFailed("ONNX embedding inference failed", e); + } finally { + for (OnnxTensor t : inputs.values()) { + t.close(); + } + } + } + } + + private List extractEmbeddings(OrtSession.Result result, long[][] attentionMask) + throws OrtException { + if (pooledOutputName != null) { + Optional v = result.get(pooledOutputName); + if (v.isPresent() && v.get() instanceof OnnxTensor t) { + Object val = t.getValue(); + if (val instanceof float[][] arr) { + return normalizeRows(arr); + } + } + } + Optional hidden = result.get(FALLBACK_HIDDEN_NAME); + if (hidden.isEmpty() || !(hidden.get() instanceof OnnxTensor t)) { + throw new IllegalStateException( + "Model output contains neither a known pooled vector nor '" + FALLBACK_HIDDEN_NAME + "'"); + } + Object val = t.getValue(); + if (!(val instanceof float[][][] arr)) { + throw new IllegalStateException( + "Unexpected last_hidden_state shape for bge-m3: " + val.getClass().getName()); + } + float[][] pooled = PoolingMath.meanPool(arr, attentionMask); + return normalizeRows(pooled); + } + + private static List normalizeRows(float[][] rows) { + List out = new ArrayList<>(rows.length); + for (float[] row : rows) { + PoolingMath.l2NormalizeInPlace(row); + out.add(row); + } + return out; + } + + private static long[] flatten(long[][] rows) { + int batch = rows.length; + int seq = batch == 0 ? 0 : rows[0].length; + long[] flat = new long[batch * seq]; + for (int i = 0; i < batch; i++) { + System.arraycopy(rows[i], 0, flat, i * seq, seq); + } + return flat; + } + + private static T requireStarted(@org.jspecify.annotations.Nullable T t, String what) { + if (t == null) { + throw new IllegalStateException("OnnxEmbeddingService not started: " + what + " is null"); + } + return t; + } +} diff --git a/trueref-adapters/src/main/java/com/trueref/adapter/out/embedding/onnx/OnnxProperties.java b/trueref-adapters/src/main/java/com/trueref/adapter/out/embedding/onnx/OnnxProperties.java new file mode 100644 index 0000000..4642042 --- /dev/null +++ b/trueref-adapters/src/main/java/com/trueref/adapter/out/embedding/onnx/OnnxProperties.java @@ -0,0 +1,78 @@ +package com.trueref.adapter.out.embedding.onnx; + +import java.nio.file.Path; +import java.util.List; +import java.util.Map; +import org.jspecify.annotations.Nullable; +import org.springframework.boot.context.properties.ConfigurationProperties; + +/** + * Typed configuration for the ONNX embedding / reranker stack. Bound from the {@code + * trueref.embedding} prefix in {@code application.yml}. + */ +@ConfigurationProperties("trueref.embedding") +public record OnnxProperties( + @Nullable String model, + @Nullable String rerankerModel, + @Nullable List onnxProviders, + @Nullable Integer sessionCount, + @Nullable Integer rerankerSessionCount, + @Nullable Integer batchSize, + @Nullable Integer maxSeqLen, + @Nullable Integer gpuDeviceId, + @Nullable Long gpuMemLimitBytes, + @Nullable Path home, + @Nullable Map>> modelSources) { + + public OnnxProperties { + if (model == null || model.isBlank()) model = "bge-base-en-v1.5"; + if (rerankerModel == null || rerankerModel.isBlank()) rerankerModel = "ms-marco-MiniLM-L6-v2"; + if (onnxProviders == null || onnxProviders.isEmpty()) onnxProviders = List.of("cuda", "directml", "cpu"); + if (sessionCount == null) sessionCount = 2; + if (sessionCount < 0) sessionCount = 0; + // Reranker runs sequentially on top-K docs; 1 session avoids VRAM pressure from 4 concurrent pools. + if (rerankerSessionCount == null) rerankerSessionCount = 1; + if (rerankerSessionCount < 0) rerankerSessionCount = 0; + if (batchSize == null || batchSize <= 0) batchSize = 32; + if (maxSeqLen == null || maxSeqLen <= 0) maxSeqLen = 512; + if (gpuDeviceId == null || gpuDeviceId < 0) gpuDeviceId = 0; + if (gpuMemLimitBytes == null || gpuMemLimitBytes <= 0L) gpuMemLimitBytes = 0L; // 0 = no cap + if (modelSources == null) modelSources = Map.of(); + } + + public String modelOrDefault() { + return model == null ? "bge-base-en-v1.5" : model; + } + + public String rerankerModelOrDefault() { + return rerankerModel == null ? "ms-marco-MiniLM-L6-v2" : rerankerModel; + } + + public List providersOrDefault() { + return onnxProviders == null ? List.of("cuda", "directml", "cpu") : onnxProviders; + } + + public int sessionCountOrDefault() { + return sessionCount == null ? 2 : sessionCount; + } + + public int rerankerSessionCountOrDefault() { + return rerankerSessionCount == null ? 1 : rerankerSessionCount; + } + + public int batchSizeOrDefault() { + return batchSize == null ? 32 : batchSize; + } + + public int maxSeqLenOrDefault() { + return maxSeqLen == null ? 512 : maxSeqLen; + } + + public int gpuDeviceIdOrDefault() { + return gpuDeviceId == null ? 0 : gpuDeviceId; + } + + public long gpuMemLimitBytesOrDefault() { + return gpuMemLimitBytes == null ? 0L : gpuMemLimitBytes; + } +} diff --git a/trueref-adapters/src/main/java/com/trueref/adapter/out/embedding/onnx/OnnxRerankerService.java b/trueref-adapters/src/main/java/com/trueref/adapter/out/embedding/onnx/OnnxRerankerService.java new file mode 100644 index 0000000..75ad4b2 --- /dev/null +++ b/trueref-adapters/src/main/java/com/trueref/adapter/out/embedding/onnx/OnnxRerankerService.java @@ -0,0 +1,259 @@ +package com.trueref.adapter.out.embedding.onnx; + +import ai.onnxruntime.OnnxTensor; +import ai.onnxruntime.OnnxValue; +import ai.onnxruntime.OrtEnvironment; +import ai.onnxruntime.OrtException; +import ai.onnxruntime.OrtSession; +import com.trueref.adapter.out.embedding.onnx.HuggingFaceTokenizerWrapper.EncodedBatch; +import com.trueref.domain.error.IngestionFailed; +import com.trueref.domain.model.SearchHit; +import com.trueref.domain.port.out.RerankerService; +import jakarta.annotation.PostConstruct; +import jakarta.annotation.PreDestroy; +import java.nio.LongBuffer; +import java.nio.file.Path; +import java.util.ArrayList; +import java.util.Collections; +import java.util.Comparator; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Optional; +import java.util.Set; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import org.springframework.beans.factory.annotation.Qualifier; +import org.springframework.stereotype.Component; + +/** + * ONNX Runtime-backed {@link RerankerService} for BAAI/bge-reranker-v2-m3. Scores {@code (query, + * passage)} pairs with a cross-encoder, applies sigmoid on the raw logit, and returns candidates + * sorted descending by probability. + */ +@Component +public class OnnxRerankerService implements RerankerService { + + private static final Logger log = LoggerFactory.getLogger(OnnxRerankerService.class); + + /** Safety cap matching the task spec — never rerank more than this many candidates. */ + private static final int MAX_CANDIDATES = 1000; + + private static final String IN_INPUT_IDS = "input_ids"; + private static final String IN_ATTENTION_MASK = "attention_mask"; + private static final String IN_TOKEN_TYPE_IDS = "token_type_ids"; + private static final String OUT_LOGITS = "logits"; + + private final OnnxProperties properties; + private final GpuSemaphore gpuSemaphore; + private final Path modelsHome; + + private @org.jspecify.annotations.Nullable OnnxSessionPool pool; + private @org.jspecify.annotations.Nullable HuggingFaceTokenizerWrapper tokenizer; + private @org.jspecify.annotations.Nullable String logitsOutputName; + private boolean usesTokenTypeIds; + + public OnnxRerankerService( + OnnxProperties properties, + GpuSemaphore gpuSemaphore, + @Qualifier("onnxModelsHome") Path modelsHome) { + this.properties = properties; + this.gpuSemaphore = gpuSemaphore; + this.modelsHome = modelsHome; + } + + @PostConstruct + void start() { + if (properties.rerankerSessionCountOrDefault() <= 0) { + log.warn("OnnxRerankerService DISABLED (reranker-session-count=0); rerank() passes through input order"); + return; + } + String model = properties.rerankerModelOrDefault(); + Path modelDir = modelsHome.resolve(model); + new ModelDownloader(properties.modelSources()).ensureModel(model, modelDir); + Path modelPath = modelDir.resolve("model.onnx"); + Path tokenizerPath = modelDir.resolve("tokenizer.json"); + + this.tokenizer = new HuggingFaceTokenizerWrapper(tokenizerPath); + this.pool = new OnnxSessionPool( + OrtEnvironment.getEnvironment(), + modelPath, + properties.rerankerSessionCountOrDefault(), + properties.providersOrDefault(), + properties.gpuDeviceIdOrDefault(), + properties.gpuMemLimitBytesOrDefault()); + + Set outputs = pool.outputNames(); + // Most bge-reranker-v2-m3 ONNX exports expose a single output named "logits". If that is + // missing we fall back to the first declared output. + this.logitsOutputName = outputs.contains(OUT_LOGITS) ? OUT_LOGITS : outputs.stream().findFirst().orElse(null); + this.usesTokenTypeIds = pool.inputNames().contains(IN_TOKEN_TYPE_IDS); + log.info( + "OnnxRerankerService ready: model={} sessions={} logitsOutput={} useTokenTypeIds={}", + model, + properties.rerankerSessionCountOrDefault(), + logitsOutputName, + usesTokenTypeIds); + } + + @PreDestroy + void stop() { + if (pool != null) { + pool.close(); + pool = null; + } + if (tokenizer != null) { + tokenizer.close(); + tokenizer = null; + } + } + + @Override + public List rerank(String query, List candidates) { + if (candidates.isEmpty()) { + return List.of(); + } + if (pool == null) { // disabled mode + return candidates; + } + int limit = Math.min(candidates.size(), MAX_CANDIDATES); + List slice = candidates.subList(0, limit); + + int batchSize = properties.batchSizeOrDefault(); + int maxSeqLen = properties.maxSeqLenOrDefault(); + double[] scores = new double[slice.size()]; + for (int start = 0; start < slice.size(); start += batchSize) { + int end = Math.min(start + batchSize, slice.size()); + List batch = slice.subList(start, end); + double[] batchScores = scoreBatch(query, batch, maxSeqLen); + System.arraycopy(batchScores, 0, scores, start, batchScores.length); + } + + List rescored = new ArrayList<>(slice.size()); + for (int i = 0; i < slice.size(); i++) { + SearchHit c = slice.get(i); + rescored.add(new SearchHit( + c.chunkId(), + c.repoId(), + c.versionId(), + c.repoName(), + c.tag(), + c.filePath(), + c.startLine(), + c.endLine(), + c.language(), + c.symbol(), + c.content(), + scores[i])); + } + rescored.sort(Comparator.comparingDouble(SearchHit::score).reversed()); + return Collections.unmodifiableList(rescored); + } + + private double[] scoreBatch(String query, List batch, int maxSeqLen) { + HuggingFaceTokenizerWrapper tok = requireStarted(tokenizer, "tokenizer"); + OnnxSessionPool sessions = requireStarted(pool, "session pool"); + List firsts = new ArrayList<>(batch.size()); + List seconds = new ArrayList<>(batch.size()); + for (SearchHit c : batch) { + firsts.add(query); + seconds.add(c.content()); + } + EncodedBatch enc = tok.encodePairs(firsts, seconds, maxSeqLen); + + gpuSemaphore.acquire(); + OrtSession session = null; + try { + session = sessions.borrow(); + return runRerank(session, enc); + } finally { + if (session != null) { + sessions.release(session); + } + gpuSemaphore.release(); + } + } + + /** + * Runs a single ONNX reranker inference batch against the supplied session. + * + *

Virtual-thread pinning — intentional anti-pattern.
+ * See {@link OnnxEmbeddingService} ({@code runEmbed}) for the + * full rationale. The same constraint applies here: CUDA contexts are bound to OS threads, + * and allowing the virtual thread to unmount mid-inference corrupts the context and causes + * {@code SIGSEGV} in {@code libonnxruntime.so}. {@code synchronized (session)} is the + * deliberate, bounded workaround until inference is moved to a platform-thread executor. + */ + private double[] runRerank(OrtSession session, EncodedBatch enc) { + synchronized (session) { + OrtEnvironment env = OrtEnvironment.getEnvironment(); + int batch = enc.inputIds().length; + long[] shape = {batch, enc.seqLen()}; + Map inputs = new HashMap<>(); + try { + inputs.put( + IN_INPUT_IDS, + OnnxTensor.createTensor(env, LongBuffer.wrap(flatten(enc.inputIds())), shape)); + inputs.put( + IN_ATTENTION_MASK, + OnnxTensor.createTensor(env, LongBuffer.wrap(flatten(enc.attentionMask())), shape)); + if (usesTokenTypeIds) { + inputs.put( + IN_TOKEN_TYPE_IDS, + OnnxTensor.createTensor(env, LongBuffer.wrap(flatten(enc.tokenTypeIds())), shape)); + } + try (OrtSession.Result result = session.run(inputs)) { + return extractScores(result, batch); + } + } catch (OrtException e) { + throw new IngestionFailed("ONNX reranker inference failed", e); + } finally { + for (OnnxTensor t : inputs.values()) { + t.close(); + } + } + } // end synchronized + } + + private double[] extractScores(OrtSession.Result result, int batch) throws OrtException { + String name = logitsOutputName; + Optional v = name == null ? Optional.empty() : result.get(name); + if (v.isEmpty() || !(v.get() instanceof OnnxTensor t)) { + throw new IllegalStateException( + "Reranker output '" + name + "' missing or not a tensor"); + } + Object raw = t.getValue(); + double[] out = new double[batch]; + if (raw instanceof float[][] arr) { + // Shape [batch, 1] (or [batch, n] where we take the first column as the positive logit) + for (int i = 0; i < batch; i++) { + out[i] = PoolingMath.sigmoid(arr[i][0]); + } + } else if (raw instanceof float[] arr) { + for (int i = 0; i < batch; i++) { + out[i] = PoolingMath.sigmoid(arr[i]); + } + } else { + throw new IllegalStateException( + "Unexpected reranker logits shape: " + raw.getClass().getName()); + } + return out; + } + + private static long[] flatten(long[][] rows) { + int batch = rows.length; + int seq = batch == 0 ? 0 : rows[0].length; + long[] flat = new long[batch * seq]; + for (int i = 0; i < batch; i++) { + System.arraycopy(rows[i], 0, flat, i * seq, seq); + } + return flat; + } + + private static T requireStarted(@org.jspecify.annotations.Nullable T t, String what) { + if (t == null) { + throw new IllegalStateException("OnnxRerankerService not started: " + what + " is null"); + } + return t; + } +} diff --git a/trueref-adapters/src/main/java/com/trueref/adapter/out/embedding/onnx/OnnxSessionPool.java b/trueref-adapters/src/main/java/com/trueref/adapter/out/embedding/onnx/OnnxSessionPool.java new file mode 100644 index 0000000..9338e53 --- /dev/null +++ b/trueref-adapters/src/main/java/com/trueref/adapter/out/embedding/onnx/OnnxSessionPool.java @@ -0,0 +1,198 @@ +package com.trueref.adapter.out.embedding.onnx; + +import ai.onnxruntime.OrtEnvironment; +import ai.onnxruntime.OrtException; +import ai.onnxruntime.OrtSession; +import ai.onnxruntime.providers.OrtCUDAProviderOptions; +import java.nio.file.Path; +import java.util.ArrayList; +import java.util.List; +import java.util.Set; +import java.util.concurrent.ArrayBlockingQueue; +import java.util.concurrent.BlockingQueue; +import java.util.concurrent.atomic.AtomicBoolean; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +/** + * Blocking pool of {@link OrtSession}s for a single ONNX model file. Sessions are pre-built at + * construction. {@link #borrow()} blocks when all sessions are in use; {@link #release(OrtSession)} + * returns one to the queue. Closed by {@link #close()} which shuts down every pooled session. + */ +final class OnnxSessionPool implements AutoCloseable { + + private static final Logger log = LoggerFactory.getLogger(OnnxSessionPool.class); + + private final OrtEnvironment env; + private final Path modelPath; + private final int size; + private final int gpuDeviceId; + private final long gpuMemLimitBytes; + private final BlockingQueue idle; + private final List allSessions; + private final Set outputNames; + private final Set inputNames; + private final AtomicBoolean closed = new AtomicBoolean(false); + + OnnxSessionPool(OrtEnvironment env, Path modelPath, int size, List providers) { + this(env, modelPath, size, providers, 0, 0L); + } + + OnnxSessionPool( + OrtEnvironment env, Path modelPath, int size, List providers, int gpuDeviceId) { + this(env, modelPath, size, providers, gpuDeviceId, 0L); + } + + OnnxSessionPool( + OrtEnvironment env, + Path modelPath, + int size, + List providers, + int gpuDeviceId, + long gpuMemLimitBytes) { + if (size <= 0) { + throw new IllegalArgumentException("size must be > 0, got " + size); + } + if (gpuDeviceId < 0) { + throw new IllegalArgumentException("gpuDeviceId must be >= 0, got " + gpuDeviceId); + } + if (gpuMemLimitBytes < 0L) { + throw new IllegalArgumentException("gpuMemLimitBytes must be >= 0, got " + gpuMemLimitBytes); + } + this.env = env; + this.modelPath = modelPath; + this.size = size; + this.gpuDeviceId = gpuDeviceId; + this.gpuMemLimitBytes = gpuMemLimitBytes; + this.idle = new ArrayBlockingQueue<>(size); + this.allSessions = new ArrayList<>(size); + try { + for (int i = 0; i < size; i++) { + OrtSession s = env.createSession(modelPath.toString(), buildOptions(providers)); + allSessions.add(s); + idle.add(s); + } + this.outputNames = Set.copyOf(allSessions.get(0).getOutputNames()); + this.inputNames = Set.copyOf(allSessions.get(0).getInputNames()); + // Belt-and-suspenders: release CUDA memory even when Spring's @PreDestroy + // is skipped (e.g. the previous run was SIGKILL'd by the OOM killer). + // A normal Spring shutdown calls close() explicitly before this hook fires, + // and close() is idempotent, so double-invocation is safe. + Thread hook = new Thread(this::close, "onnx-pool-shutdown-" + modelPath.getFileName()); + hook.setDaemon(false); + Runtime.getRuntime().addShutdownHook(hook); + log.info( + "Initialized ONNX session pool: model={} sessions={} inputs={} outputs={}", + modelPath, + size, + inputNames, + outputNames); + } catch (OrtException e) { + for (OrtSession s : allSessions) { + try { + s.close(); + } catch (OrtException ignored) { + // best-effort cleanup during failed init + } + } + throw new IllegalStateException("Failed to create ONNX session for " + modelPath, e); + } + } + + private OrtSession.SessionOptions buildOptions(List providers) throws OrtException { + OrtSession.SessionOptions opts = new OrtSession.SessionOptions(); + // Disable ORT memory-pattern optimization so the BFC arena does NOT try to + // pre-compute and pre-allocate the model's full static memory layout during + // session init. Without this, ORT requests a large contiguous CUDA block + // upfront, which fails even when plenty of VRAM is free. + opts.setMemoryPatternOptimization(false); + for (String raw : providers) { + String p = raw == null ? "" : raw.trim().toLowerCase(); + switch (p) { + case "cuda" -> { + try { + OrtCUDAProviderOptions cudaOpts = new OrtCUDAProviderOptions(gpuDeviceId); + // Grow arena only by exactly what's requested, never exponentially. + cudaOpts.add("arena_extend_strategy", "kSameAsRequested"); + // Use the default CUDA stream for D2H/H2D copies; avoids subtle + // stream-ordering races between embedding and reranker pools. + cudaOpts.add("do_copy_in_default_stream", "1"); + if (gpuMemLimitBytes > 0L) { + cudaOpts.add("gpu_mem_limit", Long.toString(gpuMemLimitBytes)); + } + opts.addCUDA(cudaOpts); + log.info( + "ONNX provider CUDA enabled for {} (device {}, memLimitBytes={})", + modelPath.getFileName(), + gpuDeviceId, + gpuMemLimitBytes > 0L ? gpuMemLimitBytes : "unbounded"); + } catch (OrtException | RuntimeException | LinkageError e) { + log.warn( + "CUDA execution provider unavailable for {} (device {}): {}", + modelPath.getFileName(), + gpuDeviceId, + e.getMessage()); + } + } + case "directml" -> { + // DirectML requires the onnxruntime-directml classifier which we don't ship; + // silently skip on non-Windows / non-DirectML builds. + log.debug("DirectML provider requested but not bundled; skipping"); + } + case "cpu" -> { + // CPU is always enabled by ORT as a fallback; no explicit registration needed. + log.debug("CPU provider is always available as fallback"); + } + default -> log.warn("Unknown ONNX provider '{}' — ignored", raw); + } + } + return opts; + } + + OrtSession borrow() { + if (closed.get()) { + throw new IllegalStateException("Session pool is closed: " + modelPath); + } + try { + return idle.take(); + } catch (InterruptedException e) { + Thread.currentThread().interrupt(); + throw new RuntimeException("Interrupted while borrowing ONNX session", e); + } + } + + void release(OrtSession session) { + if (closed.get()) { + return; + } + idle.offer(session); + } + + Set inputNames() { + return inputNames; + } + + Set outputNames() { + return outputNames; + } + + int size() { + return size; + } + + @Override + public void close() { + if (!closed.compareAndSet(false, true)) { + return; // already closed — guard against concurrent shutdown hooks + } + idle.clear(); + for (OrtSession s : allSessions) { + try { + s.close(); + } catch (OrtException e) { + log.warn("Error closing ONNX session for {}: {}", modelPath, e.getMessage()); + } + } + allSessions.clear(); + } +} diff --git a/trueref-adapters/src/main/java/com/trueref/adapter/out/embedding/onnx/PoolingMath.java b/trueref-adapters/src/main/java/com/trueref/adapter/out/embedding/onnx/PoolingMath.java new file mode 100644 index 0000000..096d3a7 --- /dev/null +++ b/trueref-adapters/src/main/java/com/trueref/adapter/out/embedding/onnx/PoolingMath.java @@ -0,0 +1,64 @@ +package com.trueref.adapter.out.embedding.onnx; + +/** Pure, stateless math helpers used by {@link OnnxEmbeddingService}. */ +final class PoolingMath { + + private PoolingMath() {} + + /** + * Mean-pools a {@code [batch, seq, hidden]} tensor weighted by an attention mask, producing a + * {@code [batch, hidden]} result. Masked positions contribute nothing and are excluded from the + * denominator. Samples with an all-zero mask yield a zero row. + */ + static float[][] meanPool(float[][][] lastHidden, long[][] attentionMask) { + int batch = lastHidden.length; + if (batch == 0) { + return new float[0][]; + } + int seq = lastHidden[0].length; + int hidden = seq == 0 ? 0 : lastHidden[0][0].length; + float[][] out = new float[batch][hidden]; + for (int i = 0; i < batch; i++) { + double denom = 0.0; + for (int j = 0; j < seq; j++) { + long m = attentionMask[i][j]; + if (m == 0L) continue; + denom += 1.0; + float[] row = lastHidden[i][j]; + for (int d = 0; d < hidden; d++) { + out[i][d] += row[d]; + } + } + if (denom > 0.0) { + float inv = (float) (1.0 / denom); + for (int d = 0; d < hidden; d++) { + out[i][d] *= inv; + } + } + } + return out; + } + + /** L2-normalizes {@code v} in place. Zero-norm vectors are left as zeros. */ + static void l2NormalizeInPlace(float[] v) { + double sum = 0.0; + for (float x : v) { + sum += (double) x * x; + } + if (sum <= 0.0) return; + float inv = (float) (1.0 / Math.sqrt(sum)); + for (int i = 0; i < v.length; i++) { + v[i] *= inv; + } + } + + /** Returns 1/(1+e^-x). */ + static double sigmoid(double x) { + if (x >= 0.0) { + double z = Math.exp(-x); + return 1.0 / (1.0 + z); + } + double z = Math.exp(x); + return z / (1.0 + z); + } +} diff --git a/trueref-adapters/src/main/java/com/trueref/adapter/out/embedding/onnx/package-info.java b/trueref-adapters/src/main/java/com/trueref/adapter/out/embedding/onnx/package-info.java new file mode 100644 index 0000000..ff797c4 --- /dev/null +++ b/trueref-adapters/src/main/java/com/trueref/adapter/out/embedding/onnx/package-info.java @@ -0,0 +1,9 @@ +/** + * ONNX-Runtime-backed embedding and reranker adapters. Provides {@code OnnxEmbeddingService} + * (BAAI/bge-m3, 1024-dim dense) and {@code OnnxRerankerService} (BAAI/bge-reranker-v2-m3). Models + * live under {@code ${trueref.home}/models//} and are auto-downloaded from HuggingFace on + * first run when absent. GPU access is gated by a shared {@link + * com.trueref.adapter.out.embedding.onnx.GpuSemaphore}. + */ +@org.jspecify.annotations.NullMarked +package com.trueref.adapter.out.embedding.onnx; diff --git a/trueref-adapters/src/main/java/com/trueref/adapter/out/git/jgit/JGitClient.java b/trueref-adapters/src/main/java/com/trueref/adapter/out/git/jgit/JGitClient.java new file mode 100644 index 0000000..444cb3b --- /dev/null +++ b/trueref-adapters/src/main/java/com/trueref/adapter/out/git/jgit/JGitClient.java @@ -0,0 +1,408 @@ +package com.trueref.adapter.out.git.jgit; + +import com.trueref.domain.error.IngestionFailed; +import com.trueref.domain.error.TagNotFound; +import com.trueref.domain.port.out.GitClient; +import java.io.IOException; +import java.nio.file.FileVisitResult; +import java.nio.file.Files; +import java.nio.file.NoSuchFileException; +import java.nio.file.Path; +import java.nio.file.SimpleFileVisitor; +import java.nio.file.attribute.BasicFileAttributes; +import java.nio.file.attribute.PosixFilePermission; +import java.util.ArrayList; +import java.util.Comparator; +import java.util.EnumSet; +import java.util.List; +import java.util.Set; +import org.eclipse.jgit.api.Git; +import org.eclipse.jgit.api.ListBranchCommand; +import org.eclipse.jgit.api.TransportConfigCallback; +import org.eclipse.jgit.api.errors.GitAPIException; +import org.eclipse.jgit.errors.NoRemoteRepositoryException; +import org.eclipse.jgit.lib.Constants; +import org.eclipse.jgit.lib.FileMode; +import org.eclipse.jgit.lib.ObjectId; +import org.eclipse.jgit.lib.ObjectLoader; +import org.eclipse.jgit.lib.ObjectReader; +import org.eclipse.jgit.lib.Ref; +import org.eclipse.jgit.lib.Repository; +import org.eclipse.jgit.revwalk.RevCommit; +import org.eclipse.jgit.revwalk.RevObject; +import org.eclipse.jgit.revwalk.RevTag; +import org.eclipse.jgit.revwalk.RevTree; +import org.eclipse.jgit.revwalk.RevWalk; +import org.eclipse.jgit.transport.SshTransport; +import org.eclipse.jgit.transport.TagOpt; +import org.eclipse.jgit.transport.sshd.SshdSessionFactory; +import org.eclipse.jgit.treewalk.CanonicalTreeParser; +import org.eclipse.jgit.treewalk.TreeWalk; +import org.jspecify.annotations.Nullable; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import org.springframework.stereotype.Component; + +/** + * JGit implementation of {@link GitClient}. + * + *

Authentication: relies on the host ssh agent / ssh config for SSH remotes and the system git + * credential helper for HTTPS — no in-app credential management (see ARCHITECTURE §13). + * + *

Worktrees: implemented by copying tree contents at the requested commit into a sibling + * directory (repoPath.resolveSibling("trueref-worktrees")/<repo>-<sha>-<nanos>) + * rather than using JGit's multi-worktree support — this avoids touching the source repo HEAD and + * is safe under concurrent ingestion jobs. + * + *

Thread-safety: each method opens and closes its own {@link Git}/{@link Repository} instance; + * no shared mutable state. + */ +@Component +public final class JGitClient implements GitClient { + + private static final Logger log = LoggerFactory.getLogger(JGitClient.class); + private static final String WORKTREES_DIR = "trueref-worktrees"; + + /** + * SSH transport callback that points JGit's internal Apache MINA SSHD client at the user's + * real {@code ~/.ssh} directory (keys, known_hosts, config). Without this, the embedded SSHD + * finds no keys and fails with "publickey: no keys to try". + * + *

We also override {@code getDefaultIdentities()} to enumerate all private key + * files present in {@code ~/.ssh} — including non-default names like {@code id_rsa_github} — + * rather than only the four filenames MINA SSHD probes by default. + */ + private static final TransportConfigCallback SSH_CALLBACK; + + static { + java.io.File sshDir = new java.io.File(System.getProperty("user.home"), ".ssh"); + SshdSessionFactory factory = new SshdSessionFactory(null, null) { + @Override + public java.io.File getSshDirectory() { + return sshDir; + } + + @Override + protected java.util.List getDefaultIdentities(java.io.File dir) { + // Return every file in ~/.ssh that has a paired .pub — these are private keys. + java.io.File[] files = dir.listFiles(); + if (files == null) return super.getDefaultIdentities(dir); + java.util.List keys = new java.util.ArrayList<>(); + for (java.io.File f : files) { + String name = f.getName(); + // A private key has a corresponding .pub file and no .pub suffix itself. + if (!name.endsWith(".pub") && new java.io.File(dir, name + ".pub").exists()) { + keys.add(f.toPath()); + } + } + if (keys.isEmpty()) return super.getDefaultIdentities(dir); + return keys; + } + }; + SSH_CALLBACK = transport -> { + if (transport instanceof SshTransport sshTransport) { + sshTransport.setSshSessionFactory(factory); + } + }; + log.info("JGitClient SSH factory configured: sshDir={}", sshDir); + } + + @Override + public void cloneRepo(String remoteUrl, Path localPath) { + if (Files.isDirectory(localPath.resolve(".git"))) { + log.debug("clone skipped, repo already exists at {}", localPath); + return; + } + log.info("cloning {} into {}", remoteUrl, localPath); + try { + Files.createDirectories(localPath); + try (Git git = Git.cloneRepository() + .setURI(remoteUrl) + .setDirectory(localPath.toFile()) + .setCloneAllBranches(true) + .setNoCheckout(false) + .setTransportConfigCallback(SSH_CALLBACK) + .call()) { + log.info("cloned {} ({} branches)", remoteUrl, countBranches(git)); + } + } catch (GitAPIException | IOException e) { + throw new IngestionFailed("failed to clone " + remoteUrl + " into " + localPath, e); + } + } + + @Override + public void fetch(Path localPath) { + log.info("fetching {}", localPath); + try (Git git = Git.open(localPath.toFile())) { + if (git.getRepository().getConfig().getSubsections("remote").isEmpty()) { + log.debug("fetch skipped, repo {} has no configured remote", localPath); + return; + } + git.fetch() + .setRemoveDeletedRefs(true) + .setTagOpt(TagOpt.FETCH_TAGS) + .setTransportConfigCallback(SSH_CALLBACK) + .call(); + } catch (NoRemoteRepositoryException e) { + log.debug("fetch skipped for {} — no remote repository", localPath); + } catch (GitAPIException | IOException e) { + throw new IngestionFailed("failed to fetch " + localPath, e); + } + } + + @Override + public List listTags(Path localPath) { + try (Git git = Git.open(localPath.toFile())) { + Repository repo = git.getRepository(); + List refs = repo.getRefDatabase().getRefsByPrefix(Constants.R_TAGS); + List tags = new ArrayList<>(refs.size()); + try (RevWalk walk = new RevWalk(repo)) { + for (Ref ref : refs) { + String name = ref.getName().substring(Constants.R_TAGS.length()); + try { + RevObject obj = walk.parseAny(ref.getObjectId()); + RevCommit commit; + long epoch; + if (obj instanceof RevTag annotated) { + RevObject peeled = walk.peel(annotated); + if (!(peeled instanceof RevCommit c)) { + log.debug("skipping non-commit tag {} in {}", name, localPath); + continue; + } + commit = c; + epoch = annotated.getTaggerIdent() != null + ? annotated.getTaggerIdent().getWhenAsInstant().getEpochSecond() + : commit.getCommitTime(); + } else if (obj instanceof RevCommit c) { + commit = c; + epoch = commit.getCommitTime(); + } else { + log.debug("skipping tag {} pointing to {} in {}", name, obj.getType(), localPath); + continue; + } + tags.add(new TagInfo(name, commit.getName(), epoch)); + } catch (IOException e) { + log.debug("failed to resolve tag {} in {}: {}", name, localPath, e.toString()); + } + } + } + tags.sort(Comparator.comparingLong(TagInfo::taggerEpochSeconds).reversed()); + return tags; + } catch (IOException e) { + throw new IngestionFailed("failed to list tags in " + localPath, e); + } + } + + @Override + public String resolveRef(Path localPath, String ref) { + try (Git git = Git.open(localPath.toFile())) { + ObjectId id = git.getRepository().resolve(ref); + if (id == null) { + throw new TagNotFound(localPath.getFileName().toString(), ref); + } + return id.getName(); + } catch (IOException e) { + throw new IngestionFailed("failed to resolve ref " + ref + " in " + localPath, e); + } + } + + @Override + public Path checkoutWorktree(Path repoPath, String ref) { + try (Git git = Git.open(repoPath.toFile())) { + Repository repo = git.getRepository(); + ObjectId commitId = repo.resolve(ref); + if (commitId == null) { + throw new TagNotFound(repoPath.getFileName().toString(), ref); + } + String shortSha = commitId.getName().substring(0, Math.min(10, commitId.getName().length())); + String repoName = repoPath.getFileName() != null ? repoPath.getFileName().toString() : "repo"; + Path worktreesRoot = repoPath.resolveSibling(WORKTREES_DIR); + Files.createDirectories(worktreesRoot); + Path target = worktreesRoot.resolve(repoName + "-" + shortSha + "-" + System.nanoTime()); + Files.createDirectories(target); + try (RevWalk walk = new RevWalk(repo)) { + RevCommit commit = walk.parseCommit(commitId); + RevTree tree = commit.getTree(); + materializeTree(repo, tree, target); + } + log.debug("materialized worktree for {}@{} at {}", repoPath, ref, target); + return target; + } catch (IOException e) { + throw new IngestionFailed("failed to checkout worktree " + ref + " in " + repoPath, e); + } + } + + @Override + public void removeWorktree(Path repoPath, Path worktree) { + try { + if (!Files.exists(worktree)) { + return; + } + Files.walkFileTree(worktree, new SimpleFileVisitor<>() { + @Override + public FileVisitResult visitFile(Path file, BasicFileAttributes attrs) throws IOException { + Files.deleteIfExists(file); + return FileVisitResult.CONTINUE; + } + + @Override + public FileVisitResult postVisitDirectory(Path dir, @Nullable IOException exc) throws IOException { + Files.deleteIfExists(dir); + return FileVisitResult.CONTINUE; + } + }); + } catch (NoSuchFileException ignored) { + // already gone + } catch (IOException e) { + throw new IngestionFailed("failed to remove worktree " + worktree, e); + } + } + + @Override + public List diff(Path repoPath, @Nullable String baseRef, String headRef) { + try (Git git = Git.open(repoPath.toFile())) { + Repository repo = git.getRepository(); + ObjectId headId = repo.resolve(headRef); + if (headId == null) { + throw new TagNotFound(repoPath.getFileName().toString(), headRef); + } + if (baseRef == null) { + return listHeadTree(repo, headId); + } + ObjectId baseId = repo.resolve(baseRef); + if (baseId == null) { + throw new TagNotFound(repoPath.getFileName().toString(), baseRef); + } + try (RevWalk walk = new RevWalk(repo); + ObjectReader reader = repo.newObjectReader()) { + RevCommit baseCommit = walk.parseCommit(baseId); + RevCommit headCommit = walk.parseCommit(headId); + CanonicalTreeParser oldTree = new CanonicalTreeParser(); + oldTree.reset(reader, baseCommit.getTree().getId()); + CanonicalTreeParser newTree = new CanonicalTreeParser(); + newTree.reset(reader, headCommit.getTree().getId()); + List entries = git.diff() + .setOldTree(oldTree) + .setNewTree(newTree) + .setShowNameAndStatusOnly(true) + .call(); + // Rename detection must be applied on top of the raw entries. + org.eclipse.jgit.diff.RenameDetector rd = new org.eclipse.jgit.diff.RenameDetector(repo); + rd.addAll(entries); + List detected = rd.compute(); + List out = new ArrayList<>(detected.size()); + for (org.eclipse.jgit.diff.DiffEntry e : detected) { + out.add(mapDiffEntry(e)); + } + return out; + } + } catch (GitAPIException | IOException e) { + throw new IngestionFailed( + "failed to diff " + (baseRef == null ? "" : baseRef) + ".." + headRef + " in " + repoPath, + e); + } + } + + // ---------- helpers ---------- + + private static int countBranches(Git git) { + try { + return git.branchList().setListMode(ListBranchCommand.ListMode.ALL).call().size(); + } catch (GitAPIException e) { + return -1; + } + } + + private static List listHeadTree(Repository repo, ObjectId headId) throws IOException { + List out = new ArrayList<>(); + try (RevWalk walk = new RevWalk(repo); + TreeWalk tw = new TreeWalk(repo)) { + RevCommit commit = walk.parseCommit(headId); + tw.addTree(commit.getTree()); + tw.setRecursive(true); + while (tw.next()) { + if (tw.getFileMode(0) == FileMode.GITLINK) { + continue; + } + out.add(new DiffEntry(tw.getPathString(), null, DiffEntry.ChangeType.ADDED)); + } + } + return out; + } + + private static DiffEntry mapDiffEntry(org.eclipse.jgit.diff.DiffEntry e) { + DiffEntry.ChangeType change = + switch (e.getChangeType()) { + case ADD -> DiffEntry.ChangeType.ADDED; + case MODIFY -> DiffEntry.ChangeType.MODIFIED; + case DELETE -> DiffEntry.ChangeType.DELETED; + case RENAME -> DiffEntry.ChangeType.RENAMED; + case COPY -> DiffEntry.ChangeType.COPIED; + }; + String path = + switch (e.getChangeType()) { + case DELETE -> e.getOldPath(); + default -> e.getNewPath(); + }; + String oldPath = + switch (e.getChangeType()) { + case RENAME, COPY -> e.getOldPath(); + default -> null; + }; + return new DiffEntry(path, oldPath, change); + } + + private static void materializeTree(Repository repo, RevTree tree, Path target) throws IOException { + try (TreeWalk tw = new TreeWalk(repo)) { + tw.addTree(tree); + tw.setRecursive(true); + while (tw.next()) { + FileMode mode = tw.getFileMode(0); + if (mode == FileMode.GITLINK) { + continue; + } + String path = tw.getPathString(); + Path dest = target.resolve(path).normalize(); + if (!dest.startsWith(target)) { + log.debug("skipping entry outside target: {}", path); + continue; + } + if (dest.getParent() != null) { + Files.createDirectories(dest.getParent()); + } + ObjectId blobId = tw.getObjectId(0); + ObjectLoader loader = repo.open(blobId); + if (mode == FileMode.SYMLINK) { + String linkTarget = new String(loader.getBytes()); + try { + Files.deleteIfExists(dest); + Files.createSymbolicLink(dest, Path.of(linkTarget)); + } catch (UnsupportedOperationException | IOException ex) { + // Filesystem does not support symlinks — fall back to writing the link text. + Files.writeString(dest, linkTarget); + } + continue; + } + try (var out = Files.newOutputStream(dest)) { + loader.copyTo(out); + } + if (mode == FileMode.EXECUTABLE_FILE) { + applyExecutableBit(dest); + } + } + } + } + + private static void applyExecutableBit(Path file) { + try { + Set perms = EnumSet.copyOf(Files.getPosixFilePermissions(file)); + perms.add(PosixFilePermission.OWNER_EXECUTE); + perms.add(PosixFilePermission.GROUP_EXECUTE); + perms.add(PosixFilePermission.OTHERS_EXECUTE); + Files.setPosixFilePermissions(file, perms); + } catch (UnsupportedOperationException | IOException ex) { + // Non-POSIX FS (e.g. Windows NTFS without WSL) — executable bit is not representable. + } + } + +} diff --git a/trueref-adapters/src/main/java/com/trueref/adapter/out/git/jgit/package-info.java b/trueref-adapters/src/main/java/com/trueref/adapter/out/git/jgit/package-info.java new file mode 100644 index 0000000..1460390 --- /dev/null +++ b/trueref-adapters/src/main/java/com/trueref/adapter/out/git/jgit/package-info.java @@ -0,0 +1,6 @@ +/** + * JGit-backed implementation of {@link com.trueref.domain.port.out.GitClient}. Relies on the host + * environment for authentication (ssh agent / git credential helpers) — see ARCHITECTURE §13. + */ +@org.jspecify.annotations.NullMarked +package com.trueref.adapter.out.git.jgit; diff --git a/trueref-adapters/src/main/java/com/trueref/adapter/out/parsing/heuristic/BraceBalancedStrategy.java b/trueref-adapters/src/main/java/com/trueref/adapter/out/parsing/heuristic/BraceBalancedStrategy.java new file mode 100644 index 0000000..7587682 --- /dev/null +++ b/trueref-adapters/src/main/java/com/trueref/adapter/out/parsing/heuristic/BraceBalancedStrategy.java @@ -0,0 +1,185 @@ +package com.trueref.adapter.out.parsing.heuristic; + +import com.trueref.domain.port.out.CodeParser.ParsedChunk; +import java.util.ArrayList; +import java.util.List; +import java.util.regex.Matcher; +import java.util.regex.Pattern; +import org.jspecify.annotations.Nullable; + +/** + * Splitter for C-family brace-balanced languages (Java, Kotlin, Scala, C/C++, C#, JS/TS, Go, Rust, + * Swift, PHP). Walks characters tracking brace depth, line/block comments, and string literals + * (including JS/TS template literals via backtick). Emits a chunk each time brace depth returns to + * zero; top-level statements preceding the first opening brace are emitted as a {@code "

"} + * preamble chunk. + */ +final class BraceBalancedStrategy { + + private static final int LARGE = 300; + private static final int WINDOW = 200; + private static final int OVERLAP = 20; + + private static final Pattern SYMBOL = Pattern.compile( + "\\b(?:class|interface|enum|record|trait|struct|function|fn|func|def|sub|module|impl)\\s+([A-Za-z_][A-Za-z_0-9]*)"); + + private BraceBalancedStrategy() {} + + static List parse(List lines, String language) { + if (lines.isEmpty()) { + return List.of(); + } + + Scanner s = new Scanner(); + int n = lines.size(); + int regionStart = 1; + int constructStart = -1; + int lastCodeLine = 0; + boolean emittedAny = false; + List out = new ArrayList<>(); + + for (int i = 1; i <= n; i++) { + String line = lines.get(i - 1); + s.scanLine(line); + if (s.hadCodeContent) { + lastCodeLine = i; + } + + if (s.transitionedUp && constructStart < 0) { + // New top-level construct begins in this line. + int cs; + if (s.contentBeforeOpen) { + cs = i; + } else if (lastCodeLine >= regionStart && lastCodeLine < i) { + cs = lastCodeLine; + } else { + cs = i; + } + constructStart = cs; + if (!emittedAny && cs > regionStart) { + // Preamble: top-level statements before the first opening brace. + out.addAll(Chunks.splitLong( + lines, regionStart, cs - 1, language, "
", LARGE, WINDOW, OVERLAP)); + emittedAny = true; + regionStart = cs; + } + } + + if (s.depth == 0 && constructStart > 0) { + int cs = constructStart; + int ce = i; + int effectiveStart = Math.min(cs, regionStart); + String symbol = extractSymbol(lines, cs); + out.addAll(Chunks.splitLong(lines, effectiveStart, ce, language, symbol, LARGE, WINDOW, OVERLAP)); + emittedAny = true; + regionStart = ce + 1; + constructStart = -1; + } + } + + // Trailing unterminated construct or header-only file. + if (regionStart <= n) { + if (constructStart > 0) { + String symbol = extractSymbol(lines, constructStart); + out.addAll(Chunks.splitLong(lines, regionStart, n, language, symbol, LARGE, WINDOW, OVERLAP)); + } else if (!emittedAny) { + out.addAll(Chunks.splitLong(lines, regionStart, n, language, "
", LARGE, WINDOW, OVERLAP)); + } else { + out.addAll(Chunks.splitLong(lines, regionStart, n, language, null, LARGE, WINDOW, OVERLAP)); + } + } + return out; + } + + private static @Nullable String extractSymbol(List lines, int signatureLine) { + Matcher m = SYMBOL.matcher(lines.get(signatureLine - 1)); + if (m.find()) { + return m.group(1); + } + if (signatureLine >= 2) { + m = SYMBOL.matcher(lines.get(signatureLine - 2)); + if (m.find()) { + return m.group(1); + } + } + return null; + } + + /** Char-by-char scanner with persistent block-comment state across lines. */ + private static final class Scanner { + int depth; + boolean inBlockComment; + // Per-line outputs: + boolean transitionedUp; + boolean contentBeforeOpen; + boolean hadCodeContent; + + void scanLine(String line) { + transitionedUp = false; + contentBeforeOpen = false; + hadCodeContent = false; + char stringDelim = 0; // 0 = not in string. Single-line only (simplification). + int len = line.length(); + int i = 0; + while (i < len) { + char c = line.charAt(i); + if (inBlockComment) { + if (c == '*' && i + 1 < len && line.charAt(i + 1) == '/') { + inBlockComment = false; + i += 2; + } else { + i++; + } + continue; + } + if (stringDelim != 0) { + if (c == '\\' && i + 1 < len) { + i += 2; + continue; + } + if (c == stringDelim) { + stringDelim = 0; + } + i++; + continue; + } + if (c == '/' && i + 1 < len && line.charAt(i + 1) == '/') { + break; // line comment until EOL + } + if (c == '/' && i + 1 < len && line.charAt(i + 1) == '*') { + inBlockComment = true; + i += 2; + continue; + } + if (c == '"' || c == '\'' || c == '`') { + stringDelim = c; + hadCodeContent = true; + i++; + continue; + } + if (c == '{') { + if (depth == 0 && !transitionedUp) { + transitionedUp = true; + contentBeforeOpen = hadCodeContent; + } + hadCodeContent = true; + depth++; + i++; + continue; + } + if (c == '}') { + hadCodeContent = true; + if (depth > 0) { + depth--; + } + i++; + continue; + } + if (!Character.isWhitespace(c)) { + hadCodeContent = true; + } + i++; + } + } + } +} diff --git a/trueref-adapters/src/main/java/com/trueref/adapter/out/parsing/heuristic/Chunks.java b/trueref-adapters/src/main/java/com/trueref/adapter/out/parsing/heuristic/Chunks.java new file mode 100644 index 0000000..82aeee4 --- /dev/null +++ b/trueref-adapters/src/main/java/com/trueref/adapter/out/parsing/heuristic/Chunks.java @@ -0,0 +1,64 @@ +package com.trueref.adapter.out.parsing.heuristic; + +import com.trueref.domain.port.out.CodeParser.ParsedChunk; +import java.util.ArrayList; +import java.util.List; +import java.util.Optional; +import org.jspecify.annotations.Nullable; + +/** Shared helpers for strategies: trim-to-chunk and size-based re-splitting. */ +final class Chunks { + + private Chunks() {} + + /** + * Build a chunk from an inclusive 1-based line range, trimming leading/trailing blank lines. If + * the range is all blank, returns empty. + */ + static Optional trimmed( + List lines, int startLine, int endLine, String language, @Nullable String symbol) { + int s = startLine; + int e = endLine; + while (s <= e && lines.get(s - 1).isBlank()) { + s++; + } + while (e >= s && lines.get(e - 1).isBlank()) { + e--; + } + if (s > e) { + return Optional.empty(); + } + String content = String.join("\n", lines.subList(s - 1, e)); + return Optional.of(new ParsedChunk(content, language, symbol, s, e)); + } + + /** + * Emit a single chunk if the range fits within {@code threshold} lines, otherwise slide a + * window of {@code window} lines with {@code overlap} overlap across the range. Chunks still + * have blank-line trimming applied. + */ + static List splitLong( + List lines, + int startLine, + int endLine, + String language, + @Nullable String symbol, + int threshold, + int window, + int overlap) { + int size = endLine - startLine + 1; + if (size <= threshold) { + return trimmed(lines, startLine, endLine, language, symbol).map(List::of).orElse(List.of()); + } + int step = Math.max(1, window - overlap); + List out = new ArrayList<>(); + for (int s = startLine; s <= endLine; s += step) { + int e = Math.min(s + window - 1, endLine); + trimmed(lines, s, e, language, symbol).ifPresent(out::add); + if (e == endLine) { + break; + } + } + return out; + } +} diff --git a/trueref-adapters/src/main/java/com/trueref/adapter/out/parsing/heuristic/HeuristicCodeParser.java b/trueref-adapters/src/main/java/com/trueref/adapter/out/parsing/heuristic/HeuristicCodeParser.java new file mode 100644 index 0000000..c9314ea --- /dev/null +++ b/trueref-adapters/src/main/java/com/trueref/adapter/out/parsing/heuristic/HeuristicCodeParser.java @@ -0,0 +1,63 @@ +package com.trueref.adapter.out.parsing.heuristic; + +import com.trueref.domain.port.out.CodeParser; +import java.io.IOException; +import java.nio.charset.MalformedInputException; +import java.nio.charset.StandardCharsets; +import java.nio.file.Files; +import java.nio.file.Path; +import java.util.List; +import java.util.Set; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import org.springframework.stereotype.Component; + +/** + * Universal-fallback {@link CodeParser} using pure-Java heuristic strategies (see FINDINGS §F11). + * {@link #supports(Path)} always returns {@code true} — this adapter accepts every text file and + * delegates to a strategy chosen by file extension. Binary files fail UTF-8 decoding and yield an + * empty chunk list. + */ +@Component +public final class HeuristicCodeParser implements CodeParser { + + private static final Logger log = LoggerFactory.getLogger(HeuristicCodeParser.class); + + private static final Set INDENT_LANGS = Set.of("python", "ruby", "yaml"); + private static final Set BRACE_LANGS = Set.of( + "java", "kotlin", "scala", "c", "cpp", "csharp", "javascript", "typescript", "go", "rust", "swift", "php"); + + @Override + public boolean supports(Path file) { + return true; + } + + @Override + public List parse(Path file, String repoRelativePath) { + String content; + try { + content = Files.readString(file, StandardCharsets.UTF_8); + } catch (MalformedInputException e) { + log.debug("skipping binary/non-utf8 file {}", repoRelativePath); + return List.of(); + } catch (IOException e) { + log.warn("failed to read {}: {}", repoRelativePath, e.toString()); + return List.of(); + } + + // Preserve all lines (including trailing empties) using split with limit=-1. + List lines = List.of(content.split("\\R", -1)); + String language = LanguageDetector.detect(file); + + if ("markdown".equals(language)) { + return MarkdownStrategy.parse(lines, language); + } + if (INDENT_LANGS.contains(language)) { + return IndentBasedStrategy.parse(lines, language); + } + if (BRACE_LANGS.contains(language)) { + return BraceBalancedStrategy.parse(lines, language); + } + return SlidingWindowStrategy.parse(lines, language); + } +} diff --git a/trueref-adapters/src/main/java/com/trueref/adapter/out/parsing/heuristic/IndentBasedStrategy.java b/trueref-adapters/src/main/java/com/trueref/adapter/out/parsing/heuristic/IndentBasedStrategy.java new file mode 100644 index 0000000..161b37c --- /dev/null +++ b/trueref-adapters/src/main/java/com/trueref/adapter/out/parsing/heuristic/IndentBasedStrategy.java @@ -0,0 +1,103 @@ +package com.trueref.adapter.out.parsing.heuristic; + +import com.trueref.domain.port.out.CodeParser.ParsedChunk; +import java.util.ArrayList; +import java.util.List; +import java.util.regex.Matcher; +import java.util.regex.Pattern; +import org.jspecify.annotations.Nullable; + +/** + * Indent-driven splitter for Python/Ruby/YAML. Top-level constructs are recognized at column 0 via + * keyword (def/class/module/function) or YAML key ({@code key:}). Each top-level construct plus its + * indented body becomes a chunk. + */ +final class IndentBasedStrategy { + + private static final int LARGE_SEGMENT = 200; + private static final int WINDOW = 100; + private static final int OVERLAP = 10; + + private static final Pattern PY_RUBY_DECL = + Pattern.compile("^(?:def|class|module|function)\\s+([A-Za-z_][A-Za-z_0-9]*)\\b.*"); + private static final Pattern YAML_KEY = Pattern.compile("^([A-Za-z_][A-Za-z_0-9-]*)\\s*:.*"); + + private IndentBasedStrategy() {} + + static List parse(List lines, String language) { + if (lines.isEmpty()) { + return List.of(); + } + + boolean yaml = "yaml".equals(language); + List segs = new ArrayList<>(); + int n = lines.size(); + + int curStart = 1; + String curSymbol = null; + + for (int i = 1; i <= n; i++) { + String line = lines.get(i - 1); + if (line.isBlank() || isComment(line, yaml)) { + continue; + } + if (leadingIndent(line) > 0) { + continue; + } + String sym = topLevelSymbol(line, yaml); + if (sym == null) { + // Non-construct top-level content (e.g. stray python statement, yaml list root): keep accumulating. + continue; + } + // Start of a new top-level construct. + if (i > curStart) { + segs.add(new Segment(curStart, i - 1, curSymbol)); + } + curStart = i; + curSymbol = sym; + } + segs.add(new Segment(curStart, n, curSymbol)); + + List out = new ArrayList<>(); + for (Segment s : segs) { + out.addAll( + Chunks.splitLong(lines, s.startLine, s.endLine, language, s.symbol, LARGE_SEGMENT, WINDOW, OVERLAP)); + } + return out; + } + + private static int leadingIndent(String line) { + int i = 0; + while (i < line.length()) { + char c = line.charAt(i); + if (c == ' ' || c == '\t') { + i++; + } else { + break; + } + } + return i; + } + + private static boolean isComment(String line, boolean yaml) { + String t = line.stripLeading(); + return t.startsWith("#") || (yaml && t.startsWith("---")); + } + + private static @Nullable String topLevelSymbol(String line, boolean yaml) { + if (yaml) { + Matcher m = YAML_KEY.matcher(line); + if (m.matches()) { + return m.group(1); + } + return null; + } + Matcher m = PY_RUBY_DECL.matcher(line); + if (m.matches()) { + return m.group(1); + } + return null; + } + + private record Segment(int startLine, int endLine, @Nullable String symbol) {} +} diff --git a/trueref-adapters/src/main/java/com/trueref/adapter/out/parsing/heuristic/LanguageDetector.java b/trueref-adapters/src/main/java/com/trueref/adapter/out/parsing/heuristic/LanguageDetector.java new file mode 100644 index 0000000..34b0aef --- /dev/null +++ b/trueref-adapters/src/main/java/com/trueref/adapter/out/parsing/heuristic/LanguageDetector.java @@ -0,0 +1,61 @@ +package com.trueref.adapter.out.parsing.heuristic; + +import java.nio.file.Path; +import java.util.Locale; +import java.util.Map; + +/** Extension-based language detector. No content sniffing. */ +final class LanguageDetector { + + private static final Map EXT = Map.ofEntries( + Map.entry("java", "java"), + Map.entry("kt", "kotlin"), + Map.entry("kts", "kotlin"), + Map.entry("scala", "scala"), + Map.entry("py", "python"), + Map.entry("js", "javascript"), + Map.entry("mjs", "javascript"), + Map.entry("cjs", "javascript"), + Map.entry("ts", "typescript"), + Map.entry("tsx", "typescript"), + Map.entry("go", "go"), + Map.entry("rs", "rust"), + Map.entry("c", "c"), + Map.entry("h", "c"), + Map.entry("cpp", "cpp"), + Map.entry("cc", "cpp"), + Map.entry("cxx", "cpp"), + Map.entry("hpp", "cpp"), + Map.entry("hh", "cpp"), + Map.entry("hxx", "cpp"), + Map.entry("cs", "csharp"), + Map.entry("rb", "ruby"), + Map.entry("swift", "swift"), + Map.entry("php", "php"), + Map.entry("md", "markdown"), + Map.entry("markdown", "markdown"), + Map.entry("yml", "yaml"), + Map.entry("yaml", "yaml"), + Map.entry("json", "json"), + Map.entry("sql", "sql"), + Map.entry("html", "html"), + Map.entry("htm", "html"), + Map.entry("css", "css"), + Map.entry("scss", "css"), + Map.entry("sass", "css"), + Map.entry("sh", "shell"), + Map.entry("bash", "shell"), + Map.entry("zsh", "shell")); + + private LanguageDetector() {} + + static String detect(Path file) { + String name = file.getFileName().toString(); + int dot = name.lastIndexOf('.'); + if (dot < 0 || dot == name.length() - 1) { + return "text"; + } + String ext = name.substring(dot + 1).toLowerCase(Locale.ROOT); + return EXT.getOrDefault(ext, "text"); + } +} diff --git a/trueref-adapters/src/main/java/com/trueref/adapter/out/parsing/heuristic/MarkdownStrategy.java b/trueref-adapters/src/main/java/com/trueref/adapter/out/parsing/heuristic/MarkdownStrategy.java new file mode 100644 index 0000000..4888a07 --- /dev/null +++ b/trueref-adapters/src/main/java/com/trueref/adapter/out/parsing/heuristic/MarkdownStrategy.java @@ -0,0 +1,82 @@ +package com.trueref.adapter.out.parsing.heuristic; + +import com.trueref.domain.port.out.CodeParser.ParsedChunk; +import java.util.ArrayList; +import java.util.List; +import java.util.regex.Matcher; +import java.util.regex.Pattern; +import org.jspecify.annotations.Nullable; + +/** + * Split Markdown on ATX ({@code #..######}) and Setext ({@code ===}/{@code ---} underline) + * headings. Sections longer than 200 lines are further split via sliding windows. + */ +final class MarkdownStrategy { + + private static final int LARGE_SECTION = 200; + private static final int WINDOW = 100; + private static final int OVERLAP = 10; + + private static final Pattern ATX = Pattern.compile("^\\s{0,3}(#{1,6})\\s+(.*?)\\s*#*\\s*$"); + private static final Pattern SETEXT_UNDERLINE = Pattern.compile("^\\s{0,3}(=+|-+)\\s*$"); + + private MarkdownStrategy() {} + + static List parse(List lines, String language) { + if (lines.isEmpty()) { + return List.of(); + } + + List
sections = splitIntoSections(lines); + List out = new ArrayList<>(); + for (Section sec : sections) { + out.addAll(Chunks.splitLong( + lines, sec.startLine, sec.endLine, language, sec.symbol, LARGE_SECTION, WINDOW, OVERLAP)); + } + return out; + } + + private static List
splitIntoSections(List lines) { + // Compute heading line indices (1-based) with titles. + List headingLines = new ArrayList<>(); // [lineIdx, titleLineIdx] + List titles = new ArrayList<>(); + + int n = lines.size(); + for (int i = 1; i <= n; i++) { + String line = lines.get(i - 1); + Matcher atx = ATX.matcher(line); + if (atx.matches()) { + headingLines.add(new int[] {i, i}); + titles.add(atx.group(2).trim()); + continue; + } + if (i < n) { + String next = lines.get(i); // line i+1 (1-based) + if (!line.isBlank() && SETEXT_UNDERLINE.matcher(next).matches()) { + // Section header spans lines i..i+1; section content starts at i (title line). + headingLines.add(new int[] {i, i}); + titles.add(line.trim()); + } + } + } + + List
sections = new ArrayList<>(); + if (headingLines.isEmpty()) { + sections.add(new Section(1, n, null)); + return sections; + } + + int firstHeading = headingLines.get(0)[0]; + if (firstHeading > 1) { + sections.add(new Section(1, firstHeading - 1, null)); + } + for (int h = 0; h < headingLines.size(); h++) { + int start = headingLines.get(h)[0]; + int end = (h + 1 < headingLines.size()) ? headingLines.get(h + 1)[0] - 1 : n; + sections.add(new Section(start, end, titles.get(h))); + } + return sections; + } + + private record Section(int startLine, int endLine, @Nullable String symbol) {} +} diff --git a/trueref-adapters/src/main/java/com/trueref/adapter/out/parsing/heuristic/SlidingWindowStrategy.java b/trueref-adapters/src/main/java/com/trueref/adapter/out/parsing/heuristic/SlidingWindowStrategy.java new file mode 100644 index 0000000..a258728 --- /dev/null +++ b/trueref-adapters/src/main/java/com/trueref/adapter/out/parsing/heuristic/SlidingWindowStrategy.java @@ -0,0 +1,21 @@ +package com.trueref.adapter.out.parsing.heuristic; + +import com.trueref.domain.port.out.CodeParser.ParsedChunk; +import java.util.List; + +/** Universal fallback: fixed-size line windows with overlap. Symbol is always null. */ +final class SlidingWindowStrategy { + + static final int DEFAULT_WINDOW = 80; + static final int DEFAULT_OVERLAP = 10; + + private SlidingWindowStrategy() {} + + static List parse(List lines, String language) { + if (lines.isEmpty()) { + return List.of(); + } + return Chunks.splitLong( + lines, 1, lines.size(), language, null, /* threshold */ DEFAULT_WINDOW, DEFAULT_WINDOW, DEFAULT_OVERLAP); + } +} diff --git a/trueref-adapters/src/main/java/com/trueref/adapter/out/parsing/heuristic/package-info.java b/trueref-adapters/src/main/java/com/trueref/adapter/out/parsing/heuristic/package-info.java new file mode 100644 index 0000000..31fddfb --- /dev/null +++ b/trueref-adapters/src/main/java/com/trueref/adapter/out/parsing/heuristic/package-info.java @@ -0,0 +1,8 @@ +/** + * Pure-Java heuristic {@link com.trueref.domain.port.out.CodeParser} adapter. Chooses a splitting + * strategy per file based on extension (see {@link + * com.trueref.adapter.out.parsing.heuristic.LanguageDetector}) and falls back to a sliding-window + * splitter for unknown formats. See FINDINGS §F11 for rationale. + */ +@org.jspecify.annotations.NullMarked +package com.trueref.adapter.out.parsing.heuristic; diff --git a/trueref-adapters/src/main/java/com/trueref/adapter/out/persistence/h2/H2JobStore.java b/trueref-adapters/src/main/java/com/trueref/adapter/out/persistence/h2/H2JobStore.java new file mode 100644 index 0000000..2305ca1 --- /dev/null +++ b/trueref-adapters/src/main/java/com/trueref/adapter/out/persistence/h2/H2JobStore.java @@ -0,0 +1,240 @@ +package com.trueref.adapter.out.persistence.h2; + +import com.trueref.domain.model.IngestionJob; +import com.trueref.domain.model.JobId; +import com.trueref.domain.model.JobStage; +import com.trueref.domain.model.JobStatus; +import com.trueref.domain.model.JobType; +import com.trueref.domain.model.RepositoryId; +import com.trueref.domain.model.VersionId; +import com.trueref.domain.port.out.JobStore; +import java.sql.Timestamp; +import java.time.Instant; +import java.util.ArrayList; +import java.util.List; +import java.util.Optional; +import org.jspecify.annotations.Nullable; +import org.springframework.jdbc.core.RowMapper; +import org.springframework.jdbc.core.simple.JdbcClient; +import org.springframework.stereotype.Component; + +@Component +class H2JobStore implements JobStore { + + private final JdbcClient jdbc; + + H2JobStore(JdbcClient jdbc) { + this.jdbc = jdbc; + } + + private static final RowMapper JOB_ROW = (rs, i) -> { + Timestamp s = rs.getTimestamp("started_at"); + Timestamp f = rs.getTimestamp("finished_at"); + String vid = rs.getString("version_id"); + return new IngestionJob( + JobId.of(rs.getString("id")), + RepositoryId.of(rs.getString("repo_id")), + vid == null ? null : VersionId.of(vid), + JobType.valueOf(rs.getString("type")), + JobStatus.valueOf(rs.getString("status")), + s == null ? null : s.toInstant(), + f == null ? null : f.toInstant(), + List.of()); + }; + + private static final RowMapper STAGE_ROW = (rs, i) -> { + Timestamp s = rs.getTimestamp("started_at"); + Timestamp f = rs.getTimestamp("finished_at"); + return new JobStage( + JobId.of(rs.getString("job_id")), + JobStage.StageName.valueOf(rs.getString("name")), + JobStage.StageStatus.valueOf(rs.getString("status")), + s == null ? null : s.toInstant(), + f == null ? null : f.toInstant(), + rs.getLong("items_processed"), + rs.getLong("items_total"), + rs.getLong("bytes_processed"), + rs.getString("error_message")); + }; + + @Override + public IngestionJob save(IngestionJob job) { + int updated = jdbc.sql( + """ + UPDATE ingestion_jobs SET status = :status, started_at = :s, finished_at = :f + WHERE id = :id + """) + .param("id", job.id().toString()) + .param("status", job.status().name()) + .param("s", job.startedAt() == null ? null : Timestamp.from(job.startedAt())) + .param("f", job.finishedAt() == null ? null : Timestamp.from(job.finishedAt())) + .update(); + if (updated == 0) { + jdbc.sql( + """ + INSERT INTO ingestion_jobs + (id, repo_id, version_id, type, status, started_at, finished_at, created_at) + VALUES (:id, :repo, :version, :type, :status, :s, :f, :now) + """) + .param("id", job.id().toString()) + .param("repo", job.repoId().toString()) + .param("version", job.versionId() == null ? null : job.versionId().toString()) + .param("type", job.type().name()) + .param("status", job.status().name()) + .param("s", job.startedAt() == null ? null : Timestamp.from(job.startedAt())) + .param("f", job.finishedAt() == null ? null : Timestamp.from(job.finishedAt())) + .param("now", Timestamp.from(Instant.now())) + .update(); + } + for (JobStage st : job.stages()) { + upsertStage(st); + } + return job; + } + + @Override + public Optional findById(JobId id) { + Optional bare = jdbc.sql("SELECT * FROM ingestion_jobs WHERE id = :id") + .param("id", id.toString()) + .query(JOB_ROW) + .optional(); + return bare.map(j -> hydrateStages(j)); + } + + @Override + public List findRunning() { + List jobs = jdbc.sql("SELECT * FROM ingestion_jobs WHERE status = 'RUNNING'") + .query(JOB_ROW) + .list(); + return jobs.stream().map(this::hydrateStages).toList(); + } + + @Override + public List find( + @Nullable RepositoryId repoId, + @Nullable VersionId versionId, + @Nullable JobStatus status, + int limit) { + StringBuilder sql = new StringBuilder("SELECT * FROM ingestion_jobs WHERE 1=1"); + var spec = jdbc.sql(""); // placeholder; rebuilt below + StringBuilder where = new StringBuilder(); + List bindings = new ArrayList<>(); + if (repoId != null) { + where.append(" AND repo_id = :repo"); + bindings.add(new Object[] {"repo", repoId.toString()}); + } + if (versionId != null) { + where.append(" AND version_id = :version"); + bindings.add(new Object[] {"version", versionId.toString()}); + } + if (status != null) { + where.append(" AND status = :status"); + bindings.add(new Object[] {"status", status.name()}); + } + String finalSql = "SELECT * FROM ingestion_jobs WHERE 1=1" + where + " ORDER BY created_at DESC LIMIT :lim"; + var stmt = jdbc.sql(finalSql).param("lim", Math.max(1, limit)); + for (Object[] b : bindings) { + stmt = stmt.param((String) b[0], b[1]); + } + List jobs = stmt.query(JOB_ROW).list(); + return jobs.stream().map(this::hydrateStages).toList(); + } + + @Override + public void updateStatus( + JobId id, JobStatus status, @Nullable Instant startedAt, @Nullable Instant finishedAt) { + jdbc.sql( + """ + UPDATE ingestion_jobs SET + status = :s, + started_at = COALESCE(:start, started_at), + finished_at = COALESCE(:finish, finished_at) + WHERE id = :id + """) + .param("id", id.toString()) + .param("s", status.name()) + .param("start", startedAt == null ? null : Timestamp.from(startedAt)) + .param("finish", finishedAt == null ? null : Timestamp.from(finishedAt)) + .update(); + } + + @Override + public void upsertStage(JobStage st) { + int updated = jdbc.sql( + """ + UPDATE job_stages SET + status = :status, + started_at = :s, + finished_at = :f, + items_processed = :ip, + items_total = :it, + bytes_processed = :bp, + error_message = :err + WHERE job_id = :job AND name = :name + """) + .param("job", st.jobId().toString()) + .param("name", st.name().name()) + .param("status", st.status().name()) + .param("s", st.startedAt() == null ? null : Timestamp.from(st.startedAt())) + .param("f", st.finishedAt() == null ? null : Timestamp.from(st.finishedAt())) + .param("ip", st.itemsProcessed()) + .param("it", st.itemsTotal()) + .param("bp", st.bytesProcessed()) + .param("err", st.errorMessage()) + .update(); + if (updated == 0) { + jdbc.sql( + """ + INSERT INTO job_stages + (job_id, name, status, started_at, finished_at, items_processed, items_total, bytes_processed, error_message) + VALUES (:job, :name, :status, :s, :f, :ip, :it, :bp, :err) + """) + .param("job", st.jobId().toString()) + .param("name", st.name().name()) + .param("status", st.status().name()) + .param("s", st.startedAt() == null ? null : Timestamp.from(st.startedAt())) + .param("f", st.finishedAt() == null ? null : Timestamp.from(st.finishedAt())) + .param("ip", st.itemsProcessed()) + .param("it", st.itemsTotal()) + .param("bp", st.bytesProcessed()) + .param("err", st.errorMessage()) + .update(); + } + } + + private IngestionJob hydrateStages(IngestionJob j) { + List stages = jdbc.sql("SELECT * FROM job_stages WHERE job_id = :id ORDER BY name") + .param("id", j.id().toString()) + .query(STAGE_ROW) + .list(); + return new IngestionJob( + j.id(), j.repoId(), j.versionId(), j.type(), j.status(), j.startedAt(), j.finishedAt(), stages); + } + + @Override + public int failStaleJobs(Instant finishedAt) { + Timestamp ts = Timestamp.from(finishedAt); + // Mark all RUNNING and QUEUED stages on stale jobs as FAILED first (FK child before parent). + jdbc.sql(""" + UPDATE job_stages SET + status = 'FAILED', + finished_at = :ts, + error_message = COALESCE(error_message, 'interrupted by server restart') + WHERE status = 'RUNNING' + AND job_id IN ( + SELECT id FROM ingestion_jobs WHERE status IN ('RUNNING', 'QUEUED') + ) + """) + .param("ts", ts) + .update(); + // Now fail the jobs themselves. + return jdbc.sql(""" + UPDATE ingestion_jobs SET + status = 'FAILED', + finished_at = COALESCE(finished_at, :ts) + WHERE status IN ('RUNNING', 'QUEUED') + """) + .param("ts", ts) + .update(); + } +} diff --git a/trueref-adapters/src/main/java/com/trueref/adapter/out/persistence/h2/H2RepositoryStore.java b/trueref-adapters/src/main/java/com/trueref/adapter/out/persistence/h2/H2RepositoryStore.java new file mode 100644 index 0000000..f64d90b --- /dev/null +++ b/trueref-adapters/src/main/java/com/trueref/adapter/out/persistence/h2/H2RepositoryStore.java @@ -0,0 +1,256 @@ +package com.trueref.adapter.out.persistence.h2; + +import com.trueref.domain.model.Repository; +import com.trueref.domain.model.RepositoryId; +import com.trueref.domain.model.Version; +import com.trueref.domain.model.VersionId; +import com.trueref.domain.model.VersionStatus; +import com.trueref.domain.port.out.RepositoryStore; +import java.sql.Timestamp; +import java.time.Duration; +import java.time.Instant; +import java.util.List; +import java.util.Optional; +import org.jspecify.annotations.Nullable; +import org.springframework.jdbc.core.RowMapper; +import org.springframework.jdbc.core.simple.JdbcClient; +import org.springframework.stereotype.Component; + +@Component +class H2RepositoryStore implements RepositoryStore { + + private final JdbcClient jdbc; + + H2RepositoryStore(JdbcClient jdbc) { + this.jdbc = jdbc; + } + + private static final RowMapper REPO_MAPPER = (rs, i) -> new Repository( + RepositoryId.of(rs.getString("id")), + rs.getString("name"), + rs.getString("remote_url"), + rs.getString("local_path"), + rs.getBoolean("managed_clone"), + JsonCodec.readStrings(rs.getString("ignore_globs")), + rs.getLong("max_file_size_bytes"), + Duration.ofSeconds(rs.getLong("poll_interval_seconds")), + rs.getInt("tag_cap"), + JsonCodec.readTagPatterns(rs.getString("version_mapping_rules")), + rs.getTimestamp("created_at").toInstant(), + rs.getTimestamp("updated_at").toInstant()); + + private static final RowMapper VERSION_MAPPER = (rs, i) -> { + Timestamp idx = rs.getTimestamp("indexed_at"); + return new Version( + VersionId.of(rs.getString("id")), + RepositoryId.of(rs.getString("repo_id")), + rs.getString("tag"), + rs.getString("commit_sha"), + VersionStatus.valueOf(rs.getString("status")), + idx == null ? null : idx.toInstant(), + rs.getInt("chunk_count"), + rs.getString("error_message")); + }; + + @Override + public Repository save(Repository r) { + int updated = jdbc.sql( + """ + UPDATE repositories SET + name = :name, + remote_url = :remoteUrl, + local_path = :localPath, + managed_clone = :managedClone, + ignore_globs = :ignoreGlobs, + max_file_size_bytes = :maxFileSizeBytes, + poll_interval_seconds = :pollSec, + tag_cap = :tagCap, + version_mapping_rules = :rules, + updated_at = :updatedAt + WHERE id = :id + """) + .param("id", r.id().toString()) + .param("name", r.name()) + .param("remoteUrl", r.remoteUrl()) + .param("localPath", r.localPath()) + .param("managedClone", r.managedClone()) + .param("ignoreGlobs", JsonCodec.writeStrings(r.ignoreGlobs())) + .param("maxFileSizeBytes", r.maxFileSizeBytes()) + .param("pollSec", r.pollInterval().toSeconds()) + .param("tagCap", r.tagCap()) + .param("rules", JsonCodec.writeTagPatterns(r.versionMappingRules())) + .param("updatedAt", Timestamp.from(r.updatedAt())) + .update(); + if (updated == 0) { + jdbc.sql( + """ + INSERT INTO repositories ( + id, name, remote_url, local_path, managed_clone, ignore_globs, + max_file_size_bytes, poll_interval_seconds, tag_cap, + version_mapping_rules, created_at, updated_at) + VALUES ( + :id, :name, :remoteUrl, :localPath, :managedClone, :ignoreGlobs, + :maxFileSizeBytes, :pollSec, :tagCap, :rules, :createdAt, :updatedAt) + """) + .param("id", r.id().toString()) + .param("name", r.name()) + .param("remoteUrl", r.remoteUrl()) + .param("localPath", r.localPath()) + .param("managedClone", r.managedClone()) + .param("ignoreGlobs", JsonCodec.writeStrings(r.ignoreGlobs())) + .param("maxFileSizeBytes", r.maxFileSizeBytes()) + .param("pollSec", r.pollInterval().toSeconds()) + .param("tagCap", r.tagCap()) + .param("rules", JsonCodec.writeTagPatterns(r.versionMappingRules())) + .param("createdAt", Timestamp.from(r.createdAt())) + .param("updatedAt", Timestamp.from(r.updatedAt())) + .update(); + } + return r; + } + + @Override + public Optional findById(RepositoryId id) { + return jdbc.sql("SELECT * FROM repositories WHERE id = :id") + .param("id", id.toString()) + .query(REPO_MAPPER) + .optional(); + } + + @Override + public Optional findByName(String name) { + return jdbc.sql("SELECT * FROM repositories WHERE name = :name") + .param("name", name) + .query(REPO_MAPPER) + .optional(); + } + + @Override + public List findAll() { + return jdbc.sql("SELECT * FROM repositories ORDER BY name") + .query(REPO_MAPPER) + .list(); + } + + @Override + public void delete(RepositoryId id) { + jdbc.sql("DELETE FROM repositories WHERE id = :id") + .param("id", id.toString()) + .update(); + } + + @Override + public Version saveVersion(Version v) { + int updated = jdbc.sql( + """ + UPDATE versions SET + commit_sha = :sha, + status = :status, + indexed_at = :indexedAt, + chunk_count = :chunkCount, + error_message = :err + WHERE id = :id + """) + .param("id", v.id().toString()) + .param("sha", v.commitSha()) + .param("status", v.status().name()) + .param("indexedAt", v.indexedAt() == null ? null : Timestamp.from(v.indexedAt())) + .param("chunkCount", v.chunkCount()) + .param("err", v.errorMessage()) + .update(); + if (updated == 0) { + jdbc.sql( + """ + INSERT INTO versions (id, repo_id, tag, commit_sha, status, indexed_at, chunk_count, error_message) + VALUES (:id, :repo, :tag, :sha, :status, :indexedAt, :chunkCount, :err) + """) + .param("id", v.id().toString()) + .param("repo", v.repoId().toString()) + .param("tag", v.tag()) + .param("sha", v.commitSha()) + .param("status", v.status().name()) + .param("indexedAt", v.indexedAt() == null ? null : Timestamp.from(v.indexedAt())) + .param("chunkCount", v.chunkCount()) + .param("err", v.errorMessage()) + .update(); + } + return v; + } + + @Override + public Optional findVersion(VersionId id) { + return jdbc.sql("SELECT * FROM versions WHERE id = :id") + .param("id", id.toString()) + .query(VERSION_MAPPER) + .optional(); + } + + @Override + public Optional findVersionByTag(RepositoryId repoId, String tag) { + return jdbc.sql("SELECT * FROM versions WHERE repo_id = :r AND tag = :t") + .param("r", repoId.toString()) + .param("t", tag) + .query(VERSION_MAPPER) + .optional(); + } + + @Override + public List findVersionsByRepo(RepositoryId repoId) { + return jdbc.sql("SELECT * FROM versions WHERE repo_id = :r ORDER BY tag") + .param("r", repoId.toString()) + .query(VERSION_MAPPER) + .list(); + } + + @Override + public List findVersionsByStatus(@Nullable RepositoryId repoId, VersionStatus status) { + if (repoId == null) { + return jdbc.sql("SELECT * FROM versions WHERE status = :s") + .param("s", status.name()) + .query(VERSION_MAPPER) + .list(); + } + return jdbc.sql("SELECT * FROM versions WHERE repo_id = :r AND status = :s") + .param("r", repoId.toString()) + .param("s", status.name()) + .query(VERSION_MAPPER) + .list(); + } + + @Override + public void updateVersionStatus(VersionId id, VersionStatus status, @Nullable String errorMessage) { + jdbc.sql("UPDATE versions SET status = :s, error_message = :err WHERE id = :id") + .param("id", id.toString()) + .param("s", status.name()) + .param("err", errorMessage) + .update(); + } + + @Override + public void markVersionIndexed(VersionId id, int chunkCount) { + jdbc.sql( + """ + UPDATE versions SET status = 'INDEXED', + indexed_at = :now, + chunk_count = :c, + error_message = NULL + WHERE id = :id + """) + .param("id", id.toString()) + .param("now", Timestamp.from(Instant.now())) + .param("c", chunkCount) + .update(); + } + + @Override + public int failStaleIndexingVersions(String errorMessage) { + return jdbc.sql(""" + UPDATE versions SET + status = 'FAILED', + error_message = :err + WHERE status = 'INDEXING' + """) + .param("err", errorMessage) + .update(); + } +} diff --git a/trueref-adapters/src/main/java/com/trueref/adapter/out/persistence/h2/JsonCodec.java b/trueref-adapters/src/main/java/com/trueref/adapter/out/persistence/h2/JsonCodec.java new file mode 100644 index 0000000..e717e87 --- /dev/null +++ b/trueref-adapters/src/main/java/com/trueref/adapter/out/persistence/h2/JsonCodec.java @@ -0,0 +1,78 @@ +package com.trueref.adapter.out.persistence.h2; + +import com.fasterxml.jackson.core.type.TypeReference; +import com.fasterxml.jackson.databind.ObjectMapper; +import com.trueref.domain.model.TagPattern; +import java.util.List; + +/** JSON helpers for serializing list-typed columns. */ +final class JsonCodec { + + static final ObjectMapper MAPPER = new ObjectMapper(); + + private static final TypeReference> STRING_LIST = new TypeReference<>() {}; + private static final TypeReference> TAG_PATTERN_LIST = new TypeReference<>() {}; + + private JsonCodec() {} + + static String writeStrings(List v) { + try { + return MAPPER.writeValueAsString(v); + } catch (Exception e) { + throw new IllegalStateException(e); + } + } + + static List readStrings(String json) { + if (json == null || json.isBlank()) return List.of(); + try { + return MAPPER.readValue(json, STRING_LIST); + } catch (Exception e) { + throw new IllegalStateException(e); + } + } + + static String writeTagPatterns(List v) { + try { + List dtos = v.stream().map(TagPatternDto::from).toList(); + return MAPPER.writeValueAsString(dtos); + } catch (Exception e) { + throw new IllegalStateException(e); + } + } + + static List readTagPatterns(String json) { + if (json == null || json.isBlank()) return List.of(); + try { + return MAPPER.readValue(json, TAG_PATTERN_LIST).stream() + .map(TagPatternDto::toModel) + .toList(); + } catch (Exception e) { + throw new IllegalStateException(e); + } + } + + /** Discriminated DTO for {@link TagPattern}. */ + record TagPatternDto(String type, String template) { + static TagPatternDto from(TagPattern p) { + return switch (p) { + case TagPattern.Exact e -> new TagPatternDto("EXACT", null); + case TagPattern.VPrefix v -> new TagPatternDto("V_PREFIX", null); + case TagPattern.ReleasePrefix r -> new TagPatternDto("RELEASE_PREFIX", null); + case TagPattern.SemverFuzzy s -> new TagPatternDto("SEMVER_FUZZY", null); + case TagPattern.Custom c -> new TagPatternDto("CUSTOM", c.template()); + }; + } + + TagPattern toModel() { + return switch (type) { + case "EXACT" -> new TagPattern.Exact(); + case "V_PREFIX" -> new TagPattern.VPrefix(); + case "RELEASE_PREFIX" -> new TagPattern.ReleasePrefix(); + case "SEMVER_FUZZY" -> new TagPattern.SemverFuzzy(); + case "CUSTOM" -> new TagPattern.Custom(template); + default -> throw new IllegalStateException("unknown TagPattern type: " + type); + }; + } + } +} diff --git a/trueref-adapters/src/main/java/com/trueref/adapter/out/persistence/h2/package-info.java b/trueref-adapters/src/main/java/com/trueref/adapter/out/persistence/h2/package-info.java new file mode 100644 index 0000000..ddce9ea --- /dev/null +++ b/trueref-adapters/src/main/java/com/trueref/adapter/out/persistence/h2/package-info.java @@ -0,0 +1,6 @@ +/** + * H2 persistence adapter (driven port). Uses Spring {@code JdbcClient} with explicit row mappers, + * no JPA. Schema is managed by Flyway migrations under {@code db/migration}. + */ +@org.jspecify.annotations.NullMarked +package com.trueref.adapter.out.persistence.h2; diff --git a/trueref-adapters/src/main/java/com/trueref/adapter/out/vectorstore/lucene/LuceneChunkStore.java b/trueref-adapters/src/main/java/com/trueref/adapter/out/vectorstore/lucene/LuceneChunkStore.java new file mode 100644 index 0000000..cafcf0c --- /dev/null +++ b/trueref-adapters/src/main/java/com/trueref/adapter/out/vectorstore/lucene/LuceneChunkStore.java @@ -0,0 +1,425 @@ +package com.trueref.adapter.out.vectorstore.lucene; + +import com.trueref.domain.model.Chunk; +import com.trueref.domain.model.ChunkId; +import com.trueref.domain.model.ChunkVersion; +import com.trueref.domain.model.Embedding; +import com.trueref.domain.model.RepositoryId; +import com.trueref.domain.model.SearchHit; +import com.trueref.domain.model.SearchScope; +import com.trueref.domain.model.VersionId; +import com.trueref.domain.port.out.ChunkStore; +import jakarta.annotation.PreDestroy; +import java.io.IOException; +import java.io.UncheckedIOException; +import java.nio.file.Files; +import java.nio.file.Path; +import java.util.ArrayList; +import java.util.Collection; +import java.util.HashMap; +import java.util.HashSet; +import java.util.List; +import java.util.Map; +import java.util.Optional; +import java.util.Set; +import java.util.concurrent.ConcurrentHashMap; +import org.apache.lucene.analysis.Analyzer; +import org.apache.lucene.analysis.standard.StandardAnalyzer; +import org.apache.lucene.document.Document; +import org.apache.lucene.document.Field.Store; +import org.apache.lucene.document.KnnFloatVectorField; +import org.apache.lucene.document.NumericDocValuesField; +import org.apache.lucene.document.StoredField; +import org.apache.lucene.document.StringField; +import org.apache.lucene.document.TextField; +import org.apache.lucene.index.IndexWriter; +import org.apache.lucene.index.IndexWriterConfig; +import org.apache.lucene.index.IndexWriterConfig.OpenMode; +import org.apache.lucene.index.LeafReaderContext; +import org.apache.lucene.index.StoredFields; +import org.apache.lucene.index.Term; +import org.apache.lucene.index.VectorSimilarityFunction; +import org.apache.lucene.search.BooleanClause; +import org.apache.lucene.search.BooleanQuery; +import org.apache.lucene.search.CollectorManager; +import org.apache.lucene.search.IndexSearcher; +import org.apache.lucene.search.KnnFloatVectorQuery; +import org.apache.lucene.search.Query; +import org.apache.lucene.search.ScoreDoc; +import org.apache.lucene.search.ScoreMode; +import org.apache.lucene.search.SearcherFactory; +import org.apache.lucene.search.SearcherManager; +import org.apache.lucene.search.SimpleCollector; +import org.apache.lucene.search.TermInSetQuery; +import org.apache.lucene.search.TermQuery; +import org.apache.lucene.search.TopDocs; +import org.apache.lucene.store.FSDirectory; +import org.apache.lucene.util.BytesRef; +import org.apache.lucene.util.QueryBuilder; +import org.jspecify.annotations.Nullable; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import org.springframework.beans.factory.annotation.Value; +import org.springframework.stereotype.Component; + +/** + * Single-index Lucene 10 implementation of {@link ChunkStore}. The index co-locates two doctypes + * (chunks + links) distinguished by the {@code doc_type} field so BM25, HNSW kNN and membership + * joins live in one writer/searcher pair. + */ +@Component +final class LuceneChunkStore implements ChunkStore { + + private static final Logger log = LoggerFactory.getLogger(LuceneChunkStore.class); + + private static final String F_DOC_TYPE = "doc_type"; + private static final String F_CHUNK_ID = "chunk_id"; + private static final String F_CONTENT_HASH = "content_hash"; + private static final String F_LANGUAGE = "language"; + private static final String F_SYMBOL = "symbol"; + private static final String F_TOKEN_COUNT = "token_count"; + private static final String F_CONTENT = "content"; + private static final String F_VECTOR = "vector"; + private static final String F_VERSION_ID = "version_id"; + private static final String F_FILE_PATH = "file_path"; + private static final String F_START_LINE = "start_line"; + private static final String F_END_LINE = "end_line"; + + private static final String TYPE_CHUNK = "chunk"; + private static final String TYPE_LINK = "link"; + + private final FSDirectory directory; + private final Analyzer analyzer; + private final IndexWriter writer; + private final SearcherManager searcherManager; + private final QueryBuilder queryBuilder; + + LuceneChunkStore(LuceneProperties props, @Value("${trueref.home:./data}") Path trueRefHome) + throws IOException { + Path home = props.home() != null ? props.home() : trueRefHome.resolve("lucene"); + Files.createDirectories(home); + this.directory = FSDirectory.open(home); + this.analyzer = new StandardAnalyzer(); + IndexWriterConfig cfg = new IndexWriterConfig(analyzer); + cfg.setOpenMode(OpenMode.CREATE_OR_APPEND); + this.writer = new IndexWriter(directory, cfg); + this.searcherManager = new SearcherManager(writer, new SearcherFactory()); + this.queryBuilder = new QueryBuilder(analyzer); + log.info("Lucene index opened at {}", home); + } + + @PreDestroy + void close() { + try { + searcherManager.close(); + } catch (IOException e) { + log.warn("failed to close SearcherManager", e); + } + try { + writer.close(); + } catch (IOException e) { + log.warn("failed to close IndexWriter", e); + } + try { + directory.close(); + } catch (IOException e) { + log.warn("failed to close FSDirectory", e); + } + } + + @Override + public Optional findByContentHash(String contentHash) { + Query q = new BooleanQuery.Builder() + .add(new TermQuery(new Term(F_DOC_TYPE, TYPE_CHUNK)), BooleanClause.Occur.MUST) + .add(new TermQuery(new Term(F_CONTENT_HASH, contentHash)), BooleanClause.Occur.MUST) + .build(); + return withSearcher(searcher -> { + TopDocs top = searcher.search(q, 1); + if (top.scoreDocs.length == 0) { + return Optional.empty(); + } + Document d = searcher.storedFields().document(top.scoreDocs[0].doc); + return Optional.of(toChunk(d)); + }); + } + + @Override + public Chunk upsertChunk(Chunk chunk, Embedding embedding) { + Document doc = new Document(); + doc.add(new StringField(F_DOC_TYPE, TYPE_CHUNK, Store.YES)); + doc.add(new StringField(F_CHUNK_ID, chunk.id().toString(), Store.YES)); + doc.add(new StringField(F_CONTENT_HASH, chunk.contentHash(), Store.YES)); + doc.add(new StringField(F_LANGUAGE, chunk.language(), Store.YES)); + if (chunk.symbol() != null) { + doc.add(new StringField(F_SYMBOL, chunk.symbol(), Store.YES)); + } + doc.add(new StoredField(F_TOKEN_COUNT, chunk.tokenCount())); + doc.add(new NumericDocValuesField(F_TOKEN_COUNT, chunk.tokenCount())); + doc.add(new TextField(F_CONTENT, chunk.content(), Store.YES)); + doc.add(new KnnFloatVectorField(F_VECTOR, embedding.vector(), VectorSimilarityFunction.COSINE)); + try { + writer.updateDocument(new Term(F_CHUNK_ID, chunk.id().toString()), doc); + } catch (IOException e) { + throw new UncheckedIOException(e); + } + log.debug("upserted chunk {} (hash {})", chunk.id(), chunk.contentHash()); + return chunk; + } + + @Override + public void linkChunks(Collection links) { + if (links.isEmpty()) { + return; + } + List docs = new ArrayList<>(links.size()); + for (ChunkVersion cv : links) { + Document d = new Document(); + d.add(new StringField(F_DOC_TYPE, TYPE_LINK, Store.YES)); + d.add(new StringField(F_CHUNK_ID, cv.chunkId().toString(), Store.YES)); + d.add(new StringField(F_VERSION_ID, cv.versionId().toString(), Store.YES)); + d.add(new StringField(F_FILE_PATH, cv.filePath(), Store.YES)); + d.add(new StoredField(F_START_LINE, cv.startLine())); + d.add(new NumericDocValuesField(F_START_LINE, cv.startLine())); + d.add(new StoredField(F_END_LINE, cv.endLine())); + d.add(new NumericDocValuesField(F_END_LINE, cv.endLine())); + docs.add(d); + } + try { + writer.addDocuments(docs); + } catch (IOException e) { + throw new UncheckedIOException(e); + } + log.debug("added {} link docs", docs.size()); + } + + @Override + public void unlinkVersion(VersionId versionId) { + Query q = new BooleanQuery.Builder() + .add(new TermQuery(new Term(F_DOC_TYPE, TYPE_LINK)), BooleanClause.Occur.MUST) + .add(new TermQuery(new Term(F_VERSION_ID, versionId.toString())), BooleanClause.Occur.MUST) + .build(); + try { + writer.deleteDocuments(q); + } catch (IOException e) { + throw new UncheckedIOException(e); + } + log.debug("unlinked version {}", versionId); + } + + @Override + public Set chunkIdsForVersion(VersionId versionId) { + Query q = new BooleanQuery.Builder() + .add(new TermQuery(new Term(F_DOC_TYPE, TYPE_LINK)), BooleanClause.Occur.MUST) + .add(new TermQuery(new Term(F_VERSION_ID, versionId.toString())), BooleanClause.Occur.MUST) + .build(); + Set ids = ConcurrentHashMap.newKeySet(); + withSearcher(searcher -> { + forEachDoc(searcher, q, (doc, sf) -> { + Document d = sf.document(doc, Set.of(F_CHUNK_ID)); + String v = d.get(F_CHUNK_ID); + if (v != null) { + ids.add(ChunkId.of(v)); + } + }); + return null; + }); + return ids; + } + + @Override + public List bm25Search(String queryText, SearchScope scope, int topK) { + return withSearcher(searcher -> { + ScopeIndex scopeIndex = collectScope(scope, searcher); + if (scopeIndex.chunkIds().isEmpty()) { + return List.of(); + } + Query textQ = queryBuilder.createBooleanQuery(F_CONTENT, queryText); + if (textQ == null) { + return List.of(); + } + Query q = new BooleanQuery.Builder() + .add(new TermQuery(new Term(F_DOC_TYPE, TYPE_CHUNK)), BooleanClause.Occur.MUST) + .add(new TermInSetQuery(F_CHUNK_ID, scopeIndex.chunkIds()), BooleanClause.Occur.MUST) + .add(textQ, BooleanClause.Occur.MUST) + .build(); + TopDocs top = searcher.search(q, topK); + return toHits(searcher, top, scopeIndex); + }); + } + + @Override + public List denseSearch(float[] queryVector, SearchScope scope, int topK) { + return withSearcher(searcher -> { + ScopeIndex scopeIndex = collectScope(scope, searcher); + if (scopeIndex.chunkIds().isEmpty()) { + return List.of(); + } + Query filter = new BooleanQuery.Builder() + .add(new TermQuery(new Term(F_DOC_TYPE, TYPE_CHUNK)), BooleanClause.Occur.MUST) + .add(new TermInSetQuery(F_CHUNK_ID, scopeIndex.chunkIds()), BooleanClause.Occur.MUST) + .build(); + Query q = new KnnFloatVectorQuery(F_VECTOR, queryVector, topK, filter); + TopDocs top = searcher.search(q, topK); + return toHits(searcher, top, scopeIndex); + }); + } + + @Override + public void commit() { + synchronized (writer) { + try { + writer.commit(); + searcherManager.maybeRefreshBlocking(); + } catch (IOException e) { + throw new UncheckedIOException(e); + } + } + log.debug("committed Lucene index"); + } + + // --- helpers ---------------------------------------------------------------------------- + + private ScopeIndex collectScope(SearchScope scope, IndexSearcher searcher) throws IOException { + Map versionToRepo = new HashMap<>(); + BooleanQuery.Builder versionOr = new BooleanQuery.Builder(); + for (SearchScope.RepoVersionRef r : scope.refs()) { + String vid = r.versionId().toString(); + versionToRepo.put(vid, r.repoId()); + versionOr.add(new TermQuery(new Term(F_VERSION_ID, vid)), BooleanClause.Occur.SHOULD); + } + Query outer = new BooleanQuery.Builder() + .add(new TermQuery(new Term(F_DOC_TYPE, TYPE_LINK)), BooleanClause.Occur.MUST) + .add(versionOr.build(), BooleanClause.Occur.MUST) + .build(); + + Set chunkIds = ConcurrentHashMap.newKeySet(); + Map rep = new ConcurrentHashMap<>(); + forEachDoc(searcher, outer, (doc, sf) -> { + Document d = sf.document(doc); + String cid = d.get(F_CHUNK_ID); + String vid = d.get(F_VERSION_ID); + if (cid == null || vid == null) { + return; + } + chunkIds.add(new BytesRef(cid)); + ChunkId chunkId = ChunkId.of(cid); + if (!rep.containsKey(chunkId)) { + int start = d.getField(F_START_LINE).numericValue().intValue(); + int end = d.getField(F_END_LINE).numericValue().intValue(); + rep.putIfAbsent( + chunkId, + new LinkDoc( + versionToRepo.get(vid), + VersionId.of(vid), + d.get(F_FILE_PATH), + start, + end)); + } + }); + return new ScopeIndex(chunkIds, rep); + } + + private List toHits(IndexSearcher searcher, TopDocs top, ScopeIndex idx) + throws IOException { + List hits = new ArrayList<>(top.scoreDocs.length); + StoredFields sf = searcher.storedFields(); + for (ScoreDoc sd : top.scoreDocs) { + Document d = sf.document(sd.doc); + String cidStr = d.get(F_CHUNK_ID); + if (cidStr == null) { + continue; + } + ChunkId cid = ChunkId.of(cidStr); + LinkDoc link = idx.representative().get(cid); + if (link == null) { + continue; + } + hits.add(new SearchHit( + cid, + link.repoId(), + link.versionId(), + "", // repo_name — enriched by application layer + "", // tag — enriched by application layer + link.filePath(), + link.startLine(), + link.endLine(), + d.get(F_LANGUAGE), + d.get(F_SYMBOL), + d.get(F_CONTENT), + sd.score)); + } + return hits; + } + + private static Chunk toChunk(Document d) { + int tokens = d.getField(F_TOKEN_COUNT).numericValue().intValue(); + return new Chunk( + ChunkId.of(d.get(F_CHUNK_ID)), + d.get(F_CONTENT_HASH), + d.get(F_CONTENT), + d.get(F_LANGUAGE), + d.get(F_SYMBOL), + tokens); + } + + private T withSearcher(SearcherTask task) { + try { + searcherManager.maybeRefresh(); + IndexSearcher s = searcherManager.acquire(); + try { + return task.run(s); + } finally { + searcherManager.release(s); + } + } catch (IOException e) { + throw new UncheckedIOException(e); + } + } + + @FunctionalInterface + private interface SearcherTask { + T run(IndexSearcher searcher) throws IOException; + } + + @FunctionalInterface + private interface DocConsumer { + void accept(int doc, StoredFields sf) throws IOException; + } + + private static void forEachDoc(IndexSearcher searcher, Query query, DocConsumer consumer) + throws IOException { + searcher.search(query, new CollectorManager() { + @Override + public SimpleCollector newCollector() { + return new SimpleCollector() { + @Nullable private StoredFields sf; + + @Override + public ScoreMode scoreMode() { + return ScoreMode.COMPLETE_NO_SCORES; + } + + @Override + protected void doSetNextReader(LeafReaderContext ctx) throws IOException { + this.sf = ctx.reader().storedFields(); + } + + @Override + public void collect(int doc) throws IOException { + consumer.accept(doc, sf); + } + }; + } + + @Override + public Void reduce(Collection collectors) { + return null; + } + }); + } + + private record ScopeIndex(Set chunkIds, Map representative) {} + + private record LinkDoc( + RepositoryId repoId, VersionId versionId, String filePath, int startLine, int endLine) {} +} diff --git a/trueref-adapters/src/main/java/com/trueref/adapter/out/vectorstore/lucene/LuceneConfig.java b/trueref-adapters/src/main/java/com/trueref/adapter/out/vectorstore/lucene/LuceneConfig.java new file mode 100644 index 0000000..250df96 --- /dev/null +++ b/trueref-adapters/src/main/java/com/trueref/adapter/out/vectorstore/lucene/LuceneConfig.java @@ -0,0 +1,12 @@ +package com.trueref.adapter.out.vectorstore.lucene; + +import org.springframework.boot.context.properties.EnableConfigurationProperties; +import org.springframework.context.annotation.Configuration; + +/** + * Enables {@link LuceneProperties} binding. The {@link LuceneChunkStore} bean itself is picked up + * via {@code @Component} scan. + */ +@Configuration +@EnableConfigurationProperties(LuceneProperties.class) +public class LuceneConfig {} diff --git a/trueref-adapters/src/main/java/com/trueref/adapter/out/vectorstore/lucene/LuceneProperties.java b/trueref-adapters/src/main/java/com/trueref/adapter/out/vectorstore/lucene/LuceneProperties.java new file mode 100644 index 0000000..239841d --- /dev/null +++ b/trueref-adapters/src/main/java/com/trueref/adapter/out/vectorstore/lucene/LuceneProperties.java @@ -0,0 +1,12 @@ +package com.trueref.adapter.out.vectorstore.lucene; + +import java.nio.file.Path; +import org.jspecify.annotations.Nullable; +import org.springframework.boot.context.properties.ConfigurationProperties; + +/** + * Typed configuration bound to {@code trueref.lucene.*}. When {@link #home()} is null, + * {@link LuceneChunkStore} falls back to {@code ${trueref.home:./data}/lucene}. + */ +@ConfigurationProperties("trueref.lucene") +public record LuceneProperties(@Nullable Path home) {} diff --git a/trueref-adapters/src/main/java/com/trueref/adapter/out/vectorstore/lucene/package-info.java b/trueref-adapters/src/main/java/com/trueref/adapter/out/vectorstore/lucene/package-info.java new file mode 100644 index 0000000..872edce --- /dev/null +++ b/trueref-adapters/src/main/java/com/trueref/adapter/out/vectorstore/lucene/package-info.java @@ -0,0 +1,9 @@ +/** + * Lucene 10 adapter implementing the {@link com.trueref.domain.port.out.ChunkStore} SPI. Holds + * both BM25-tokenized text and dense HNSW kNN vectors in a single index under + * {@code $TRUEREF_HOME/lucene}. + */ +@NullMarked +package com.trueref.adapter.out.vectorstore.lucene; + +import org.jspecify.annotations.NullMarked; diff --git a/trueref-adapters/src/test/java/com/trueref/adapter/out/cache/disk/DiskEmbeddingCacheTest.java b/trueref-adapters/src/test/java/com/trueref/adapter/out/cache/disk/DiskEmbeddingCacheTest.java new file mode 100644 index 0000000..56bf0c3 --- /dev/null +++ b/trueref-adapters/src/test/java/com/trueref/adapter/out/cache/disk/DiskEmbeddingCacheTest.java @@ -0,0 +1,106 @@ +package com.trueref.adapter.out.cache.disk; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatThrownBy; + +import java.nio.ByteBuffer; +import java.nio.ByteOrder; +import java.nio.file.Files; +import java.nio.file.Path; +import java.util.Arrays; +import java.util.Optional; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.io.TempDir; + +class DiskEmbeddingCacheTest { + + private static final int DIM = 8; + + private static DiskEmbeddingCache newCache(Path home) { + EmbeddingCacheProperties props = new EmbeddingCacheProperties(home, 16, DIM); + DiskEmbeddingCache cache = new DiskEmbeddingCache(props); + cache.init(); + return cache; + } + + private static String hash(String s) { + // pad short identifiers to a 64-char hex-shaped string + StringBuilder b = new StringBuilder(s); + while (b.length() < 64) { + b.append('0'); + } + return b.substring(0, 64); + } + + @Test + void roundTripsVector(@TempDir Path home) { + DiskEmbeddingCache cache = newCache(home); + String h = hash("abcd"); + float[] vec = new float[] {1f, -2f, 3.5f, 0f, 7.25f, -0.125f, 42f, 100f}; + + cache.put(h, vec); + + Optional got = cache.get(h); + assertThat(got).isPresent(); + assertThat(Arrays.equals(got.get(), vec)).isTrue(); + } + + @Test + void secondGetIsServedFromMemoryAfterFileDeletion(@TempDir Path home) throws Exception { + DiskEmbeddingCache cache = newCache(home); + String h = hash("beef"); + float[] vec = new float[DIM]; + for (int i = 0; i < DIM; i++) { + vec[i] = i * 1.5f; + } + cache.put(h, vec); + + // first get populates LRU (after a fresh read path); second get should hit memory. + Optional first = cache.get(h); + assertThat(first).isPresent(); + + // Delete on-disk file; LRU still holds it. + Path file = home.resolve("ab").resolve("cd").resolve(hash("abcd") + ".f32"); + // Resolve real path for the just-put hash: + Path placed = home.resolve(h.substring(0, 2)).resolve(h.substring(2, 4)).resolve(h + ".f32"); + Files.delete(placed); + assertThat(Files.exists(placed)).isFalse(); + + Optional second = cache.get(h); + assertThat(second).isPresent(); + assertThat(Arrays.equals(second.get(), vec)).isTrue(); + // suppress unused warning + assertThat(file.toString()).isNotEmpty(); + } + + @Test + void wrongDimensionFileReturnsEmpty(@TempDir Path home) throws Exception { + DiskEmbeddingCache cache = newCache(home); + String h = hash("dead"); + + // Write a stub file with the wrong byte count directly. + Path target = home.resolve(h.substring(0, 2)).resolve(h.substring(2, 4)).resolve(h + ".f32"); + Files.createDirectories(target.getParent()); + ByteBuffer buf = + ByteBuffer.allocate((DIM - 1) * Float.BYTES).order(ByteOrder.LITTLE_ENDIAN); + for (int i = 0; i < DIM - 1; i++) { + buf.putFloat(i); + } + Files.write(target, buf.array()); + + assertThat(cache.get(h)).isEmpty(); + } + + @Test + void putRejectsWrongDimension(@TempDir Path home) { + DiskEmbeddingCache cache = newCache(home); + assertThatThrownBy(() -> cache.put(hash("aabb"), new float[DIM + 1])) + .isInstanceOf(IllegalArgumentException.class); + } + + @Test + void missReturnsEmpty(@TempDir Path home) { + DiskEmbeddingCache cache = newCache(home); + assertThat(cache.get(hash("ffff"))).isEmpty(); + } +} diff --git a/trueref-adapters/src/test/java/com/trueref/adapter/out/embedding/onnx/PoolingMathTest.java b/trueref-adapters/src/test/java/com/trueref/adapter/out/embedding/onnx/PoolingMathTest.java new file mode 100644 index 0000000..e89afae --- /dev/null +++ b/trueref-adapters/src/test/java/com/trueref/adapter/out/embedding/onnx/PoolingMathTest.java @@ -0,0 +1,50 @@ +package com.trueref.adapter.out.embedding.onnx; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.within; + +import org.junit.jupiter.api.Test; + +class PoolingMathTest { + + @Test + void l2NormalizeScalesToUnitLength() { + float[] v = {3.0f, 4.0f}; + PoolingMath.l2NormalizeInPlace(v); + assertThat(v[0]).isCloseTo(0.6f, within(1e-6f)); + assertThat(v[1]).isCloseTo(0.8f, within(1e-6f)); + } + + @Test + void l2NormalizeLeavesZeroVectorAsZero() { + float[] v = {0.0f, 0.0f, 0.0f}; + PoolingMath.l2NormalizeInPlace(v); + assertThat(v).containsExactly(0.0f, 0.0f, 0.0f); + } + + @Test + void meanPoolAveragesAcrossUnmaskedPositions() { + // batch=1, seq=3, hidden=2. Mask out position 2. + float[][][] hidden = {{{1.0f, 2.0f}, {3.0f, 4.0f}, {100.0f, 100.0f}}}; + long[][] mask = {{1L, 1L, 0L}}; + float[][] pooled = PoolingMath.meanPool(hidden, mask); + assertThat(pooled).hasDimensions(1, 2); + assertThat(pooled[0][0]).isCloseTo(2.0f, within(1e-6f)); + assertThat(pooled[0][1]).isCloseTo(3.0f, within(1e-6f)); + } + + @Test + void meanPoolReturnsZeroWhenMaskIsAllZero() { + float[][][] hidden = {{{5.0f, 5.0f}}}; + long[][] mask = {{0L}}; + float[][] pooled = PoolingMath.meanPool(hidden, mask); + assertThat(pooled[0]).containsExactly(0.0f, 0.0f); + } + + @Test + void sigmoidHandlesLargeMagnitudesWithoutOverflow() { + assertThat(PoolingMath.sigmoid(0.0)).isCloseTo(0.5, within(1e-9)); + assertThat(PoolingMath.sigmoid(1000.0)).isCloseTo(1.0, within(1e-9)); + assertThat(PoolingMath.sigmoid(-1000.0)).isCloseTo(0.0, within(1e-9)); + } +} diff --git a/trueref-adapters/src/test/java/com/trueref/adapter/out/parsing/heuristic/HeuristicCodeParserTest.java b/trueref-adapters/src/test/java/com/trueref/adapter/out/parsing/heuristic/HeuristicCodeParserTest.java new file mode 100644 index 0000000..f3f0e3e --- /dev/null +++ b/trueref-adapters/src/test/java/com/trueref/adapter/out/parsing/heuristic/HeuristicCodeParserTest.java @@ -0,0 +1,141 @@ +package com.trueref.adapter.out.parsing.heuristic; + +import static org.assertj.core.api.Assertions.assertThat; + +import com.trueref.domain.port.out.CodeParser.ParsedChunk; +import java.nio.charset.StandardCharsets; +import java.nio.file.Files; +import java.nio.file.Path; +import java.util.List; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.io.TempDir; + +class HeuristicCodeParserTest { + + private final HeuristicCodeParser parser = new HeuristicCodeParser(); + + @Test + void supportsEveryFile(@TempDir Path tmp) throws Exception { + Path file = write(tmp, "anything.xyz", "hello\n"); + assertThat(parser.supports(file)).isTrue(); + } + + @Test + void parsesJavaIntoPreambleAndConstructs(@TempDir Path tmp) throws Exception { + String src = + """ + package com.example; + + import java.util.List; + + class Foo { + void bar() { + System.out.println("hi {"); + } + } + + class Baz { + int x; + } + """; + Path file = write(tmp, "Sample.java", src); + List chunks = parser.parse(file, "Sample.java"); + + assertThat(chunks).isNotEmpty(); + assertThat(chunks).allSatisfy(c -> { + assertThat(c.language()).isEqualTo("java"); + assertThat(c.startLine()).isPositive(); + assertThat(c.endLine()).isGreaterThanOrEqualTo(c.startLine()); + }); + assertThat(chunks).anyMatch(c -> "
".equals(c.symbol())); + assertThat(chunks).anyMatch(c -> "Foo".equals(c.symbol())); + assertThat(chunks).anyMatch(c -> "Baz".equals(c.symbol())); + } + + @Test + void parsesMarkdownOnHeadings(@TempDir Path tmp) throws Exception { + String src = + """ + Intro paragraph without heading. + + # First + + Body 1. + + ## Sub + + Body 2. + + Underlined + ========== + + Body 3. + """; + Path file = write(tmp, "doc.md", src); + List chunks = parser.parse(file, "doc.md"); + + assertThat(chunks).hasSizeGreaterThanOrEqualTo(4); + assertThat(chunks.get(0).symbol()).isNull(); + List symbols = chunks.stream().map(ParsedChunk::symbol).toList(); + assertThat(symbols).contains("First", "Sub", "Underlined"); + assertThat(chunks).allSatisfy(c -> assertThat(c.language()).isEqualTo("markdown")); + } + + @Test + void parsesPythonIndentConstructs(@TempDir Path tmp) throws Exception { + String src = + """ + import os + + def foo(): + return 1 + + class Bar: + def baz(self): + return 2 + """; + Path file = write(tmp, "mod.py", src); + List chunks = parser.parse(file, "mod.py"); + + assertThat(chunks).isNotEmpty(); + List symbols = chunks.stream().map(ParsedChunk::symbol).toList(); + assertThat(symbols).contains("foo", "Bar"); + assertThat(chunks).allSatisfy(c -> assertThat(c.language()).isEqualTo("python")); + } + + @Test + void slidingWindowFallbackForUnknownExtension(@TempDir Path tmp) throws Exception { + StringBuilder sb = new StringBuilder(); + for (int i = 1; i <= 200; i++) { + sb.append("line ").append(i).append('\n'); + } + Path file = write(tmp, "big.xyz", sb.toString()); + List chunks = parser.parse(file, "big.xyz"); + + assertThat(chunks).hasSizeGreaterThan(1); + assertThat(chunks).allSatisfy(c -> { + assertThat(c.language()).isEqualTo("text"); + assertThat(c.symbol()).isNull(); + assertThat(c.endLine() - c.startLine() + 1).isLessThanOrEqualTo(80); + }); + // Overlap check: second chunk should start before previous chunk ends. + ParsedChunk first = chunks.get(0); + ParsedChunk second = chunks.get(1); + assertThat(second.startLine()).isLessThanOrEqualTo(first.endLine()); + } + + @Test + void binaryFileYieldsEmptyList(@TempDir Path tmp) throws Exception { + Path file = tmp.resolve("blob.bin"); + byte[] bad = new byte[] {(byte) 0xC3, (byte) 0x28, (byte) 0xA0, (byte) 0xA1}; // invalid UTF-8 + Files.write(file, bad); + List chunks = parser.parse(file, "blob.bin"); + assertThat(chunks).isEmpty(); + } + + private static Path write(Path dir, String name, String content) throws Exception { + Path p = dir.resolve(name); + Files.writeString(p, content, StandardCharsets.UTF_8); + return p; + } +} diff --git a/trueref-domain/src/main/java/com/trueref/domain/port/out/ChunkStore.java b/trueref-domain/src/main/java/com/trueref/domain/port/out/ChunkStore.java new file mode 100644 index 0000000..10bd925 --- /dev/null +++ b/trueref-domain/src/main/java/com/trueref/domain/port/out/ChunkStore.java @@ -0,0 +1,44 @@ +package com.trueref.domain.port.out; + +import com.trueref.domain.model.Chunk; +import com.trueref.domain.model.ChunkId; +import com.trueref.domain.model.ChunkVersion; +import com.trueref.domain.model.Embedding; +import com.trueref.domain.model.SearchHit; +import com.trueref.domain.model.SearchScope; +import com.trueref.domain.model.VersionId; +import java.util.Collection; +import java.util.List; +import java.util.Optional; +import java.util.Set; + +/** + * Combined chunk metadata + vector + lexical store. The Lucene adapter implements this with a + * single index that holds both BM25-tokenized text and an HNSW kNN field. + */ +public interface ChunkStore { + + /** Looks up an existing chunk by content hash so callers can avoid re-embedding. */ + Optional findByContentHash(String contentHash); + + /** Persists a brand-new chunk and its embedding. Idempotent on {@code contentHash}. */ + Chunk upsertChunk(Chunk chunk, Embedding embedding); + + /** Adds membership rows linking chunks to a version's source files. */ + void linkChunks(Collection links); + + /** Removes all chunk-version links for the given version. */ + void unlinkVersion(VersionId versionId); + + /** Returns the set of {@link ChunkId}s reachable from the given version. */ + Set chunkIdsForVersion(VersionId versionId); + + /** BM25 lexical search restricted to the given scope. */ + List bm25Search(String queryText, SearchScope scope, int topK); + + /** HNSW dense kNN search restricted to the given scope. */ + List denseSearch(float[] queryVector, SearchScope scope, int topK); + + /** Forces a Lucene commit. Called at the end of an indexing job. */ + void commit(); +} diff --git a/trueref-domain/src/main/java/com/trueref/domain/port/out/CodeParser.java b/trueref-domain/src/main/java/com/trueref/domain/port/out/CodeParser.java new file mode 100644 index 0000000..a4e3932 --- /dev/null +++ b/trueref-domain/src/main/java/com/trueref/domain/port/out/CodeParser.java @@ -0,0 +1,24 @@ +package com.trueref.domain.port.out; + +import java.nio.file.Path; +import java.util.List; +import org.jspecify.annotations.Nullable; + +/** + * Parses a single source file into a list of code chunks. Implementations use tree-sitter for + * supported languages and a sliding-window text splitter as fallback. + */ +public interface CodeParser { + + /** True if at least one grammar can handle the given file. */ + boolean supports(Path file); + + List parse(Path file, String repoRelativePath); + + record ParsedChunk( + String content, + String language, + @Nullable String symbol, + int startLine, + int endLine) {} +} diff --git a/trueref-domain/src/main/java/com/trueref/domain/port/out/EmbeddingCache.java b/trueref-domain/src/main/java/com/trueref/domain/port/out/EmbeddingCache.java new file mode 100644 index 0000000..a76b0dd --- /dev/null +++ b/trueref-domain/src/main/java/com/trueref/domain/port/out/EmbeddingCache.java @@ -0,0 +1,11 @@ +package com.trueref.domain.port.out; + +import java.util.Optional; + +/** Persistent on-disk cache of embedding vectors keyed by content hash. */ +public interface EmbeddingCache { + + Optional get(String contentHash); + + void put(String contentHash, float[] vector); +} diff --git a/trueref-domain/src/main/java/com/trueref/domain/port/out/EmbeddingService.java b/trueref-domain/src/main/java/com/trueref/domain/port/out/EmbeddingService.java new file mode 100644 index 0000000..fa76209 --- /dev/null +++ b/trueref-domain/src/main/java/com/trueref/domain/port/out/EmbeddingService.java @@ -0,0 +1,13 @@ +package com.trueref.domain.port.out; + +import java.util.List; + +/** Generates dense embedding vectors. Implementations are expected to be batch-friendly. */ +public interface EmbeddingService { + + /** Embedding dimensionality of the underlying model. */ + int dimension(); + + /** Embeds a batch of texts. Implementations should call out to GPU through a semaphore. */ + List embed(List texts); +} diff --git a/trueref-domain/src/main/java/com/trueref/domain/port/out/GitClient.java b/trueref-domain/src/main/java/com/trueref/domain/port/out/GitClient.java new file mode 100644 index 0000000..bf3c335 --- /dev/null +++ b/trueref-domain/src/main/java/com/trueref/domain/port/out/GitClient.java @@ -0,0 +1,50 @@ +package com.trueref.domain.port.out; + +import java.nio.file.Path; +import java.util.List; +import org.jspecify.annotations.Nullable; + +/** Git operations the application needs. Implementations: JGit. */ +public interface GitClient { + + /** Clones a remote repository to a local directory (no-op if it already exists). */ + void cloneRepo(String remoteUrl, Path localPath); + + /** Runs git fetch on an existing local repository (no-op for non-managed repos that lack a remote). */ + void fetch(Path localPath); + + /** Lists tags currently present in the local repository. */ + List listTags(Path localPath); + + /** Resolves a ref (tag or branch) to its commit SHA. */ + String resolveRef(Path localPath, String ref); + + /** + * Checks the given ref out into a transient worktree directory the caller is responsible for + * cleaning up. Returns the worktree root. + */ + Path checkoutWorktree(Path repoPath, String ref); + + /** Removes a worktree previously created by {@link #checkoutWorktree}. */ + void removeWorktree(Path repoPath, Path worktree); + + /** + * Returns the list of files changed between two commits, classified by status. + * + * @param baseRef the previously indexed tag/commit (may be null when there is no parent) + */ + List diff(Path repoPath, @Nullable String baseRef, String headRef); + + record TagInfo(String name, String commitSha, long taggerEpochSeconds) {} + + record DiffEntry(String path, @Nullable String oldPath, ChangeType change) { + + public enum ChangeType { + ADDED, + MODIFIED, + DELETED, + RENAMED, + COPIED + } + } +} diff --git a/trueref-domain/src/main/java/com/trueref/domain/port/out/JobEventBus.java b/trueref-domain/src/main/java/com/trueref/domain/port/out/JobEventBus.java new file mode 100644 index 0000000..bd954bb --- /dev/null +++ b/trueref-domain/src/main/java/com/trueref/domain/port/out/JobEventBus.java @@ -0,0 +1,21 @@ +package com.trueref.domain.port.out; + +import com.trueref.domain.model.IngestionJob; +import com.trueref.domain.model.JobId; +import com.trueref.domain.model.JobLogEvent; +import java.util.function.Consumer; + +/** + * In-process pub/sub bus for ingestion observability events. The application publishes; REST/SSE + * adapters subscribe to fan events out to UI clients. + */ +public interface JobEventBus { + + void publishJob(IngestionJob job); + + void publishLog(JobLogEvent event); + + AutoCloseable subscribeJobs(Consumer listener); + + AutoCloseable subscribeLogs(JobId jobId, Consumer listener); +} diff --git a/trueref-domain/src/main/java/com/trueref/domain/port/out/JobStore.java b/trueref-domain/src/main/java/com/trueref/domain/port/out/JobStore.java new file mode 100644 index 0000000..b73bd92 --- /dev/null +++ b/trueref-domain/src/main/java/com/trueref/domain/port/out/JobStore.java @@ -0,0 +1,41 @@ +package com.trueref.domain.port.out; + +import com.trueref.domain.model.IngestionJob; +import com.trueref.domain.model.JobId; +import com.trueref.domain.model.JobStage; +import com.trueref.domain.model.JobStatus; +import com.trueref.domain.model.RepositoryId; +import com.trueref.domain.model.VersionId; +import java.time.Instant; +import java.util.List; +import java.util.Optional; +import org.jspecify.annotations.Nullable; + +/** Persistence SPI for ingestion jobs and their stages. */ +public interface JobStore { + + IngestionJob save(IngestionJob job); + + Optional findById(JobId id); + + List findRunning(); + + List find( + @Nullable RepositoryId repoId, @Nullable VersionId versionId, @Nullable JobStatus status, int limit); + + void updateStatus( + JobId id, + JobStatus status, + @Nullable Instant startedAt, + @Nullable Instant finishedAt); + + void upsertStage(JobStage stage); + + /** + * Marks all RUNNING and QUEUED jobs as FAILED and their RUNNING stages as FAILED. + * Called once on startup to clear jobs that were interrupted by a previous crash or restart. + * + * @return the number of jobs that were transitioned to FAILED + */ + int failStaleJobs(Instant finishedAt); +} diff --git a/trueref-domain/src/main/java/com/trueref/domain/port/out/RepositoryStore.java b/trueref-domain/src/main/java/com/trueref/domain/port/out/RepositoryStore.java new file mode 100644 index 0000000..5ad5d70 --- /dev/null +++ b/trueref-domain/src/main/java/com/trueref/domain/port/out/RepositoryStore.java @@ -0,0 +1,47 @@ +package com.trueref.domain.port.out; + +import com.trueref.domain.model.Repository; +import com.trueref.domain.model.RepositoryId; +import com.trueref.domain.model.Version; +import com.trueref.domain.model.VersionId; +import com.trueref.domain.model.VersionStatus; +import java.util.List; +import java.util.Optional; +import org.jspecify.annotations.Nullable; + +/** Persistence SPI for repositories and their versions. */ +public interface RepositoryStore { + + Repository save(Repository repo); + + Optional findById(RepositoryId id); + + Optional findByName(String name); + + List findAll(); + + void delete(RepositoryId id); + + Version saveVersion(Version version); + + Optional findVersion(VersionId id); + + Optional findVersionByTag(RepositoryId repoId, String tag); + + List findVersionsByRepo(RepositoryId repoId); + + List findVersionsByStatus(@Nullable RepositoryId repoId, VersionStatus status); + + void updateVersionStatus(VersionId id, VersionStatus status, @Nullable String errorMessage); + + /** Updates {@code chunkCount} and sets {@code indexedAt = now()}. */ + void markVersionIndexed(VersionId id, int chunkCount); + + /** + * Marks all INDEXING versions as FAILED. + * Called once on startup to clear versions whose indexing job was interrupted. + * + * @return the number of versions transitioned to FAILED + */ + int failStaleIndexingVersions(String errorMessage); +} diff --git a/trueref-domain/src/main/java/com/trueref/domain/port/out/RerankerService.java b/trueref-domain/src/main/java/com/trueref/domain/port/out/RerankerService.java new file mode 100644 index 0000000..1c484a9 --- /dev/null +++ b/trueref-domain/src/main/java/com/trueref/domain/port/out/RerankerService.java @@ -0,0 +1,11 @@ +package com.trueref.domain.port.out; + +import com.trueref.domain.model.SearchHit; +import java.util.List; + +/** Cross-encoder reranker. Re-scores a candidate list against a query. */ +public interface RerankerService { + + /** Returns the candidates re-sorted by cross-encoder score, with score replaced. */ + List rerank(String query, List candidates); +} diff --git a/trueref-domain/src/main/java/com/trueref/domain/port/out/package-info.java b/trueref-domain/src/main/java/com/trueref/domain/port/out/package-info.java new file mode 100644 index 0000000..eb58db7 --- /dev/null +++ b/trueref-domain/src/main/java/com/trueref/domain/port/out/package-info.java @@ -0,0 +1,6 @@ +/** + * Driven ports — SPIs implemented by adapters (persistence, vector store, embedding service, git, + * parser, etc.) and called by the application layer. + */ +@org.jspecify.annotations.NullMarked +package com.trueref.domain.port.out;