turnstone/tests/context/test_retriever.py

98 lines
3.2 KiB
Python

"""Tests for app/context/retriever.py."""
import sqlite3
import pytest
from pathlib import Path
from app.context.retriever import get_relevant_facts, retrieve_context, format_context_block
@pytest.fixture
def db(tmp_path):
db_path = tmp_path / "t.db"
conn = sqlite3.connect(str(db_path))
conn.executescript("""
CREATE TABLE context_facts (
id TEXT PRIMARY KEY, category TEXT NOT NULL, key TEXT NOT NULL,
value TEXT NOT NULL, source TEXT, created_at TEXT NOT NULL
);
CREATE TABLE context_documents (
id TEXT PRIMARY KEY, filename TEXT NOT NULL, doc_type TEXT NOT NULL,
full_text TEXT NOT NULL, file_size INTEGER, uploaded_at TEXT NOT NULL
);
CREATE TABLE context_chunks (
id TEXT PRIMARY KEY, document_id TEXT NOT NULL
REFERENCES context_documents(id) ON DELETE CASCADE,
chunk_index INTEGER NOT NULL, text TEXT NOT NULL, embedding BLOB
);
""")
conn.execute(
"INSERT INTO context_facts VALUES ('f1','service','plex','port:32400 image:plexinc/pms-docker','wizard','2026-01-01T00:00:00+00:00')"
)
conn.execute(
"INSERT INTO context_facts VALUES ('f2','host','hostname','heimdall.local','wizard','2026-01-01T00:00:01+00:00')"
)
conn.commit()
conn.close()
return db_path
@pytest.fixture
def empty_db(tmp_path):
db_path = tmp_path / "empty.db"
conn = sqlite3.connect(str(db_path))
conn.executescript("""
CREATE TABLE context_facts (
id TEXT PRIMARY KEY, category TEXT NOT NULL, key TEXT NOT NULL,
value TEXT NOT NULL, source TEXT, created_at TEXT NOT NULL
);
CREATE TABLE context_documents (
id TEXT PRIMARY KEY, filename TEXT NOT NULL, doc_type TEXT NOT NULL,
full_text TEXT NOT NULL, file_size INTEGER, uploaded_at TEXT NOT NULL
);
CREATE TABLE context_chunks (
id TEXT PRIMARY KEY, document_id TEXT NOT NULL
REFERENCES context_documents(id) ON DELETE CASCADE,
chunk_index INTEGER NOT NULL, text TEXT NOT NULL, embedding BLOB
);
""")
conn.commit()
conn.close()
return db_path
def test_get_relevant_facts_keyword_match(db):
facts = get_relevant_facts(db, "plex audio stopped")
keys = [f["key"] for f in facts]
assert "plex" in keys
def test_get_relevant_facts_empty_db(empty_db):
facts = get_relevant_facts(empty_db, "anything")
assert facts == []
def test_get_relevant_facts_no_tables(tmp_path):
db_path = tmp_path / "notables.db"
conn = sqlite3.connect(str(db_path))
conn.close()
# Should return empty without crashing
facts = get_relevant_facts(db_path, "query")
assert facts == []
def test_retrieve_context_returns_dataclass(db):
ctx = retrieve_context(db, "plex stopped")
assert ctx.facts
assert isinstance(ctx.chunks, list)
def test_format_context_block_with_facts(db):
ctx = retrieve_context(db, "plex stopped")
block = format_context_block(ctx)
assert block is not None
assert "plex" in block
def test_format_context_block_empty(empty_db):
ctx = retrieve_context(empty_db, "anything")
block = format_context_block(ctx)
assert block is None