diff --git a/app/services/bm25_index.py b/app/services/bm25_index.py index c3f03c6..5aee25e 100644 --- a/app/services/bm25_index.py +++ b/app/services/bm25_index.py @@ -47,12 +47,18 @@ class BM25Index: """Rebuild from SQLite if dirty.""" if not self._dirty: return - conn = sqlite3.connect(db_path) - conn.row_factory = sqlite3.Row - rows = conn.execute( - "SELECT id, doc_id, page_number, text FROM page_chunks ORDER BY doc_id, page_number" - ).fetchall() - conn.close() + try: + conn = sqlite3.connect(db_path) + conn.row_factory = sqlite3.Row + try: + rows = conn.execute( + "SELECT id, doc_id, page_number, text FROM page_chunks ORDER BY doc_id, page_number" + ).fetchall() + finally: + conn.close() + except sqlite3.Error as exc: + logger.error("BM25 index rebuild failed: %s", exc) + return self._load_chunks([dict(r) for r in rows]) self._dirty = False logger.info("BM25 index rebuilt: %d chunks", len(self._chunks)) diff --git a/tests/test_bm25_index.py b/tests/test_bm25_index.py index 221cea2..43eea5b 100644 --- a/tests/test_bm25_index.py +++ b/tests/test_bm25_index.py @@ -43,9 +43,10 @@ def test_query_returns_relevant_result(): def test_query_respects_top_k(): + # "action" matches all three chunks; top_k=2 must hard-cap the result list idx = _seeded_index() - results = idx.query("rules", top_k=2) - assert len(results) <= 2 + results = idx.query("action", top_k=2) + assert len(results) == 2 def test_query_filters_by_doc_id():