turnstone/app/context/retriever.py
pyr0ball 854818ca1a fix(db): add timeout=30s to all sqlite3.connect() calls across app
Watcher, REST endpoints, services (search, incidents, blocklist),
MCP server, context retriever, embedder, glean_scheduler, and
doc_upload all used the default 5-second SQLite busy timeout.
During collect glean write phases, watcher flush threads were hitting
'database is locked' errors when the glean held the write lock longer
than 5 seconds.

All connections now use timeout=30.0, matching the pipeline fix
from commit ee39ffb. No logic changes.
2026-05-26 23:12:48 -07:00

183 lines
6.8 KiB
Python

"""Context retrieval — structured keyword lookup (Free) + chunk search — MIT licensed.
Two retrieval modes for context_chunks:
Vector search — cosine similarity over stored embeddings (when available)
Keyword search — LIKE-based fallback when no embedder is configured
Both modes are called from retrieve_context(); the best available mode is used
automatically so callers need not check EMBEDDING_AVAILABLE themselves.
"""
from __future__ import annotations
import logging
import sqlite3
from dataclasses import dataclass, field
from pathlib import Path
import numpy as np
from app.services.embeddings import (
EMBEDDING_AVAILABLE,
cosine_similarity,
get_embedder,
unpack_vector,
)
logger = logging.getLogger(__name__)
@dataclass
class RetrievedContext:
facts: list[dict[str, str]] = field(default_factory=list)
chunks: list[dict[str, str]] = field(default_factory=list)
# ── Structured fact retrieval (always runs) ───────────────────────────────────
def get_relevant_facts(db_path: Path, query: str) -> list[dict[str, str]]:
"""Keyword match against context_facts. Always runs — Free tier."""
try:
conn = sqlite3.connect(str(db_path), timeout=30.0)
conn.execute("PRAGMA journal_mode=WAL")
conn.row_factory = sqlite3.Row
keywords = [w.lower() for w in query.split() if len(w) > 2]
if not keywords:
rows = conn.execute(
"SELECT category, key, value, source FROM context_facts"
" ORDER BY category LIMIT 20"
).fetchall()
else:
conditions = " OR ".join(
"(LOWER(key) LIKE ? OR LOWER(value) LIKE ?)" for _ in keywords
)
params: list[str] = []
for kw in keywords:
params.extend([f"%{kw}%", f"%{kw}%"])
rows = conn.execute(
f"SELECT category, key, value, source FROM context_facts"
f" WHERE {conditions} ORDER BY category LIMIT 10",
params,
).fetchall()
conn.close()
return [dict(r) for r in rows]
except sqlite3.OperationalError:
return []
# ── Chunk retrieval: vector path ──────────────────────────────────────────────
def _search_chunks_vector(
db_path: Path,
query: str,
top_k: int = 3,
) -> list[dict[str, str]]:
"""Cosine similarity search over embedded context_chunks.
Loads all stored embeddings into memory and scores in-process with numpy.
Skips any chunk whose BLOB dimension does not match the current model dim
(stale embeddings from a previous model — they will be re-embedded on the
next document upload).
Returns at most *top_k* results ordered by similarity descending.
"""
embedder = get_embedder()
if embedder is None:
return []
try:
query_vec: np.ndarray = embedder.embed(query)
model_dim: int = embedder.dim
except Exception as exc:
logger.warning("Query embedding failed: %s", exc)
return []
try:
conn = sqlite3.connect(str(db_path), timeout=30.0)
conn.execute("PRAGMA journal_mode=WAL")
conn.row_factory = sqlite3.Row
rows = conn.execute(
"SELECT cc.id, cc.text, cc.embedding, cd.filename"
" FROM context_chunks cc"
" JOIN context_documents cd ON cc.document_id = cd.id"
" WHERE cc.embedding IS NOT NULL"
).fetchall()
conn.close()
except sqlite3.OperationalError:
return []
scored: list[tuple[float, dict[str, str]]] = []
for row in rows:
blob: bytes = row["embedding"]
# Guard against blobs from a different-dimension model
if len(blob) // 4 != model_dim:
continue
try:
chunk_vec = unpack_vector(blob)
score = cosine_similarity(query_vec, chunk_vec)
scored.append((score, {"text": row["text"], "filename": row["filename"]}))
except Exception:
continue
scored.sort(key=lambda t: t[0], reverse=True)
return [item for _, item in scored[:top_k]]
# ── Chunk retrieval: keyword fallback ─────────────────────────────────────────
def _search_chunks_keyword(db_path: Path, query: str) -> list[dict[str, str]]:
"""LIKE-based keyword search across context_chunks. Fallback when no embedder."""
try:
conn = sqlite3.connect(str(db_path), timeout=30.0)
conn.execute("PRAGMA journal_mode=WAL")
conn.row_factory = sqlite3.Row
keywords = [w.lower() for w in query.split() if len(w) > 2][:5]
if not keywords:
conn.close()
return []
conditions = " OR ".join("LOWER(cc.text) LIKE ?" for _ in keywords)
params = [f"%{kw}%" for kw in keywords]
rows = conn.execute(
f"SELECT cc.text, cd.filename FROM context_chunks cc"
f" JOIN context_documents cd ON cc.document_id = cd.id"
f" WHERE {conditions} LIMIT 3",
params,
).fetchall()
conn.close()
return [{"text": r["text"], "filename": r["filename"]} for r in rows]
except sqlite3.OperationalError:
return []
# ── Public interface ──────────────────────────────────────────────────────────
def retrieve_context(db_path: Path, query: str) -> RetrievedContext:
"""Retrieve structured facts and relevant chunks for a query.
Chunk retrieval uses vector search when an embedder is available and at
least one embedded chunk exists; falls back to keyword search otherwise.
"""
facts = get_relevant_facts(db_path, query)
if EMBEDDING_AVAILABLE:
chunks = _search_chunks_vector(db_path, query)
if not chunks:
# Vector search returned nothing (no embedded chunks yet) — fall back.
chunks = _search_chunks_keyword(db_path, query)
else:
chunks = _search_chunks_keyword(db_path, query)
return RetrievedContext(facts=facts, chunks=chunks)
def format_context_block(ctx: RetrievedContext) -> str | None:
"""Format context for injection into an LLM prompt. Returns None when empty."""
lines: list[str] = []
if ctx.facts:
lines.append("Known environment facts:")
for f in ctx.facts:
lines.append(f" [{f['category']}] {f['key']}: {f['value']}")
if ctx.chunks:
lines.append("Relevant documentation:")
for c in ctx.chunks:
lines.append(f" [{c['filename']}] {c['text'][:200]}")
return "\n".join(lines) if lines else None