pagepiper/app/services/retriever.py
pyr0ball 347b391c6e fix: prevent LLM hallucination when retrieval returns low-signal results
- Strengthen synthesizer system prompt: hard 'respond with exactly' constraint
  instead of soft 'say so'; removes any wiggle room for the model to supplement
  from training data
- Add early return in synthesize() when chunks is empty (belt-and-suspenders
  alongside the existing guard in chat.py)
- Add MIN_SIGNAL threshold (0.01) in retriever: if the top combined score is
  below the threshold, return empty so the caller's no-results path fires instead
  of sending noise chunks to the LLM
2026-05-06 10:17:51 -07:00

203 lines
7.2 KiB
Python

# app/services/retriever.py
"""
Hybrid BM25 + semantic retriever.
BSL 1.1 — semantic path requires PAGEPIPER_OLLAMA_URL (BYOK gate).
BM25-only path is MIT and has no gate.
"""
from __future__ import annotations
import logging
import sqlite3
from dataclasses import dataclass
from app.services.bm25_index import BM25Index
logger = logging.getLogger(__name__)
def _fetch_adjacent(
hits: list["RetrievedChunk"],
db_path: str,
window: int = 1,
) -> list["RetrievedChunk"]:
"""Return chunks immediately before/after each hit that aren't already in the hit set.
Definitional passages often start mid-sentence because the EPUB/PDF chunk
boundary fell mid-paragraph. Fetching the preceding chunk restores the subject
so the LLM can understand 'them' / 'they' references correctly.
"""
if not hits:
return []
existing_keys = {(c.doc_id, c.page_number) for c in hits}
needed: dict[str, set[int]] = {}
for c in hits:
for delta in range(-window, window + 1):
if delta == 0:
continue
adj_page = c.page_number + delta
if adj_page > 0 and (c.doc_id, adj_page) not in existing_keys:
needed.setdefault(c.doc_id, set()).add(adj_page)
if not needed:
return []
extra: list[RetrievedChunk] = []
try:
conn = sqlite3.connect(db_path)
conn.row_factory = sqlite3.Row
for doc_id, pages in needed.items():
placeholders = ",".join("?" * len(pages))
rows = conn.execute(
f"SELECT id, doc_id, page_number, text FROM page_chunks "
f"WHERE doc_id=? AND page_number IN ({placeholders})",
[doc_id] + sorted(pages),
).fetchall()
for row in rows:
extra.append(
RetrievedChunk(
chunk_id=row["id"],
doc_id=row["doc_id"],
page_number=row["page_number"],
text=row["text"],
bm25_score=0.0,
vector_score=None,
)
)
conn.close()
except Exception as exc:
logger.warning("Context expansion query failed (non-fatal): %s", exc)
return extra
@dataclass(frozen=True)
class RetrievedChunk:
"""A chunk returned by the retriever, with source scores."""
chunk_id: str
doc_id: str
page_number: int
text: str
bm25_score: float
vector_score: float | None
class Retriever:
def __init__(self, bm25: BM25Index) -> None:
self._bm25 = bm25
def hybrid_search(
self,
query: str,
top_k: int,
doc_ids: list[str] | None,
db_path: str,
vec_db_path: str,
llm, # LLMRouter | None — caller must pass
) -> list[RetrievedChunk]:
"""
Merge BM25 and semantic results.
Falls back to BM25-only if llm is None.
"""
if llm is None:
return self._bm25_only(query, top_k, doc_ids, db_path)
from circuitforge_core.vector.sqlite_vec import LocalSQLiteVecStore
self._bm25.ensure_fresh(db_path)
bm25_hits = {
r.chunk_id: r
for r in self._bm25.query(query, top_k=top_k * 2, doc_ids=doc_ids)
}
try:
vec = llm.embed([query])[0]
except Exception as exc:
logger.warning("Embed failed, falling back to BM25-only: %s", exc)
return self._bm25_only(query, top_k, doc_ids, db_path)
from app.config import VEC_DIMENSIONS
store = LocalSQLiteVecStore(db_path=vec_db_path, table="page_vecs", dimensions=VEC_DIMENSIONS)
# sqlite-vec applies filter_metadata as a Python post-filter after fetching k
# nearest globally. When the corpus spans many documents and only a subset is
# selected, most of those k candidates are from non-target docs and get dropped,
# leaving too few vector hits. Oversample heavily and filter in Python instead.
if doc_ids:
vec_candidates = store.query(vec, top_k=top_k * 20)
vec_hits = [h for h in vec_candidates if h.metadata.get("doc_id") in doc_ids]
else:
vec_hits = store.query(vec, top_k=top_k * 2)
# Merge: BM25 hits take priority; vector hits fill in additional results
merged: dict[str, RetrievedChunk] = {}
for cid, r in bm25_hits.items():
merged[cid] = RetrievedChunk(
chunk_id=cid,
doc_id=r.doc_id,
page_number=r.page_number,
text=r.text,
bm25_score=r.score,
vector_score=None,
)
for vh in vec_hits:
# _chunks is the loaded list of dicts from BM25Index; no public accessor exists
text = next((c["text"] for c in self._bm25._chunks if c["id"] == vh.entry_id), "")
if vh.entry_id in merged:
existing = merged[vh.entry_id]
merged[vh.entry_id] = RetrievedChunk(
chunk_id=existing.chunk_id,
doc_id=existing.doc_id,
page_number=existing.page_number,
text=existing.text,
bm25_score=existing.bm25_score,
vector_score=vh.score,
)
else:
merged[vh.entry_id] = RetrievedChunk(
chunk_id=vh.entry_id,
doc_id=vh.metadata.get("doc_id", ""),
page_number=int(vh.metadata.get("page_number", 0)),
text=text,
bm25_score=0.0,
vector_score=vh.score,
)
def _combined(r: RetrievedChunk) -> float:
bm25 = r.bm25_score
# sqlite-vec returns L2 distance (lower=better); invert to [0,1] higher-is-better
vec = (1.0 / (1.0 + r.vector_score)) if r.vector_score is not None else 0.0
return bm25 * 0.5 + vec * 0.5
ranked = sorted(merged.values(), key=_combined, reverse=True)[:top_k]
# Discard results where the best match is pure noise (neither BM25 term
# overlap nor vector similarity exceeded the minimum signal threshold).
# This lets the caller's empty-result guard fire instead of sending
# low-confidence chunks to the LLM where it fills gaps with training data.
MIN_SIGNAL = 0.01
if ranked and _combined(ranked[0]) < MIN_SIGNAL:
return []
adjacent = _fetch_adjacent(ranked, db_path)
return ranked + adjacent
def _bm25_only(
self, query: str, top_k: int, doc_ids: list[str] | None, db_path: str
) -> list[RetrievedChunk]:
self._bm25.ensure_fresh(db_path)
hits = [
RetrievedChunk(
chunk_id=r.chunk_id,
doc_id=r.doc_id,
page_number=r.page_number,
text=r.text,
bm25_score=r.score,
vector_score=None,
)
for r in self._bm25.query(query, top_k=top_k, doc_ids=doc_ids)
]
MIN_SIGNAL = 0.01
if hits and hits[0].bm25_score < MIN_SIGNAL:
return []
adjacent = _fetch_adjacent(hits, db_path)
return hits + adjacent