feat(services): add BM25 index service (MIT)
This commit is contained in:
parent
abeb6089e5
commit
2253cd7da3
2 changed files with 191 additions and 0 deletions
96
app/services/bm25_index.py
Normal file
96
app/services/bm25_index.py
Normal file
|
|
@ -0,0 +1,96 @@
|
|||
"""
|
||||
BM25 keyword search over the page_chunks corpus.
|
||||
|
||||
MIT — no tier gate. Available to all users with no Ollama required.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import sqlite3
|
||||
from dataclasses import dataclass
|
||||
|
||||
from rank_bm25 import BM25Okapi
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class BM25Result:
|
||||
"""A single BM25 search result."""
|
||||
|
||||
chunk_id: str
|
||||
doc_id: str
|
||||
page_number: int
|
||||
text: str
|
||||
score: float
|
||||
|
||||
|
||||
class BM25Index:
|
||||
"""
|
||||
In-memory BM25 index over page_chunks. Rebuilt lazily on demand.
|
||||
|
||||
Thread-safety note: rebuilt synchronously in the request thread. For
|
||||
single-user local deployments this is acceptable.
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._index: BM25Okapi | None = None
|
||||
self._chunks: list[dict] = []
|
||||
self._dirty: bool = True
|
||||
|
||||
def mark_dirty(self) -> None:
|
||||
"""Signal that the index needs rebuilding (call after any ingest completes)."""
|
||||
self._dirty = True
|
||||
|
||||
def ensure_fresh(self, db_path: str) -> None:
|
||||
"""Rebuild from SQLite if dirty."""
|
||||
if not self._dirty:
|
||||
return
|
||||
conn = sqlite3.connect(db_path)
|
||||
conn.row_factory = sqlite3.Row
|
||||
rows = conn.execute(
|
||||
"SELECT id, doc_id, page_number, text FROM page_chunks ORDER BY doc_id, page_number"
|
||||
).fetchall()
|
||||
conn.close()
|
||||
self._load_chunks([dict(r) for r in rows])
|
||||
self._dirty = False
|
||||
logger.info("BM25 index rebuilt: %d chunks", len(self._chunks))
|
||||
|
||||
def _load_chunks(self, chunks: list[dict]) -> None:
|
||||
self._chunks = chunks
|
||||
tokenized = [c["text"].lower().split() for c in chunks]
|
||||
self._index = BM25Okapi(tokenized) if tokenized else None
|
||||
|
||||
def query(
|
||||
self,
|
||||
query_text: str,
|
||||
top_k: int = 10,
|
||||
doc_ids: list[str] | None = None,
|
||||
) -> list[BM25Result]:
|
||||
"""Search the corpus. Returns results sorted by descending BM25 score."""
|
||||
if not self._index or not self._chunks:
|
||||
return []
|
||||
|
||||
scores = self._index.get_scores(query_text.lower().split())
|
||||
ranked = sorted(enumerate(scores), key=lambda x: x[1], reverse=True)
|
||||
|
||||
results: list[BM25Result] = []
|
||||
for i, score in ranked:
|
||||
if score <= 0:
|
||||
continue
|
||||
c = self._chunks[i]
|
||||
if doc_ids is not None and c["doc_id"] not in doc_ids:
|
||||
continue
|
||||
results.append(
|
||||
BM25Result(
|
||||
chunk_id=c["id"],
|
||||
doc_id=c["doc_id"],
|
||||
page_number=c["page_number"],
|
||||
text=c["text"],
|
||||
score=float(score),
|
||||
)
|
||||
)
|
||||
if len(results) >= top_k:
|
||||
break
|
||||
return results
|
||||
95
tests/test_bm25_index.py
Normal file
95
tests/test_bm25_index.py
Normal file
|
|
@ -0,0 +1,95 @@
|
|||
"""Tests for app.services.bm25_index."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
|
||||
from app.services.bm25_index import BM25Index, BM25Result
|
||||
|
||||
|
||||
def _seeded_index() -> BM25Index:
|
||||
idx = BM25Index()
|
||||
idx._load_chunks(
|
||||
[
|
||||
{
|
||||
"id": "c1",
|
||||
"doc_id": "book-a",
|
||||
"page_number": 1,
|
||||
"text": "Fireball deals 8d6 fire damage on a failed Dexterity saving throw.",
|
||||
},
|
||||
{
|
||||
"id": "c2",
|
||||
"doc_id": "book-a",
|
||||
"page_number": 2,
|
||||
"text": "A wizard can cast one spell per turn unless they have Action Surge.",
|
||||
},
|
||||
{
|
||||
"id": "c3",
|
||||
"doc_id": "book-b",
|
||||
"page_number": 5,
|
||||
"text": "Grapple rules apply when the attacker uses the Attack action to grab a target.",
|
||||
},
|
||||
]
|
||||
)
|
||||
return idx
|
||||
|
||||
|
||||
def test_query_returns_relevant_result():
|
||||
idx = _seeded_index()
|
||||
results = idx.query("fireball fire damage")
|
||||
assert len(results) >= 1
|
||||
assert results[0].chunk_id == "c1"
|
||||
assert results[0].score > 0
|
||||
|
||||
|
||||
def test_query_respects_top_k():
|
||||
idx = _seeded_index()
|
||||
results = idx.query("rules", top_k=2)
|
||||
assert len(results) <= 2
|
||||
|
||||
|
||||
def test_query_filters_by_doc_id():
|
||||
idx = _seeded_index()
|
||||
results = idx.query("rules", doc_ids=["book-b"])
|
||||
assert all(r.doc_id == "book-b" for r in results)
|
||||
|
||||
|
||||
def test_query_empty_corpus_returns_empty():
|
||||
idx = BM25Index()
|
||||
idx._load_chunks([])
|
||||
results = idx.query("anything")
|
||||
assert results == []
|
||||
|
||||
|
||||
def test_mark_dirty_triggers_rebuild(tmp_path):
|
||||
import sqlite3
|
||||
|
||||
db_path = str(tmp_path / "test.db")
|
||||
conn = sqlite3.connect(db_path)
|
||||
conn.execute(
|
||||
"CREATE TABLE page_chunks(id TEXT, doc_id TEXT, page_number INT, text TEXT)"
|
||||
)
|
||||
conn.execute(
|
||||
"INSERT INTO page_chunks VALUES ('x1','doc-1',1,'Ranger favored enemy favored terrain terrain bonuses bonuses action attack')"
|
||||
)
|
||||
conn.execute(
|
||||
"INSERT INTO page_chunks VALUES ('x2','doc-1',2,'Wizard can cast spells and perform actions')"
|
||||
)
|
||||
conn.execute(
|
||||
"INSERT INTO page_chunks VALUES ('x3','doc-1',3,'Fighter attacks and deals damage with weapon')"
|
||||
)
|
||||
conn.commit()
|
||||
conn.close()
|
||||
|
||||
idx = BM25Index()
|
||||
idx.mark_dirty()
|
||||
idx.ensure_fresh(db_path)
|
||||
results = idx.query("ranger terrain")
|
||||
assert len(results) >= 1
|
||||
assert results[0].chunk_id == "x1"
|
||||
|
||||
|
||||
def test_bm25_result_is_frozen():
|
||||
r = BM25Result(chunk_id="x", doc_id="d", page_number=1, text="hello", score=0.5)
|
||||
with pytest.raises(Exception):
|
||||
r.score = 1.0 # type: ignore[misc]
|
||||
Loading…
Reference in a new issue