96 lines
2.8 KiB
Python
96 lines
2.8 KiB
Python
"""
|
|
BM25 keyword search over the page_chunks corpus.
|
|
|
|
MIT — no tier gate. Available to all users with no Ollama required.
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
import logging
|
|
import sqlite3
|
|
from dataclasses import dataclass
|
|
|
|
from rank_bm25 import BM25Okapi
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
@dataclass(frozen=True)
|
|
class BM25Result:
|
|
"""A single BM25 search result."""
|
|
|
|
chunk_id: str
|
|
doc_id: str
|
|
page_number: int
|
|
text: str
|
|
score: float
|
|
|
|
|
|
class BM25Index:
|
|
"""
|
|
In-memory BM25 index over page_chunks. Rebuilt lazily on demand.
|
|
|
|
Thread-safety note: rebuilt synchronously in the request thread. For
|
|
single-user local deployments this is acceptable.
|
|
"""
|
|
|
|
def __init__(self) -> None:
|
|
self._index: BM25Okapi | None = None
|
|
self._chunks: list[dict] = []
|
|
self._dirty: bool = True
|
|
|
|
def mark_dirty(self) -> None:
|
|
"""Signal that the index needs rebuilding (call after any ingest completes)."""
|
|
self._dirty = True
|
|
|
|
def ensure_fresh(self, db_path: str) -> None:
|
|
"""Rebuild from SQLite if dirty."""
|
|
if not self._dirty:
|
|
return
|
|
conn = sqlite3.connect(db_path)
|
|
conn.row_factory = sqlite3.Row
|
|
rows = conn.execute(
|
|
"SELECT id, doc_id, page_number, text FROM page_chunks ORDER BY doc_id, page_number"
|
|
).fetchall()
|
|
conn.close()
|
|
self._load_chunks([dict(r) for r in rows])
|
|
self._dirty = False
|
|
logger.info("BM25 index rebuilt: %d chunks", len(self._chunks))
|
|
|
|
def _load_chunks(self, chunks: list[dict]) -> None:
|
|
self._chunks = chunks
|
|
tokenized = [c["text"].lower().split() for c in chunks]
|
|
self._index = BM25Okapi(tokenized) if tokenized else None
|
|
|
|
def query(
|
|
self,
|
|
query_text: str,
|
|
top_k: int = 10,
|
|
doc_ids: list[str] | None = None,
|
|
) -> list[BM25Result]:
|
|
"""Search the corpus. Returns results sorted by descending BM25 score."""
|
|
if not self._index or not self._chunks:
|
|
return []
|
|
|
|
scores = self._index.get_scores(query_text.lower().split())
|
|
ranked = sorted(enumerate(scores), key=lambda x: x[1], reverse=True)
|
|
|
|
results: list[BM25Result] = []
|
|
for i, score in ranked:
|
|
if score <= 0:
|
|
continue
|
|
c = self._chunks[i]
|
|
if doc_ids is not None and c["doc_id"] not in doc_ids:
|
|
continue
|
|
results.append(
|
|
BM25Result(
|
|
chunk_id=c["id"],
|
|
doc_id=c["doc_id"],
|
|
page_number=c["page_number"],
|
|
text=c["text"],
|
|
score=float(score),
|
|
)
|
|
)
|
|
if len(results) >= top_k:
|
|
break
|
|
return results
|