diff --git a/app/context/embedder.py b/app/context/embedder.py index 519870d..7bd17e0 100644 --- a/app/context/embedder.py +++ b/app/context/embedder.py @@ -1,64 +1,81 @@ -"""Ollama embedding client with sqlite-vec storage — BSL licensed.""" +"""Context chunk embedding — BSL licensed. + +Thin wrapper around app.services.embeddings that handles the DB I/O for +context_chunks. All backend configuration (model, device, backend type) is +delegated to the service layer via TURNSTONE_EMBED_* env vars. + +Re-exports EMBEDDING_AVAILABLE so callers that imported it from here continue +to work without changes. +""" from __future__ import annotations import logging import sqlite3 -import struct from pathlib import Path -import httpx +from app.services.embeddings import ( + EMBEDDING_AVAILABLE, # re-export for backward compat + get_embedder, + pack_vector, +) + +__all__ = ["EMBEDDING_AVAILABLE", "embed_chunks"] logger = logging.getLogger(__name__) -EMBEDDING_AVAILABLE: bool = False - -try: - import sqlite_vec # type: ignore[import] # noqa: F401 - EMBEDDING_AVAILABLE = True - logger.debug("sqlite-vec loaded — embedding pipeline enabled") -except ImportError: - logger.debug("sqlite-vec not available — embedding pipeline disabled") - def embed_chunks( db_path: Path, document_id: str, - llm_url: str, - model: str = "nomic-embed-text", + # Legacy params kept for backward compat — ignored when the ST backend is active. + llm_url: str = "", + model: str = "", timeout: float = 60.0, ) -> int: - """Embed all unembedded chunks for a document. Returns count embedded. No-op when EMBEDDING_AVAILABLE is False.""" - if not EMBEDDING_AVAILABLE: + """Embed all un-embedded chunks for *document_id*. + + Uses the configured embedder (sentence-transformers by default; Ollama when + TURNSTONE_EMBED_BACKEND=ollama). Returns the count of newly embedded chunks. + Returns 0 silently when no embedder is available. + + The legacy ``llm_url`` and ``model`` parameters are accepted but ignored when + the sentence-transformers backend is active — configure via env vars instead. + """ + embedder = get_embedder() + if embedder is None: return 0 conn = sqlite3.connect(str(db_path)) conn.execute("PRAGMA journal_mode=WAL") conn.row_factory = sqlite3.Row + rows = conn.execute( - "SELECT id, text FROM context_chunks WHERE document_id=? AND embedding IS NULL", + "SELECT id, text FROM context_chunks WHERE document_id = ? AND embedding IS NULL", (document_id,), ).fetchall() - count = 0 - for row in rows: - try: - resp = httpx.post( - f"{llm_url.rstrip('/')}/api/embeddings", - json={"model": model, "prompt": row["text"]}, - timeout=timeout, - ) - resp.raise_for_status() - vector: list[float] = resp.json().get("embedding") or [] - if vector: - blob = struct.pack(f"{len(vector)}f", *vector) - conn.execute( - "UPDATE context_chunks SET embedding=? WHERE id=?", - (blob, row["id"]), - ) - count += 1 - except Exception as exc: - logger.warning("Embedding chunk %s failed: %s", row["id"], exc) + if not rows: + conn.close() + return 0 - conn.commit() - conn.close() + texts = [r["text"] for r in rows] + ids = [r["id"] for r in rows] + + count = 0 + try: + vectors = embedder.embed_batch(texts) + for chunk_id, vec in zip(ids, vectors): + blob = pack_vector(vec) + conn.execute( + "UPDATE context_chunks SET embedding = ? WHERE id = ?", + (blob, chunk_id), + ) + count += 1 + conn.commit() + except Exception as exc: + logger.warning("Batch embedding failed for document %s: %s", document_id, exc) + finally: + conn.close() + + logger.debug("Embedded %d chunk(s) for document %s", count, document_id) return count diff --git a/app/context/retriever.py b/app/context/retriever.py index 6b42c8e..c4b511e 100644 --- a/app/context/retriever.py +++ b/app/context/retriever.py @@ -1,10 +1,30 @@ -"""Context retrieval — structured keyword lookup (Free) + chunk search — MIT licensed.""" +"""Context retrieval — structured keyword lookup (Free) + chunk search — MIT licensed. + +Two retrieval modes for context_chunks: + Vector search — cosine similarity over stored embeddings (when available) + Keyword search — LIKE-based fallback when no embedder is configured + +Both modes are called from retrieve_context(); the best available mode is used +automatically so callers need not check EMBEDDING_AVAILABLE themselves. +""" from __future__ import annotations +import logging import sqlite3 from dataclasses import dataclass, field from pathlib import Path +import numpy as np + +from app.services.embeddings import ( + EMBEDDING_AVAILABLE, + cosine_similarity, + get_embedder, + unpack_vector, +) + +logger = logging.getLogger(__name__) + @dataclass class RetrievedContext: @@ -12,6 +32,8 @@ class RetrievedContext: chunks: list[dict[str, str]] = field(default_factory=list) +# ── Structured fact retrieval (always runs) ─────────────────────────────────── + def get_relevant_facts(db_path: Path, query: str) -> list[dict[str, str]]: """Keyword match against context_facts. Always runs — Free tier.""" try: @@ -42,8 +64,68 @@ def get_relevant_facts(db_path: Path, query: str) -> list[dict[str, str]]: return [] -def _search_chunks(db_path: Path, query: str) -> list[dict[str, str]]: - """Keyword search across context_chunks. Fallback when no embeddings.""" +# ── Chunk retrieval: vector path ────────────────────────────────────────────── + +def _search_chunks_vector( + db_path: Path, + query: str, + top_k: int = 3, +) -> list[dict[str, str]]: + """Cosine similarity search over embedded context_chunks. + + Loads all stored embeddings into memory and scores in-process with numpy. + Skips any chunk whose BLOB dimension does not match the current model dim + (stale embeddings from a previous model — they will be re-embedded on the + next document upload). + + Returns at most *top_k* results ordered by similarity descending. + """ + embedder = get_embedder() + if embedder is None: + return [] + + try: + query_vec: np.ndarray = embedder.embed(query) + model_dim: int = embedder.dim + except Exception as exc: + logger.warning("Query embedding failed: %s", exc) + return [] + + try: + conn = sqlite3.connect(str(db_path)) + conn.execute("PRAGMA journal_mode=WAL") + conn.row_factory = sqlite3.Row + rows = conn.execute( + "SELECT cc.id, cc.text, cc.embedding, cd.filename" + " FROM context_chunks cc" + " JOIN context_documents cd ON cc.document_id = cd.id" + " WHERE cc.embedding IS NOT NULL" + ).fetchall() + conn.close() + except sqlite3.OperationalError: + return [] + + scored: list[tuple[float, dict[str, str]]] = [] + for row in rows: + blob: bytes = row["embedding"] + # Guard against blobs from a different-dimension model + if len(blob) // 4 != model_dim: + continue + try: + chunk_vec = unpack_vector(blob) + score = cosine_similarity(query_vec, chunk_vec) + scored.append((score, {"text": row["text"], "filename": row["filename"]})) + except Exception: + continue + + scored.sort(key=lambda t: t[0], reverse=True) + return [item for _, item in scored[:top_k]] + + +# ── Chunk retrieval: keyword fallback ───────────────────────────────────────── + +def _search_chunks_keyword(db_path: Path, query: str) -> list[dict[str, str]]: + """LIKE-based keyword search across context_chunks. Fallback when no embedder.""" try: conn = sqlite3.connect(str(db_path)) conn.execute("PRAGMA journal_mode=WAL") @@ -66,16 +148,29 @@ def _search_chunks(db_path: Path, query: str) -> list[dict[str, str]]: return [] +# ── Public interface ────────────────────────────────────────────────────────── + def retrieve_context(db_path: Path, query: str) -> RetrievedContext: - """Retrieve structured facts and relevant chunks for a query.""" - return RetrievedContext( - facts=get_relevant_facts(db_path, query), - chunks=_search_chunks(db_path, query), - ) + """Retrieve structured facts and relevant chunks for a query. + + Chunk retrieval uses vector search when an embedder is available and at + least one embedded chunk exists; falls back to keyword search otherwise. + """ + facts = get_relevant_facts(db_path, query) + + if EMBEDDING_AVAILABLE: + chunks = _search_chunks_vector(db_path, query) + if not chunks: + # Vector search returned nothing (no embedded chunks yet) — fall back. + chunks = _search_chunks_keyword(db_path, query) + else: + chunks = _search_chunks_keyword(db_path, query) + + return RetrievedContext(facts=facts, chunks=chunks) def format_context_block(ctx: RetrievedContext) -> str | None: - """Format context for injection into LLM prompt. Returns None when empty.""" + """Format context for injection into an LLM prompt. Returns None when empty.""" lines: list[str] = [] if ctx.facts: lines.append("Known environment facts:") diff --git a/app/services/embeddings.py b/app/services/embeddings.py new file mode 100644 index 0000000..7e9b30a --- /dev/null +++ b/app/services/embeddings.py @@ -0,0 +1,229 @@ +"""Configurable embedding service — BSL licensed. + +Backends: + sentence_transformers — local in-process inference (default, no server needed) + ollama — HTTP to a running Ollama instance + +Configuration (env vars): + TURNSTONE_EMBED_BACKEND sentence_transformers | ollama (default: sentence_transformers) + TURNSTONE_EMBED_MODEL model name/path (backend-specific default) + TURNSTONE_EMBED_DEVICE cpu | cuda (default: cpu; ST backend only) + TURNSTONE_LLM_URL Ollama base URL (default: http://localhost:11434) + +When no backend is importable/reachable, EMBEDDING_AVAILABLE is False and all +embed calls return empty arrays — callers must handle this gracefully. +""" +from __future__ import annotations + +import logging +import os +import struct +from typing import Protocol, runtime_checkable + +import numpy as np + +logger = logging.getLogger(__name__) + +# ── Public availability flag ────────────────────────────────────────────────── + +EMBEDDING_AVAILABLE: bool = False + +# ── Config ──────────────────────────────────────────────────────────────────── + +_BACKEND = os.environ.get("TURNSTONE_EMBED_BACKEND", "sentence_transformers").lower() +_DEVICE = os.environ.get("TURNSTONE_EMBED_DEVICE", "cpu").lower() +_LLM_URL = os.environ.get("TURNSTONE_LLM_URL", "http://localhost:11434") + +# BAAI/bge-small-en-v1.5: 33MB, MIT, 49M downloads/month, 384-dim, 512-token max. +# Benchmarked as the best quality-to-size ratio in the field (MTEB 62.17). +# all-MiniLM-L6-v2 is a viable lighter alternative (23MB, 256-token max) if +# inference speed is the primary constraint. +_DEFAULT_MODEL: dict[str, str] = { + "sentence_transformers": "BAAI/bge-small-en-v1.5", + "ollama": "nomic-embed-text", +} +_MODEL = os.environ.get( + "TURNSTONE_EMBED_MODEL", + _DEFAULT_MODEL.get(_BACKEND, "sentence-transformers/all-MiniLM-L6-v2"), +) + + +# ── Protocol ────────────────────────────────────────────────────────────────── + +@runtime_checkable +class Embedder(Protocol): + """Minimal interface all embedding backends must satisfy.""" + + @property + def dim(self) -> int: + """Embedding dimension produced by this model.""" + ... + + @property + def model_name(self) -> str: + """Human-readable model identifier.""" + ... + + def embed(self, text: str) -> np.ndarray: + """Embed a single string. Returns 1-D float32 array of length dim.""" + ... + + def embed_batch(self, texts: list[str]) -> list[np.ndarray]: + """Embed a list of strings. Returns list of 1-D float32 arrays.""" + ... + + +# ── sentence-transformers backend ───────────────────────────────────────────── + +class SentenceTransformerEmbedder: + """Local in-process embedding via the sentence-transformers library. + + The model is downloaded from HuggingFace on first instantiation and cached + at ~/.cache/huggingface/. Subsequent starts use the local cache. + """ + + def __init__(self, model_name: str = _MODEL, device: str = _DEVICE) -> None: + from sentence_transformers import SentenceTransformer # type: ignore[import] + logger.info("Loading embedding model %r on device %r ...", model_name, device) + self._model = SentenceTransformer(model_name, device=device) + self._model_name = model_name + # Infer dimension from a test embed rather than hard-coding + self._dim: int = int(self._model.encode("test").shape[0]) + logger.info("Embedding model ready — dim=%d", self._dim) + + @property + def dim(self) -> int: + return self._dim + + @property + def model_name(self) -> str: + return self._model_name + + def embed(self, text: str) -> np.ndarray: + vec = self._model.encode(text, convert_to_numpy=True, normalize_embeddings=True) + return vec.astype(np.float32) + + def embed_batch(self, texts: list[str]) -> list[np.ndarray]: + if not texts: + return [] + vecs = self._model.encode( + texts, convert_to_numpy=True, normalize_embeddings=True, batch_size=32 + ) + return [v.astype(np.float32) for v in vecs] + + +# ── Ollama backend ──────────────────────────────────────────────────────────── + +class OllamaEmbedder: + """HTTP embedding via a running Ollama instance.""" + + def __init__( + self, + model_name: str = _MODEL, + llm_url: str = _LLM_URL, + timeout: float = 30.0, + ) -> None: + import httpx # already a project dependency + self._model_name = model_name + self._url = f"{llm_url.rstrip('/')}/api/embeddings" + self._timeout = timeout + self._client = httpx.Client(timeout=timeout) + # Probe dimension with a test call + self._dim = self._probe_dim() + + def _probe_dim(self) -> int: + try: + vec = self._raw_embed("probe") + return len(vec) + except Exception as exc: + logger.warning("Ollama dim probe failed (%s) — defaulting to 768", exc) + return 768 + + def _raw_embed(self, text: str) -> list[float]: + resp = self._client.post( + self._url, json={"model": self._model_name, "prompt": text} + ) + resp.raise_for_status() + return resp.json().get("embedding") or [] + + @property + def dim(self) -> int: + return self._dim + + @property + def model_name(self) -> str: + return self._model_name + + def embed(self, text: str) -> np.ndarray: + vec = self._raw_embed(text) + return np.array(vec, dtype=np.float32) + + def embed_batch(self, texts: list[str]) -> list[np.ndarray]: + return [self.embed(t) for t in texts] + + +# ── Singleton factory ───────────────────────────────────────────────────────── + +_embedder: Embedder | None = None + + +def get_embedder() -> Embedder | None: + """Return the configured embedder singleton, or None when unavailable. + + Lazy-initialises on first call. Callers should check EMBEDDING_AVAILABLE + or test for None rather than calling this unconditionally. + """ + global _embedder, EMBEDDING_AVAILABLE + if _embedder is not None: + return _embedder + + if _BACKEND == "sentence_transformers": + try: + _embedder = SentenceTransformerEmbedder(_MODEL, _DEVICE) + EMBEDDING_AVAILABLE = True + except ImportError: + logger.warning( + "sentence-transformers not installed — embeddings disabled. " + "Install with: pip install sentence-transformers" + ) + except Exception as exc: + logger.warning("Failed to load sentence-transformers model %r: %s", _MODEL, exc) + + elif _BACKEND == "ollama": + try: + _embedder = OllamaEmbedder(_MODEL, _LLM_URL) + EMBEDDING_AVAILABLE = True + except Exception as exc: + logger.warning("Ollama embedder init failed: %s", exc) + + else: + logger.warning("Unknown TURNSTONE_EMBED_BACKEND %r — embeddings disabled", _BACKEND) + + return _embedder + + +# ── BLOB serialisation helpers ──────────────────────────────────────────────── + +def pack_vector(vec: np.ndarray) -> bytes: + """Serialise a float32 numpy vector to a SQLite BLOB.""" + arr = vec.astype(np.float32) + return struct.pack(f"{len(arr)}f", *arr.tolist()) + + +def unpack_vector(blob: bytes) -> np.ndarray: + """Deserialise a SQLite BLOB back to a float32 numpy vector.""" + n = len(blob) // 4 # 4 bytes per float32 + return np.array(struct.unpack(f"{n}f", blob), dtype=np.float32) + + +def cosine_similarity(a: np.ndarray, b: np.ndarray) -> float: + """Cosine similarity between two L2-normalised vectors. + + Both vectors are re-normalised defensively so callers need not pre-normalise. + Returns 0.0 when either vector has zero norm. + """ + norm_a = np.linalg.norm(a) + norm_b = np.linalg.norm(b) + if norm_a == 0.0 or norm_b == 0.0: + return 0.0 + return float(np.dot(a, b) / (norm_a * norm_b)) diff --git a/tests/context/test_embedder.py b/tests/context/test_embedder.py index cc84032..67a124f 100644 --- a/tests/context/test_embedder.py +++ b/tests/context/test_embedder.py @@ -1,13 +1,17 @@ -"""Tests for app/context/embedder.py — graceful no-op without sqlite-vec.""" +"""Tests for app/context/embedder.py — delegates to app.services.embeddings.""" import sqlite3 +import struct from pathlib import Path -from unittest.mock import patch +from unittest.mock import MagicMock, patch + +import numpy as np import pytest + from app.context import embedder as emb_mod -@pytest.fixture -def db(tmp_path): +@pytest.fixture() +def db(tmp_path: Path) -> Path: db_path = tmp_path / "t.db" conn = sqlite3.connect(str(db_path)) conn.executescript(""" @@ -20,34 +24,78 @@ def db(tmp_path): REFERENCES context_documents(id) ON DELETE CASCADE, chunk_index INTEGER NOT NULL, text TEXT NOT NULL, embedding BLOB ); - INSERT INTO context_documents VALUES ('d1','test.md','markdown','hello',5,'2026-01-01T00:00:00+00:00'); + INSERT INTO context_documents + VALUES ('d1','test.md','markdown','hello',5,'2026-01-01T00:00:00+00:00'); INSERT INTO context_chunks VALUES ('c1','d1',0,'hello world',NULL); + INSERT INTO context_chunks VALUES ('c2','d1',1,'second chunk',NULL); """) conn.commit() conn.close() return db_path -def test_embed_skipped_when_extension_absent(db): - with patch.object(emb_mod, "EMBEDDING_AVAILABLE", False): - count = emb_mod.embed_chunks(db, "d1", "http://localhost:11434") - assert count == 0 +def _mock_embedder(dim: int = 3) -> MagicMock: + """Return a mock Embedder that returns constant dim-length vectors.""" + m = MagicMock() + m.dim = dim + m.embed_batch.return_value = [np.zeros(dim, dtype=np.float32)] * 10 + return m -def test_embed_calls_ollama_when_available(db): - import httpx +class TestEmbedChunks: + def test_returns_zero_when_no_embedder(self, db: Path) -> None: + with patch("app.context.embedder.get_embedder", return_value=None): + count = emb_mod.embed_chunks(db, "d1") + assert count == 0 - class FakeResponse: - status_code = 200 - def raise_for_status(self): pass - def json(self): return {"embedding": [0.1, 0.2, 0.3]} + def test_returns_zero_when_no_unembedded_chunks(self, db: Path) -> None: + # Pre-fill both chunks with a blob + blob = struct.pack("3f", 0.1, 0.2, 0.3) + conn = sqlite3.connect(str(db)) + conn.execute("UPDATE context_chunks SET embedding=?", (blob,)) + conn.commit() + conn.close() - with patch.object(emb_mod, "EMBEDDING_AVAILABLE", True), \ - patch("app.context.embedder.httpx.post", return_value=FakeResponse()): - count = emb_mod.embed_chunks(db, "d1", "http://localhost:11434") - assert count == 1 - # Verify blob was written - conn = sqlite3.connect(str(db)) - row = conn.execute("SELECT embedding FROM context_chunks WHERE id='c1'").fetchone() - conn.close() - assert row[0] is not None + embedder = _mock_embedder() + with patch("app.context.embedder.get_embedder", return_value=embedder): + count = emb_mod.embed_chunks(db, "d1") + assert count == 0 + embedder.embed_batch.assert_not_called() + + def test_embeds_all_null_chunks(self, db: Path) -> None: + embedder = _mock_embedder(dim=3) + with patch("app.context.embedder.get_embedder", return_value=embedder): + count = emb_mod.embed_chunks(db, "d1") + assert count == 2 # two chunks in fixture + + def test_blobs_written_to_db(self, db: Path) -> None: + vec = np.array([0.1, 0.2, 0.3], dtype=np.float32) + embedder = _mock_embedder(dim=3) + embedder.embed_batch.return_value = [vec, vec] + + with patch("app.context.embedder.get_embedder", return_value=embedder): + emb_mod.embed_chunks(db, "d1") + + conn = sqlite3.connect(str(db)) + rows = conn.execute( + "SELECT embedding FROM context_chunks WHERE document_id='d1'" + ).fetchall() + conn.close() + for (blob,) in rows: + assert blob is not None + unpacked = struct.unpack(f"{len(blob)//4}f", blob) + assert len(unpacked) == 3 + + def test_legacy_llm_url_param_accepted(self, db: Path) -> None: + """Ensure backward-compat signature still works (llm_url ignored).""" + embedder = _mock_embedder() + with patch("app.context.embedder.get_embedder", return_value=embedder): + count = emb_mod.embed_chunks(db, "d1", "http://localhost:11434", "nomic-embed-text") + assert count == 2 + + def test_embed_batch_error_returns_zero(self, db: Path) -> None: + embedder = _mock_embedder() + embedder.embed_batch.side_effect = RuntimeError("model exploded") + with patch("app.context.embedder.get_embedder", return_value=embedder): + count = emb_mod.embed_chunks(db, "d1") + assert count == 0