refactor: extract embeddings service layer — decouple context embedder from Ollama

- New app/services/embeddings.py: TURNSTONE_EMBED_* env vars, multi-backend support
- embedder.py delegates to service layer; re-exports EMBEDDING_AVAILABLE for compat
- retriever.py updated to use service layer
- Test coverage updated in tests/context/test_embedder.py
This commit is contained in:
pyr0ball 2026-05-25 11:01:25 -07:00
parent 1b109aab55
commit 3e7a1fa064
4 changed files with 460 additions and 71 deletions

View file

@ -1,64 +1,81 @@
"""Ollama embedding client with sqlite-vec storage — BSL licensed."""
"""Context chunk embedding — BSL licensed.
Thin wrapper around app.services.embeddings that handles the DB I/O for
context_chunks. All backend configuration (model, device, backend type) is
delegated to the service layer via TURNSTONE_EMBED_* env vars.
Re-exports EMBEDDING_AVAILABLE so callers that imported it from here continue
to work without changes.
"""
from __future__ import annotations
import logging
import sqlite3
import struct
from pathlib import Path
import httpx
from app.services.embeddings import (
EMBEDDING_AVAILABLE, # re-export for backward compat
get_embedder,
pack_vector,
)
__all__ = ["EMBEDDING_AVAILABLE", "embed_chunks"]
logger = logging.getLogger(__name__)
EMBEDDING_AVAILABLE: bool = False
try:
import sqlite_vec # type: ignore[import] # noqa: F401
EMBEDDING_AVAILABLE = True
logger.debug("sqlite-vec loaded — embedding pipeline enabled")
except ImportError:
logger.debug("sqlite-vec not available — embedding pipeline disabled")
def embed_chunks(
db_path: Path,
document_id: str,
llm_url: str,
model: str = "nomic-embed-text",
# Legacy params kept for backward compat — ignored when the ST backend is active.
llm_url: str = "",
model: str = "",
timeout: float = 60.0,
) -> int:
"""Embed all unembedded chunks for a document. Returns count embedded. No-op when EMBEDDING_AVAILABLE is False."""
if not EMBEDDING_AVAILABLE:
"""Embed all un-embedded chunks for *document_id*.
Uses the configured embedder (sentence-transformers by default; Ollama when
TURNSTONE_EMBED_BACKEND=ollama). Returns the count of newly embedded chunks.
Returns 0 silently when no embedder is available.
The legacy ``llm_url`` and ``model`` parameters are accepted but ignored when
the sentence-transformers backend is active configure via env vars instead.
"""
embedder = get_embedder()
if embedder is None:
return 0
conn = sqlite3.connect(str(db_path))
conn.execute("PRAGMA journal_mode=WAL")
conn.row_factory = sqlite3.Row
rows = conn.execute(
"SELECT id, text FROM context_chunks WHERE document_id = ? AND embedding IS NULL",
(document_id,),
).fetchall()
if not rows:
conn.close()
return 0
texts = [r["text"] for r in rows]
ids = [r["id"] for r in rows]
count = 0
for row in rows:
try:
resp = httpx.post(
f"{llm_url.rstrip('/')}/api/embeddings",
json={"model": model, "prompt": row["text"]},
timeout=timeout,
)
resp.raise_for_status()
vector: list[float] = resp.json().get("embedding") or []
if vector:
blob = struct.pack(f"{len(vector)}f", *vector)
vectors = embedder.embed_batch(texts)
for chunk_id, vec in zip(ids, vectors):
blob = pack_vector(vec)
conn.execute(
"UPDATE context_chunks SET embedding = ? WHERE id = ?",
(blob, row["id"]),
(blob, chunk_id),
)
count += 1
except Exception as exc:
logger.warning("Embedding chunk %s failed: %s", row["id"], exc)
conn.commit()
except Exception as exc:
logger.warning("Batch embedding failed for document %s: %s", document_id, exc)
finally:
conn.close()
logger.debug("Embedded %d chunk(s) for document %s", count, document_id)
return count

View file

