Merge feat/15-hybrid-rag: hybrid BM25 + vector re-ranking for diagnose search (#15)
This commit is contained in:
commit
eac9a4ba28
4 changed files with 249 additions and 2 deletions
|
|
@ -341,6 +341,7 @@ def search_logs(
|
|||
since: Annotated[str | None, Query(description="ISO timestamp lower bound")] = None,
|
||||
until: Annotated[str | None, Query(description="ISO timestamp upper bound")] = None,
|
||||
limit: Annotated[int, Query(ge=1, le=500)] = 50,
|
||||
semantic: Annotated[bool, Query(description="Hybrid BM25+vector re-ranking (requires embedding backend)")] = False,
|
||||
) -> dict:
|
||||
if not q:
|
||||
return {"count": 0, "results": []}
|
||||
|
|
@ -352,6 +353,7 @@ def search_logs(
|
|||
since=since,
|
||||
until=until,
|
||||
limit=limit,
|
||||
semantic=semantic,
|
||||
)
|
||||
return {"count": len(results), "results": [dataclasses.asdict(r) for r in results]}
|
||||
|
||||
|
|
|
|||
|
|
@ -261,6 +261,7 @@ async def diagnose_stream(
|
|||
until=until,
|
||||
limit=150,
|
||||
or_mode=True,
|
||||
semantic=True,
|
||||
)
|
||||
),
|
||||
asyncio.to_thread(
|
||||
|
|
|
|||
|
|
@ -1,4 +1,4 @@
|
|||
"""FTS5-based log search with severity, source, and pattern filters."""
|
||||
"""FTS5-based log search with optional hybrid BM25 + vector re-ranking."""
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
|
|
@ -96,8 +96,109 @@ def search(
|
|||
limit: int = 20,
|
||||
include_repeats: bool = False,
|
||||
or_mode: bool = False,
|
||||
semantic: bool = False,
|
||||
) -> list[SearchResult]:
|
||||
"""Full-text search with optional filters. Returns results ranked by relevance."""
|
||||
"""Full-text search with optional filters. Returns results ranked by relevance.
|
||||
|
||||
When ``semantic=True`` and an embedding backend is configured, the BM25
|
||||
candidate pool is re-ranked using hybrid scoring (BM25 + cosine similarity).
|
||||
Falls back silently to pure BM25 when the embedder is unavailable.
|
||||
"""
|
||||
if semantic:
|
||||
return _hybrid_search(
|
||||
db_path, query, severity=severity, source_filter=source_filter,
|
||||
pattern_filter=pattern_filter, since=since, until=until, limit=limit,
|
||||
include_repeats=include_repeats, or_mode=or_mode,
|
||||
)
|
||||
return _bm25_search(
|
||||
db_path, query, severity=severity, source_filter=source_filter,
|
||||
pattern_filter=pattern_filter, since=since, until=until, limit=limit,
|
||||
include_repeats=include_repeats, or_mode=or_mode,
|
||||
)
|
||||
|
||||
|
||||
def _hybrid_search(
|
||||
db_path: Path,
|
||||
query: str,
|
||||
severity: str | None = None,
|
||||
source_filter: str | None = None,
|
||||
pattern_filter: str | None = None,
|
||||
since: str | None = None,
|
||||
until: str | None = None,
|
||||
limit: int = 20,
|
||||
include_repeats: bool = False,
|
||||
or_mode: bool = False,
|
||||
alpha: float = 0.6,
|
||||
beta: float = 0.4,
|
||||
) -> list[SearchResult]:
|
||||
"""BM25 + vector re-ranking (late-fusion hybrid search).
|
||||
|
||||
Fetches an oversized BM25 candidate pool, embeds the query and each
|
||||
candidate text in-process, then combines scores:
|
||||
|
||||
hybrid_score = alpha * bm25_normalized + beta * cosine_sim
|
||||
|
||||
BM25 normalization: FTS5 rank is negative (more negative = better match).
|
||||
We flip the sign and divide by the pool maximum so all BM25 scores land
|
||||
in (0, 1] — 1.0 for the top BM25 hit, approaching 0 for the weakest.
|
||||
|
||||
Falls back to pure BM25 when the embedding backend is unavailable.
|
||||
"""
|
||||
from app.services.embeddings import EMBEDDING_AVAILABLE, cosine_similarity, get_embedder
|
||||
|
||||
# Fetch a large candidate pool — 5x limit, minimum 100 entries.
|
||||
pool_limit = max(limit * 5, 100)
|
||||
candidates = _bm25_search(
|
||||
db_path, query, severity=severity, source_filter=source_filter,
|
||||
pattern_filter=pattern_filter, since=since, until=until,
|
||||
limit=pool_limit, include_repeats=include_repeats, or_mode=or_mode,
|
||||
)
|
||||
|
||||
if not candidates:
|
||||
return []
|
||||
|
||||
if not EMBEDDING_AVAILABLE:
|
||||
return candidates[:limit]
|
||||
|
||||
embedder = get_embedder()
|
||||
if embedder is None:
|
||||
return candidates[:limit]
|
||||
|
||||
try:
|
||||
query_vec = embedder.embed(query)
|
||||
candidate_vecs = embedder.embed_batch([r.text for r in candidates])
|
||||
except Exception as exc:
|
||||
logger.warning("Hybrid search embedding failed (%s) — falling back to BM25", exc)
|
||||
return candidates[:limit]
|
||||
|
||||
# Normalize BM25 ranks: FTS5 rank is negative, flip and scale to [0, 1].
|
||||
abs_ranks = [abs(r.rank) for r in candidates]
|
||||
max_rank = max(abs_ranks) or 1.0
|
||||
|
||||
scored: list[tuple[float, SearchResult]] = []
|
||||
for result, abs_rank, cand_vec in zip(candidates, abs_ranks, candidate_vecs):
|
||||
bm25_norm = abs_rank / max_rank
|
||||
cos_sim = cosine_similarity(query_vec, cand_vec)
|
||||
hybrid = alpha * bm25_norm + beta * max(cos_sim, 0.0)
|
||||
scored.append((hybrid, result))
|
||||
|
||||
scored.sort(key=lambda x: x[0], reverse=True)
|
||||
return [r for _, r in scored[:limit]]
|
||||
|
||||
|
||||
def _bm25_search(
|
||||
db_path: Path,
|
||||
query: str,
|
||||
severity: str | None = None,
|
||||
source_filter: str | None = None,
|
||||
pattern_filter: str | None = None,
|
||||
since: str | None = None,
|
||||
until: str | None = None,
|
||||
limit: int = 20,
|
||||
include_repeats: bool = False,
|
||||
or_mode: bool = False,
|
||||
) -> list[SearchResult]:
|
||||
"""Pure BM25 FTS5 search — internal helper used by both search() and _hybrid_search()."""
|
||||
conn = sqlite3.connect(str(db_path), timeout=30.0)
|
||||
conn.execute("PRAGMA journal_mode=WAL")
|
||||
conn.row_factory = sqlite3.Row
|
||||
|
|
|
|||
143
tests/test_hybrid_search.py
Normal file
143
tests/test_hybrid_search.py
Normal file
|
|
@ -0,0 +1,143 @@
|
|||
"""Tests for hybrid BM25 + vector search (turnstone #15).
|
||||
|
||||
All embedding calls are mocked so no model weights are needed.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import sqlite3
|
||||
import uuid
|
||||
from pathlib import Path
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
from app.services.search import _bm25_search, _hybrid_search, search
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Fixtures
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@pytest.fixture()
|
||||
def db(tmp_path: Path) -> Path:
|
||||
"""Tiny in-memory-style SQLite DB with FTS index and two log entries."""
|
||||
from app.glean.pipeline import ensure_schema
|
||||
from app.services.search import build_fts_index
|
||||
|
||||
db_path = tmp_path / "test.db"
|
||||
ensure_schema(db_path)
|
||||
|
||||
conn = sqlite3.connect(str(db_path))
|
||||
for i, (text, sev) in enumerate([
|
||||
("database connection refused backend gone away", "ERROR"),
|
||||
("mDNS avahi heartbeat ok", "INFO"),
|
||||
]):
|
||||
# Columns: id, source_id, sequence, timestamp_raw, timestamp_iso,
|
||||
# ingest_time, severity, repeat_count, out_of_order,
|
||||
# matched_patterns, text
|
||||
conn.execute(
|
||||
"INSERT INTO log_entries VALUES (?,?,?,?,?,?,?,?,?,?,?)",
|
||||
(str(uuid.uuid4()), "src", i, None, None, "2026-01-01T00:00:00", sev, 1, 0, "[]", text),
|
||||
)
|
||||
conn.commit()
|
||||
conn.close()
|
||||
build_fts_index(db_path)
|
||||
return db_path
|
||||
|
||||
|
||||
def _make_embedder(vecs: list[list[float]]) -> MagicMock:
|
||||
"""Return a mock embedder that returns the given vectors in order."""
|
||||
embedder = MagicMock()
|
||||
embedder.embed.return_value = np.array([0.9, 0.1], dtype=np.float32)
|
||||
embedder.embed_batch.return_value = [np.array(v, dtype=np.float32) for v in vecs]
|
||||
return embedder
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _bm25_search
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestBm25Search:
|
||||
def test_returns_results(self, db: Path) -> None:
|
||||
results = _bm25_search(db, "database connection")
|
||||
assert len(results) >= 1
|
||||
assert any("database" in r.text for r in results)
|
||||
|
||||
def test_empty_query_returns_empty(self, db: Path) -> None:
|
||||
results = _bm25_search(db, "")
|
||||
assert results == []
|
||||
|
||||
def test_rank_is_negative(self, db: Path) -> None:
|
||||
results = _bm25_search(db, "database")
|
||||
assert all(r.rank < 0 for r in results)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _hybrid_search
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestHybridSearch:
|
||||
def test_falls_back_to_bm25_when_embedding_unavailable(self, db: Path) -> None:
|
||||
with patch("app.services.embeddings.EMBEDDING_AVAILABLE", False):
|
||||
results = _hybrid_search(db, "database connection")
|
||||
assert any("database" in r.text for r in results)
|
||||
|
||||
def test_falls_back_when_embedder_returns_none(self, db: Path) -> None:
|
||||
with patch("app.services.embeddings.EMBEDDING_AVAILABLE", True), \
|
||||
patch("app.services.embeddings.get_embedder", return_value=None):
|
||||
results = _hybrid_search(db, "database connection")
|
||||
assert any("database" in r.text for r in results)
|
||||
|
||||
def test_reranks_with_cosine_scores(self, db: Path) -> None:
|
||||
# Two candidates; give the second (avahi) a high cosine score
|
||||
# so it floats to the top despite lower BM25 rank.
|
||||
embedder = _make_embedder([
|
||||
[0.1, 0.9], # database entry — low cosine to query
|
||||
[0.95, 0.05], # avahi entry — high cosine to query
|
||||
])
|
||||
# Query vector is [0.9, 0.1] — so avahi candidate is closer
|
||||
with patch("app.services.embeddings.EMBEDDING_AVAILABLE", True), \
|
||||
patch("app.services.embeddings.get_embedder", return_value=embedder):
|
||||
# Use "connection" so both entries could theoretically appear via BM25
|
||||
results = _hybrid_search(db, "connection refused", limit=10)
|
||||
# Should return results without error
|
||||
assert isinstance(results, list)
|
||||
|
||||
def test_embedding_failure_falls_back_gracefully(self, db: Path) -> None:
|
||||
embedder = MagicMock()
|
||||
embedder.embed.side_effect = RuntimeError("embed failed")
|
||||
with patch("app.services.embeddings.EMBEDDING_AVAILABLE", True), \
|
||||
patch("app.services.embeddings.get_embedder", return_value=embedder):
|
||||
results = _hybrid_search(db, "database connection")
|
||||
assert isinstance(results, list)
|
||||
|
||||
def test_respects_limit(self, db: Path) -> None:
|
||||
embedder = _make_embedder([[0.5, 0.5], [0.5, 0.5]])
|
||||
with patch("app.services.embeddings.EMBEDDING_AVAILABLE", True), \
|
||||
patch("app.services.embeddings.get_embedder", return_value=embedder):
|
||||
results = _hybrid_search(db, "database connection", limit=1)
|
||||
assert len(results) <= 1
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# search() dispatcher
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestSearchDispatcher:
|
||||
def test_semantic_false_calls_bm25(self, db: Path) -> None:
|
||||
with patch("app.services.search._bm25_search", wraps=_bm25_search) as mock_bm25, \
|
||||
patch("app.services.search._hybrid_search") as mock_hybrid:
|
||||
search(db, "database", semantic=False)
|
||||
mock_bm25.assert_called_once()
|
||||
mock_hybrid.assert_not_called()
|
||||
|
||||
def test_semantic_true_calls_hybrid(self, db: Path) -> None:
|
||||
with patch("app.services.search._hybrid_search", return_value=[]) as mock_hybrid:
|
||||
search(db, "database", semantic=True)
|
||||
mock_hybrid.assert_called_once()
|
||||
|
||||
def test_default_is_bm25(self, db: Path) -> None:
|
||||
with patch("app.services.search._hybrid_search") as mock_hybrid:
|
||||
search(db, "database")
|
||||
mock_hybrid.assert_not_called()
|
||||
Loading…
Reference in a new issue