- Add app/db/ abstraction layer: Backend enum, DbConn wrapper, dialect helper (q() for ? vs %s paramstyle), get_conn(), tenant_id() - Auto-detect backend from DATABASE_URL; SQLite remains default when unset — no config change for local deployments - Add tenant_id column to all three logical DBs (main, context, incidents); idempotent ALTER TABLE migration runs before schema scripts on existing DBs - All INSERTs inject tenant_id; SELECTs use (tenant_id = ? OR tenant_id = '') for backward compat with pre-namespacing rows - Add docker-compose.yml with named volume turnstone_pgdata (survives rebuilds) and optional external Postgres support via DATABASE_URL override - Add scripts/migrate_sqlite_to_postgres.py — one-shot idempotent migration for existing SQLite data; ON CONFLICT DO NOTHING for safe re-runs - Fix SSH glean path in pipeline.py to use ensure_schema + get_conn (was still using raw sqlite3.connect + old _SCHEMA without tenant_id) - Fix FTS5 JOIN ambiguity: qualify repeat_count as f.repeat_count in search - Update all tests to use ensure_*_schema fixtures; add row_factory where needed - 394/394 tests passing Closes: #42 Closes: #50
142 lines
5.9 KiB
Python
142 lines
5.9 KiB
Python
"""Tests for hybrid BM25 + vector search (turnstone #15).
|
|
|
|
All embedding calls are mocked so no model weights are needed.
|
|
"""
|
|
from __future__ import annotations
|
|
|
|
import sqlite3
|
|
import uuid
|
|
from pathlib import Path
|
|
from unittest.mock import MagicMock, patch
|
|
|
|
import numpy as np
|
|
import pytest
|
|
|
|
from app.services.search import _bm25_search, _hybrid_search, search
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Fixtures
|
|
# ---------------------------------------------------------------------------
|
|
|
|
@pytest.fixture()
|
|
def db(tmp_path: Path) -> Path:
|
|
"""Tiny in-memory-style SQLite DB with FTS index and two log entries."""
|
|
from app.glean.pipeline import ensure_schema
|
|
from app.services.search import build_fts_index
|
|
|
|
db_path = tmp_path / "test.db"
|
|
ensure_schema(db_path)
|
|
|
|
conn = sqlite3.connect(str(db_path))
|
|
for i, (text, sev) in enumerate([
|
|
("database connection refused backend gone away", "ERROR"),
|
|
("mDNS avahi heartbeat ok", "INFO"),
|
|
]):
|
|
conn.execute(
|
|
"INSERT INTO log_entries(id, tenant_id, source_id, sequence, timestamp_raw, "
|
|
"timestamp_iso, ingest_time, severity, repeat_count, out_of_order, "
|
|
"matched_patterns, text) VALUES (?,?,?,?,?,?,?,?,?,?,?,?)",
|
|
(str(uuid.uuid4()), "", "src", i, None, None, "2026-01-01T00:00:00", sev, 1, 0, "[]", text),
|
|
)
|
|
conn.commit()
|
|
conn.close()
|
|
build_fts_index(db_path)
|
|
return db_path
|
|
|
|
|
|
def _make_embedder(vecs: list[list[float]]) -> MagicMock:
|
|
"""Return a mock embedder that returns the given vectors in order."""
|
|
embedder = MagicMock()
|
|
embedder.embed.return_value = np.array([0.9, 0.1], dtype=np.float32)
|
|
embedder.embed_batch.return_value = [np.array(v, dtype=np.float32) for v in vecs]
|
|
return embedder
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# _bm25_search
|
|
# ---------------------------------------------------------------------------
|
|
|
|
class TestBm25Search:
|
|
def test_returns_results(self, db: Path) -> None:
|
|
results = _bm25_search(db, "database connection")
|
|
assert len(results) >= 1
|
|
assert any("database" in r.text for r in results)
|
|
|
|
def test_empty_query_returns_empty(self, db: Path) -> None:
|
|
results = _bm25_search(db, "")
|
|
assert results == []
|
|
|
|
def test_rank_is_negative(self, db: Path) -> None:
|
|
results = _bm25_search(db, "database")
|
|
assert all(r.rank < 0 for r in results)
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# _hybrid_search
|
|
# ---------------------------------------------------------------------------
|
|
|
|
class TestHybridSearch:
|
|
def test_falls_back_to_bm25_when_embedding_unavailable(self, db: Path) -> None:
|
|
with patch("app.services.embeddings.EMBEDDING_AVAILABLE", False):
|
|
results = _hybrid_search(db, "database connection")
|
|
assert any("database" in r.text for r in results)
|
|
|
|
def test_falls_back_when_embedder_returns_none(self, db: Path) -> None:
|
|
with patch("app.services.embeddings.EMBEDDING_AVAILABLE", True), \
|
|
patch("app.services.embeddings.get_embedder", return_value=None):
|
|
results = _hybrid_search(db, "database connection")
|
|
assert any("database" in r.text for r in results)
|
|
|
|
def test_reranks_with_cosine_scores(self, db: Path) -> None:
|
|
# Two candidates; give the second (avahi) a high cosine score
|
|
# so it floats to the top despite lower BM25 rank.
|
|
embedder = _make_embedder([
|
|
[0.1, 0.9], # database entry — low cosine to query
|
|
[0.95, 0.05], # avahi entry — high cosine to query
|
|
])
|
|
# Query vector is [0.9, 0.1] — so avahi candidate is closer
|
|
with patch("app.services.embeddings.EMBEDDING_AVAILABLE", True), \
|
|
patch("app.services.embeddings.get_embedder", return_value=embedder):
|
|
# Use "connection" so both entries could theoretically appear via BM25
|
|
results = _hybrid_search(db, "connection refused", limit=10)
|
|
# Should return results without error
|
|
assert isinstance(results, list)
|
|
|
|
def test_embedding_failure_falls_back_gracefully(self, db: Path) -> None:
|
|
embedder = MagicMock()
|
|
embedder.embed.side_effect = RuntimeError("embed failed")
|
|
with patch("app.services.embeddings.EMBEDDING_AVAILABLE", True), \
|
|
patch("app.services.embeddings.get_embedder", return_value=embedder):
|
|
results = _hybrid_search(db, "database connection")
|
|
assert isinstance(results, list)
|
|
|
|
def test_respects_limit(self, db: Path) -> None:
|
|
embedder = _make_embedder([[0.5, 0.5], [0.5, 0.5]])
|
|
with patch("app.services.embeddings.EMBEDDING_AVAILABLE", True), \
|
|
patch("app.services.embeddings.get_embedder", return_value=embedder):
|
|
results = _hybrid_search(db, "database connection", limit=1)
|
|
assert len(results) <= 1
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# search() dispatcher
|
|
# ---------------------------------------------------------------------------
|
|
|
|
class TestSearchDispatcher:
|
|
def test_semantic_false_calls_bm25(self, db: Path) -> None:
|
|
with patch("app.services.search._bm25_search", wraps=_bm25_search) as mock_bm25, \
|
|
patch("app.services.search._hybrid_search") as mock_hybrid:
|
|
search(db, "database", semantic=False)
|
|
mock_bm25.assert_called_once()
|
|
mock_hybrid.assert_not_called()
|
|
|
|
def test_semantic_true_calls_hybrid(self, db: Path) -> None:
|
|
with patch("app.services.search._hybrid_search", return_value=[]) as mock_hybrid:
|
|
search(db, "database", semantic=True)
|
|
mock_hybrid.assert_called_once()
|
|
|
|
def test_default_is_bm25(self, db: Path) -> None:
|
|
with patch("app.services.search._hybrid_search") as mock_hybrid:
|
|
search(db, "database")
|
|
mock_hybrid.assert_not_called()
|