@ -1,10 +1,30 @@
"""Context retrieval — structured keyword lookup (Free) + chunk search — MIT licensed."""
"""Context retrieval — structured keyword lookup (Free) + chunk search — MIT licensed.
Two retrieval modes for context_chunks:
Vector search cosine similarity over stored embeddings (when available)
Keyword search LIKE-based fallback when no embedder is configured
Both modes are called from retrieve_context(); the best available mode is used
automatically so callers need not check EMBEDDING_AVAILABLE themselves.
"""
from __future__ import annotations
import logging
import sqlite3
from dataclasses import dataclass, field
from pathlib import Path
import numpy as np
from app.services.embeddings import (
EMBEDDING_AVAILABLE,
cosine_similarity,
get_embedder,
unpack_vector,
)
logger = logging.getLogger(__name__)
@dataclass
class RetrievedContext:
@ -12,6 +32,8 @@ class RetrievedContext:
chunks: list[dict[str, str]] = field(default_factory=list)
# ── Structured fact retrieval (always runs) ───────────────────────────────────
def get_relevant_facts(db_path: Path, query: str) -> list[dict[str, str]]:
"""Keyword match against context_facts. Always runs — Free tier."""
try:
@ -42,8 +64,68 @@ def get_relevant_facts(db_path: Path, query: str) -> list[dict[str, str]]:
return []
def _search_chunks(db_path: Path, query: str) -> list[dict[str, str]]:
"""Keyword search across context_chunks. Fallback when no embeddings."""
# ── Chunk retrieval: vector path ──────────────────────────────────────────────
def _search_chunks_vector(
db_path: Path,
query: str,
top_k: int = 3,
) -> list[dict[str, str]]:
"""Cosine similarity search over embedded context_chunks.
Loads all stored embeddings into memory and scores in-process with numpy.
Skips any chunk whose BLOB dimension does not match the current model dim
(stale embeddings from a previous model they will be re-embedded on the
next document upload).
Returns at most *top_k* results ordered by similarity descending.
"""
embedder = get_embedder()
if embedder is None:
return []
try:
query_vec: np.ndarray = embedder.embed(query)
model_dim: int = embedder.dim
except Exception as exc:
logger.warning("Query embedding failed: %s", exc)
return []
try:
conn = sqlite3.connect(str(db_path))
conn.execute("PRAGMA journal_mode=WAL")
conn.row_factory = sqlite3.Row
rows = conn.execute(
"SELECT cc.id, cc.text, cc.embedding, cd.filename"
" FROM context_chunks cc"
" JOIN context_documents cd ON cc.document_id = cd.id"
" WHERE cc.embedding IS NOT NULL"
).fetchall()
conn.close()
except sqlite3.OperationalError:
return []
scored: list[tuple[float, dict[str, str]]] = []
for row in rows:
blob: bytes = row["embedding"]
# Guard against blobs from a different-dimension model
if len(blob) // 4 != model_dim:
continue
try:
chunk_vec = unpack_vector(blob)
score = cosine_similarity(query_vec, chunk_vec)
scored.append((score, {"text": row["text"], "filename": row["filename"]}))
except Exception:
continue
scored.sort(key=lambda t: t[0], reverse=True)
return [item for _, item in scored[:top_k]]
# ── Chunk retrieval: keyword fallback ─────────────────────────────────────────
def _search_chunks_keyword(db_path: Path, query: str) -> list[dict[str, str]]:
"""LIKE-based keyword search across context_chunks. Fallback when no embedder."""
try:
conn = sqlite3.connect(str(db_path))
conn.execute("PRAGMA journal_mode=WAL")
@ -66,16 +148,29 @@ def _search_chunks(db_path: Path, query: str) -> list[dict[str, str]]:
return []
# ── Public interface ──────────────────────────────────────────────────────────
def retrieve_context(db_path: Path, query: str) -> RetrievedContext:
"""Retrieve structured facts and relevant chunks for a query."""
return RetrievedContext(
facts=get_relevant_facts(db_path, query),
chunks=_search_chunks(db_path, query),
)
"""Retrieve structured facts and relevant chunks for a query.
Chunk retrieval uses vector search when an embedder is available and at
least one embedded chunk exists; falls back to keyword search otherwise.
"""
facts = get_relevant_facts(db_path, query)
if EMBEDDING_AVAILABLE:
chunks = _search_chunks_vector(db_path, query)
if not chunks:
# Vector search returned nothing (no embedded chunks yet) — fall back.
chunks = _search_chunks_keyword(db_path, query)
else:
chunks = _search_chunks_keyword(db_path, query)
return RetrievedContext(facts=facts, chunks=chunks)
def format_context_block(ctx: RetrievedContext) -> str | None:
"""Format context for injection into LLM prompt. Returns None when empty."""
"""Format context for injection into an LLM prompt. Returns None when empty."""
lines: list[str] = []
if ctx.facts:
lines.append("Known environment facts:")

