pagepiper/app/services/retriever.py
pyr0ball 17cdb552a3 fix: T7 quality — SynthesisResult.citations tuple, retriever comments, test assertion
- SynthesisResult.citations changed from list[Citation] to tuple[Citation, ...]
  so frozen=True dataclass is genuinely immutable end-to-end
- synthesize() now builds tuple via generator expression
- retriever._combined: add comment explaining L2 distance inversion
- retriever.hybrid_search: comment on _bm25._chunks private access
- test_synthesizer_builds_context_from_chunks: drop vacuous str(call_args)
  fallback; assert directly on call_args.args[0]
2026-05-04 17:51:22 -07:00

123 lines
4.1 KiB
Python

# app/services/retriever.py
"""
Hybrid BM25 + semantic retriever.
BSL 1.1 — semantic path requires PAGEPIPER_OLLAMA_URL (BYOK gate).
BM25-only path is MIT and has no gate.
"""
from __future__ import annotations
import logging
from dataclasses import dataclass
from app.services.bm25_index import BM25Index
logger = logging.getLogger(__name__)
@dataclass(frozen=True)
class RetrievedChunk:
"""A chunk returned by the retriever, with source scores."""
chunk_id: str
doc_id: str
page_number: int
text: str
bm25_score: float
vector_score: float | None
class Retriever:
def __init__(self, bm25: BM25Index) -> None:
self._bm25 = bm25
def hybrid_search(
self,
query: str,
top_k: int,
doc_ids: list[str] | None,
db_path: str,
vec_db_path: str,
llm, # LLMRouter | None — caller must pass
) -> list[RetrievedChunk]:
"""
Merge BM25 and semantic results.
Falls back to BM25-only if llm is None.
"""
if llm is None:
return self._bm25_only(query, top_k, doc_ids, db_path)
from circuitforge_core.vector.sqlite_vec import LocalSQLiteVecStore
self._bm25.ensure_fresh(db_path)
bm25_hits = {
r.chunk_id: r
for r in self._bm25.query(query, top_k=top_k * 2, doc_ids=doc_ids)
}
vec = llm.embed([query])[0]
store = LocalSQLiteVecStore(db_path=vec_db_path, table="page_vecs", dimensions=768)
filter_meta = {"doc_id": doc_ids[0]} if doc_ids and len(doc_ids) == 1 else None
vec_hits = store.query(vec, top_k=top_k * 2, filter_metadata=filter_meta)
if doc_ids and len(doc_ids) > 1:
vec_hits = [h for h in vec_hits if h.metadata.get("doc_id") in doc_ids]
# Merge: BM25 hits take priority; vector hits fill in additional results
merged: dict[str, RetrievedChunk] = {}
for cid, r in bm25_hits.items():
merged[cid] = RetrievedChunk(
chunk_id=cid,
doc_id=r.doc_id,
page_number=r.page_number,
text=r.text,
bm25_score=r.score,
vector_score=None,
)
for vh in vec_hits:
# _chunks is the loaded list of dicts from BM25Index; no public accessor exists
text = next((c["text"] for c in self._bm25._chunks if c["id"] == vh.id), "")
if vh.id in merged:
existing = merged[vh.id]
merged[vh.id] = RetrievedChunk(
chunk_id=existing.chunk_id,
doc_id=existing.doc_id,
page_number=existing.page_number,
text=existing.text,
bm25_score=existing.bm25_score,
vector_score=vh.score,
)
else:
merged[vh.id] = RetrievedChunk(
chunk_id=vh.id,
doc_id=vh.metadata.get("doc_id", ""),
page_number=int(vh.metadata.get("page_number", 0)),
text=text,
bm25_score=0.0,
vector_score=vh.score,
)
def _combined(r: RetrievedChunk) -> float:
bm25 = r.bm25_score
# sqlite-vec returns L2 distance (lower=better); invert to [0,1] higher-is-better
vec = (1.0 / (1.0 + r.vector_score)) if r.vector_score is not None else 0.0
return bm25 * 0.5 + vec * 0.5
ranked = sorted(merged.values(), key=_combined, reverse=True)
return ranked[:top_k]
def _bm25_only(
self, query: str, top_k: int, doc_ids: list[str] | None, db_path: str
) -> list[RetrievedChunk]:
self._bm25.ensure_fresh(db_path)
return [
RetrievedChunk(
chunk_id=r.chunk_id,
doc_id=r.doc_id,
page_number=r.page_number,
text=r.text,
bm25_score=r.score,
vector_score=None,
)
for r in self._bm25.query(query, top_k=top_k, doc_ids=doc_ids)
]