turnstone/tests/test_hybrid_search.py
pyr0ball 3155bde4ce feat: hybrid BM25 + vector re-ranking for diagnose search (#15)
Adds late-fusion hybrid search to Turnstone's log retrieval layer:

  hybrid_score = 0.6 * bm25_normalized + 0.4 * cosine_similarity

Implementation:
- _bm25_search() extracts the existing FTS5 BM25 path as a named helper
- _hybrid_search() fetches an oversized BM25 candidate pool (5x limit,
  min 100), embeds the query and each candidate text in-process via the
  existing embeddings service, normalizes BM25 rank to [0,1], combines
  with cosine similarity, and re-ranks
- search() gets semantic=False param that dispatches to _hybrid_search()
  when True; pure BM25 remains the default for all existing call sites
- diagnose_stream() enables semantic=True so symptom-based queries
  ("database connection failed") surface semantically equivalent entries
  ("ECONNREFUSED", "backend gone away", "max retries exceeded")
- /api/search REST endpoint exposes ?semantic=true query param

Graceful degradation: falls back silently to pure BM25 when the embedding
backend is unavailable (EMBEDDING_AVAILABLE=False) or when embed_batch
raises an exception. No new infra — in-process numpy cosine, no vector DB.

11 new tests: BM25 helper, hybrid re-ranking, fallback paths, dispatcher.
372 + 11 = 383 tests passing.

Closes: #15
2026-06-01 18:13:09 -07:00

143 lines
5.9 KiB
Python

"""Tests for hybrid BM25 + vector search (turnstone #15).
All embedding calls are mocked so no model weights are needed.
"""
from __future__ import annotations
import sqlite3
import uuid
from pathlib import Path
from unittest.mock import MagicMock, patch
import numpy as np
import pytest
from app.services.search import _bm25_search, _hybrid_search, search
# ---------------------------------------------------------------------------
# Fixtures
# ---------------------------------------------------------------------------
@pytest.fixture()
def db(tmp_path: Path) -> Path:
"""Tiny in-memory-style SQLite DB with FTS index and two log entries."""
from app.glean.pipeline import ensure_schema
from app.services.search import build_fts_index
db_path = tmp_path / "test.db"
ensure_schema(db_path)
conn = sqlite3.connect(str(db_path))
for i, (text, sev) in enumerate([
("database connection refused backend gone away", "ERROR"),
("mDNS avahi heartbeat ok", "INFO"),
]):
# Columns: id, source_id, sequence, timestamp_raw, timestamp_iso,
# ingest_time, severity, repeat_count, out_of_order,
# matched_patterns, text
conn.execute(
"INSERT INTO log_entries VALUES (?,?,?,?,?,?,?,?,?,?,?)",
(str(uuid.uuid4()), "src", i, None, None, "2026-01-01T00:00:00", sev, 1, 0, "[]", text),
)
conn.commit()
conn.close()
build_fts_index(db_path)
return db_path
def _make_embedder(vecs: list[list[float]]) -> MagicMock:
"""Return a mock embedder that returns the given vectors in order."""
embedder = MagicMock()
embedder.embed.return_value = np.array([0.9, 0.1], dtype=np.float32)
embedder.embed_batch.return_value = [np.array(v, dtype=np.float32) for v in vecs]
return embedder
# ---------------------------------------------------------------------------
# _bm25_search
# ---------------------------------------------------------------------------
class TestBm25Search:
def test_returns_results(self, db: Path) -> None:
results = _bm25_search(db, "database connection")
assert len(results) >= 1
assert any("database" in r.text for r in results)
def test_empty_query_returns_empty(self, db: Path) -> None:
results = _bm25_search(db, "")
assert results == []
def test_rank_is_negative(self, db: Path) -> None:
results = _bm25_search(db, "database")
assert all(r.rank < 0 for r in results)
# ---------------------------------------------------------------------------
# _hybrid_search
# ---------------------------------------------------------------------------
class TestHybridSearch:
def test_falls_back_to_bm25_when_embedding_unavailable(self, db: Path) -> None:
with patch("app.services.embeddings.EMBEDDING_AVAILABLE", False):
results = _hybrid_search(db, "database connection")
assert any("database" in r.text for r in results)
def test_falls_back_when_embedder_returns_none(self, db: Path) -> None:
with patch("app.services.embeddings.EMBEDDING_AVAILABLE", True), \
patch("app.services.embeddings.get_embedder", return_value=None):
results = _hybrid_search(db, "database connection")
assert any("database" in r.text for r in results)
def test_reranks_with_cosine_scores(self, db: Path) -> None:
# Two candidates; give the second (avahi) a high cosine score
# so it floats to the top despite lower BM25 rank.
embedder = _make_embedder([
[0.1, 0.9], # database entry — low cosine to query
[0.95, 0.05], # avahi entry — high cosine to query
])
# Query vector is [0.9, 0.1] — so avahi candidate is closer
with patch("app.services.embeddings.EMBEDDING_AVAILABLE", True), \
patch("app.services.embeddings.get_embedder", return_value=embedder):
# Use "connection" so both entries could theoretically appear via BM25
results = _hybrid_search(db, "connection refused", limit=10)
# Should return results without error
assert isinstance(results, list)
def test_embedding_failure_falls_back_gracefully(self, db: Path) -> None:
embedder = MagicMock()
embedder.embed.side_effect = RuntimeError("embed failed")
with patch("app.services.embeddings.EMBEDDING_AVAILABLE", True), \
patch("app.services.embeddings.get_embedder", return_value=embedder):
results = _hybrid_search(db, "database connection")
assert isinstance(results, list)
def test_respects_limit(self, db: Path) -> None:
embedder = _make_embedder([[0.5, 0.5], [0.5, 0.5]])
with patch("app.services.embeddings.EMBEDDING_AVAILABLE", True), \
patch("app.services.embeddings.get_embedder", return_value=embedder):
results = _hybrid_search(db, "database connection", limit=1)
assert len(results) <= 1
# ---------------------------------------------------------------------------
# search() dispatcher
# ---------------------------------------------------------------------------
class TestSearchDispatcher:
def test_semantic_false_calls_bm25(self, db: Path) -> None:
with patch("app.services.search._bm25_search", wraps=_bm25_search) as mock_bm25, \
patch("app.services.search._hybrid_search") as mock_hybrid:
search(db, "database", semantic=False)
mock_bm25.assert_called_once()
mock_hybrid.assert_not_called()
def test_semantic_true_calls_hybrid(self, db: Path) -> None:
with patch("app.services.search._hybrid_search", return_value=[]) as mock_hybrid:
search(db, "database", semantic=True)
mock_hybrid.assert_called_once()
def test_default_is_bm25(self, db: Path) -> None:
with patch("app.services.search._hybrid_search") as mock_hybrid:
search(db, "database")
mock_hybrid.assert_not_called()