229
app/services/embeddings.py Normal file
View file

@ -0,0 +1,229 @@
"""Configurable embedding service — BSL licensed.
Backends:
sentence_transformers local in-process inference (default, no server needed)
ollama HTTP to a running Ollama instance
Configuration (env vars):
TURNSTONE_EMBED_BACKEND sentence_transformers | ollama (default: sentence_transformers)
TURNSTONE_EMBED_MODEL model name/path (backend-specific default)
TURNSTONE_EMBED_DEVICE cpu | cuda (default: cpu; ST backend only)
TURNSTONE_LLM_URL Ollama base URL (default: http://localhost:11434)
When no backend is importable/reachable, EMBEDDING_AVAILABLE is False and all
embed calls return empty arrays callers must handle this gracefully.
"""
from __future__ import annotations
import logging
import os
import struct
from typing import Protocol, runtime_checkable
import numpy as np
logger = logging.getLogger(__name__)
# ── Public availability flag ──────────────────────────────────────────────────
EMBEDDING_AVAILABLE: bool = False
# ── Config ────────────────────────────────────────────────────────────────────
_BACKEND = os.environ.get("TURNSTONE_EMBED_BACKEND", "sentence_transformers").lower()
_DEVICE = os.environ.get("TURNSTONE_EMBED_DEVICE", "cpu").lower()
_LLM_URL = os.environ.get("TURNSTONE_LLM_URL", "http://localhost:11434")
# BAAI/bge-small-en-v1.5: 33MB, MIT, 49M downloads/month, 384-dim, 512-token max.
# Benchmarked as the best quality-to-size ratio in the field (MTEB 62.17).
# all-MiniLM-L6-v2 is a viable lighter alternative (23MB, 256-token max) if
# inference speed is the primary constraint.
_DEFAULT_MODEL: dict[str, str] = {
"sentence_transformers": "BAAI/bge-small-en-v1.5",
"ollama": "nomic-embed-text",
}
_MODEL = os.environ.get(
"TURNSTONE_EMBED_MODEL",
_DEFAULT_MODEL.get(_BACKEND, "sentence-transformers/all-MiniLM-L6-v2"),
)
# ── Protocol ──────────────────────────────────────────────────────────────────
@runtime_checkable
class Embedder(Protocol):
"""Minimal interface all embedding backends must satisfy."""
@property
def dim(self) -> int:
"""Embedding dimension produced by this model."""
...
@property
def model_name(self) -> str:
"""Human-readable model identifier."""
...
def embed(self, text: str) -> np.ndarray:
"""Embed a single string. Returns 1-D float32 array of length dim."""
...
def embed_batch(self, texts: list[str]) -> list[np.ndarray]:
"""Embed a list of strings. Returns list of 1-D float32 arrays."""
...
# ── sentence-transformers backend ─────────────────────────────────────────────
class SentenceTransformerEmbedder:
"""Local in-process embedding via the sentence-transformers library.
The model is downloaded from HuggingFace on first instantiation and cached
at ~/.cache/huggingface/. Subsequent starts use the local cache.
"""
def __init__(self, model_name: str = _MODEL, device: str = _DEVICE) -> None:
from sentence_transformers import SentenceTransformer # type: ignore[import]
logger.info("Loading embedding model %r on device %r ...", model_name, device)
self._model = SentenceTransformer(model_name, device=device)
self._model_name = model_name
# Infer dimension from a test embed rather than hard-coding
self._dim: int = int(self._model.encode("test").shape[0])
logger.info("Embedding model ready — dim=%d", self._dim)
@property
def dim(self) -> int:
return self._dim
@property
def model_name(self) -> str:
return self._model_name
def embed(self, text: str) -> np.ndarray:
vec = self._model.encode(text, convert_to_numpy=True, normalize_embeddings=True)
return vec.astype(np.float32)
def embed_batch(self, texts: list[str]) -> list[np.ndarray]:
if not texts:
return []
vecs = self._model.encode(
texts, convert_to_numpy=True, normalize_embeddings=True, batch_size=32
)
return [v.astype(np.float32) for v in vecs]
# ── Ollama backend ────────────────────────────────────────────────────────────
class OllamaEmbedder:
"""HTTP embedding via a running Ollama instance."""
def __init__(
self,
model_name: str = _MODEL,
llm_url: str = _LLM_URL,
timeout: float = 30.0,
) -> None:
import httpx # already a project dependency
self._model_name = model_name
self._url = f"{llm_url.rstrip('/')}/api/embeddings"
self._timeout = timeout
self._client = httpx.Client(timeout=timeout)
# Probe dimension with a test call
self._dim = self._probe_dim()
def _probe_dim(self) -> int:
try:
vec = self._raw_embed("probe")
return len(vec)
except Exception as exc:
logger.warning("Ollama dim probe failed (%s) — defaulting to 768", exc)
return 768
def _raw_embed(self, text: str) -> list[float]:
resp = self._client.post(
self._url, json={"model": self._model_name, "prompt": text}
)
resp.raise_for_status()
return resp.json().get("embedding") or []
@property
def dim(self) -> int:
return self._dim
@property
def model_name(self) -> str:
return self._model_name
def embed(self, text: str) -> np.ndarray:
vec = self._raw_embed(text)
return np.array(vec, dtype=np.float32)
def embed_batch(self, texts: list[str]) -> list[np.ndarray]:
return [self.embed(t) for t in texts]
# ── Singleton factory ─────────────────────────────────────────────────────────
_embedder: Embedder | None = None
def get_embedder() -> Embedder | None:
"""Return the configured embedder singleton, or None when unavailable.
Lazy-initialises on first call. Callers should check EMBEDDING_AVAILABLE
or test for None rather than calling this unconditionally.
"""
global _embedder, EMBEDDING_AVAILABLE
if _embedder is not None:
return _embedder
if _BACKEND == "sentence_transformers":
try:
_embedder = SentenceTransformerEmbedder(_MODEL, _DEVICE)
EMBEDDING_AVAILABLE = True
except ImportError:
logger.warning(
"sentence-transformers not installed — embeddings disabled. "
"Install with: pip install sentence-transformers"
)
except Exception as exc:
logger.warning("Failed to load sentence-transformers model %r: %s", _MODEL, exc)
elif _BACKEND == "ollama":
try:
_embedder = OllamaEmbedder(_MODEL, _LLM_URL)
EMBEDDING_AVAILABLE = True
except Exception as exc:
logger.warning("Ollama embedder init failed: %s", exc)
else:
logger.warning("Unknown TURNSTONE_EMBED_BACKEND %r — embeddings disabled", _BACKEND)
return _embedder
# ── BLOB serialisation helpers ────────────────────────────────────────────────
def pack_vector(vec: np.ndarray) -> bytes:
"""Serialise a float32 numpy vector to a SQLite BLOB."""
arr = vec.astype(np.float32)
return struct.pack(f"{len(arr)}f", *arr.tolist())
def unpack_vector(blob: bytes) -> np.ndarray:
"""Deserialise a SQLite BLOB back to a float32 numpy vector."""
n = len(blob) // 4 # 4 bytes per float32
return np.array(struct.unpack(f"{n}f", blob), dtype=np.float32)
def cosine_similarity(a: np.ndarray, b: np.ndarray) -> float:
"""Cosine similarity between two L2-normalised vectors.
Both vectors are re-normalised defensively so callers need not pre-normalise.
Returns 0.0 when either vector has zero norm.
"""
norm_a = np.linalg.norm(a)
norm_b = np.linalg.norm(b)
if norm_a == 0.0 or norm_b == 0.0:
return 0.0
return float(np.dot(a, b) / (norm_a * norm_b))

