- SynthesisResult.citations changed from list[Citation] to tuple[Citation, ...] so frozen=True dataclass is genuinely immutable end-to-end - synthesize() now builds tuple via generator expression - retriever._combined: add comment explaining L2 distance inversion - retriever.hybrid_search: comment on _bm25._chunks private access - test_synthesizer_builds_context_from_chunks: drop vacuous str(call_args) fallback; assert directly on call_args.args[0]
123 lines
4.1 KiB
Python
123 lines
4.1 KiB
Python
# app/services/retriever.py
|
|
"""
|
|
Hybrid BM25 + semantic retriever.
|
|
|
|
BSL 1.1 — semantic path requires PAGEPIPER_OLLAMA_URL (BYOK gate).
|
|
BM25-only path is MIT and has no gate.
|
|
"""
|
|
from __future__ import annotations
|
|
|
|
import logging
|
|
from dataclasses import dataclass
|
|
|
|
from app.services.bm25_index import BM25Index
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
@dataclass(frozen=True)
|
|
class RetrievedChunk:
|
|
"""A chunk returned by the retriever, with source scores."""
|
|
|
|
chunk_id: str
|
|
doc_id: str
|
|
page_number: int
|
|
text: str
|
|
bm25_score: float
|
|
vector_score: float | None
|
|
|
|
|
|
class Retriever:
|
|
def __init__(self, bm25: BM25Index) -> None:
|
|
self._bm25 = bm25
|
|
|
|
def hybrid_search(
|
|
self,
|
|
query: str,
|
|
top_k: int,
|
|
doc_ids: list[str] | None,
|
|
db_path: str,
|
|
vec_db_path: str,
|
|
llm, # LLMRouter | None — caller must pass
|
|
) -> list[RetrievedChunk]:
|
|
"""
|
|
Merge BM25 and semantic results.
|
|
Falls back to BM25-only if llm is None.
|
|
"""
|
|
if llm is None:
|
|
return self._bm25_only(query, top_k, doc_ids, db_path)
|
|
|
|
from circuitforge_core.vector.sqlite_vec import LocalSQLiteVecStore
|
|
|
|
self._bm25.ensure_fresh(db_path)
|
|
bm25_hits = {
|
|
r.chunk_id: r
|
|
for r in self._bm25.query(query, top_k=top_k * 2, doc_ids=doc_ids)
|
|
}
|
|
|
|
vec = llm.embed([query])[0]
|
|
store = LocalSQLiteVecStore(db_path=vec_db_path, table="page_vecs", dimensions=768)
|
|
filter_meta = {"doc_id": doc_ids[0]} if doc_ids and len(doc_ids) == 1 else None
|
|
vec_hits = store.query(vec, top_k=top_k * 2, filter_metadata=filter_meta)
|
|
|
|
if doc_ids and len(doc_ids) > 1:
|
|
vec_hits = [h for h in vec_hits if h.metadata.get("doc_id") in doc_ids]
|
|
|
|
# Merge: BM25 hits take priority; vector hits fill in additional results
|
|
merged: dict[str, RetrievedChunk] = {}
|
|
for cid, r in bm25_hits.items():
|
|
merged[cid] = RetrievedChunk(
|
|
chunk_id=cid,
|
|
doc_id=r.doc_id,
|
|
page_number=r.page_number,
|
|
text=r.text,
|
|
bm25_score=r.score,
|
|
vector_score=None,
|
|
)
|
|
for vh in vec_hits:
|
|
# _chunks is the loaded list of dicts from BM25Index; no public accessor exists
|
|
text = next((c["text"] for c in self._bm25._chunks if c["id"] == vh.id), "")
|
|
if vh.id in merged:
|
|
existing = merged[vh.id]
|
|
merged[vh.id] = RetrievedChunk(
|
|
chunk_id=existing.chunk_id,
|
|
doc_id=existing.doc_id,
|
|
page_number=existing.page_number,
|
|
text=existing.text,
|
|
bm25_score=existing.bm25_score,
|
|
vector_score=vh.score,
|
|
)
|
|
else:
|
|
merged[vh.id] = RetrievedChunk(
|
|
chunk_id=vh.id,
|
|
doc_id=vh.metadata.get("doc_id", ""),
|
|
page_number=int(vh.metadata.get("page_number", 0)),
|
|
text=text,
|
|
bm25_score=0.0,
|
|
vector_score=vh.score,
|
|
)
|
|
|
|
def _combined(r: RetrievedChunk) -> float:
|
|
bm25 = r.bm25_score
|
|
# sqlite-vec returns L2 distance (lower=better); invert to [0,1] higher-is-better
|
|
vec = (1.0 / (1.0 + r.vector_score)) if r.vector_score is not None else 0.0
|
|
return bm25 * 0.5 + vec * 0.5
|
|
|
|
ranked = sorted(merged.values(), key=_combined, reverse=True)
|
|
return ranked[:top_k]
|
|
|
|
def _bm25_only(
|
|
self, query: str, top_k: int, doc_ids: list[str] | None, db_path: str
|
|
) -> list[RetrievedChunk]:
|
|
self._bm25.ensure_fresh(db_path)
|
|
return [
|
|
RetrievedChunk(
|
|
chunk_id=r.chunk_id,
|
|
doc_id=r.doc_id,
|
|
page_number=r.page_number,
|
|
text=r.text,
|
|
bm25_score=r.score,
|
|
vector_score=None,
|
|
)
|
|
for r in self._bm25.query(query, top_k=top_k, doc_ids=doc_ids)
|
|
]
|