feat(api): add BM25 search endpoint (MIT, no tier gate)
This commit is contained in:
parent
c6fa9baf2c
commit
6869f32392
2 changed files with 130 additions and 2 deletions
|
|
@ -1,5 +1,64 @@
|
||||||
# app/api/search.py
|
# app/api/search.py
|
||||||
"""Search API — BM25 keyword and RAG vector search (Task 5)."""
|
"""
|
||||||
from fastapi import APIRouter
|
BM25 keyword search across the document library.
|
||||||
|
|
||||||
|
MIT — no tier gate. No Ollama required.
|
||||||
|
"""
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
from typing import Annotated
|
||||||
|
|
||||||
|
from fastapi import APIRouter, Depends
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
|
from app.services.bm25_index import BM25Index
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
router = APIRouter(prefix="/api/search", tags=["search"])
|
router = APIRouter(prefix="/api/search", tags=["search"])
|
||||||
|
|
||||||
|
|
||||||
|
class SearchRequest(BaseModel):
|
||||||
|
query: str
|
||||||
|
top_k: int = Field(default=10, ge=1, le=50)
|
||||||
|
doc_ids: list[str] | None = None
|
||||||
|
|
||||||
|
|
||||||
|
class SearchResult(BaseModel):
|
||||||
|
chunk_id: str
|
||||||
|
doc_id: str
|
||||||
|
page_number: int
|
||||||
|
text_snippet: str # first 300 chars of the page text
|
||||||
|
bm25_score: float
|
||||||
|
|
||||||
|
|
||||||
|
def _get_bm25() -> BM25Index:
|
||||||
|
from app.main import _bm25
|
||||||
|
return _bm25
|
||||||
|
|
||||||
|
|
||||||
|
def _get_db_path() -> str:
|
||||||
|
"""Read lazily so test fixtures (monkeypatch.setattr) take effect."""
|
||||||
|
import pathlib
|
||||||
|
data_dir = pathlib.Path(os.environ.get("PAGEPIPER_DATA_DIR", "data"))
|
||||||
|
return str(data_dir / "pagepiper.db")
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("")
|
||||||
|
def search(
|
||||||
|
req: SearchRequest,
|
||||||
|
bm25: Annotated[BM25Index, Depends(_get_bm25)],
|
||||||
|
) -> list[SearchResult]:
|
||||||
|
bm25.ensure_fresh(_get_db_path())
|
||||||
|
hits = bm25.query(req.query, top_k=req.top_k, doc_ids=req.doc_ids)
|
||||||
|
return [
|
||||||
|
SearchResult(
|
||||||
|
chunk_id=h.chunk_id,
|
||||||
|
doc_id=h.doc_id,
|
||||||
|
page_number=h.page_number,
|
||||||
|
text_snippet=h.text[:300],
|
||||||
|
bm25_score=h.score,
|
||||||
|
)
|
||||||
|
for h in hits
|
||||||
|
]
|
||||||
|
|
|
||||||
69
tests/test_search_api.py
Normal file
69
tests/test_search_api.py
Normal file
|
|
@ -0,0 +1,69 @@
|
||||||
|
# tests/test_search_api.py
|
||||||
|
"""Tests for POST /api/search — BM25 keyword search (MIT, no tier gate)."""
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import sqlite3
|
||||||
|
|
||||||
|
|
||||||
|
def _add_chunks(db_path: str, doc_id: str, chunks: list[dict]) -> None:
|
||||||
|
conn = sqlite3.connect(db_path)
|
||||||
|
conn.execute(
|
||||||
|
"INSERT OR IGNORE INTO documents(id, title, file_path, status) VALUES (?,'Book','p.pdf','ready')",
|
||||||
|
[doc_id],
|
||||||
|
)
|
||||||
|
for c in chunks:
|
||||||
|
conn.execute(
|
||||||
|
"INSERT INTO page_chunks(doc_id, page_number, text, source, word_count) VALUES (?,?,?,?,?)",
|
||||||
|
[doc_id, c["page_number"], c["text"], "text_layer", len(c["text"].split())],
|
||||||
|
)
|
||||||
|
conn.commit()
|
||||||
|
conn.close()
|
||||||
|
|
||||||
|
|
||||||
|
def test_search_returns_results(client, test_db, monkeypatch):
|
||||||
|
import app.api.search as _search_mod
|
||||||
|
monkeypatch.setattr(_search_mod, "_get_db_path", lambda: test_db)
|
||||||
|
# BM25Okapi IDF is 0 when df == N/2 (e.g. 2 docs, 1 match → log(1.0) = 0).
|
||||||
|
# Add a 3rd unrelated chunk so relevant terms score above zero.
|
||||||
|
_add_chunks(test_db, "book-a", [
|
||||||
|
{"page_number": 1, "text": "Fireball deals 8d6 fire damage on a failed saving throw."},
|
||||||
|
{"page_number": 2, "text": "Cure Wounds restores hit points to a living creature."},
|
||||||
|
{"page_number": 3, "text": "Shield grants plus five to armor class until next turn."},
|
||||||
|
])
|
||||||
|
|
||||||
|
resp = client.post("/api/search", json={"query": "fireball fire damage"})
|
||||||
|
assert resp.status_code == 200
|
||||||
|
results = resp.json()
|
||||||
|
assert len(results) >= 1
|
||||||
|
assert results[0]["page_number"] == 1
|
||||||
|
assert results[0]["bm25_score"] > 0
|
||||||
|
assert "text_snippet" in results[0]
|
||||||
|
|
||||||
|
|
||||||
|
def test_search_empty_index_returns_empty(client):
|
||||||
|
resp = client.post("/api/search", json={"query": "anything"})
|
||||||
|
assert resp.status_code == 200
|
||||||
|
assert resp.json() == []
|
||||||
|
|
||||||
|
|
||||||
|
def test_search_filters_by_doc_ids(client, test_db, monkeypatch):
|
||||||
|
import app.api.search as _search_mod
|
||||||
|
monkeypatch.setattr(_search_mod, "_get_db_path", lambda: test_db)
|
||||||
|
# Three chunks so BM25Okapi IDF is non-zero for terms appearing in one doc.
|
||||||
|
_add_chunks(test_db, "book-a", [
|
||||||
|
{"page_number": 1, "text": "Grapple rules for melee attacks."},
|
||||||
|
{"page_number": 2, "text": "Shield spell protects from incoming blows."},
|
||||||
|
])
|
||||||
|
_add_chunks(test_db, "book-b", [{"page_number": 3, "text": "Grapple also applies to ranged attacks."}])
|
||||||
|
|
||||||
|
resp = client.post("/api/search", json={"query": "grapple", "doc_ids": ["book-a"]})
|
||||||
|
assert resp.status_code == 200
|
||||||
|
results = resp.json()
|
||||||
|
assert len(results) >= 1, "expected at least one grapple result from book-a"
|
||||||
|
assert all(r["doc_id"] == "book-a" for r in results)
|
||||||
|
|
||||||
|
|
||||||
|
def test_search_has_no_tier_gate(client):
|
||||||
|
# Search endpoint must return 200 with no PAGEPIPER_OLLAMA_URL set
|
||||||
|
resp = client.post("/api/search", json={"query": "anything"})
|
||||||
|
assert resp.status_code == 200 # Not 402
|
||||||
Loading…
Reference in a new issue