"""Tests for app.services.bm25_index.""" from __future__ import annotations import pytest from app.services.bm25_index import BM25Index, BM25Result def _seeded_index() -> BM25Index: idx = BM25Index() idx._load_chunks( [ { "id": "c1", "doc_id": "book-a", "page_number": 1, "text": "Fireball deals 8d6 fire damage on a failed Dexterity saving throw.", }, { "id": "c2", "doc_id": "book-a", "page_number": 2, "text": "A wizard can cast one spell per turn unless they have Action Surge.", }, { "id": "c3", "doc_id": "book-b", "page_number": 5, "text": "Grapple rules apply when the attacker uses the Attack action to grab a target.", }, ] ) return idx def test_query_returns_relevant_result(): idx = _seeded_index() results = idx.query("fireball fire damage") assert len(results) >= 1 assert results[0].chunk_id == "c1" assert results[0].score > 0 def test_query_respects_top_k(): # "action" matches all three chunks; top_k=2 must hard-cap the result list idx = _seeded_index() results = idx.query("action", top_k=2) assert len(results) == 2 def test_query_filters_by_doc_id(): idx = _seeded_index() results = idx.query("rules", doc_ids=["book-b"]) assert all(r.doc_id == "book-b" for r in results) def test_query_empty_corpus_returns_empty(): idx = BM25Index() idx._load_chunks([]) results = idx.query("anything") assert results == [] def test_mark_dirty_triggers_rebuild(tmp_path): import sqlite3 db_path = str(tmp_path / "test.db") conn = sqlite3.connect(db_path) conn.execute( "CREATE TABLE page_chunks(id TEXT, doc_id TEXT, page_number INT, text TEXT)" ) conn.execute( "INSERT INTO page_chunks VALUES ('x1','doc-1',1,'Ranger favored enemy favored terrain terrain bonuses bonuses action attack')" ) conn.execute( "INSERT INTO page_chunks VALUES ('x2','doc-1',2,'Wizard can cast spells and perform actions')" ) conn.execute( "INSERT INTO page_chunks VALUES ('x3','doc-1',3,'Fighter attacks and deals damage with weapon')" ) conn.commit() conn.close() idx = BM25Index() idx.mark_dirty() idx.ensure_fresh(db_path) results = idx.query("ranger terrain") assert len(results) >= 1 assert results[0].chunk_id == "x1" def test_bm25_result_is_frozen(): r = BM25Result(chunk_id="x", doc_id="d", page_number=1, text="hello", score=0.5) with pytest.raises(Exception): r.score = 1.0 # type: ignore[misc]