diff --git a/app/services/diagnose.py b/app/services/diagnose.py index b8a0362..4d1c887 100644 --- a/app/services/diagnose.py +++ b/app/services/diagnose.py @@ -10,6 +10,7 @@ from datetime import datetime, timedelta, timezone from pathlib import Path from typing import Any +from app.context.retriever import retrieve_context, format_context_block from app.services.llm import summarize from app.services.search import SearchResult, entries_in_window, search @@ -177,6 +178,15 @@ async def diagnose_stream( until = until or parsed_until time_detected = keywords != query + yield {"type": "status", "message": "Loading environment context…"} + ctx = await asyncio.to_thread(lambda: retrieve_context(db_path, query)) + context_block = format_context_block(ctx) + yield { + "type": "context", + "facts": ctx.facts, + "chunks": ctx.chunks, + } + yield {"type": "status", "message": "Searching logs…"} if source_browse: @@ -237,7 +247,7 @@ async def diagnose_stream( if llm_url and llm_model and combined: yield {"type": "status", "message": "Analyzing with LLM…"} reasoning = await asyncio.to_thread( - lambda: summarize(query, combined, llm_url, llm_model, llm_api_key) + lambda: summarize(query, combined, llm_url, llm_model, llm_api_key, context_block=context_block) ) if reasoning: yield {"type": "reasoning", "text": reasoning} diff --git a/app/services/llm.py b/app/services/llm.py index c152c8b..6d6c219 100644 --- a/app/services/llm.py +++ b/app/services/llm.py @@ -14,7 +14,7 @@ You are a homelab diagnostic assistant. A user described a symptom and the syste Analyze the log entries below and write a 2-4 sentence plain-language diagnosis. Focus on errors and their likely root cause. Be specific and concise — name the services involved, not generic platitudes. User query: {query} - +{context_section} Log entries ({n} shown, highest severity first): {log_block} @@ -47,11 +47,20 @@ def summarize( llm_model: str, api_key: str | None = None, timeout: float = 120.0, + context_block: str | None = None, ) -> str | None: if not entries: return None log_block = _build_context(entries) - prompt = _PROMPT_TEMPLATE.format(query=query, n=min(len(entries), 25), log_block=log_block) + context_section = ( + f"\nEnvironment context:\n{context_block}\n" if context_block else "" + ) + prompt = _PROMPT_TEMPLATE.format( + query=query, + n=min(len(entries), 25), + log_block=log_block, + context_section=context_section, + ) headers = {"Authorization": f"Bearer {api_key}"} if api_key else {} messages = [{"role": "user", "content": prompt}] diff --git a/tests/context/test_diagnose_context.py b/tests/context/test_diagnose_context.py new file mode 100644 index 0000000..f34da5f --- /dev/null +++ b/tests/context/test_diagnose_context.py @@ -0,0 +1,118 @@ +"""Verify context SSE event and LLM prompt injection.""" +import asyncio +import sqlite3 +from pathlib import Path +from unittest.mock import patch +import pytest +from app.services.llm import summarize +from app.services.search import SearchResult + + +def _entry(text: str) -> SearchResult: + return SearchResult( + entry_id="x", source_id="svc", sequence=0, + timestamp_iso="2026-05-13T00:00:00+00:00", + severity="ERROR", text=text, + matched_patterns=[], repeat_count=1, out_of_order=False, rank=0.0, + ) + + +def test_summarize_includes_context_block(): + captured = {} + + def fake_post(url, json=None, headers=None, timeout=None): + captured["json"] = json + raise ConnectionError("offline") + + with patch("app.services.llm.httpx.post", side_effect=fake_post): + summarize( + "plex stopped", + [_entry("plex error")], + llm_url="http://localhost:11434", + llm_model="llama3", + context_block="Known environment facts:\n [service] plex: port:32400", + ) + + messages = captured.get("json", {}).get("messages", []) + content = " ".join(m.get("content", "") for m in messages) + assert "Known environment facts" in content + assert "plex: port:32400" in content + + +def test_summarize_without_context_block_unchanged(): + """When context_block is None the prompt must not contain the context header.""" + captured = {} + + def fake_post(url, json=None, headers=None, timeout=None): + captured["json"] = json + raise ConnectionError("offline") + + with patch("app.services.llm.httpx.post", side_effect=fake_post): + summarize( + "plex stopped", + [_entry("plex error")], + llm_url="http://localhost:11434", + llm_model="llama3", + context_block=None, + ) + + messages = captured.get("json", {}).get("messages", []) + content = " ".join(m.get("content", "") for m in messages) + assert "Known environment facts" not in content + + +@pytest.fixture +def db_with_facts(tmp_path): + db_path = tmp_path / "t.db" + conn = sqlite3.connect(str(db_path)) + conn.executescript(""" + CREATE TABLE log_entries ( + id TEXT PRIMARY KEY, source_id TEXT NOT NULL, sequence INTEGER NOT NULL, + timestamp_raw TEXT, timestamp_iso TEXT, ingest_time TEXT NOT NULL, + severity TEXT, repeat_count INTEGER DEFAULT 1, out_of_order INTEGER DEFAULT 0, + matched_patterns TEXT DEFAULT '[]', text TEXT NOT NULL + ); + CREATE VIRTUAL TABLE IF NOT EXISTS log_fts USING fts5( + text, entry_id UNINDEXED, source_id UNINDEXED, sequence UNINDEXED, + severity UNINDEXED, timestamp_iso UNINDEXED, matched_patterns UNINDEXED, + repeat_count UNINDEXED, out_of_order UNINDEXED, tokenize='porter ascii' + ); + CREATE TABLE context_facts ( + id TEXT PRIMARY KEY, category TEXT NOT NULL, key TEXT NOT NULL, + value TEXT NOT NULL, source TEXT, created_at TEXT NOT NULL + ); + CREATE TABLE context_documents ( + id TEXT PRIMARY KEY, filename TEXT NOT NULL, doc_type TEXT NOT NULL, + full_text TEXT NOT NULL, file_size INTEGER, uploaded_at TEXT NOT NULL + ); + CREATE TABLE context_chunks ( + id TEXT PRIMARY KEY, document_id TEXT NOT NULL + REFERENCES context_documents(id) ON DELETE CASCADE, + chunk_index INTEGER NOT NULL, text TEXT NOT NULL, embedding BLOB + ); + INSERT INTO context_facts VALUES ( + 'f1','service','plex','port:32400','wizard','2026-05-13T00:00:00+00:00' + ); + """) + conn.commit() + conn.close() + return db_path + + +def test_diagnose_stream_emits_context_event(db_with_facts): + from app.services.diagnose import diagnose_stream + events = [] + + async def collect(): + async for evt in diagnose_stream( + db_path=db_with_facts, + query="plex stopped", + ): + events.append(evt) + + asyncio.run(collect()) + types = [e["type"] for e in events] + assert "context" in types + ctx_event = next(e for e in events if e["type"] == "context") + assert "facts" in ctx_event + assert any(f["key"] == "plex" for f in ctx_event["facts"])