diff --git a/app/services/bm25_index.py b/app/services/bm25_index.py new file mode 100644 index 0000000..c3f03c6 --- /dev/null +++ b/app/services/bm25_index.py @@ -0,0 +1,96 @@ +""" +BM25 keyword search over the page_chunks corpus. + +MIT — no tier gate. Available to all users with no Ollama required. +""" + +from __future__ import annotations + +import logging +import sqlite3 +from dataclasses import dataclass + +from rank_bm25 import BM25Okapi + +logger = logging.getLogger(__name__) + + +@dataclass(frozen=True) +class BM25Result: + """A single BM25 search result.""" + + chunk_id: str + doc_id: str + page_number: int + text: str + score: float + + +class BM25Index: + """ + In-memory BM25 index over page_chunks. Rebuilt lazily on demand. + + Thread-safety note: rebuilt synchronously in the request thread. For + single-user local deployments this is acceptable. + """ + + def __init__(self) -> None: + self._index: BM25Okapi | None = None + self._chunks: list[dict] = [] + self._dirty: bool = True + + def mark_dirty(self) -> None: + """Signal that the index needs rebuilding (call after any ingest completes).""" + self._dirty = True + + def ensure_fresh(self, db_path: str) -> None: + """Rebuild from SQLite if dirty.""" + if not self._dirty: + return + conn = sqlite3.connect(db_path) + conn.row_factory = sqlite3.Row + rows = conn.execute( + "SELECT id, doc_id, page_number, text FROM page_chunks ORDER BY doc_id, page_number" + ).fetchall() + conn.close() + self._load_chunks([dict(r) for r in rows]) + self._dirty = False + logger.info("BM25 index rebuilt: %d chunks", len(self._chunks)) + + def _load_chunks(self, chunks: list[dict]) -> None: + self._chunks = chunks + tokenized = [c["text"].lower().split() for c in chunks] + self._index = BM25Okapi(tokenized) if tokenized else None + + def query( + self, + query_text: str, + top_k: int = 10, + doc_ids: list[str] | None = None, + ) -> list[BM25Result]: + """Search the corpus. Returns results sorted by descending BM25 score.""" + if not self._index or not self._chunks: + return [] + + scores = self._index.get_scores(query_text.lower().split()) + ranked = sorted(enumerate(scores), key=lambda x: x[1], reverse=True) + + results: list[BM25Result] = [] + for i, score in ranked: + if score <= 0: + continue + c = self._chunks[i] + if doc_ids is not None and c["doc_id"] not in doc_ids: + continue + results.append( + BM25Result( + chunk_id=c["id"], + doc_id=c["doc_id"], + page_number=c["page_number"], + text=c["text"], + score=float(score), + ) + ) + if len(results) >= top_k: + break + return results diff --git a/tests/test_bm25_index.py b/tests/test_bm25_index.py new file mode 100644 index 0000000..221cea2 --- /dev/null +++ b/tests/test_bm25_index.py @@ -0,0 +1,95 @@ +"""Tests for app.services.bm25_index.""" + +from __future__ import annotations + +import pytest + +from app.services.bm25_index import BM25Index, BM25Result + + +def _seeded_index() -> BM25Index: + idx = BM25Index() + idx._load_chunks( + [ + { + "id": "c1", + "doc_id": "book-a", + "page_number": 1, + "text": "Fireball deals 8d6 fire damage on a failed Dexterity saving throw.", + }, + { + "id": "c2", + "doc_id": "book-a", + "page_number": 2, + "text": "A wizard can cast one spell per turn unless they have Action Surge.", + }, + { + "id": "c3", + "doc_id": "book-b", + "page_number": 5, + "text": "Grapple rules apply when the attacker uses the Attack action to grab a target.", + }, + ] + ) + return idx + + +def test_query_returns_relevant_result(): + idx = _seeded_index() + results = idx.query("fireball fire damage") + assert len(results) >= 1 + assert results[0].chunk_id == "c1" + assert results[0].score > 0 + + +def test_query_respects_top_k(): + idx = _seeded_index() + results = idx.query("rules", top_k=2) + assert len(results) <= 2 + + +def test_query_filters_by_doc_id(): + idx = _seeded_index() + results = idx.query("rules", doc_ids=["book-b"]) + assert all(r.doc_id == "book-b" for r in results) + + +def test_query_empty_corpus_returns_empty(): + idx = BM25Index() + idx._load_chunks([]) + results = idx.query("anything") + assert results == [] + + +def test_mark_dirty_triggers_rebuild(tmp_path): + import sqlite3 + + db_path = str(tmp_path / "test.db") + conn = sqlite3.connect(db_path) + conn.execute( + "CREATE TABLE page_chunks(id TEXT, doc_id TEXT, page_number INT, text TEXT)" + ) + conn.execute( + "INSERT INTO page_chunks VALUES ('x1','doc-1',1,'Ranger favored enemy favored terrain terrain bonuses bonuses action attack')" + ) + conn.execute( + "INSERT INTO page_chunks VALUES ('x2','doc-1',2,'Wizard can cast spells and perform actions')" + ) + conn.execute( + "INSERT INTO page_chunks VALUES ('x3','doc-1',3,'Fighter attacks and deals damage with weapon')" + ) + conn.commit() + conn.close() + + idx = BM25Index() + idx.mark_dirty() + idx.ensure_fresh(db_path) + results = idx.query("ranger terrain") + assert len(results) >= 1 + assert results[0].chunk_id == "x1" + + +def test_bm25_result_is_frozen(): + r = BM25Result(chunk_id="x", doc_id="d", page_number=1, text="hello", score=0.5) + with pytest.raises(Exception): + r.score = 1.0 # type: ignore[misc]