turnstone/tests/context/test_embedder.py
pyr0ball f7bcc6c9b7 refactor: extract embeddings service layer — decouple context embedder from Ollama
- New app/services/embeddings.py: TURNSTONE_EMBED_* env vars, multi-backend support
- embedder.py delegates to service layer; re-exports EMBEDDING_AVAILABLE for compat
- retriever.py updated to use service layer
- Test coverage updated in tests/context/test_embedder.py
2026-05-25 11:01:25 -07:00

101 lines
3.9 KiB
Python

"""Tests for app/context/embedder.py — delegates to app.services.embeddings."""
import sqlite3
import struct
from pathlib import Path
from unittest.mock import MagicMock, patch
import numpy as np
import pytest
from app.context import embedder as emb_mod
@pytest.fixture()
def db(tmp_path: Path) -> Path:
db_path = tmp_path / "t.db"
conn = sqlite3.connect(str(db_path))
conn.executescript("""
CREATE TABLE context_documents (
id TEXT PRIMARY KEY, filename TEXT NOT NULL, doc_type TEXT NOT NULL,
full_text TEXT NOT NULL, file_size INTEGER, uploaded_at TEXT NOT NULL
);
CREATE TABLE context_chunks (
id TEXT PRIMARY KEY, document_id TEXT NOT NULL
REFERENCES context_documents(id) ON DELETE CASCADE,
chunk_index INTEGER NOT NULL, text TEXT NOT NULL, embedding BLOB
);
INSERT INTO context_documents
VALUES ('d1','test.md','markdown','hello',5,'2026-01-01T00:00:00+00:00');
INSERT INTO context_chunks VALUES ('c1','d1',0,'hello world',NULL);
INSERT INTO context_chunks VALUES ('c2','d1',1,'second chunk',NULL);
""")
conn.commit()
conn.close()
return db_path
def _mock_embedder(dim: int = 3) -> MagicMock:
"""Return a mock Embedder that returns constant dim-length vectors."""
m = MagicMock()
m.dim = dim
m.embed_batch.return_value = [np.zeros(dim, dtype=np.float32)] * 10
return m
class TestEmbedChunks:
def test_returns_zero_when_no_embedder(self, db: Path) -> None:
with patch("app.context.embedder.get_embedder", return_value=None):
count = emb_mod.embed_chunks(db, "d1")
assert count == 0
def test_returns_zero_when_no_unembedded_chunks(self, db: Path) -> None:
# Pre-fill both chunks with a blob
blob = struct.pack("3f", 0.1, 0.2, 0.3)
conn = sqlite3.connect(str(db))
conn.execute("UPDATE context_chunks SET embedding=?", (blob,))
conn.commit()
conn.close()
embedder = _mock_embedder()
with patch("app.context.embedder.get_embedder", return_value=embedder):
count = emb_mod.embed_chunks(db, "d1")
assert count == 0
embedder.embed_batch.assert_not_called()
def test_embeds_all_null_chunks(self, db: Path) -> None:
embedder = _mock_embedder(dim=3)
with patch("app.context.embedder.get_embedder", return_value=embedder):
count = emb_mod.embed_chunks(db, "d1")
assert count == 2 # two chunks in fixture
def test_blobs_written_to_db(self, db: Path) -> None:
vec = np.array([0.1, 0.2, 0.3], dtype=np.float32)
embedder = _mock_embedder(dim=3)
embedder.embed_batch.return_value = [vec, vec]
with patch("app.context.embedder.get_embedder", return_value=embedder):
emb_mod.embed_chunks(db, "d1")
conn = sqlite3.connect(str(db))
rows = conn.execute(
"SELECT embedding FROM context_chunks WHERE document_id='d1'"
).fetchall()
conn.close()
for (blob,) in rows:
assert blob is not None
unpacked = struct.unpack(f"{len(blob)//4}f", blob)
assert len(unpacked) == 3
def test_legacy_llm_url_param_accepted(self, db: Path) -> None:
"""Ensure backward-compat signature still works (llm_url ignored)."""
embedder = _mock_embedder()
with patch("app.context.embedder.get_embedder", return_value=embedder):
count = emb_mod.embed_chunks(db, "d1", "http://localhost:11434", "nomic-embed-text")
assert count == 2
def test_embed_batch_error_returns_zero(self, db: Path) -> None:
embedder = _mock_embedder()
embedder.embed_batch.side_effect = RuntimeError("model exploded")
with patch("app.context.embedder.get_embedder", return_value=embedder):
count = emb_mod.embed_chunks(db, "d1")
assert count == 0