95 lines
2.6 KiB
Python
95 lines
2.6 KiB
Python
"""Tests for app.services.bm25_index."""
|
|
|
|
from __future__ import annotations
|
|
|
|
import pytest
|
|
|
|
from app.services.bm25_index import BM25Index, BM25Result
|
|
|
|
|
|
def _seeded_index() -> BM25Index:
|
|
idx = BM25Index()
|
|
idx._load_chunks(
|
|
[
|
|
{
|
|
"id": "c1",
|
|
"doc_id": "book-a",
|
|
"page_number": 1,
|
|
"text": "Fireball deals 8d6 fire damage on a failed Dexterity saving throw.",
|
|
},
|
|
{
|
|
"id": "c2",
|
|
"doc_id": "book-a",
|
|
"page_number": 2,
|
|
"text": "A wizard can cast one spell per turn unless they have Action Surge.",
|
|
},
|
|
{
|
|
"id": "c3",
|
|
"doc_id": "book-b",
|
|
"page_number": 5,
|
|
"text": "Grapple rules apply when the attacker uses the Attack action to grab a target.",
|
|
},
|
|
]
|
|
)
|
|
return idx
|
|
|
|
|
|
def test_query_returns_relevant_result():
|
|
idx = _seeded_index()
|
|
results = idx.query("fireball fire damage")
|
|
assert len(results) >= 1
|
|
assert results[0].chunk_id == "c1"
|
|
assert results[0].score > 0
|
|
|
|
|
|
def test_query_respects_top_k():
|
|
idx = _seeded_index()
|
|
results = idx.query("rules", top_k=2)
|
|
assert len(results) <= 2
|
|
|
|
|
|
def test_query_filters_by_doc_id():
|
|
idx = _seeded_index()
|
|
results = idx.query("rules", doc_ids=["book-b"])
|
|
assert all(r.doc_id == "book-b" for r in results)
|
|
|
|
|
|
def test_query_empty_corpus_returns_empty():
|
|
idx = BM25Index()
|
|
idx._load_chunks([])
|
|
results = idx.query("anything")
|
|
assert results == []
|
|
|
|
|
|
def test_mark_dirty_triggers_rebuild(tmp_path):
|
|
import sqlite3
|
|
|
|
db_path = str(tmp_path / "test.db")
|
|
conn = sqlite3.connect(db_path)
|
|
conn.execute(
|
|
"CREATE TABLE page_chunks(id TEXT, doc_id TEXT, page_number INT, text TEXT)"
|
|
)
|
|
conn.execute(
|
|
"INSERT INTO page_chunks VALUES ('x1','doc-1',1,'Ranger favored enemy favored terrain terrain bonuses bonuses action attack')"
|
|
)
|
|
conn.execute(
|
|
"INSERT INTO page_chunks VALUES ('x2','doc-1',2,'Wizard can cast spells and perform actions')"
|
|
)
|
|
conn.execute(
|
|
"INSERT INTO page_chunks VALUES ('x3','doc-1',3,'Fighter attacks and deals damage with weapon')"
|
|
)
|
|
conn.commit()
|
|
conn.close()
|
|
|
|
idx = BM25Index()
|
|
idx.mark_dirty()
|
|
idx.ensure_fresh(db_path)
|
|
results = idx.query("ranger terrain")
|
|
assert len(results) >= 1
|
|
assert results[0].chunk_id == "x1"
|
|
|
|
|
|
def test_bm25_result_is_frozen():
|
|
r = BM25Result(chunk_id="x", doc_id="d", page_number=1, text="hello", score=0.5)
|
|
with pytest.raises(Exception):
|
|
r.score = 1.0 # type: ignore[misc]
|