View file

@ -1,13 +1,17 @@
"""Tests for app/context/embedder.py — graceful no-op without sqlite-vec."""
"""Tests for app/context/embedder.py — delegates to app.services.embeddings."""
import sqlite3
import struct
from pathlib import Path
from unittest.mock import patch
from unittest.mock import MagicMock, patch
import numpy as np
import pytest
from app.context import embedder as emb_mod
@pytest.fixture
def db(tmp_path):
@pytest.fixture()
def db(tmp_path: Path) -> Path:
db_path = tmp_path / "t.db"
conn = sqlite3.connect(str(db_path))
conn.executescript("""
@ -20,34 +24,78 @@ def db(tmp_path):
REFERENCES context_documents(id) ON DELETE CASCADE,
chunk_index INTEGER NOT NULL, text TEXT NOT NULL, embedding BLOB
);
INSERT INTO context_documents VALUES ('d1','test.md','markdown','hello',5,'2026-01-01T00:00:00+00:00');
INSERT INTO context_documents
VALUES ('d1','test.md','markdown','hello',5,'2026-01-01T00:00:00+00:00');
INSERT INTO context_chunks VALUES ('c1','d1',0,'hello world',NULL);
INSERT INTO context_chunks VALUES ('c2','d1',1,'second chunk',NULL);
""")
conn.commit()
conn.close()
return db_path
def test_embed_skipped_when_extension_absent(db):
with patch.object(emb_mod, "EMBEDDING_AVAILABLE", False):
count = emb_mod.embed_chunks(db, "d1", "http://localhost:11434")
def _mock_embedder(dim: int = 3) -> MagicMock:
"""Return a mock Embedder that returns constant dim-length vectors."""
m = MagicMock()
m.dim = dim
m.embed_batch.return_value = [np.zeros(dim, dtype=np.float32)] * 10
return m
class TestEmbedChunks:
def test_returns_zero_when_no_embedder(self, db: Path) -> None:
with patch("app.context.embedder.get_embedder", return_value=None):
count = emb_mod.embed_chunks(db, "d1")
assert count == 0
def test_embed_calls_ollama_when_available(db):
import httpx
class FakeResponse:
status_code = 200
def raise_for_status(self): pass
def json(self): return {"embedding": [0.1, 0.2, 0.3]}
with patch.object(emb_mod, "EMBEDDING_AVAILABLE", True), \
patch("app.context.embedder.httpx.post", return_value=FakeResponse()):
count = emb_mod.embed_chunks(db, "d1", "http://localhost:11434")
assert count == 1
# Verify blob was written
def test_returns_zero_when_no_unembedded_chunks(self, db: Path) -> None:
# Pre-fill both chunks with a blob
blob = struct.pack("3f", 0.1, 0.2, 0.3)
conn = sqlite3.connect(str(db))
row = conn.execute("SELECT embedding FROM context_chunks WHERE id='c1'").fetchone()
conn.execute("UPDATE context_chunks SET embedding=?", (blob,))
conn.commit()
conn.close()
assert row[0] is not None
embedder = _mock_embedder()
with patch("app.context.embedder.get_embedder", return_value=embedder):
count = emb_mod.embed_chunks(db, "d1")
assert count == 0
embedder.embed_batch.assert_not_called()
def test_embeds_all_null_chunks(self, db: Path) -> None:
embedder = _mock_embedder(dim=3)
with patch("app.context.embedder.get_embedder", return_value=embedder):
count = emb_mod.embed_chunks(db, "d1")
assert count == 2 # two chunks in fixture
def test_blobs_written_to_db(self, db: Path) -> None:
vec = np.array([0.1, 0.2, 0.3], dtype=np.float32)
embedder = _mock_embedder(dim=3)
embedder.embed_batch.return_value = [vec, vec]
with patch("app.context.embedder.get_embedder", return_value=embedder):
emb_mod.embed_chunks(db, "d1")
conn = sqlite3.connect(str(db))
rows = conn.execute(
"SELECT embedding FROM context_chunks WHERE document_id='d1'"
).fetchall()
conn.close()
for (blob,) in rows:
assert blob is not None
unpacked = struct.unpack(f"{len(blob)//4}f", blob)
assert len(unpacked) == 3
def test_legacy_llm_url_param_accepted(self, db: Path) -> None:
"""Ensure backward-compat signature still works (llm_url ignored)."""
embedder = _mock_embedder()
with patch("app.context.embedder.get_embedder", return_value=embedder):
count = emb_mod.embed_chunks(db, "d1", "http://localhost:11434", "nomic-embed-text")
assert count == 2
def test_embed_batch_error_returns_zero(self, db: Path) -> None:
embedder = _mock_embedder()
embedder.embed_batch.side_effect = RuntimeError("model exploded")
with patch("app.context.embedder.get_embedder", return_value=embedder):
count = emb_mod.embed_chunks(db, "d1")
assert count == 0