feat: inject environment context into diagnose pipeline and LLM prompt
- Add context_block param to summarize() and thread it into _PROMPT_TEMPLATE - Wire retrieve_context/format_context_block into diagnose_stream() before log search; emit context SSE event (facts + chunks) to the client - 3 new tests covering prompt injection and SSE event emission (155 total, all pass)
This commit is contained in:
parent
9c8c60e461
commit
2c408907ac
3 changed files with 140 additions and 3 deletions
|
|
@ -10,6 +10,7 @@ from datetime import datetime, timedelta, timezone
|
|||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
from app.context.retriever import retrieve_context, format_context_block
|
||||
from app.services.llm import summarize
|
||||
from app.services.search import SearchResult, entries_in_window, search
|
||||
|
||||
|
|
@ -177,6 +178,15 @@ async def diagnose_stream(
|
|||
until = until or parsed_until
|
||||
time_detected = keywords != query
|
||||
|
||||
yield {"type": "status", "message": "Loading environment context…"}
|
||||
ctx = await asyncio.to_thread(lambda: retrieve_context(db_path, query))
|
||||
context_block = format_context_block(ctx)
|
||||
yield {
|
||||
"type": "context",
|
||||
"facts": ctx.facts,
|
||||
"chunks": ctx.chunks,
|
||||
}
|
||||
|
||||
yield {"type": "status", "message": "Searching logs…"}
|
||||
|
||||
if source_browse:
|
||||
|
|
@ -237,7 +247,7 @@ async def diagnose_stream(
|
|||
if llm_url and llm_model and combined:
|
||||
yield {"type": "status", "message": "Analyzing with LLM…"}
|
||||
reasoning = await asyncio.to_thread(
|
||||
lambda: summarize(query, combined, llm_url, llm_model, llm_api_key)
|
||||
lambda: summarize(query, combined, llm_url, llm_model, llm_api_key, context_block=context_block)
|
||||
)
|
||||
if reasoning:
|
||||
yield {"type": "reasoning", "text": reasoning}
|
||||
|
|
|
|||
|
|
@ -14,7 +14,7 @@ You are a homelab diagnostic assistant. A user described a symptom and the syste
|
|||
Analyze the log entries below and write a 2-4 sentence plain-language diagnosis. Focus on errors and their likely root cause. Be specific and concise — name the services involved, not generic platitudes.
|
||||
|
||||
User query: {query}
|
||||
|
||||
{context_section}
|
||||
Log entries ({n} shown, highest severity first):
|
||||
{log_block}
|
||||
|
||||
|
|
@ -47,11 +47,20 @@ def summarize(
|
|||
llm_model: str,
|
||||
api_key: str | None = None,
|
||||
timeout: float = 120.0,
|
||||
context_block: str | None = None,
|
||||
) -> str | None:
|
||||
if not entries:
|
||||
return None
|
||||
log_block = _build_context(entries)
|
||||
prompt = _PROMPT_TEMPLATE.format(query=query, n=min(len(entries), 25), log_block=log_block)
|
||||
context_section = (
|
||||
f"\nEnvironment context:\n{context_block}\n" if context_block else ""
|
||||
)
|
||||
prompt = _PROMPT_TEMPLATE.format(
|
||||
query=query,
|
||||
n=min(len(entries), 25),
|
||||
log_block=log_block,
|
||||
context_section=context_section,
|
||||
)
|
||||
headers = {"Authorization": f"Bearer {api_key}"} if api_key else {}
|
||||
messages = [{"role": "user", "content": prompt}]
|
||||
|
||||
|
|
|
|||
118
tests/context/test_diagnose_context.py
Normal file
118
tests/context/test_diagnose_context.py
Normal file
|
|
@ -0,0 +1,118 @@
|
|||
"""Verify context SSE event and LLM prompt injection."""
|
||||
import asyncio
|
||||
import sqlite3
|
||||
from pathlib import Path
|
||||
from unittest.mock import patch
|
||||
import pytest
|
||||
from app.services.llm import summarize
|
||||
from app.services.search import SearchResult
|
||||
|
||||
|
||||
def _entry(text: str) -> SearchResult:
|
||||
return SearchResult(
|
||||
entry_id="x", source_id="svc", sequence=0,
|
||||
timestamp_iso="2026-05-13T00:00:00+00:00",
|
||||
severity="ERROR", text=text,
|
||||
matched_patterns=[], repeat_count=1, out_of_order=False, rank=0.0,
|
||||
)
|
||||
|
||||
|
||||
def test_summarize_includes_context_block():
|
||||
captured = {}
|
||||
|
||||
def fake_post(url, json=None, headers=None, timeout=None):
|
||||
captured["json"] = json
|
||||
raise ConnectionError("offline")
|
||||
|
||||
with patch("app.services.llm.httpx.post", side_effect=fake_post):
|
||||
summarize(
|
||||
"plex stopped",
|
||||
[_entry("plex error")],
|
||||
llm_url="http://localhost:11434",
|
||||
llm_model="llama3",
|
||||
context_block="Known environment facts:\n [service] plex: port:32400",
|
||||
)
|
||||
|
||||
messages = captured.get("json", {}).get("messages", [])
|
||||
content = " ".join(m.get("content", "") for m in messages)
|
||||
assert "Known environment facts" in content
|
||||
assert "plex: port:32400" in content
|
||||
|
||||
|
||||
def test_summarize_without_context_block_unchanged():
|
||||
"""When context_block is None the prompt must not contain the context header."""
|
||||
captured = {}
|
||||
|
||||
def fake_post(url, json=None, headers=None, timeout=None):
|
||||
captured["json"] = json
|
||||
raise ConnectionError("offline")
|
||||
|
||||
with patch("app.services.llm.httpx.post", side_effect=fake_post):
|
||||
summarize(
|
||||
"plex stopped",
|
||||
[_entry("plex error")],
|
||||
llm_url="http://localhost:11434",
|
||||
llm_model="llama3",
|
||||
context_block=None,
|
||||
)
|
||||
|
||||
messages = captured.get("json", {}).get("messages", [])
|
||||
content = " ".join(m.get("content", "") for m in messages)
|
||||
assert "Known environment facts" not in content
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def db_with_facts(tmp_path):
|
||||
db_path = tmp_path / "t.db"
|
||||
conn = sqlite3.connect(str(db_path))
|
||||
conn.executescript("""
|
||||
CREATE TABLE log_entries (
|
||||
id TEXT PRIMARY KEY, source_id TEXT NOT NULL, sequence INTEGER NOT NULL,
|
||||
timestamp_raw TEXT, timestamp_iso TEXT, ingest_time TEXT NOT NULL,
|
||||
severity TEXT, repeat_count INTEGER DEFAULT 1, out_of_order INTEGER DEFAULT 0,
|
||||
matched_patterns TEXT DEFAULT '[]', text TEXT NOT NULL
|
||||
);
|
||||
CREATE VIRTUAL TABLE IF NOT EXISTS log_fts USING fts5(
|
||||
text, entry_id UNINDEXED, source_id UNINDEXED, sequence UNINDEXED,
|
||||
severity UNINDEXED, timestamp_iso UNINDEXED, matched_patterns UNINDEXED,
|
||||
repeat_count UNINDEXED, out_of_order UNINDEXED, tokenize='porter ascii'
|
||||
);
|
||||
CREATE TABLE context_facts (
|
||||
id TEXT PRIMARY KEY, category TEXT NOT NULL, key TEXT NOT NULL,
|
||||
value TEXT NOT NULL, source TEXT, created_at TEXT NOT NULL
|
||||
);
|
||||
CREATE TABLE context_documents (
|
||||
id TEXT PRIMARY KEY, filename TEXT NOT NULL, doc_type TEXT NOT NULL,
|
||||
full_text TEXT NOT NULL, file_size INTEGER, uploaded_at TEXT NOT NULL
|
||||
);
|
||||
CREATE TABLE context_chunks (
|
||||
id TEXT PRIMARY KEY, document_id TEXT NOT NULL
|
||||
REFERENCES context_documents(id) ON DELETE CASCADE,
|
||||
chunk_index INTEGER NOT NULL, text TEXT NOT NULL, embedding BLOB
|
||||
);
|
||||
INSERT INTO context_facts VALUES (
|
||||
'f1','service','plex','port:32400','wizard','2026-05-13T00:00:00+00:00'
|
||||
);
|
||||
""")
|
||||
conn.commit()
|
||||
conn.close()
|
||||
return db_path
|
||||
|
||||
|
||||
def test_diagnose_stream_emits_context_event(db_with_facts):
|
||||
from app.services.diagnose import diagnose_stream
|
||||
events = []
|
||||
|
||||
async def collect():
|
||||
async for evt in diagnose_stream(
|
||||
db_path=db_with_facts,
|
||||
query="plex stopped",
|
||||
):
|
||||
events.append(evt)
|
||||
|
||||
asyncio.run(collect())
|
||||
types = [e["type"] for e in events]
|
||||
assert "context" in types
|
||||
ctx_event = next(e for e in events if e["type"] == "context")
|
||||
assert "facts" in ctx_event
|
||||
assert any(f["key"] == "plex" for f in ctx_event["facts"])
|
||||
Loading…
Reference in a new issue