refactor: extract embeddings service layer — decouple context embedder from Ollama
- New app/services/embeddings.py: TURNSTONE_EMBED_* env vars, multi-backend support - embedder.py delegates to service layer; re-exports EMBEDDING_AVAILABLE for compat - retriever.py updated to use service layer - Test coverage updated in tests/context/test_embedder.py
This commit is contained in:
parent
6fec294a53
commit
f7bcc6c9b7
4 changed files with 460 additions and 71 deletions
|
|
@ -1,64 +1,81 @@
|
|||
"""Ollama embedding client with sqlite-vec storage — BSL licensed."""
|
||||
"""Context chunk embedding — BSL licensed.
|
||||
|
||||
Thin wrapper around app.services.embeddings that handles the DB I/O for
|
||||
context_chunks. All backend configuration (model, device, backend type) is
|
||||
delegated to the service layer via TURNSTONE_EMBED_* env vars.
|
||||
|
||||
Re-exports EMBEDDING_AVAILABLE so callers that imported it from here continue
|
||||
to work without changes.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import sqlite3
|
||||
import struct
|
||||
from pathlib import Path
|
||||
|
||||
import httpx
|
||||
from app.services.embeddings import (
|
||||
EMBEDDING_AVAILABLE, # re-export for backward compat
|
||||
get_embedder,
|
||||
pack_vector,
|
||||
)
|
||||
|
||||
__all__ = ["EMBEDDING_AVAILABLE", "embed_chunks"]
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
EMBEDDING_AVAILABLE: bool = False
|
||||
|
||||
try:
|
||||
import sqlite_vec # type: ignore[import] # noqa: F401
|
||||
EMBEDDING_AVAILABLE = True
|
||||
logger.debug("sqlite-vec loaded — embedding pipeline enabled")
|
||||
except ImportError:
|
||||
logger.debug("sqlite-vec not available — embedding pipeline disabled")
|
||||
|
||||
|
||||
def embed_chunks(
|
||||
db_path: Path,
|
||||
document_id: str,
|
||||
llm_url: str,
|
||||
model: str = "nomic-embed-text",
|
||||
# Legacy params kept for backward compat — ignored when the ST backend is active.
|
||||
llm_url: str = "",
|
||||
model: str = "",
|
||||
timeout: float = 60.0,
|
||||
) -> int:
|
||||
"""Embed all unembedded chunks for a document. Returns count embedded. No-op when EMBEDDING_AVAILABLE is False."""
|
||||
if not EMBEDDING_AVAILABLE:
|
||||
"""Embed all un-embedded chunks for *document_id*.
|
||||
|
||||
Uses the configured embedder (sentence-transformers by default; Ollama when
|
||||
TURNSTONE_EMBED_BACKEND=ollama). Returns the count of newly embedded chunks.
|
||||
Returns 0 silently when no embedder is available.
|
||||
|
||||
The legacy ``llm_url`` and ``model`` parameters are accepted but ignored when
|
||||
the sentence-transformers backend is active — configure via env vars instead.
|
||||
"""
|
||||
embedder = get_embedder()
|
||||
if embedder is None:
|
||||
return 0
|
||||
|
||||
conn = sqlite3.connect(str(db_path))
|
||||
conn.execute("PRAGMA journal_mode=WAL")
|
||||
conn.row_factory = sqlite3.Row
|
||||
|
||||
rows = conn.execute(
|
||||
"SELECT id, text FROM context_chunks WHERE document_id = ? AND embedding IS NULL",
|
||||
(document_id,),
|
||||
).fetchall()
|
||||
|
||||
if not rows:
|
||||
conn.close()
|
||||
return 0
|
||||
|
||||
texts = [r["text"] for r in rows]
|
||||
ids = [r["id"] for r in rows]
|
||||
|
||||
count = 0
|
||||
for row in rows:
|
||||
try:
|
||||
resp = httpx.post(
|
||||
f"{llm_url.rstrip('/')}/api/embeddings",
|
||||
json={"model": model, "prompt": row["text"]},
|
||||
timeout=timeout,
|
||||
)
|
||||
resp.raise_for_status()
|
||||
vector: list[float] = resp.json().get("embedding") or []
|
||||
if vector:
|
||||
blob = struct.pack(f"{len(vector)}f", *vector)
|
||||
vectors = embedder.embed_batch(texts)
|
||||
for chunk_id, vec in zip(ids, vectors):
|
||||
blob = pack_vector(vec)
|
||||
conn.execute(
|
||||
"UPDATE context_chunks SET embedding = ? WHERE id = ?",
|
||||
(blob, row["id"]),
|
||||
(blob, chunk_id),
|
||||
)
|
||||
count += 1
|
||||
except Exception as exc:
|
||||
logger.warning("Embedding chunk %s failed: %s", row["id"], exc)
|
||||
|
||||
conn.commit()
|
||||
except Exception as exc:
|
||||
logger.warning("Batch embedding failed for document %s: %s", document_id, exc)
|
||||
finally:
|
||||
conn.close()
|
||||
|
||||
logger.debug("Embedded %d chunk(s) for document %s", count, document_id)
|
||||
return count
|
||||
|
|
|
|||
|
|
@ -1,10 +1,30 @@
|
|||
"""Context retrieval — structured keyword lookup (Free) + chunk search — MIT licensed."""
|
||||
"""Context retrieval — structured keyword lookup (Free) + chunk search — MIT licensed.
|
||||
|
||||
Two retrieval modes for context_chunks:
|
||||
Vector search — cosine similarity over stored embeddings (when available)
|
||||
Keyword search — LIKE-based fallback when no embedder is configured
|
||||
|
||||
Both modes are called from retrieve_context(); the best available mode is used
|
||||
automatically so callers need not check EMBEDDING_AVAILABLE themselves.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import sqlite3
|
||||
from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
|
||||
from app.services.embeddings import (
|
||||
EMBEDDING_AVAILABLE,
|
||||
cosine_similarity,
|
||||
get_embedder,
|
||||
unpack_vector,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class RetrievedContext:
|
||||
|
|
@ -12,6 +32,8 @@ class RetrievedContext:
|
|||
chunks: list[dict[str, str]] = field(default_factory=list)
|
||||
|
||||
|
||||
# ── Structured fact retrieval (always runs) ───────────────────────────────────
|
||||
|
||||
def get_relevant_facts(db_path: Path, query: str) -> list[dict[str, str]]:
|
||||
"""Keyword match against context_facts. Always runs — Free tier."""
|
||||
try:
|
||||
|
|
@ -42,8 +64,68 @@ def get_relevant_facts(db_path: Path, query: str) -> list[dict[str, str]]:
|
|||
return []
|
||||
|
||||
|
||||
def _search_chunks(db_path: Path, query: str) -> list[dict[str, str]]:
|
||||
"""Keyword search across context_chunks. Fallback when no embeddings."""
|
||||
# ── Chunk retrieval: vector path ──────────────────────────────────────────────
|
||||
|
||||
def _search_chunks_vector(
|
||||
db_path: Path,
|
||||
query: str,
|
||||
top_k: int = 3,
|
||||
) -> list[dict[str, str]]:
|
||||
"""Cosine similarity search over embedded context_chunks.
|
||||
|
||||
Loads all stored embeddings into memory and scores in-process with numpy.
|
||||
Skips any chunk whose BLOB dimension does not match the current model dim
|
||||
(stale embeddings from a previous model — they will be re-embedded on the
|
||||
next document upload).
|
||||
|
||||
Returns at most *top_k* results ordered by similarity descending.
|
||||
"""
|
||||
embedder = get_embedder()
|
||||
if embedder is None:
|
||||
return []
|
||||
|
||||
try:
|
||||
query_vec: np.ndarray = embedder.embed(query)
|
||||
model_dim: int = embedder.dim
|
||||
except Exception as exc:
|
||||
logger.warning("Query embedding failed: %s", exc)
|
||||
return []
|
||||
|
||||
try:
|
||||
conn = sqlite3.connect(str(db_path))
|
||||
conn.execute("PRAGMA journal_mode=WAL")
|
||||
conn.row_factory = sqlite3.Row
|
||||
rows = conn.execute(
|
||||
"SELECT cc.id, cc.text, cc.embedding, cd.filename"
|
||||
" FROM context_chunks cc"
|
||||
" JOIN context_documents cd ON cc.document_id = cd.id"
|
||||
" WHERE cc.embedding IS NOT NULL"
|
||||
).fetchall()
|
||||
conn.close()
|
||||
except sqlite3.OperationalError:
|
||||
return []
|
||||
|
||||
scored: list[tuple[float, dict[str, str]]] = []
|
||||
for row in rows:
|
||||
blob: bytes = row["embedding"]
|
||||
# Guard against blobs from a different-dimension model
|
||||
if len(blob) // 4 != model_dim:
|
||||
continue
|
||||
try:
|
||||
chunk_vec = unpack_vector(blob)
|
||||
score = cosine_similarity(query_vec, chunk_vec)
|
||||
scored.append((score, {"text": row["text"], "filename": row["filename"]}))
|
||||
except Exception:
|
||||
continue
|
||||
|
||||
scored.sort(key=lambda t: t[0], reverse=True)
|
||||
return [item for _, item in scored[:top_k]]
|
||||
|
||||
|
||||
# ── Chunk retrieval: keyword fallback ─────────────────────────────────────────
|
||||
|
||||
def _search_chunks_keyword(db_path: Path, query: str) -> list[dict[str, str]]:
|
||||
"""LIKE-based keyword search across context_chunks. Fallback when no embedder."""
|
||||
try:
|
||||
conn = sqlite3.connect(str(db_path))
|
||||
conn.execute("PRAGMA journal_mode=WAL")
|
||||
|
|
@ -66,16 +148,29 @@ def _search_chunks(db_path: Path, query: str) -> list[dict[str, str]]:
|
|||
return []
|
||||
|
||||
|
||||
# ── Public interface ──────────────────────────────────────────────────────────
|
||||
|
||||
def retrieve_context(db_path: Path, query: str) -> RetrievedContext:
|
||||
"""Retrieve structured facts and relevant chunks for a query."""
|
||||
return RetrievedContext(
|
||||
facts=get_relevant_facts(db_path, query),
|
||||
chunks=_search_chunks(db_path, query),
|
||||
)
|
||||
"""Retrieve structured facts and relevant chunks for a query.
|
||||
|
||||
Chunk retrieval uses vector search when an embedder is available and at
|
||||
least one embedded chunk exists; falls back to keyword search otherwise.
|
||||
"""
|
||||
facts = get_relevant_facts(db_path, query)
|
||||
|
||||
if EMBEDDING_AVAILABLE:
|
||||
chunks = _search_chunks_vector(db_path, query)
|
||||
if not chunks:
|
||||
# Vector search returned nothing (no embedded chunks yet) — fall back.
|
||||
chunks = _search_chunks_keyword(db_path, query)
|
||||
else:
|
||||
chunks = _search_chunks_keyword(db_path, query)
|
||||
|
||||
return RetrievedContext(facts=facts, chunks=chunks)
|
||||
|
||||
|
||||
def format_context_block(ctx: RetrievedContext) -> str | None:
|
||||
"""Format context for injection into LLM prompt. Returns None when empty."""
|
||||
"""Format context for injection into an LLM prompt. Returns None when empty."""
|
||||
lines: list[str] = []
|
||||
if ctx.facts:
|
||||
lines.append("Known environment facts:")
|
||||
|
|
|
|||
229
app/services/embeddings.py
Normal file
229
app/services/embeddings.py
Normal file
|
|
@ -0,0 +1,229 @@
|
|||
"""Configurable embedding service — BSL licensed.
|
||||
|
||||
Backends:
|
||||
sentence_transformers — local in-process inference (default, no server needed)
|
||||
ollama — HTTP to a running Ollama instance
|
||||
|
||||
Configuration (env vars):
|
||||
TURNSTONE_EMBED_BACKEND sentence_transformers | ollama (default: sentence_transformers)
|
||||
TURNSTONE_EMBED_MODEL model name/path (backend-specific default)
|
||||
TURNSTONE_EMBED_DEVICE cpu | cuda (default: cpu; ST backend only)
|
||||
TURNSTONE_LLM_URL Ollama base URL (default: http://localhost:11434)
|
||||
|
||||
When no backend is importable/reachable, EMBEDDING_AVAILABLE is False and all
|
||||
embed calls return empty arrays — callers must handle this gracefully.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import os
|
||||
import struct
|
||||
from typing import Protocol, runtime_checkable
|
||||
|
||||
import numpy as np
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# ── Public availability flag ──────────────────────────────────────────────────
|
||||
|
||||
EMBEDDING_AVAILABLE: bool = False
|
||||
|
||||
# ── Config ────────────────────────────────────────────────────────────────────
|
||||
|
||||
_BACKEND = os.environ.get("TURNSTONE_EMBED_BACKEND", "sentence_transformers").lower()
|
||||
_DEVICE = os.environ.get("TURNSTONE_EMBED_DEVICE", "cpu").lower()
|
||||
_LLM_URL = os.environ.get("TURNSTONE_LLM_URL", "http://localhost:11434")
|
||||
|
||||
# BAAI/bge-small-en-v1.5: 33MB, MIT, 49M downloads/month, 384-dim, 512-token max.
|
||||
# Benchmarked as the best quality-to-size ratio in the field (MTEB 62.17).
|
||||
# all-MiniLM-L6-v2 is a viable lighter alternative (23MB, 256-token max) if
|
||||
# inference speed is the primary constraint.
|
||||
_DEFAULT_MODEL: dict[str, str] = {
|
||||
"sentence_transformers": "BAAI/bge-small-en-v1.5",
|
||||
"ollama": "nomic-embed-text",
|
||||
}
|
||||
_MODEL = os.environ.get(
|
||||
"TURNSTONE_EMBED_MODEL",
|
||||
_DEFAULT_MODEL.get(_BACKEND, "sentence-transformers/all-MiniLM-L6-v2"),
|
||||
)
|
||||
|
||||
|
||||
# ── Protocol ──────────────────────────────────────────────────────────────────
|
||||
|
||||
@runtime_checkable
|
||||
class Embedder(Protocol):
|
||||
"""Minimal interface all embedding backends must satisfy."""
|
||||
|
||||
@property
|
||||
def dim(self) -> int:
|
||||
"""Embedding dimension produced by this model."""
|
||||
...
|
||||
|
||||
@property
|
||||
def model_name(self) -> str:
|
||||
"""Human-readable model identifier."""
|
||||
...
|
||||
|
||||
def embed(self, text: str) -> np.ndarray:
|
||||
"""Embed a single string. Returns 1-D float32 array of length dim."""
|
||||
...
|
||||
|
||||
def embed_batch(self, texts: list[str]) -> list[np.ndarray]:
|
||||
"""Embed a list of strings. Returns list of 1-D float32 arrays."""
|
||||
...
|
||||
|
||||
|
||||
# ── sentence-transformers backend ─────────────────────────────────────────────
|
||||
|
||||
class SentenceTransformerEmbedder:
|
||||
"""Local in-process embedding via the sentence-transformers library.
|
||||
|
||||
The model is downloaded from HuggingFace on first instantiation and cached
|
||||
at ~/.cache/huggingface/. Subsequent starts use the local cache.
|
||||
"""
|
||||
|
||||
def __init__(self, model_name: str = _MODEL, device: str = _DEVICE) -> None:
|
||||
from sentence_transformers import SentenceTransformer # type: ignore[import]
|
||||
logger.info("Loading embedding model %r on device %r ...", model_name, device)
|
||||
self._model = SentenceTransformer(model_name, device=device)
|
||||
self._model_name = model_name
|
||||
# Infer dimension from a test embed rather than hard-coding
|
||||
self._dim: int = int(self._model.encode("test").shape[0])
|
||||
logger.info("Embedding model ready — dim=%d", self._dim)
|
||||
|
||||
@property
|
||||
def dim(self) -> int:
|
||||
return self._dim
|
||||
|
||||
@property
|
||||
def model_name(self) -> str:
|
||||
return self._model_name
|
||||
|
||||
def embed(self, text: str) -> np.ndarray:
|
||||
vec = self._model.encode(text, convert_to_numpy=True, normalize_embeddings=True)
|
||||
return vec.astype(np.float32)
|
||||
|
||||
def embed_batch(self, texts: list[str]) -> list[np.ndarray]:
|
||||
if not texts:
|
||||
return []
|
||||
vecs = self._model.encode(
|
||||
texts, convert_to_numpy=True, normalize_embeddings=True, batch_size=32
|
||||
)
|
||||
return [v.astype(np.float32) for v in vecs]
|
||||
|
||||
|
||||
# ── Ollama backend ────────────────────────────────────────────────────────────
|
||||
|
||||
class OllamaEmbedder:
|
||||
"""HTTP embedding via a running Ollama instance."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model_name: str = _MODEL,
|
||||
llm_url: str = _LLM_URL,
|
||||
timeout: float = 30.0,
|
||||
) -> None:
|
||||
import httpx # already a project dependency
|
||||
self._model_name = model_name
|
||||
self._url = f"{llm_url.rstrip('/')}/api/embeddings"
|
||||
self._timeout = timeout
|
||||
self._client = httpx.Client(timeout=timeout)
|
||||
# Probe dimension with a test call
|
||||
self._dim = self._probe_dim()
|
||||
|
||||
def _probe_dim(self) -> int:
|
||||
try:
|
||||
vec = self._raw_embed("probe")
|
||||
return len(vec)
|
||||
except Exception as exc:
|
||||
logger.warning("Ollama dim probe failed (%s) — defaulting to 768", exc)
|
||||
return 768
|
||||
|
||||
def _raw_embed(self, text: str) -> list[float]:
|
||||
resp = self._client.post(
|
||||
self._url, json={"model": self._model_name, "prompt": text}
|
||||
)
|
||||
resp.raise_for_status()
|
||||
return resp.json().get("embedding") or []
|
||||
|
||||
@property
|
||||
def dim(self) -> int:
|
||||
return self._dim
|
||||
|
||||
@property
|
||||
def model_name(self) -> str:
|
||||
return self._model_name
|
||||
|
||||
def embed(self, text: str) -> np.ndarray:
|
||||
vec = self._raw_embed(text)
|
||||
return np.array(vec, dtype=np.float32)
|
||||
|
||||
def embed_batch(self, texts: list[str]) -> list[np.ndarray]:
|
||||
return [self.embed(t) for t in texts]
|
||||
|
||||
|
||||
# ── Singleton factory ─────────────────────────────────────────────────────────
|
||||
|
||||
_embedder: Embedder | None = None
|
||||
|
||||
|
||||
def get_embedder() -> Embedder | None:
|
||||
"""Return the configured embedder singleton, or None when unavailable.
|
||||
|
||||
Lazy-initialises on first call. Callers should check EMBEDDING_AVAILABLE
|
||||
or test for None rather than calling this unconditionally.
|
||||
"""
|
||||
global _embedder, EMBEDDING_AVAILABLE
|
||||
if _embedder is not None:
|
||||
return _embedder
|
||||
|
||||
if _BACKEND == "sentence_transformers":
|
||||
try:
|
||||
_embedder = SentenceTransformerEmbedder(_MODEL, _DEVICE)
|
||||
EMBEDDING_AVAILABLE = True
|
||||
except ImportError:
|
||||
logger.warning(
|
||||
"sentence-transformers not installed — embeddings disabled. "
|
||||
"Install with: pip install sentence-transformers"
|
||||
)
|
||||
except Exception as exc:
|
||||
logger.warning("Failed to load sentence-transformers model %r: %s", _MODEL, exc)
|
||||
|
||||
elif _BACKEND == "ollama":
|
||||
try:
|
||||
_embedder = OllamaEmbedder(_MODEL, _LLM_URL)
|
||||
EMBEDDING_AVAILABLE = True
|
||||
except Exception as exc:
|
||||
logger.warning("Ollama embedder init failed: %s", exc)
|
||||
|
||||
else:
|
||||
logger.warning("Unknown TURNSTONE_EMBED_BACKEND %r — embeddings disabled", _BACKEND)
|
||||
|
||||
return _embedder
|
||||
|
||||
|
||||
# ── BLOB serialisation helpers ────────────────────────────────────────────────
|
||||
|
||||
def pack_vector(vec: np.ndarray) -> bytes:
|
||||
"""Serialise a float32 numpy vector to a SQLite BLOB."""
|
||||
arr = vec.astype(np.float32)
|
||||
return struct.pack(f"{len(arr)}f", *arr.tolist())
|
||||
|
||||
|
||||
def unpack_vector(blob: bytes) -> np.ndarray:
|
||||
"""Deserialise a SQLite BLOB back to a float32 numpy vector."""
|
||||
n = len(blob) // 4 # 4 bytes per float32
|
||||
return np.array(struct.unpack(f"{n}f", blob), dtype=np.float32)
|
||||
|
||||
|
||||
def cosine_similarity(a: np.ndarray, b: np.ndarray) -> float:
|
||||
"""Cosine similarity between two L2-normalised vectors.
|
||||
|
||||
Both vectors are re-normalised defensively so callers need not pre-normalise.
|
||||
Returns 0.0 when either vector has zero norm.
|
||||
"""
|
||||
norm_a = np.linalg.norm(a)
|
||||
norm_b = np.linalg.norm(b)
|
||||
if norm_a == 0.0 or norm_b == 0.0:
|
||||
return 0.0
|
||||
return float(np.dot(a, b) / (norm_a * norm_b))
|
||||
|
|
@ -1,13 +1,17 @@
|
|||
"""Tests for app/context/embedder.py — graceful no-op without sqlite-vec."""
|
||||
"""Tests for app/context/embedder.py — delegates to app.services.embeddings."""
|
||||
import sqlite3
|
||||
import struct
|
||||
from pathlib import Path
|
||||
from unittest.mock import patch
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
from app.context import embedder as emb_mod
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def db(tmp_path):
|
||||
@pytest.fixture()
|
||||
def db(tmp_path: Path) -> Path:
|
||||
db_path = tmp_path / "t.db"
|
||||
conn = sqlite3.connect(str(db_path))
|
||||
conn.executescript("""
|
||||
|
|
@ -20,34 +24,78 @@ def db(tmp_path):
|
|||
REFERENCES context_documents(id) ON DELETE CASCADE,
|
||||
chunk_index INTEGER NOT NULL, text TEXT NOT NULL, embedding BLOB
|
||||
);
|
||||
INSERT INTO context_documents VALUES ('d1','test.md','markdown','hello',5,'2026-01-01T00:00:00+00:00');
|
||||
INSERT INTO context_documents
|
||||
VALUES ('d1','test.md','markdown','hello',5,'2026-01-01T00:00:00+00:00');
|
||||
INSERT INTO context_chunks VALUES ('c1','d1',0,'hello world',NULL);
|
||||
INSERT INTO context_chunks VALUES ('c2','d1',1,'second chunk',NULL);
|
||||
""")
|
||||
conn.commit()
|
||||
conn.close()
|
||||
return db_path
|
||||
|
||||
|
||||
def test_embed_skipped_when_extension_absent(db):
|
||||
with patch.object(emb_mod, "EMBEDDING_AVAILABLE", False):
|
||||
count = emb_mod.embed_chunks(db, "d1", "http://localhost:11434")
|
||||
def _mock_embedder(dim: int = 3) -> MagicMock:
|
||||
"""Return a mock Embedder that returns constant dim-length vectors."""
|
||||
m = MagicMock()
|
||||
m.dim = dim
|
||||
m.embed_batch.return_value = [np.zeros(dim, dtype=np.float32)] * 10
|
||||
return m
|
||||
|
||||
|
||||
class TestEmbedChunks:
|
||||
def test_returns_zero_when_no_embedder(self, db: Path) -> None:
|
||||
with patch("app.context.embedder.get_embedder", return_value=None):
|
||||
count = emb_mod.embed_chunks(db, "d1")
|
||||
assert count == 0
|
||||
|
||||
|
||||
def test_embed_calls_ollama_when_available(db):
|
||||
import httpx
|
||||
|
||||
class FakeResponse:
|
||||
status_code = 200
|
||||
def raise_for_status(self): pass
|
||||
def json(self): return {"embedding": [0.1, 0.2, 0.3]}
|
||||
|
||||
with patch.object(emb_mod, "EMBEDDING_AVAILABLE", True), \
|
||||
patch("app.context.embedder.httpx.post", return_value=FakeResponse()):
|
||||
count = emb_mod.embed_chunks(db, "d1", "http://localhost:11434")
|
||||
assert count == 1
|
||||
# Verify blob was written
|
||||
def test_returns_zero_when_no_unembedded_chunks(self, db: Path) -> None:
|
||||
# Pre-fill both chunks with a blob
|
||||
blob = struct.pack("3f", 0.1, 0.2, 0.3)
|
||||
conn = sqlite3.connect(str(db))
|
||||
row = conn.execute("SELECT embedding FROM context_chunks WHERE id='c1'").fetchone()
|
||||
conn.execute("UPDATE context_chunks SET embedding=?", (blob,))
|
||||
conn.commit()
|
||||
conn.close()
|
||||
assert row[0] is not None
|
||||
|
||||
embedder = _mock_embedder()
|
||||
with patch("app.context.embedder.get_embedder", return_value=embedder):
|
||||
count = emb_mod.embed_chunks(db, "d1")
|
||||
assert count == 0
|
||||
embedder.embed_batch.assert_not_called()
|
||||
|
||||
def test_embeds_all_null_chunks(self, db: Path) -> None:
|
||||
embedder = _mock_embedder(dim=3)
|
||||
with patch("app.context.embedder.get_embedder", return_value=embedder):
|
||||
count = emb_mod.embed_chunks(db, "d1")
|
||||
assert count == 2 # two chunks in fixture
|
||||
|
||||
def test_blobs_written_to_db(self, db: Path) -> None:
|
||||
vec = np.array([0.1, 0.2, 0.3], dtype=np.float32)
|
||||
embedder = _mock_embedder(dim=3)
|
||||
embedder.embed_batch.return_value = [vec, vec]
|
||||
|
||||
with patch("app.context.embedder.get_embedder", return_value=embedder):
|
||||
emb_mod.embed_chunks(db, "d1")
|
||||
|
||||
conn = sqlite3.connect(str(db))
|
||||
rows = conn.execute(
|
||||
"SELECT embedding FROM context_chunks WHERE document_id='d1'"
|
||||
).fetchall()
|
||||
conn.close()
|
||||
for (blob,) in rows:
|
||||
assert blob is not None
|
||||
unpacked = struct.unpack(f"{len(blob)//4}f", blob)
|
||||
assert len(unpacked) == 3
|
||||
|
||||
def test_legacy_llm_url_param_accepted(self, db: Path) -> None:
|
||||
"""Ensure backward-compat signature still works (llm_url ignored)."""
|
||||
embedder = _mock_embedder()
|
||||
with patch("app.context.embedder.get_embedder", return_value=embedder):
|
||||
count = emb_mod.embed_chunks(db, "d1", "http://localhost:11434", "nomic-embed-text")
|
||||
assert count == 2
|
||||
|
||||
def test_embed_batch_error_returns_zero(self, db: Path) -> None:
|
||||
embedder = _mock_embedder()
|
||||
embedder.embed_batch.side_effect = RuntimeError("model exploded")
|
||||
with patch("app.context.embedder.get_embedder", return_value=embedder):
|
||||
count = emb_mod.embed_chunks(db, "d1")
|
||||
assert count == 0
|
||||
|
|
|
|||
Loading…
Reference in a new issue