turnstone/app/context/embedder.py
pyr0ball 5f32a6678d refactor: extract embeddings service layer — decouple context embedder from Ollama
- New app/services/embeddings.py: TURNSTONE_EMBED_* env vars, multi-backend support
- embedder.py delegates to service layer; re-exports EMBEDDING_AVAILABLE for compat
- retriever.py updated to use service layer
- Test coverage updated in tests/context/test_embedder.py
2026-05-25 11:01:25 -07:00

81 lines
2.4 KiB
Python

"""Context chunk embedding — BSL licensed.
Thin wrapper around app.services.embeddings that handles the DB I/O for
context_chunks. All backend configuration (model, device, backend type) is
delegated to the service layer via TURNSTONE_EMBED_* env vars.
Re-exports EMBEDDING_AVAILABLE so callers that imported it from here continue
to work without changes.
"""
from __future__ import annotations
import logging
import sqlite3
from pathlib import Path
from app.services.embeddings import (
EMBEDDING_AVAILABLE, # re-export for backward compat
get_embedder,
pack_vector,
)
__all__ = ["EMBEDDING_AVAILABLE", "embed_chunks"]
logger = logging.getLogger(__name__)
def embed_chunks(
db_path: Path,
document_id: str,
# Legacy params kept for backward compat — ignored when the ST backend is active.
llm_url: str = "",
model: str = "",
timeout: float = 60.0,
) -> int:
"""Embed all un-embedded chunks for *document_id*.
Uses the configured embedder (sentence-transformers by default; Ollama when
TURNSTONE_EMBED_BACKEND=ollama). Returns the count of newly embedded chunks.
Returns 0 silently when no embedder is available.
The legacy ``llm_url`` and ``model`` parameters are accepted but ignored when
the sentence-transformers backend is active — configure via env vars instead.
"""
embedder = get_embedder()
if embedder is None:
return 0
conn = sqlite3.connect(str(db_path))
conn.execute("PRAGMA journal_mode=WAL")
conn.row_factory = sqlite3.Row
rows = conn.execute(
"SELECT id, text FROM context_chunks WHERE document_id = ? AND embedding IS NULL",
(document_id,),
).fetchall()
if not rows:
conn.close()
return 0
texts = [r["text"] for r in rows]
ids = [r["id"] for r in rows]
count = 0
try:
vectors = embedder.embed_batch(texts)
for chunk_id, vec in zip(ids, vectors):
blob = pack_vector(vec)
conn.execute(
"UPDATE context_chunks SET embedding = ? WHERE id = ?",
(blob, chunk_id),
)
count += 1
conn.commit()
except Exception as exc:
logger.warning("Batch embedding failed for document %s: %s", document_id, exc)
finally:
conn.close()
logger.debug("Embedded %d chunk(s) for document %s", count, document_id)
return count