feat: inject environment context into diagnose pipeline and LLM prompt

- Add context_block param to summarize() and thread it into _PROMPT_TEMPLATE
- Wire retrieve_context/format_context_block into diagnose_stream() before
  log search; emit context SSE event (facts + chunks) to the client
- 3 new tests covering prompt injection and SSE event emission (155 total, all pass)
This commit is contained in:
pyr0ball 2026-05-13 16:29:26 -07:00
parent abb61a6e90
commit f19f896300
3 changed files with 140 additions and 3 deletions

View file

@ -10,6 +10,7 @@ from datetime import datetime, timedelta, timezone
from pathlib import Path from pathlib import Path
from typing import Any from typing import Any
from app.context.retriever import retrieve_context, format_context_block
from app.services.llm import summarize from app.services.llm import summarize
from app.services.search import SearchResult, entries_in_window, search from app.services.search import SearchResult, entries_in_window, search
@ -177,6 +178,15 @@ async def diagnose_stream(
until = until or parsed_until until = until or parsed_until
time_detected = keywords != query time_detected = keywords != query
yield {"type": "status", "message": "Loading environment context…"}
ctx = await asyncio.to_thread(lambda: retrieve_context(db_path, query))
context_block = format_context_block(ctx)
yield {
"type": "context",
"facts": ctx.facts,
"chunks": ctx.chunks,
}
yield {"type": "status", "message": "Searching logs…"} yield {"type": "status", "message": "Searching logs…"}
if source_browse: if source_browse:
@ -237,7 +247,7 @@ async def diagnose_stream(
if llm_url and llm_model and combined: if llm_url and llm_model and combined:
yield {"type": "status", "message": "Analyzing with LLM…"} yield {"type": "status", "message": "Analyzing with LLM…"}
reasoning = await asyncio.to_thread( reasoning = await asyncio.to_thread(
lambda: summarize(query, combined, llm_url, llm_model, llm_api_key) lambda: summarize(query, combined, llm_url, llm_model, llm_api_key, context_block=context_block)
) )
if reasoning: if reasoning:
yield {"type": "reasoning", "text": reasoning} yield {"type": "reasoning", "text": reasoning}

View file

@ -14,7 +14,7 @@ You are a homelab diagnostic assistant. A user described a symptom and the syste
Analyze the log entries below and write a 2-4 sentence plain-language diagnosis. Focus on errors and their likely root cause. Be specific and concise name the services involved, not generic platitudes. Analyze the log entries below and write a 2-4 sentence plain-language diagnosis. Focus on errors and their likely root cause. Be specific and concise name the services involved, not generic platitudes.
User query: {query} User query: {query}
{context_section}
Log entries ({n} shown, highest severity first): Log entries ({n} shown, highest severity first):
{log_block} {log_block}
@ -47,11 +47,20 @@ def summarize(
llm_model: str, llm_model: str,
api_key: str | None = None, api_key: str | None = None,
timeout: float = 120.0, timeout: float = 120.0,
context_block: str | None = None,
) -> str | None: ) -> str | None:
if not entries: if not entries:
return None return None
log_block = _build_context(entries) log_block = _build_context(entries)
prompt = _PROMPT_TEMPLATE.format(query=query, n=min(len(entries), 25), log_block=log_block) context_section = (
f"\nEnvironment context:\n{context_block}\n" if context_block else ""
)
prompt = _PROMPT_TEMPLATE.format(
query=query,
n=min(len(entries), 25),
log_block=log_block,
context_section=context_section,
)
headers = {"Authorization": f"Bearer {api_key}"} if api_key else {} headers = {"Authorization": f"Bearer {api_key}"} if api_key else {}
messages = [{"role": "user", "content": prompt}] messages = [{"role": "user", "content": prompt}]

View file

@ -0,0 +1,118 @@
"""Verify context SSE event and LLM prompt injection."""
import asyncio
import sqlite3
from pathlib import Path
from unittest.mock import patch
import pytest
from app.services.llm import summarize
from app.services.search import SearchResult
def _entry(text: str) -> SearchResult:
return SearchResult(
entry_id="x", source_id="svc", sequence=0,
timestamp_iso="2026-05-13T00:00:00+00:00",
severity="ERROR", text=text,
matched_patterns=[], repeat_count=1, out_of_order=False, rank=0.0,
)
def test_summarize_includes_context_block():
captured = {}
def fake_post(url, json=None, headers=None, timeout=None):
captured["json"] = json
raise ConnectionError("offline")
with patch("app.services.llm.httpx.post", side_effect=fake_post):
summarize(
"plex stopped",
[_entry("plex error")],
llm_url="http://localhost:11434",
llm_model="llama3",
context_block="Known environment facts:\n [service] plex: port:32400",
)
messages = captured.get("json", {}).get("messages", [])
content = " ".join(m.get("content", "") for m in messages)
assert "Known environment facts" in content
assert "plex: port:32400" in content
def test_summarize_without_context_block_unchanged():
"""When context_block is None the prompt must not contain the context header."""
captured = {}
def fake_post(url, json=None, headers=None, timeout=None):
captured["json"] = json
raise ConnectionError("offline")
with patch("app.services.llm.httpx.post", side_effect=fake_post):
summarize(
"plex stopped",
[_entry("plex error")],
llm_url="http://localhost:11434",
llm_model="llama3",
context_block=None,
)
messages = captured.get("json", {}).get("messages", [])
content = " ".join(m.get("content", "") for m in messages)
assert "Known environment facts" not in content
@pytest.fixture
def db_with_facts(tmp_path):
db_path = tmp_path / "t.db"
conn = sqlite3.connect(str(db_path))
conn.executescript("""
CREATE TABLE log_entries (
id TEXT PRIMARY KEY, source_id TEXT NOT NULL, sequence INTEGER NOT NULL,
timestamp_raw TEXT, timestamp_iso TEXT, ingest_time TEXT NOT NULL,
severity TEXT, repeat_count INTEGER DEFAULT 1, out_of_order INTEGER DEFAULT 0,
matched_patterns TEXT DEFAULT '[]', text TEXT NOT NULL
);
CREATE VIRTUAL TABLE IF NOT EXISTS log_fts USING fts5(
text, entry_id UNINDEXED, source_id UNINDEXED, sequence UNINDEXED,
severity UNINDEXED, timestamp_iso UNINDEXED, matched_patterns UNINDEXED,
repeat_count UNINDEXED, out_of_order UNINDEXED, tokenize='porter ascii'
);
CREATE TABLE context_facts (
id TEXT PRIMARY KEY, category TEXT NOT NULL, key TEXT NOT NULL,
value TEXT NOT NULL, source TEXT, created_at TEXT NOT NULL
);
CREATE TABLE context_documents (
id TEXT PRIMARY KEY, filename TEXT NOT NULL, doc_type TEXT NOT NULL,
full_text TEXT NOT NULL, file_size INTEGER, uploaded_at TEXT NOT NULL
);
CREATE TABLE context_chunks (
id TEXT PRIMARY KEY, document_id TEXT NOT NULL
REFERENCES context_documents(id) ON DELETE CASCADE,
chunk_index INTEGER NOT NULL, text TEXT NOT NULL, embedding BLOB
);
INSERT INTO context_facts VALUES (
'f1','service','plex','port:32400','wizard','2026-05-13T00:00:00+00:00'
);
""")
conn.commit()
conn.close()
return db_path
def test_diagnose_stream_emits_context_event(db_with_facts):
from app.services.diagnose import diagnose_stream
events = []
async def collect():
async for evt in diagnose_stream(
db_path=db_with_facts,
query="plex stopped",
):
events.append(evt)
asyncio.run(collect())
types = [e["type"] for e in events]
assert "context" in types
ctx_event = next(e for e in events if e["type"] == "context")
assert "facts" in ctx_event
assert any(f["key"] == "plex" for f in ctx_event["facts"])