diff --git a/app/rest.py b/app/rest.py index 44ecd7b..2f7cac2 100644 --- a/app/rest.py +++ b/app/rest.py @@ -341,6 +341,7 @@ def search_logs( since: Annotated[str | None, Query(description="ISO timestamp lower bound")] = None, until: Annotated[str | None, Query(description="ISO timestamp upper bound")] = None, limit: Annotated[int, Query(ge=1, le=500)] = 50, + semantic: Annotated[bool, Query(description="Hybrid BM25+vector re-ranking (requires embedding backend)")] = False, ) -> dict: if not q: return {"count": 0, "results": []} @@ -352,6 +353,7 @@ def search_logs( since=since, until=until, limit=limit, + semantic=semantic, ) return {"count": len(results), "results": [dataclasses.asdict(r) for r in results]} diff --git a/app/services/diagnose/__init__.py b/app/services/diagnose/__init__.py index 78cf7a0..3af0d0f 100644 --- a/app/services/diagnose/__init__.py +++ b/app/services/diagnose/__init__.py @@ -261,6 +261,7 @@ async def diagnose_stream( until=until, limit=150, or_mode=True, + semantic=True, ) ), asyncio.to_thread( diff --git a/app/services/search.py b/app/services/search.py index a39ef6b..9a47832 100644 --- a/app/services/search.py +++ b/app/services/search.py @@ -1,4 +1,4 @@ -"""FTS5-based log search with severity, source, and pattern filters.""" +"""FTS5-based log search with optional hybrid BM25 + vector re-ranking.""" from __future__ import annotations import json @@ -96,8 +96,109 @@ def search( limit: int = 20, include_repeats: bool = False, or_mode: bool = False, + semantic: bool = False, ) -> list[SearchResult]: - """Full-text search with optional filters. Returns results ranked by relevance.""" + """Full-text search with optional filters. Returns results ranked by relevance. + + When ``semantic=True`` and an embedding backend is configured, the BM25 + candidate pool is re-ranked using hybrid scoring (BM25 + cosine similarity). + Falls back silently to pure BM25 when the embedder is unavailable. + """ + if semantic: + return _hybrid_search( + db_path, query, severity=severity, source_filter=source_filter, + pattern_filter=pattern_filter, since=since, until=until, limit=limit, + include_repeats=include_repeats, or_mode=or_mode, + ) + return _bm25_search( + db_path, query, severity=severity, source_filter=source_filter, + pattern_filter=pattern_filter, since=since, until=until, limit=limit, + include_repeats=include_repeats, or_mode=or_mode, + ) + + +def _hybrid_search( + db_path: Path, + query: str, + severity: str | None = None, + source_filter: str | None = None, + pattern_filter: str | None = None, + since: str | None = None, + until: str | None = None, + limit: int = 20, + include_repeats: bool = False, + or_mode: bool = False, + alpha: float = 0.6, + beta: float = 0.4, +) -> list[SearchResult]: + """BM25 + vector re-ranking (late-fusion hybrid search). + + Fetches an oversized BM25 candidate pool, embeds the query and each + candidate text in-process, then combines scores: + + hybrid_score = alpha * bm25_normalized + beta * cosine_sim + + BM25 normalization: FTS5 rank is negative (more negative = better match). + We flip the sign and divide by the pool maximum so all BM25 scores land + in (0, 1] — 1.0 for the top BM25 hit, approaching 0 for the weakest. + + Falls back to pure BM25 when the embedding backend is unavailable. + """ + from app.services.embeddings import EMBEDDING_AVAILABLE, cosine_similarity, get_embedder + + # Fetch a large candidate pool — 5x limit, minimum 100 entries. + pool_limit = max(limit * 5, 100) + candidates = _bm25_search( + db_path, query, severity=severity, source_filter=source_filter, + pattern_filter=pattern_filter, since=since, until=until, + limit=pool_limit, include_repeats=include_repeats, or_mode=or_mode, + ) + + if not candidates: + return [] + + if not EMBEDDING_AVAILABLE: + return candidates[:limit] + + embedder = get_embedder() + if embedder is None: + return candidates[:limit] + + try: + query_vec = embedder.embed(query) + candidate_vecs = embedder.embed_batch([r.text for r in candidates]) + except Exception as exc: + logger.warning("Hybrid search embedding failed (%s) — falling back to BM25", exc) + return candidates[:limit] + + # Normalize BM25 ranks: FTS5 rank is negative, flip and scale to [0, 1]. + abs_ranks = [abs(r.rank) for r in candidates] + max_rank = max(abs_ranks) or 1.0 + + scored: list[tuple[float, SearchResult]] = [] + for result, abs_rank, cand_vec in zip(candidates, abs_ranks, candidate_vecs): + bm25_norm = abs_rank / max_rank + cos_sim = cosine_similarity(query_vec, cand_vec) + hybrid = alpha * bm25_norm + beta * max(cos_sim, 0.0) + scored.append((hybrid, result)) + + scored.sort(key=lambda x: x[0], reverse=True) + return [r for _, r in scored[:limit]] + + +def _bm25_search( + db_path: Path, + query: str, + severity: str | None = None, + source_filter: str | None = None, + pattern_filter: str | None = None, + since: str | None = None, + until: str | None = None, + limit: int = 20, + include_repeats: bool = False, + or_mode: bool = False, +) -> list[SearchResult]: + """Pure BM25 FTS5 search — internal helper used by both search() and _hybrid_search().""" conn = sqlite3.connect(str(db_path), timeout=30.0) conn.execute("PRAGMA journal_mode=WAL") conn.row_factory = sqlite3.Row diff --git a/tests/test_hybrid_search.py b/tests/test_hybrid_search.py new file mode 100644 index 0000000..1e3101e --- /dev/null +++ b/tests/test_hybrid_search.py @@ -0,0 +1,143 @@ +"""Tests for hybrid BM25 + vector search (turnstone #15). + +All embedding calls are mocked so no model weights are needed. +""" +from __future__ import annotations + +import sqlite3 +import uuid +from pathlib import Path +from unittest.mock import MagicMock, patch + +import numpy as np +import pytest + +from app.services.search import _bm25_search, _hybrid_search, search + + +# --------------------------------------------------------------------------- +# Fixtures +# --------------------------------------------------------------------------- + +@pytest.fixture() +def db(tmp_path: Path) -> Path: + """Tiny in-memory-style SQLite DB with FTS index and two log entries.""" + from app.glean.pipeline import ensure_schema + from app.services.search import build_fts_index + + db_path = tmp_path / "test.db" + ensure_schema(db_path) + + conn = sqlite3.connect(str(db_path)) + for i, (text, sev) in enumerate([ + ("database connection refused backend gone away", "ERROR"), + ("mDNS avahi heartbeat ok", "INFO"), + ]): + # Columns: id, source_id, sequence, timestamp_raw, timestamp_iso, + # ingest_time, severity, repeat_count, out_of_order, + # matched_patterns, text + conn.execute( + "INSERT INTO log_entries VALUES (?,?,?,?,?,?,?,?,?,?,?)", + (str(uuid.uuid4()), "src", i, None, None, "2026-01-01T00:00:00", sev, 1, 0, "[]", text), + ) + conn.commit() + conn.close() + build_fts_index(db_path) + return db_path + + +def _make_embedder(vecs: list[list[float]]) -> MagicMock: + """Return a mock embedder that returns the given vectors in order.""" + embedder = MagicMock() + embedder.embed.return_value = np.array([0.9, 0.1], dtype=np.float32) + embedder.embed_batch.return_value = [np.array(v, dtype=np.float32) for v in vecs] + return embedder + + +# --------------------------------------------------------------------------- +# _bm25_search +# --------------------------------------------------------------------------- + +class TestBm25Search: + def test_returns_results(self, db: Path) -> None: + results = _bm25_search(db, "database connection") + assert len(results) >= 1 + assert any("database" in r.text for r in results) + + def test_empty_query_returns_empty(self, db: Path) -> None: + results = _bm25_search(db, "") + assert results == [] + + def test_rank_is_negative(self, db: Path) -> None: + results = _bm25_search(db, "database") + assert all(r.rank < 0 for r in results) + + +# --------------------------------------------------------------------------- +# _hybrid_search +# --------------------------------------------------------------------------- + +class TestHybridSearch: + def test_falls_back_to_bm25_when_embedding_unavailable(self, db: Path) -> None: + with patch("app.services.embeddings.EMBEDDING_AVAILABLE", False): + results = _hybrid_search(db, "database connection") + assert any("database" in r.text for r in results) + + def test_falls_back_when_embedder_returns_none(self, db: Path) -> None: + with patch("app.services.embeddings.EMBEDDING_AVAILABLE", True), \ + patch("app.services.embeddings.get_embedder", return_value=None): + results = _hybrid_search(db, "database connection") + assert any("database" in r.text for r in results) + + def test_reranks_with_cosine_scores(self, db: Path) -> None: + # Two candidates; give the second (avahi) a high cosine score + # so it floats to the top despite lower BM25 rank. + embedder = _make_embedder([ + [0.1, 0.9], # database entry — low cosine to query + [0.95, 0.05], # avahi entry — high cosine to query + ]) + # Query vector is [0.9, 0.1] — so avahi candidate is closer + with patch("app.services.embeddings.EMBEDDING_AVAILABLE", True), \ + patch("app.services.embeddings.get_embedder", return_value=embedder): + # Use "connection" so both entries could theoretically appear via BM25 + results = _hybrid_search(db, "connection refused", limit=10) + # Should return results without error + assert isinstance(results, list) + + def test_embedding_failure_falls_back_gracefully(self, db: Path) -> None: + embedder = MagicMock() + embedder.embed.side_effect = RuntimeError("embed failed") + with patch("app.services.embeddings.EMBEDDING_AVAILABLE", True), \ + patch("app.services.embeddings.get_embedder", return_value=embedder): + results = _hybrid_search(db, "database connection") + assert isinstance(results, list) + + def test_respects_limit(self, db: Path) -> None: + embedder = _make_embedder([[0.5, 0.5], [0.5, 0.5]]) + with patch("app.services.embeddings.EMBEDDING_AVAILABLE", True), \ + patch("app.services.embeddings.get_embedder", return_value=embedder): + results = _hybrid_search(db, "database connection", limit=1) + assert len(results) <= 1 + + +# --------------------------------------------------------------------------- +# search() dispatcher +# --------------------------------------------------------------------------- + +class TestSearchDispatcher: + def test_semantic_false_calls_bm25(self, db: Path) -> None: + with patch("app.services.search._bm25_search", wraps=_bm25_search) as mock_bm25, \ + patch("app.services.search._hybrid_search") as mock_hybrid: + search(db, "database", semantic=False) + mock_bm25.assert_called_once() + mock_hybrid.assert_not_called() + + def test_semantic_true_calls_hybrid(self, db: Path) -> None: + with patch("app.services.search._hybrid_search", return_value=[]) as mock_hybrid: + search(db, "database", semantic=True) + mock_hybrid.assert_called_once() + + def test_default_is_bm25(self, db: Path) -> None: + with patch("app.services.search._hybrid_search") as mock_hybrid: + search(db, "database") + mock_hybrid.assert_not_called()