106 lines
3.1 KiB
TypeScript
106 lines
3.1 KiB
TypeScript
/**
|
|
* EmbeddingService — batches embedding requests and persists results to
|
|
* the snippet_embeddings table.
|
|
*/
|
|
|
|
import type Database from 'better-sqlite3';
|
|
import type { EmbeddingProvider } from './provider.js';
|
|
|
|
interface SnippetRow {
|
|
id: string;
|
|
title: string | null;
|
|
breadcrumb: string | null;
|
|
content: string;
|
|
}
|
|
|
|
const BATCH_SIZE = 50;
|
|
const TEXT_MAX_CHARS = 2048;
|
|
|
|
export class EmbeddingService {
|
|
constructor(
|
|
private readonly db: Database.Database,
|
|
private readonly provider: EmbeddingProvider,
|
|
private readonly profileId: string = 'local-default'
|
|
) {}
|
|
|
|
/**
|
|
* Embed the given snippet IDs and store the results in snippet_embeddings.
|
|
*
|
|
* Only snippets that actually exist in the database are processed.
|
|
* Results are upserted (INSERT OR REPLACE) so re-embedding is idempotent.
|
|
*
|
|
* @param snippetIds - Array of snippet UUIDs to embed.
|
|
* @param onProgress - Optional callback invoked after each batch with
|
|
* (completedCount, totalCount).
|
|
*/
|
|
async embedSnippets(
|
|
snippetIds: string[],
|
|
onProgress?: (done: number, total: number) => void
|
|
): Promise<void> {
|
|
if (snippetIds.length === 0) return;
|
|
|
|
const placeholders = snippetIds.map(() => '?').join(',');
|
|
const snippets = this.db
|
|
.prepare<
|
|
string[],
|
|
SnippetRow
|
|
>(`SELECT id, title, breadcrumb, content FROM snippets WHERE id IN (${placeholders})`)
|
|
.all(...snippetIds);
|
|
|
|
if (snippets.length === 0) return;
|
|
|
|
const texts = snippets.map((s) =>
|
|
[s.title, s.breadcrumb, s.content].filter(Boolean).join('\n').slice(0, TEXT_MAX_CHARS)
|
|
);
|
|
|
|
const insert = this.db.prepare<[string, string, string, number, Buffer]>(`
|
|
INSERT OR REPLACE INTO snippet_embeddings (snippet_id, profile_id, model, dimensions, embedding, created_at)
|
|
VALUES (?, ?, ?, ?, ?, unixepoch())
|
|
`);
|
|
|
|
for (let i = 0; i < snippets.length; i += BATCH_SIZE) {
|
|
const batchSnippets = snippets.slice(i, i + BATCH_SIZE);
|
|
const batchTexts = texts.slice(i, i + BATCH_SIZE);
|
|
|
|
const embeddings = await this.provider.embed(batchTexts);
|
|
|
|
const insertMany = this.db.transaction(() => {
|
|
for (let j = 0; j < batchSnippets.length; j++) {
|
|
const snippet = batchSnippets[j];
|
|
const embedding = embeddings[j];
|
|
insert.run(
|
|
snippet.id,
|
|
this.profileId,
|
|
embedding.model,
|
|
embedding.dimensions,
|
|
Buffer.from(embedding.values.buffer)
|
|
);
|
|
}
|
|
});
|
|
insertMany();
|
|
|
|
onProgress?.(Math.min(i + BATCH_SIZE, snippets.length), snippets.length);
|
|
}
|
|
}
|
|
|
|
/**
|
|
* Retrieve a stored embedding for a snippet as a Float32Array.
|
|
* Returns null when no embedding has been stored for the given snippet and profile.
|
|
*
|
|
* @param snippetId - Snippet UUID
|
|
* @param profileId - Embedding profile ID (default: 'local-default')
|
|
*/
|
|
getEmbedding(snippetId: string, profileId: string = 'local-default'): Float32Array | null {
|
|
const row = this.db
|
|
.prepare<
|
|
[string, string],
|
|
{ embedding: Buffer; dimensions: number }
|
|
>(`SELECT embedding, dimensions FROM snippet_embeddings WHERE snippet_id = ? AND profile_id = ?`)
|
|
.get(snippetId, profileId);
|
|
|
|
if (!row) return null;
|
|
|
|
return new Float32Array(row.embedding.buffer, row.embedding.byteOffset, row.dimensions);
|
|
}
|
|
}
|