/** * Unit tests for the embedding provider abstraction and EmbeddingService * storage logic. * * Tests use in-memory SQLite and mock providers — no real API calls are made. */ import { describe, it, expect, beforeEach, vi } from 'vitest'; import Database from 'better-sqlite3'; import { drizzle } from 'drizzle-orm/better-sqlite3'; import { migrate } from 'drizzle-orm/better-sqlite3/migrator'; import { readFileSync } from 'node:fs'; import { join } from 'node:path'; import * as schema from '../db/schema.js'; import { NoopEmbeddingProvider, EmbeddingError, type EmbeddingVector } from './provider.js'; import { OpenAIEmbeddingProvider } from './openai.provider.js'; import { EmbeddingService } from './embedding.service.js'; import { createProviderFromConfig, defaultEmbeddingConfig, EMBEDDING_CONFIG_KEY, type EmbeddingConfig } from './factory.js'; import { createProviderFromProfile } from './registry.js'; // --------------------------------------------------------------------------- // Test DB helpers // --------------------------------------------------------------------------- function createTestDb() { const client = new Database(':memory:'); client.pragma('foreign_keys = ON'); const db = drizzle(client, { schema }); const migrationsFolder = join(import.meta.dirname, '../db/migrations'); migrate(db, { migrationsFolder }); const ftsSql = readFileSync(join(import.meta.dirname, '../db/fts.sql'), 'utf-8'); client.exec(ftsSql); return { db, client }; } const now = new Date(); function seedSnippet( db: ReturnType, client: Database.Database, overrides: Partial = {} ): string { const repoId = '/test/embed-repo'; // Ensure repo exists (ignore if already there). try { db.insert(schema.repositories) .values({ id: repoId, title: 'Embed Repo', source: 'github', sourceUrl: 'https://github.com/test/embed-repo', createdAt: now, updatedAt: now }) .run(); } catch { // already exists } const docId = crypto.randomUUID(); db.insert(schema.documents) .values({ id: docId, repositoryId: repoId, filePath: 'README.md', checksum: 'abc', indexedAt: now }) .run(); const snippetId = crypto.randomUUID(); db.insert(schema.snippets) .values({ id: snippetId, documentId: docId, repositoryId: repoId, type: 'info', content: 'Hello embedding world', title: 'Embed snippet', createdAt: now, ...overrides }) .run(); return snippetId; } // --------------------------------------------------------------------------- // NoopEmbeddingProvider // --------------------------------------------------------------------------- describe('NoopEmbeddingProvider', () => { it('returns an empty array for any input', async () => { const provider = new NoopEmbeddingProvider(); const result = await provider.embed(['text1', 'text2']); expect(result).toEqual([]); }); it('isAvailable() returns false', async () => { const provider = new NoopEmbeddingProvider(); expect(await provider.isAvailable()).toBe(false); }); it('has the expected name and dimensions', () => { const provider = new NoopEmbeddingProvider(); expect(provider.name).toBe('noop'); expect(provider.dimensions).toBe(0); expect(provider.model).toBe('none'); }); }); // --------------------------------------------------------------------------- // OpenAIEmbeddingProvider (with fetch mocking) // --------------------------------------------------------------------------- describe('OpenAIEmbeddingProvider', () => { function makeFakeEmbedding(dim: number, index = 0): number[] { return Array.from({ length: dim }, (_, i) => (i + index + 1) * 0.01); } function mockFetchSuccess(embeddings: number[][]) { const data = embeddings.map((emb, index) => ({ embedding: emb, index })); vi.stubGlobal( 'fetch', vi.fn().mockResolvedValue({ ok: true, json: async () => ({ data }) }) ); } function mockFetchFailure(status: number) { vi.stubGlobal( 'fetch', vi.fn().mockResolvedValue({ ok: false, status, json: async () => ({ error: { message: 'Bad request' } }) }) ); } beforeEach(() => { vi.restoreAllMocks(); }); it('embeds texts and returns Float32Array vectors', async () => { const emb = makeFakeEmbedding(4); mockFetchSuccess([emb]); const provider = new OpenAIEmbeddingProvider({ baseUrl: 'https://api.openai.com/v1', apiKey: 'test-key', model: 'text-embedding-3-small' }); const result = await provider.embed(['hello world']); expect(result).toHaveLength(1); expect(result[0].model).toBe('text-embedding-3-small'); expect(result[0].dimensions).toBe(4); expect(result[0].values).toBeInstanceOf(Float32Array); expect(result[0].values[0]).toBeCloseTo(0.01, 5); }); it('batches large input into multiple fetch calls', async () => { // Make fetch always succeed with 2 fake embeddings. const fetchMock = vi.fn().mockResolvedValue({ ok: true, json: async () => ({ data: [ { embedding: makeFakeEmbedding(2, 0), index: 0 }, { embedding: makeFakeEmbedding(2, 1), index: 1 } ] }) }); vi.stubGlobal('fetch', fetchMock); const provider = new OpenAIEmbeddingProvider({ baseUrl: 'https://api.openai.com/v1', apiKey: 'sk-test', model: 'text-embedding-3-small', maxBatchSize: 2 }); // 4 texts with maxBatchSize=2 → 2 fetch calls. const result = await provider.embed(['a', 'b', 'c', 'd']); expect(fetchMock).toHaveBeenCalledTimes(2); expect(result).toHaveLength(4); }); it('throws EmbeddingError on API failure', async () => { mockFetchFailure(400); const provider = new OpenAIEmbeddingProvider({ baseUrl: 'https://api.openai.com/v1', apiKey: 'bad-key', model: 'text-embedding-3-small' }); await expect(provider.embed(['hello'])).rejects.toThrow(EmbeddingError); }); it('includes dimensions in request body when configured', async () => { const fetchMock = vi.fn().mockResolvedValue({ ok: true, json: async () => ({ data: [{ embedding: [0.1, 0.2], index: 0 }] }) }); vi.stubGlobal('fetch', fetchMock); const provider = new OpenAIEmbeddingProvider({ baseUrl: 'https://api.openai.com/v1', apiKey: 'sk-test', model: 'text-embedding-3-small', dimensions: 512 }); await provider.embed(['test']); const callBody = JSON.parse(fetchMock.mock.calls[0][1].body); expect(callBody.dimensions).toBe(512); }); it('omits dimensions field when not configured', async () => { const fetchMock = vi.fn().mockResolvedValue({ ok: true, json: async () => ({ data: [{ embedding: [0.1], index: 0 }] }) }); vi.stubGlobal('fetch', fetchMock); const provider = new OpenAIEmbeddingProvider({ baseUrl: 'https://api.openai.com/v1', apiKey: 'sk-test', model: 'nomic-embed-text' }); await provider.embed(['test']); const callBody = JSON.parse(fetchMock.mock.calls[0][1].body); expect(callBody.dimensions).toBeUndefined(); }); }); // --------------------------------------------------------------------------- // Migration Tests — embedding_profiles table // --------------------------------------------------------------------------- describe('Migration — embedding_profiles', () => { it('creates the embedding_profiles table', () => { const { client } = createTestDb(); const tables = client .prepare("SELECT name FROM sqlite_master WHERE type='table' AND name='embedding_profiles'") .all(); expect(tables).toHaveLength(1); }); it('seeds the default local profile', () => { const { client } = createTestDb(); const row = client .prepare("SELECT * FROM embedding_profiles WHERE id = 'local-default'") .get() as any; expect(row).toBeDefined(); expect(row.is_default).toBe(1); expect(row.provider_kind).toBe('local-transformers'); expect(row.model).toBe('Xenova/all-MiniLM-L6-v2'); expect(row.dimensions).toBe(384); }); }); // --------------------------------------------------------------------------- // Provider Registry Tests // --------------------------------------------------------------------------- describe('Provider Registry', () => { it('creates LocalEmbeddingProvider for local-transformers', () => { const profile: schema.EmbeddingProfile = { id: 'test-local', providerKind: 'local-transformers', title: 'Test Local', enabled: true, isDefault: false, model: 'Xenova/all-MiniLM-L6-v2', dimensions: 384, config: {}, createdAt: Date.now(), updatedAt: Date.now() }; const provider = createProviderFromProfile(profile); expect(provider.name).toBe('local'); expect(provider.model).toBe('Xenova/all-MiniLM-L6-v2'); expect(provider.dimensions).toBe(384); }); it('creates OpenAIEmbeddingProvider for openai-compatible', () => { const profile: schema.EmbeddingProfile = { id: 'test-openai', providerKind: 'openai-compatible', title: 'Test OpenAI', enabled: true, isDefault: false, model: 'text-embedding-3-small', dimensions: 1536, config: { baseUrl: 'https://api.openai.com/v1', apiKey: 'test-key', model: 'text-embedding-3-small' }, createdAt: Date.now(), updatedAt: Date.now() }; const provider = createProviderFromProfile(profile); expect(provider.name).toBe('openai'); expect(provider.model).toBe('text-embedding-3-small'); }); it('returns NoopEmbeddingProvider for unknown providerKind', () => { const profile: schema.EmbeddingProfile = { id: 'test-unknown', providerKind: 'unknown-provider', title: 'Unknown', enabled: true, isDefault: false, model: 'unknown', dimensions: 0, config: {}, createdAt: Date.now(), updatedAt: Date.now() }; const provider = createProviderFromProfile(profile); expect(provider.name).toBe('noop'); }); }); // --------------------------------------------------------------------------- // EmbeddingService — storage logic // --------------------------------------------------------------------------- describe('EmbeddingService', () => { let client: Database.Database; let db: ReturnType>; beforeEach(() => { ({ client, db } = createTestDb()); }); function makeProvider(dim: number, modelName = 'test-model') { return { name: 'mock', dimensions: dim, model: modelName, async embed(texts: string[]): Promise { return texts.map(() => ({ values: new Float32Array(Array.from({ length: dim }, (_, i) => i * 0.1)), dimensions: dim, model: modelName })); }, async isAvailable() { return true; } }; } it('stores embeddings in snippet_embeddings table', async () => { const snippetId = seedSnippet(db, client); const provider = makeProvider(4); const service = new EmbeddingService(client, provider, 'local-default'); await service.embedSnippets([snippetId]); const rows = client .prepare('SELECT * FROM snippet_embeddings WHERE snippet_id = ? AND profile_id = ?') .all(snippetId, 'local-default'); expect(rows).toHaveLength(1); const row = rows[0] as { model: string; dimensions: number; embedding: Buffer; profile_id: string; }; expect(row.model).toBe('test-model'); expect(row.dimensions).toBe(4); expect(row.profile_id).toBe('local-default'); expect(row.embedding).toBeInstanceOf(Buffer); }); it('stores embeddings as retrievable Float32Array blobs', async () => { const snippetId = seedSnippet(db, client); const provider = makeProvider(3); const service = new EmbeddingService(client, provider, 'local-default'); await service.embedSnippets([snippetId]); const embedding = service.getEmbedding(snippetId, 'local-default'); expect(embedding).toBeInstanceOf(Float32Array); expect(embedding).toHaveLength(3); expect(embedding![0]).toBeCloseTo(0.0, 5); expect(embedding![1]).toBeCloseTo(0.1, 5); expect(embedding![2]).toBeCloseTo(0.2, 5); }); it('is idempotent — re-embedding replaces the existing row', async () => { const snippetId = seedSnippet(db, client); const provider = makeProvider(2); const service = new EmbeddingService(client, provider); await service.embedSnippets([snippetId]); await service.embedSnippets([snippetId]); const rows = client .prepare('SELECT COUNT(*) as cnt FROM snippet_embeddings WHERE snippet_id = ?') .get(snippetId) as { cnt: number }; expect(rows.cnt).toBe(1); }); it('calls onProgress after each batch', async () => { const ids: string[] = []; for (let i = 0; i < 3; i++) { ids.push(seedSnippet(db, client)); } const provider = makeProvider(2); const service = new EmbeddingService(client, provider); const progress: Array<[number, number]> = []; await service.embedSnippets(ids, (done, total) => { progress.push([done, total]); }); // With BATCH_SIZE=50 and 3 items, we expect exactly one progress call. expect(progress.length).toBeGreaterThan(0); expect(progress[progress.length - 1][0]).toBe(3); expect(progress[progress.length - 1][1]).toBe(3); }); it('handles empty snippetIds gracefully', async () => { const provider = makeProvider(4); const service = new EmbeddingService(client, provider); // Should not throw. await expect(service.embedSnippets([])).resolves.toBeUndefined(); }); it('returns null from getEmbedding when no embedding exists', () => { const provider = makeProvider(4); const service = new EmbeddingService(client, provider); const result = service.getEmbedding('nonexistent-id'); expect(result).toBeNull(); }); it('ignores snippet IDs that do not exist in the database', async () => { const provider = makeProvider(4); const service = new EmbeddingService(client, provider); // Should complete without error. await expect(service.embedSnippets(['ghost-id-1', 'ghost-id-2'])).resolves.toBeUndefined(); const rows = client.prepare('SELECT COUNT(*) as cnt FROM snippet_embeddings').get() as { cnt: number; }; expect(rows.cnt).toBe(0); }); }); // --------------------------------------------------------------------------- // Factory // --------------------------------------------------------------------------- describe('createProviderFromConfig', () => { it('returns NoopEmbeddingProvider for provider=none', () => { const provider = createProviderFromConfig({ provider: 'none' }); expect(provider.name).toBe('noop'); }); it('returns OpenAIEmbeddingProvider for provider=openai', () => { const provider = createProviderFromConfig({ provider: 'openai', openai: { baseUrl: 'https://api.openai.com/v1', apiKey: 'sk-test', model: 'text-embedding-3-small' } }); expect(provider.name).toBe('openai'); }); it('returns LocalEmbeddingProvider for provider=local', () => { const provider = createProviderFromConfig({ provider: 'local' }); expect(provider.name).toBe('local'); }); it('throws when openai provider is selected without config', () => { expect(() => createProviderFromConfig({ provider: 'openai' } as EmbeddingConfig)).toThrow(); }); it('defaultEmbeddingConfig returns provider=none', () => { expect(defaultEmbeddingConfig().provider).toBe('none'); }); it('EMBEDDING_CONFIG_KEY is the expected settings key', () => { expect(EMBEDDING_CONFIG_KEY).toBe('embedding_config'); }); });