From de662725ee123bb8ead37c6bdccd11e7478718ba Mon Sep 17 00:00:00 2001 From: pyr0ball Date: Wed, 13 May 2026 16:23:54 -0700 Subject: [PATCH] =?UTF-8?q?feat:=20context=20retriever=20=E2=80=94=20keywo?= =?UTF-8?q?rd=20fact=20lookup=20and=20chunk=20search?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- app/context/retriever.py | 88 +++++++++++++++++++++++++++++ tests/context/test_retriever.py | 98 +++++++++++++++++++++++++++++++++ 2 files changed, 186 insertions(+) create mode 100644 app/context/retriever.py create mode 100644 tests/context/test_retriever.py diff --git a/app/context/retriever.py b/app/context/retriever.py new file mode 100644 index 0000000..6b42c8e --- /dev/null +++ b/app/context/retriever.py @@ -0,0 +1,88 @@ +"""Context retrieval — structured keyword lookup (Free) + chunk search — MIT licensed.""" +from __future__ import annotations + +import sqlite3 +from dataclasses import dataclass, field +from pathlib import Path + + +@dataclass +class RetrievedContext: + facts: list[dict[str, str]] = field(default_factory=list) + chunks: list[dict[str, str]] = field(default_factory=list) + + +def get_relevant_facts(db_path: Path, query: str) -> list[dict[str, str]]: + """Keyword match against context_facts. Always runs — Free tier.""" + try: + conn = sqlite3.connect(str(db_path)) + conn.execute("PRAGMA journal_mode=WAL") + conn.row_factory = sqlite3.Row + keywords = [w.lower() for w in query.split() if len(w) > 2] + if not keywords: + rows = conn.execute( + "SELECT category, key, value, source FROM context_facts" + " ORDER BY category LIMIT 20" + ).fetchall() + else: + conditions = " OR ".join( + "(LOWER(key) LIKE ? OR LOWER(value) LIKE ?)" for _ in keywords + ) + params: list[str] = [] + for kw in keywords: + params.extend([f"%{kw}%", f"%{kw}%"]) + rows = conn.execute( + f"SELECT category, key, value, source FROM context_facts" + f" WHERE {conditions} ORDER BY category LIMIT 10", + params, + ).fetchall() + conn.close() + return [dict(r) for r in rows] + except sqlite3.OperationalError: + return [] + + +def _search_chunks(db_path: Path, query: str) -> list[dict[str, str]]: + """Keyword search across context_chunks. Fallback when no embeddings.""" + try: + conn = sqlite3.connect(str(db_path)) + conn.execute("PRAGMA journal_mode=WAL") + conn.row_factory = sqlite3.Row + keywords = [w.lower() for w in query.split() if len(w) > 2][:5] + if not keywords: + conn.close() + return [] + conditions = " OR ".join("LOWER(cc.text) LIKE ?" for _ in keywords) + params = [f"%{kw}%" for kw in keywords] + rows = conn.execute( + f"SELECT cc.text, cd.filename FROM context_chunks cc" + f" JOIN context_documents cd ON cc.document_id = cd.id" + f" WHERE {conditions} LIMIT 3", + params, + ).fetchall() + conn.close() + return [{"text": r["text"], "filename": r["filename"]} for r in rows] + except sqlite3.OperationalError: + return [] + + +def retrieve_context(db_path: Path, query: str) -> RetrievedContext: + """Retrieve structured facts and relevant chunks for a query.""" + return RetrievedContext( + facts=get_relevant_facts(db_path, query), + chunks=_search_chunks(db_path, query), + ) + + +def format_context_block(ctx: RetrievedContext) -> str | None: + """Format context for injection into LLM prompt. Returns None when empty.""" + lines: list[str] = [] + if ctx.facts: + lines.append("Known environment facts:") + for f in ctx.facts: + lines.append(f" [{f['category']}] {f['key']}: {f['value']}") + if ctx.chunks: + lines.append("Relevant documentation:") + for c in ctx.chunks: + lines.append(f" [{c['filename']}] {c['text'][:200]}") + return "\n".join(lines) if lines else None diff --git a/tests/context/test_retriever.py b/tests/context/test_retriever.py new file mode 100644 index 0000000..126f3b3 --- /dev/null +++ b/tests/context/test_retriever.py @@ -0,0 +1,98 @@ +"""Tests for app/context/retriever.py.""" +import sqlite3 +import pytest +from pathlib import Path +from app.context.retriever import get_relevant_facts, retrieve_context, format_context_block + + +@pytest.fixture +def db(tmp_path): + db_path = tmp_path / "t.db" + conn = sqlite3.connect(str(db_path)) + conn.executescript(""" + CREATE TABLE context_facts ( + id TEXT PRIMARY KEY, category TEXT NOT NULL, key TEXT NOT NULL, + value TEXT NOT NULL, source TEXT, created_at TEXT NOT NULL + ); + CREATE TABLE context_documents ( + id TEXT PRIMARY KEY, filename TEXT NOT NULL, doc_type TEXT NOT NULL, + full_text TEXT NOT NULL, file_size INTEGER, uploaded_at TEXT NOT NULL + ); + CREATE TABLE context_chunks ( + id TEXT PRIMARY KEY, document_id TEXT NOT NULL + REFERENCES context_documents(id) ON DELETE CASCADE, + chunk_index INTEGER NOT NULL, text TEXT NOT NULL, embedding BLOB + ); + """) + conn.execute( + "INSERT INTO context_facts VALUES ('f1','service','plex','port:32400 image:plexinc/pms-docker','wizard','2026-01-01T00:00:00+00:00')" + ) + conn.execute( + "INSERT INTO context_facts VALUES ('f2','host','hostname','heimdall.local','wizard','2026-01-01T00:00:01+00:00')" + ) + conn.commit() + conn.close() + return db_path + + +@pytest.fixture +def empty_db(tmp_path): + db_path = tmp_path / "empty.db" + conn = sqlite3.connect(str(db_path)) + conn.executescript(""" + CREATE TABLE context_facts ( + id TEXT PRIMARY KEY, category TEXT NOT NULL, key TEXT NOT NULL, + value TEXT NOT NULL, source TEXT, created_at TEXT NOT NULL + ); + CREATE TABLE context_documents ( + id TEXT PRIMARY KEY, filename TEXT NOT NULL, doc_type TEXT NOT NULL, + full_text TEXT NOT NULL, file_size INTEGER, uploaded_at TEXT NOT NULL + ); + CREATE TABLE context_chunks ( + id TEXT PRIMARY KEY, document_id TEXT NOT NULL + REFERENCES context_documents(id) ON DELETE CASCADE, + chunk_index INTEGER NOT NULL, text TEXT NOT NULL, embedding BLOB + ); + """) + conn.commit() + conn.close() + return db_path + + +def test_get_relevant_facts_keyword_match(db): + facts = get_relevant_facts(db, "plex audio stopped") + keys = [f["key"] for f in facts] + assert "plex" in keys + + +def test_get_relevant_facts_empty_db(empty_db): + facts = get_relevant_facts(empty_db, "anything") + assert facts == [] + + +def test_get_relevant_facts_no_tables(tmp_path): + db_path = tmp_path / "notables.db" + conn = sqlite3.connect(str(db_path)) + conn.close() + # Should return empty without crashing + facts = get_relevant_facts(db_path, "query") + assert facts == [] + + +def test_retrieve_context_returns_dataclass(db): + ctx = retrieve_context(db, "plex stopped") + assert ctx.facts + assert isinstance(ctx.chunks, list) + + +def test_format_context_block_with_facts(db): + ctx = retrieve_context(db, "plex stopped") + block = format_context_block(ctx) + assert block is not None + assert "plex" in block + + +def test_format_context_block_empty(empty_db): + ctx = retrieve_context(empty_db, "anything") + block = format_context_block(ctx) + assert block is None