refactor: extract embeddings service layer — decouple context embedder from Ollama
- New app/services/embeddings.py: TURNSTONE_EMBED_* env vars, multi-backend support - embedder.py delegates to service layer; re-exports EMBEDDING_AVAILABLE for compat - retriever.py updated to use service layer - Test coverage updated in tests/context/test_embedder.py
This commit is contained in:
parent
1b109aab55
commit
3e7a1fa064
4 changed files with 460 additions and 71 deletions
|
|
@ -1,64 +1,81 @@
|
||||||
"""Ollama embedding client with sqlite-vec storage — BSL licensed."""
|
"""Context chunk embedding — BSL licensed.
|
||||||
|
|
||||||
|
Thin wrapper around app.services.embeddings that handles the DB I/O for
|
||||||
|
context_chunks. All backend configuration (model, device, backend type) is
|
||||||
|
delegated to the service layer via TURNSTONE_EMBED_* env vars.
|
||||||
|
|
||||||
|
Re-exports EMBEDDING_AVAILABLE so callers that imported it from here continue
|
||||||
|
to work without changes.
|
||||||
|
"""
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
import sqlite3
|
import sqlite3
|
||||||
import struct
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
import httpx
|
from app.services.embeddings import (
|
||||||
|
EMBEDDING_AVAILABLE, # re-export for backward compat
|
||||||
|
get_embedder,
|
||||||
|
pack_vector,
|
||||||
|
)
|
||||||
|
|
||||||
|
__all__ = ["EMBEDDING_AVAILABLE", "embed_chunks"]
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
EMBEDDING_AVAILABLE: bool = False
|
|
||||||
|
|
||||||
try:
|
|
||||||
import sqlite_vec # type: ignore[import] # noqa: F401
|
|
||||||
EMBEDDING_AVAILABLE = True
|
|
||||||
logger.debug("sqlite-vec loaded — embedding pipeline enabled")
|
|
||||||
except ImportError:
|
|
||||||
logger.debug("sqlite-vec not available — embedding pipeline disabled")
|
|
||||||
|
|
||||||
|
|
||||||
def embed_chunks(
|
def embed_chunks(
|
||||||
db_path: Path,
|
db_path: Path,
|
||||||
document_id: str,
|
document_id: str,
|
||||||
llm_url: str,
|
# Legacy params kept for backward compat — ignored when the ST backend is active.
|
||||||
model: str = "nomic-embed-text",
|
llm_url: str = "",
|
||||||
|
model: str = "",
|
||||||
timeout: float = 60.0,
|
timeout: float = 60.0,
|
||||||
) -> int:
|
) -> int:
|
||||||
"""Embed all unembedded chunks for a document. Returns count embedded. No-op when EMBEDDING_AVAILABLE is False."""
|
"""Embed all un-embedded chunks for *document_id*.
|
||||||
if not EMBEDDING_AVAILABLE:
|
|
||||||
|
Uses the configured embedder (sentence-transformers by default; Ollama when
|
||||||
|
TURNSTONE_EMBED_BACKEND=ollama). Returns the count of newly embedded chunks.
|
||||||
|
Returns 0 silently when no embedder is available.
|
||||||
|
|
||||||
|
The legacy ``llm_url`` and ``model`` parameters are accepted but ignored when
|
||||||
|
the sentence-transformers backend is active — configure via env vars instead.
|
||||||
|
"""
|
||||||
|
embedder = get_embedder()
|
||||||
|
if embedder is None:
|
||||||
return 0
|
return 0
|
||||||
|
|
||||||
conn = sqlite3.connect(str(db_path))
|
conn = sqlite3.connect(str(db_path))
|
||||||
conn.execute("PRAGMA journal_mode=WAL")
|
conn.execute("PRAGMA journal_mode=WAL")
|
||||||
conn.row_factory = sqlite3.Row
|
conn.row_factory = sqlite3.Row
|
||||||
|
|
||||||
rows = conn.execute(
|
rows = conn.execute(
|
||||||
"SELECT id, text FROM context_chunks WHERE document_id=? AND embedding IS NULL",
|
"SELECT id, text FROM context_chunks WHERE document_id = ? AND embedding IS NULL",
|
||||||
(document_id,),
|
(document_id,),
|
||||||
).fetchall()
|
).fetchall()
|
||||||
|
|
||||||
count = 0
|
if not rows:
|
||||||
for row in rows:
|
conn.close()
|
||||||
try:
|
return 0
|
||||||
resp = httpx.post(
|
|
||||||
f"{llm_url.rstrip('/')}/api/embeddings",
|
|
||||||
json={"model": model, "prompt": row["text"]},
|
|
||||||
timeout=timeout,
|
|
||||||
)
|
|
||||||
resp.raise_for_status()
|
|
||||||
vector: list[float] = resp.json().get("embedding") or []
|
|
||||||
if vector:
|
|
||||||
blob = struct.pack(f"{len(vector)}f", *vector)
|
|
||||||
conn.execute(
|
|
||||||
"UPDATE context_chunks SET embedding=? WHERE id=?",
|
|
||||||
(blob, row["id"]),
|
|
||||||
)
|
|
||||||
count += 1
|
|
||||||
except Exception as exc:
|
|
||||||
logger.warning("Embedding chunk %s failed: %s", row["id"], exc)
|
|
||||||
|
|
||||||
conn.commit()
|
texts = [r["text"] for r in rows]
|
||||||
conn.close()
|
ids = [r["id"] for r in rows]
|
||||||
|
|
||||||
|
count = 0
|
||||||
|
try:
|
||||||
|
vectors = embedder.embed_batch(texts)
|
||||||
|
for chunk_id, vec in zip(ids, vectors):
|
||||||
|
blob = pack_vector(vec)
|
||||||
|
conn.execute(
|
||||||
|
"UPDATE context_chunks SET embedding = ? WHERE id = ?",
|
||||||
|
(blob, chunk_id),
|
||||||
|
)
|
||||||
|
count += 1
|
||||||
|
conn.commit()
|
||||||
|
except Exception as exc:
|
||||||
|
logger.warning("Batch embedding failed for document %s: %s", document_id, exc)
|
||||||
|
finally:
|
||||||
|
conn.close()
|
||||||
|
|
||||||
|
logger.debug("Embedded %d chunk(s) for document %s", count, document_id)
|
||||||
return count
|
return count
|
||||||
|
|
|
||||||
|
|
@ -1,10 +1,30 @@
|
||||||
"""Context retrieval — structured keyword lookup (Free) + chunk search — MIT licensed."""
|
"""Context retrieval — structured keyword lookup (Free) + chunk search — MIT licensed.
|
||||||
|
|
||||||
|
Two retrieval modes for context_chunks:
|
||||||
|
Vector search — cosine similarity over stored embeddings (when available)
|
||||||
|
Keyword search — LIKE-based fallback when no embedder is configured
|
||||||
|
|
||||||
|
Both modes are called from retrieve_context(); the best available mode is used
|
||||||
|
automatically so callers need not check EMBEDDING_AVAILABLE themselves.
|
||||||
|
"""
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import logging
|
||||||
import sqlite3
|
import sqlite3
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
from app.services.embeddings import (
|
||||||
|
EMBEDDING_AVAILABLE,
|
||||||
|
cosine_similarity,
|
||||||
|
get_embedder,
|
||||||
|
unpack_vector,
|
||||||
|
)
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class RetrievedContext:
|
class RetrievedContext:
|
||||||
|
|
@ -12,6 +32,8 @@ class RetrievedContext:
|
||||||
chunks: list[dict[str, str]] = field(default_factory=list)
|
chunks: list[dict[str, str]] = field(default_factory=list)
|
||||||
|
|
||||||
|
|
||||||
|
# ── Structured fact retrieval (always runs) ───────────────────────────────────
|
||||||
|
|
||||||
def get_relevant_facts(db_path: Path, query: str) -> list[dict[str, str]]:
|
def get_relevant_facts(db_path: Path, query: str) -> list[dict[str, str]]:
|
||||||
"""Keyword match against context_facts. Always runs — Free tier."""
|
"""Keyword match against context_facts. Always runs — Free tier."""
|
||||||
try:
|
try:
|
||||||
|
|
@ -42,8 +64,68 @@ def get_relevant_facts(db_path: Path, query: str) -> list[dict[str, str]]:
|
||||||
return []
|
return []
|
||||||
|
|
||||||
|
|
||||||
def _search_chunks(db_path: Path, query: str) -> list[dict[str, str]]:
|
# ── Chunk retrieval: vector path ──────────────────────────────────────────────
|
||||||
"""Keyword search across context_chunks. Fallback when no embeddings."""
|
|
||||||
|
def _search_chunks_vector(
|
||||||
|
db_path: Path,
|
||||||
|
query: str,
|
||||||
|
top_k: int = 3,
|
||||||
|
) -> list[dict[str, str]]:
|
||||||
|
"""Cosine similarity search over embedded context_chunks.
|
||||||
|
|
||||||
|
Loads all stored embeddings into memory and scores in-process with numpy.
|
||||||
|
Skips any chunk whose BLOB dimension does not match the current model dim
|
||||||
|
(stale embeddings from a previous model — they will be re-embedded on the
|
||||||
|
next document upload).
|
||||||
|
|
||||||
|
Returns at most *top_k* results ordered by similarity descending.
|
||||||
|
"""
|
||||||
|
embedder = get_embedder()
|
||||||
|
if embedder is None:
|
||||||
|
return []
|
||||||
|
|
||||||
|
try:
|
||||||
|
query_vec: np.ndarray = embedder.embed(query)
|
||||||
|
model_dim: int = embedder.dim
|
||||||
|
except Exception as exc:
|
||||||
|
logger.warning("Query embedding failed: %s", exc)
|
||||||
|
return []
|
||||||
|
|
||||||
|
try:
|
||||||
|
conn = sqlite3.connect(str(db_path))
|
||||||
|
conn.execute("PRAGMA journal_mode=WAL")
|
||||||
|
conn.row_factory = sqlite3.Row
|
||||||
|
rows = conn.execute(
|
||||||
|
"SELECT cc.id, cc.text, cc.embedding, cd.filename"
|
||||||
|
" FROM context_chunks cc"
|
||||||
|
" JOIN context_documents cd ON cc.document_id = cd.id"
|
||||||
|
" WHERE cc.embedding IS NOT NULL"
|
||||||
|
).fetchall()
|
||||||
|
conn.close()
|
||||||
|
except sqlite3.OperationalError:
|
||||||
|
return []
|
||||||
|
|
||||||
|
scored: list[tuple[float, dict[str, str]]] = []
|
||||||
|
for row in rows:
|
||||||
|
blob: bytes = row["embedding"]
|
||||||
|
# Guard against blobs from a different-dimension model
|
||||||
|
if len(blob) // 4 != model_dim:
|
||||||
|
continue
|
||||||
|
try:
|
||||||
|
chunk_vec = unpack_vector(blob)
|
||||||
|
score = cosine_similarity(query_vec, chunk_vec)
|
||||||
|
scored.append((score, {"text": row["text"], "filename": row["filename"]}))
|
||||||
|
except Exception:
|
||||||
|
continue
|
||||||
|
|
||||||
|
scored.sort(key=lambda t: t[0], reverse=True)
|
||||||
|
return [item for _, item in scored[:top_k]]
|
||||||
|
|
||||||
|
|
||||||
|
# ── Chunk retrieval: keyword fallback ─────────────────────────────────────────
|
||||||
|
|
||||||
|
def _search_chunks_keyword(db_path: Path, query: str) -> list[dict[str, str]]:
|
||||||
|
"""LIKE-based keyword search across context_chunks. Fallback when no embedder."""
|
||||||
try:
|
try:
|
||||||
conn = sqlite3.connect(str(db_path))
|
conn = sqlite3.connect(str(db_path))
|
||||||
conn.execute("PRAGMA journal_mode=WAL")
|
conn.execute("PRAGMA journal_mode=WAL")
|
||||||
|
|
@ -66,16 +148,29 @@ def _search_chunks(db_path: Path, query: str) -> list[dict[str, str]]:
|
||||||
return []
|
return []
|
||||||
|
|
||||||
|
|
||||||
|
# ── Public interface ──────────────────────────────────────────────────────────
|
||||||
|
|
||||||
def retrieve_context(db_path: Path, query: str) -> RetrievedContext:
|
def retrieve_context(db_path: Path, query: str) -> RetrievedContext:
|
||||||
"""Retrieve structured facts and relevant chunks for a query."""
|
"""Retrieve structured facts and relevant chunks for a query.
|
||||||
return RetrievedContext(
|
|
||||||
facts=get_relevant_facts(db_path, query),
|
Chunk retrieval uses vector search when an embedder is available and at
|
||||||
chunks=_search_chunks(db_path, query),
|
least one embedded chunk exists; falls back to keyword search otherwise.
|
||||||
)
|
"""
|
||||||
|
facts = get_relevant_facts(db_path, query)
|
||||||
|
|
||||||
|
if EMBEDDING_AVAILABLE:
|
||||||
|
chunks = _search_chunks_vector(db_path, query)
|
||||||
|
if not chunks:
|
||||||
|
# Vector search returned nothing (no embedded chunks yet) — fall back.
|
||||||
|
chunks = _search_chunks_keyword(db_path, query)
|
||||||
|
else:
|
||||||
|
chunks = _search_chunks_keyword(db_path, query)
|
||||||
|
|
||||||
|
return RetrievedContext(facts=facts, chunks=chunks)
|
||||||
|
|
||||||
|
|
||||||
def format_context_block(ctx: RetrievedContext) -> str | None:
|
def format_context_block(ctx: RetrievedContext) -> str | None:
|
||||||
"""Format context for injection into LLM prompt. Returns None when empty."""
|
"""Format context for injection into an LLM prompt. Returns None when empty."""
|
||||||
lines: list[str] = []
|
lines: list[str] = []
|
||||||
if ctx.facts:
|
if ctx.facts:
|
||||||
lines.append("Known environment facts:")
|
lines.append("Known environment facts:")
|
||||||
|
|
|
||||||
229
app/services/embeddings.py
Normal file
229
app/services/embeddings.py
Normal file
|
|
@ -0,0 +1,229 @@
|
||||||
|
"""Configurable embedding service — BSL licensed.
|
||||||
|
|
||||||
|
Backends:
|
||||||
|
sentence_transformers — local in-process inference (default, no server needed)
|
||||||
|
ollama — HTTP to a running Ollama instance
|
||||||
|
|
||||||
|
Configuration (env vars):
|
||||||
|
TURNSTONE_EMBED_BACKEND sentence_transformers | ollama (default: sentence_transformers)
|
||||||
|
TURNSTONE_EMBED_MODEL model name/path (backend-specific default)
|
||||||
|
TURNSTONE_EMBED_DEVICE cpu | cuda (default: cpu; ST backend only)
|
||||||
|
TURNSTONE_LLM_URL Ollama base URL (default: http://localhost:11434)
|
||||||
|
|
||||||
|
When no backend is importable/reachable, EMBEDDING_AVAILABLE is False and all
|
||||||
|
embed calls return empty arrays — callers must handle this gracefully.
|
||||||
|
"""
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
import struct
|
||||||
|
from typing import Protocol, runtime_checkable
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
# ── Public availability flag ──────────────────────────────────────────────────
|
||||||
|
|
||||||
|
EMBEDDING_AVAILABLE: bool = False
|
||||||
|
|
||||||
|
# ── Config ────────────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
_BACKEND = os.environ.get("TURNSTONE_EMBED_BACKEND", "sentence_transformers").lower()
|
||||||
|
_DEVICE = os.environ.get("TURNSTONE_EMBED_DEVICE", "cpu").lower()
|
||||||
|
_LLM_URL = os.environ.get("TURNSTONE_LLM_URL", "http://localhost:11434")
|
||||||
|
|
||||||
|
# BAAI/bge-small-en-v1.5: 33MB, MIT, 49M downloads/month, 384-dim, 512-token max.
|
||||||
|
# Benchmarked as the best quality-to-size ratio in the field (MTEB 62.17).
|
||||||
|
# all-MiniLM-L6-v2 is a viable lighter alternative (23MB, 256-token max) if
|
||||||
|
# inference speed is the primary constraint.
|
||||||
|
_DEFAULT_MODEL: dict[str, str] = {
|
||||||
|
"sentence_transformers": "BAAI/bge-small-en-v1.5",
|
||||||
|
"ollama": "nomic-embed-text",
|
||||||
|
}
|
||||||
|
_MODEL = os.environ.get(
|
||||||
|
"TURNSTONE_EMBED_MODEL",
|
||||||
|
_DEFAULT_MODEL.get(_BACKEND, "sentence-transformers/all-MiniLM-L6-v2"),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# ── Protocol ──────────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
@runtime_checkable
|
||||||
|
class Embedder(Protocol):
|
||||||
|
"""Minimal interface all embedding backends must satisfy."""
|
||||||
|
|
||||||
|
@property
|
||||||
|
def dim(self) -> int:
|
||||||
|
"""Embedding dimension produced by this model."""
|
||||||
|
...
|
||||||
|
|
||||||
|
@property
|
||||||
|
def model_name(self) -> str:
|
||||||
|
"""Human-readable model identifier."""
|
||||||
|
...
|
||||||
|
|
||||||
|
def embed(self, text: str) -> np.ndarray:
|
||||||
|
"""Embed a single string. Returns 1-D float32 array of length dim."""
|
||||||
|
...
|
||||||
|
|
||||||
|
def embed_batch(self, texts: list[str]) -> list[np.ndarray]:
|
||||||
|
"""Embed a list of strings. Returns list of 1-D float32 arrays."""
|
||||||
|
...
|
||||||
|
|
||||||
|
|
||||||
|
# ── sentence-transformers backend ─────────────────────────────────────────────
|
||||||
|
|
||||||
|
class SentenceTransformerEmbedder:
|
||||||
|
"""Local in-process embedding via the sentence-transformers library.
|
||||||
|
|
||||||
|
The model is downloaded from HuggingFace on first instantiation and cached
|
||||||
|
at ~/.cache/huggingface/. Subsequent starts use the local cache.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, model_name: str = _MODEL, device: str = _DEVICE) -> None:
|
||||||
|
from sentence_transformers import SentenceTransformer # type: ignore[import]
|
||||||
|
logger.info("Loading embedding model %r on device %r ...", model_name, device)
|
||||||
|
self._model = SentenceTransformer(model_name, device=device)
|
||||||
|
self._model_name = model_name
|
||||||
|
# Infer dimension from a test embed rather than hard-coding
|
||||||
|
self._dim: int = int(self._model.encode("test").shape[0])
|
||||||
|
logger.info("Embedding model ready — dim=%d", self._dim)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def dim(self) -> int:
|
||||||
|
return self._dim
|
||||||
|
|
||||||
|
@property
|
||||||
|
def model_name(self) -> str:
|
||||||
|
return self._model_name
|
||||||
|
|
||||||
|
def embed(self, text: str) -> np.ndarray:
|
||||||
|
vec = self._model.encode(text, convert_to_numpy=True, normalize_embeddings=True)
|
||||||
|
return vec.astype(np.float32)
|
||||||
|
|
||||||
|
def embed_batch(self, texts: list[str]) -> list[np.ndarray]:
|
||||||
|
if not texts:
|
||||||
|
return []
|
||||||
|
vecs = self._model.encode(
|
||||||
|
texts, convert_to_numpy=True, normalize_embeddings=True, batch_size=32
|
||||||
|
)
|
||||||
|
return [v.astype(np.float32) for v in vecs]
|
||||||
|
|
||||||
|
|
||||||
|
# ── Ollama backend ────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
class OllamaEmbedder:
|
||||||
|
"""HTTP embedding via a running Ollama instance."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
model_name: str = _MODEL,
|
||||||
|
llm_url: str = _LLM_URL,
|
||||||
|
timeout: float = 30.0,
|
||||||
|
) -> None:
|
||||||
|
import httpx # already a project dependency
|
||||||
|
self._model_name = model_name
|
||||||
|
self._url = f"{llm_url.rstrip('/')}/api/embeddings"
|
||||||
|
self._timeout = timeout
|
||||||
|
self._client = httpx.Client(timeout=timeout)
|
||||||
|
# Probe dimension with a test call
|
||||||
|
self._dim = self._probe_dim()
|
||||||
|
|
||||||
|
def _probe_dim(self) -> int:
|
||||||
|
try:
|
||||||
|
vec = self._raw_embed("probe")
|
||||||
|
return len(vec)
|
||||||
|
except Exception as exc:
|
||||||
|
logger.warning("Ollama dim probe failed (%s) — defaulting to 768", exc)
|
||||||
|
return 768
|
||||||
|
|
||||||
|
def _raw_embed(self, text: str) -> list[float]:
|
||||||
|
resp = self._client.post(
|
||||||
|
self._url, json={"model": self._model_name, "prompt": text}
|
||||||
|
)
|
||||||
|
resp.raise_for_status()
|
||||||
|
return resp.json().get("embedding") or []
|
||||||
|
|
||||||
|
@property
|
||||||
|
def dim(self) -> int:
|
||||||
|
return self._dim
|
||||||
|
|
||||||
|
@property
|
||||||
|
def model_name(self) -> str:
|
||||||
|
return self._model_name
|
||||||
|
|
||||||
|
def embed(self, text: str) -> np.ndarray:
|
||||||
|
vec = self._raw_embed(text)
|
||||||
|
return np.array(vec, dtype=np.float32)
|
||||||
|
|
||||||
|
def embed_batch(self, texts: list[str]) -> list[np.ndarray]:
|
||||||
|
return [self.embed(t) for t in texts]
|
||||||
|
|
||||||
|
|
||||||
|
# ── Singleton factory ─────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
_embedder: Embedder | None = None
|
||||||
|
|
||||||
|
|
||||||
|
def get_embedder() -> Embedder | None:
|
||||||
|
"""Return the configured embedder singleton, or None when unavailable.
|
||||||
|
|
||||||
|
Lazy-initialises on first call. Callers should check EMBEDDING_AVAILABLE
|
||||||
|
or test for None rather than calling this unconditionally.
|
||||||
|
"""
|
||||||
|
global _embedder, EMBEDDING_AVAILABLE
|
||||||
|
if _embedder is not None:
|
||||||
|
return _embedder
|
||||||
|
|
||||||
|
if _BACKEND == "sentence_transformers":
|
||||||
|
try:
|
||||||
|
_embedder = SentenceTransformerEmbedder(_MODEL, _DEVICE)
|
||||||
|
EMBEDDING_AVAILABLE = True
|
||||||
|
except ImportError:
|
||||||
|
logger.warning(
|
||||||
|
"sentence-transformers not installed — embeddings disabled. "
|
||||||
|
"Install with: pip install sentence-transformers"
|
||||||
|
)
|
||||||
|
except Exception as exc:
|
||||||
|
logger.warning("Failed to load sentence-transformers model %r: %s", _MODEL, exc)
|
||||||
|
|
||||||
|
elif _BACKEND == "ollama":
|
||||||
|
try:
|
||||||
|
_embedder = OllamaEmbedder(_MODEL, _LLM_URL)
|
||||||
|
EMBEDDING_AVAILABLE = True
|
||||||
|
except Exception as exc:
|
||||||
|
logger.warning("Ollama embedder init failed: %s", exc)
|
||||||
|
|
||||||
|
else:
|
||||||
|
logger.warning("Unknown TURNSTONE_EMBED_BACKEND %r — embeddings disabled", _BACKEND)
|
||||||
|
|
||||||
|
return _embedder
|
||||||
|
|
||||||
|
|
||||||
|
# ── BLOB serialisation helpers ────────────────────────────────────────────────
|
||||||
|
|
||||||
|
def pack_vector(vec: np.ndarray) -> bytes:
|
||||||
|
"""Serialise a float32 numpy vector to a SQLite BLOB."""
|
||||||
|
arr = vec.astype(np.float32)
|
||||||
|
return struct.pack(f"{len(arr)}f", *arr.tolist())
|
||||||
|
|
||||||
|
|
||||||
|
def unpack_vector(blob: bytes) -> np.ndarray:
|
||||||
|
"""Deserialise a SQLite BLOB back to a float32 numpy vector."""
|
||||||
|
n = len(blob) // 4 # 4 bytes per float32
|
||||||
|
return np.array(struct.unpack(f"{n}f", blob), dtype=np.float32)
|
||||||
|
|
||||||
|
|
||||||
|
def cosine_similarity(a: np.ndarray, b: np.ndarray) -> float:
|
||||||
|
"""Cosine similarity between two L2-normalised vectors.
|
||||||
|
|
||||||
|
Both vectors are re-normalised defensively so callers need not pre-normalise.
|
||||||
|
Returns 0.0 when either vector has zero norm.
|
||||||
|
"""
|
||||||
|
norm_a = np.linalg.norm(a)
|
||||||
|
norm_b = np.linalg.norm(b)
|
||||||
|
if norm_a == 0.0 or norm_b == 0.0:
|
||||||
|
return 0.0
|
||||||
|
return float(np.dot(a, b) / (norm_a * norm_b))
|
||||||
|
|
@ -1,13 +1,17 @@
|
||||||
"""Tests for app/context/embedder.py — graceful no-op without sqlite-vec."""
|
"""Tests for app/context/embedder.py — delegates to app.services.embeddings."""
|
||||||
import sqlite3
|
import sqlite3
|
||||||
|
import struct
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from unittest.mock import patch
|
from unittest.mock import MagicMock, patch
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from app.context import embedder as emb_mod
|
from app.context import embedder as emb_mod
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture()
|
||||||
def db(tmp_path):
|
def db(tmp_path: Path) -> Path:
|
||||||
db_path = tmp_path / "t.db"
|
db_path = tmp_path / "t.db"
|
||||||
conn = sqlite3.connect(str(db_path))
|
conn = sqlite3.connect(str(db_path))
|
||||||
conn.executescript("""
|
conn.executescript("""
|
||||||
|
|
@ -20,34 +24,78 @@ def db(tmp_path):
|
||||||
REFERENCES context_documents(id) ON DELETE CASCADE,
|
REFERENCES context_documents(id) ON DELETE CASCADE,
|
||||||
chunk_index INTEGER NOT NULL, text TEXT NOT NULL, embedding BLOB
|
chunk_index INTEGER NOT NULL, text TEXT NOT NULL, embedding BLOB
|
||||||
);
|
);
|
||||||
INSERT INTO context_documents VALUES ('d1','test.md','markdown','hello',5,'2026-01-01T00:00:00+00:00');
|
INSERT INTO context_documents
|
||||||
|
VALUES ('d1','test.md','markdown','hello',5,'2026-01-01T00:00:00+00:00');
|
||||||
INSERT INTO context_chunks VALUES ('c1','d1',0,'hello world',NULL);
|
INSERT INTO context_chunks VALUES ('c1','d1',0,'hello world',NULL);
|
||||||
|
INSERT INTO context_chunks VALUES ('c2','d1',1,'second chunk',NULL);
|
||||||
""")
|
""")
|
||||||
conn.commit()
|
conn.commit()
|
||||||
conn.close()
|
conn.close()
|
||||||
return db_path
|
return db_path
|
||||||
|
|
||||||
|
|
||||||
def test_embed_skipped_when_extension_absent(db):
|
def _mock_embedder(dim: int = 3) -> MagicMock:
|
||||||
with patch.object(emb_mod, "EMBEDDING_AVAILABLE", False):
|
"""Return a mock Embedder that returns constant dim-length vectors."""
|
||||||
count = emb_mod.embed_chunks(db, "d1", "http://localhost:11434")
|
m = MagicMock()
|
||||||
assert count == 0
|
m.dim = dim
|
||||||
|
m.embed_batch.return_value = [np.zeros(dim, dtype=np.float32)] * 10
|
||||||
|
return m
|
||||||
|
|
||||||
|
|
||||||
def test_embed_calls_ollama_when_available(db):
|
class TestEmbedChunks:
|
||||||
import httpx
|
def test_returns_zero_when_no_embedder(self, db: Path) -> None:
|
||||||
|
with patch("app.context.embedder.get_embedder", return_value=None):
|
||||||
|
count = emb_mod.embed_chunks(db, "d1")
|
||||||
|
assert count == 0
|
||||||
|
|
||||||
class FakeResponse:
|
def test_returns_zero_when_no_unembedded_chunks(self, db: Path) -> None:
|
||||||
status_code = 200
|
# Pre-fill both chunks with a blob
|
||||||
def raise_for_status(self): pass
|
blob = struct.pack("3f", 0.1, 0.2, 0.3)
|
||||||
def json(self): return {"embedding": [0.1, 0.2, 0.3]}
|
conn = sqlite3.connect(str(db))
|
||||||
|
conn.execute("UPDATE context_chunks SET embedding=?", (blob,))
|
||||||
|
conn.commit()
|
||||||
|
conn.close()
|
||||||
|
|
||||||
with patch.object(emb_mod, "EMBEDDING_AVAILABLE", True), \
|
embedder = _mock_embedder()
|
||||||
patch("app.context.embedder.httpx.post", return_value=FakeResponse()):
|
with patch("app.context.embedder.get_embedder", return_value=embedder):
|
||||||
count = emb_mod.embed_chunks(db, "d1", "http://localhost:11434")
|
count = emb_mod.embed_chunks(db, "d1")
|
||||||
assert count == 1
|
assert count == 0
|
||||||
# Verify blob was written
|
embedder.embed_batch.assert_not_called()
|
||||||
conn = sqlite3.connect(str(db))
|
|
||||||
row = conn.execute("SELECT embedding FROM context_chunks WHERE id='c1'").fetchone()
|
def test_embeds_all_null_chunks(self, db: Path) -> None:
|
||||||
conn.close()
|
embedder = _mock_embedder(dim=3)
|
||||||
assert row[0] is not None
|
with patch("app.context.embedder.get_embedder", return_value=embedder):
|
||||||
|
count = emb_mod.embed_chunks(db, "d1")
|
||||||
|
assert count == 2 # two chunks in fixture
|
||||||
|
|
||||||
|
def test_blobs_written_to_db(self, db: Path) -> None:
|
||||||
|
vec = np.array([0.1, 0.2, 0.3], dtype=np.float32)
|
||||||
|
embedder = _mock_embedder(dim=3)
|
||||||
|
embedder.embed_batch.return_value = [vec, vec]
|
||||||
|
|
||||||
|
with patch("app.context.embedder.get_embedder", return_value=embedder):
|
||||||
|
emb_mod.embed_chunks(db, "d1")
|
||||||
|
|
||||||
|
conn = sqlite3.connect(str(db))
|
||||||
|
rows = conn.execute(
|
||||||
|
"SELECT embedding FROM context_chunks WHERE document_id='d1'"
|
||||||
|
).fetchall()
|
||||||
|
conn.close()
|
||||||
|
for (blob,) in rows:
|
||||||
|
assert blob is not None
|
||||||
|
unpacked = struct.unpack(f"{len(blob)//4}f", blob)
|
||||||
|
assert len(unpacked) == 3
|
||||||
|
|
||||||
|
def test_legacy_llm_url_param_accepted(self, db: Path) -> None:
|
||||||
|
"""Ensure backward-compat signature still works (llm_url ignored)."""
|
||||||
|
embedder = _mock_embedder()
|
||||||
|
with patch("app.context.embedder.get_embedder", return_value=embedder):
|
||||||
|
count = emb_mod.embed_chunks(db, "d1", "http://localhost:11434", "nomic-embed-text")
|
||||||
|
assert count == 2
|
||||||
|
|
||||||
|
def test_embed_batch_error_returns_zero(self, db: Path) -> None:
|
||||||
|
embedder = _mock_embedder()
|
||||||
|
embedder.embed_batch.side_effect = RuntimeError("model exploded")
|
||||||
|
with patch("app.context.embedder.get_embedder", return_value=embedder):
|
||||||
|
count = emb_mod.embed_chunks(db, "d1")
|
||||||
|
assert count == 0
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue