98 lines
3.2 KiB
Python
98 lines
3.2 KiB
Python
"""Tests for app/context/retriever.py."""
|
|
import sqlite3
|
|
import pytest
|
|
from pathlib import Path
|
|
from app.context.retriever import get_relevant_facts, retrieve_context, format_context_block
|
|
|
|
|
|
@pytest.fixture
|
|
def db(tmp_path):
|
|
db_path = tmp_path / "t.db"
|
|
conn = sqlite3.connect(str(db_path))
|
|
conn.executescript("""
|
|
CREATE TABLE context_facts (
|
|
id TEXT PRIMARY KEY, category TEXT NOT NULL, key TEXT NOT NULL,
|
|
value TEXT NOT NULL, source TEXT, created_at TEXT NOT NULL
|
|
);
|
|
CREATE TABLE context_documents (
|
|
id TEXT PRIMARY KEY, filename TEXT NOT NULL, doc_type TEXT NOT NULL,
|
|
full_text TEXT NOT NULL, file_size INTEGER, uploaded_at TEXT NOT NULL
|
|
);
|
|
CREATE TABLE context_chunks (
|
|
id TEXT PRIMARY KEY, document_id TEXT NOT NULL
|
|
REFERENCES context_documents(id) ON DELETE CASCADE,
|
|
chunk_index INTEGER NOT NULL, text TEXT NOT NULL, embedding BLOB
|
|
);
|
|
""")
|
|
conn.execute(
|
|
"INSERT INTO context_facts VALUES ('f1','service','plex','port:32400 image:plexinc/pms-docker','wizard','2026-01-01T00:00:00+00:00')"
|
|
)
|
|
conn.execute(
|
|
"INSERT INTO context_facts VALUES ('f2','host','hostname','heimdall.local','wizard','2026-01-01T00:00:01+00:00')"
|
|
)
|
|
conn.commit()
|
|
conn.close()
|
|
return db_path
|
|
|
|
|
|
@pytest.fixture
|
|
def empty_db(tmp_path):
|
|
db_path = tmp_path / "empty.db"
|
|
conn = sqlite3.connect(str(db_path))
|
|
conn.executescript("""
|
|
CREATE TABLE context_facts (
|
|
id TEXT PRIMARY KEY, category TEXT NOT NULL, key TEXT NOT NULL,
|
|
value TEXT NOT NULL, source TEXT, created_at TEXT NOT NULL
|
|
);
|
|
CREATE TABLE context_documents (
|
|
id TEXT PRIMARY KEY, filename TEXT NOT NULL, doc_type TEXT NOT NULL,
|
|
full_text TEXT NOT NULL, file_size INTEGER, uploaded_at TEXT NOT NULL
|
|
);
|
|
CREATE TABLE context_chunks (
|
|
id TEXT PRIMARY KEY, document_id TEXT NOT NULL
|
|
REFERENCES context_documents(id) ON DELETE CASCADE,
|
|
chunk_index INTEGER NOT NULL, text TEXT NOT NULL, embedding BLOB
|
|
);
|
|
""")
|
|
conn.commit()
|
|
conn.close()
|
|
return db_path
|
|
|
|
|
|
def test_get_relevant_facts_keyword_match(db):
|
|
facts = get_relevant_facts(db, "plex audio stopped")
|
|
keys = [f["key"] for f in facts]
|
|
assert "plex" in keys
|
|
|
|
|
|
def test_get_relevant_facts_empty_db(empty_db):
|
|
facts = get_relevant_facts(empty_db, "anything")
|
|
assert facts == []
|
|
|
|
|
|
def test_get_relevant_facts_no_tables(tmp_path):
|
|
db_path = tmp_path / "notables.db"
|
|
conn = sqlite3.connect(str(db_path))
|
|
conn.close()
|
|
# Should return empty without crashing
|
|
facts = get_relevant_facts(db_path, "query")
|
|
assert facts == []
|
|
|
|
|
|
def test_retrieve_context_returns_dataclass(db):
|
|
ctx = retrieve_context(db, "plex stopped")
|
|
assert ctx.facts
|
|
assert isinstance(ctx.chunks, list)
|
|
|
|
|
|
def test_format_context_block_with_facts(db):
|
|
ctx = retrieve_context(db, "plex stopped")
|
|
block = format_context_block(ctx)
|
|
assert block is not None
|
|
assert "plex" in block
|
|
|
|
|
|
def test_format_context_block_empty(empty_db):
|
|
ctx = retrieve_context(empty_db, "anything")
|
|
block = format_context_block(ctx)
|
|
assert block is None
|