feat: inject environment context into diagnose pipeline and LLM prompt
- Add context_block param to summarize() and thread it into _PROMPT_TEMPLATE - Wire retrieve_context/format_context_block into diagnose_stream() before log search; emit context SSE event (facts + chunks) to the client - 3 new tests covering prompt injection and SSE event emission (155 total, all pass)
This commit is contained in:
parent
abb61a6e90
commit
f19f896300
3 changed files with 140 additions and 3 deletions
|
|
@ -10,6 +10,7 @@ from datetime import datetime, timedelta, timezone
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
|
from app.context.retriever import retrieve_context, format_context_block
|
||||||
from app.services.llm import summarize
|
from app.services.llm import summarize
|
||||||
from app.services.search import SearchResult, entries_in_window, search
|
from app.services.search import SearchResult, entries_in_window, search
|
||||||
|
|
||||||
|
|
@ -177,6 +178,15 @@ async def diagnose_stream(
|
||||||
until = until or parsed_until
|
until = until or parsed_until
|
||||||
time_detected = keywords != query
|
time_detected = keywords != query
|
||||||
|
|
||||||
|
yield {"type": "status", "message": "Loading environment context…"}
|
||||||
|
ctx = await asyncio.to_thread(lambda: retrieve_context(db_path, query))
|
||||||
|
context_block = format_context_block(ctx)
|
||||||
|
yield {
|
||||||
|
"type": "context",
|
||||||
|
"facts": ctx.facts,
|
||||||
|
"chunks": ctx.chunks,
|
||||||
|
}
|
||||||
|
|
||||||
yield {"type": "status", "message": "Searching logs…"}
|
yield {"type": "status", "message": "Searching logs…"}
|
||||||
|
|
||||||
if source_browse:
|
if source_browse:
|
||||||
|
|
@ -237,7 +247,7 @@ async def diagnose_stream(
|
||||||
if llm_url and llm_model and combined:
|
if llm_url and llm_model and combined:
|
||||||
yield {"type": "status", "message": "Analyzing with LLM…"}
|
yield {"type": "status", "message": "Analyzing with LLM…"}
|
||||||
reasoning = await asyncio.to_thread(
|
reasoning = await asyncio.to_thread(
|
||||||
lambda: summarize(query, combined, llm_url, llm_model, llm_api_key)
|
lambda: summarize(query, combined, llm_url, llm_model, llm_api_key, context_block=context_block)
|
||||||
)
|
)
|
||||||
if reasoning:
|
if reasoning:
|
||||||
yield {"type": "reasoning", "text": reasoning}
|
yield {"type": "reasoning", "text": reasoning}
|
||||||
|
|
|
||||||
|
|
@ -14,7 +14,7 @@ You are a homelab diagnostic assistant. A user described a symptom and the syste
|
||||||
Analyze the log entries below and write a 2-4 sentence plain-language diagnosis. Focus on errors and their likely root cause. Be specific and concise — name the services involved, not generic platitudes.
|
Analyze the log entries below and write a 2-4 sentence plain-language diagnosis. Focus on errors and their likely root cause. Be specific and concise — name the services involved, not generic platitudes.
|
||||||
|
|
||||||
User query: {query}
|
User query: {query}
|
||||||
|
{context_section}
|
||||||
Log entries ({n} shown, highest severity first):
|
Log entries ({n} shown, highest severity first):
|
||||||
{log_block}
|
{log_block}
|
||||||
|
|
||||||
|
|
@ -47,11 +47,20 @@ def summarize(
|
||||||
llm_model: str,
|
llm_model: str,
|
||||||
api_key: str | None = None,
|
api_key: str | None = None,
|
||||||
timeout: float = 120.0,
|
timeout: float = 120.0,
|
||||||
|
context_block: str | None = None,
|
||||||
) -> str | None:
|
) -> str | None:
|
||||||
if not entries:
|
if not entries:
|
||||||
return None
|
return None
|
||||||
log_block = _build_context(entries)
|
log_block = _build_context(entries)
|
||||||
prompt = _PROMPT_TEMPLATE.format(query=query, n=min(len(entries), 25), log_block=log_block)
|
context_section = (
|
||||||
|
f"\nEnvironment context:\n{context_block}\n" if context_block else ""
|
||||||
|
)
|
||||||
|
prompt = _PROMPT_TEMPLATE.format(
|
||||||
|
query=query,
|
||||||
|
n=min(len(entries), 25),
|
||||||
|
log_block=log_block,
|
||||||
|
context_section=context_section,
|
||||||
|
)
|
||||||
headers = {"Authorization": f"Bearer {api_key}"} if api_key else {}
|
headers = {"Authorization": f"Bearer {api_key}"} if api_key else {}
|
||||||
messages = [{"role": "user", "content": prompt}]
|
messages = [{"role": "user", "content": prompt}]
|
||||||
|
|
||||||
|
|
|
||||||
118
tests/context/test_diagnose_context.py
Normal file
118
tests/context/test_diagnose_context.py
Normal file
|
|
@ -0,0 +1,118 @@
|
||||||
|
"""Verify context SSE event and LLM prompt injection."""
|
||||||
|
import asyncio
|
||||||
|
import sqlite3
|
||||||
|
from pathlib import Path
|
||||||
|
from unittest.mock import patch
|
||||||
|
import pytest
|
||||||
|
from app.services.llm import summarize
|
||||||
|
from app.services.search import SearchResult
|
||||||
|
|
||||||
|
|
||||||
|
def _entry(text: str) -> SearchResult:
|
||||||
|
return SearchResult(
|
||||||
|
entry_id="x", source_id="svc", sequence=0,
|
||||||
|
timestamp_iso="2026-05-13T00:00:00+00:00",
|
||||||
|
severity="ERROR", text=text,
|
||||||
|
matched_patterns=[], repeat_count=1, out_of_order=False, rank=0.0,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_summarize_includes_context_block():
|
||||||
|
captured = {}
|
||||||
|
|
||||||
|
def fake_post(url, json=None, headers=None, timeout=None):
|
||||||
|
captured["json"] = json
|
||||||
|
raise ConnectionError("offline")
|
||||||
|
|
||||||
|
with patch("app.services.llm.httpx.post", side_effect=fake_post):
|
||||||
|
summarize(
|
||||||
|
"plex stopped",
|
||||||
|
[_entry("plex error")],
|
||||||
|
llm_url="http://localhost:11434",
|
||||||
|
llm_model="llama3",
|
||||||
|
context_block="Known environment facts:\n [service] plex: port:32400",
|
||||||
|
)
|
||||||
|
|
||||||
|
messages = captured.get("json", {}).get("messages", [])
|
||||||
|
content = " ".join(m.get("content", "") for m in messages)
|
||||||
|
assert "Known environment facts" in content
|
||||||
|
assert "plex: port:32400" in content
|
||||||
|
|
||||||
|
|
||||||
|
def test_summarize_without_context_block_unchanged():
|
||||||
|
"""When context_block is None the prompt must not contain the context header."""
|
||||||
|
captured = {}
|
||||||
|
|
||||||
|
def fake_post(url, json=None, headers=None, timeout=None):
|
||||||
|
captured["json"] = json
|
||||||
|
raise ConnectionError("offline")
|
||||||
|
|
||||||
|
with patch("app.services.llm.httpx.post", side_effect=fake_post):
|
||||||
|
summarize(
|
||||||
|
"plex stopped",
|
||||||
|
[_entry("plex error")],
|
||||||
|
llm_url="http://localhost:11434",
|
||||||
|
llm_model="llama3",
|
||||||
|
context_block=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
messages = captured.get("json", {}).get("messages", [])
|
||||||
|
content = " ".join(m.get("content", "") for m in messages)
|
||||||
|
assert "Known environment facts" not in content
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def db_with_facts(tmp_path):
|
||||||
|
db_path = tmp_path / "t.db"
|
||||||
|
conn = sqlite3.connect(str(db_path))
|
||||||
|
conn.executescript("""
|
||||||
|
CREATE TABLE log_entries (
|
||||||
|
id TEXT PRIMARY KEY, source_id TEXT NOT NULL, sequence INTEGER NOT NULL,
|
||||||
|
timestamp_raw TEXT, timestamp_iso TEXT, ingest_time TEXT NOT NULL,
|
||||||
|
severity TEXT, repeat_count INTEGER DEFAULT 1, out_of_order INTEGER DEFAULT 0,
|
||||||
|
matched_patterns TEXT DEFAULT '[]', text TEXT NOT NULL
|
||||||
|
);
|
||||||
|
CREATE VIRTUAL TABLE IF NOT EXISTS log_fts USING fts5(
|
||||||
|
text, entry_id UNINDEXED, source_id UNINDEXED, sequence UNINDEXED,
|
||||||
|
severity UNINDEXED, timestamp_iso UNINDEXED, matched_patterns UNINDEXED,
|
||||||
|
repeat_count UNINDEXED, out_of_order UNINDEXED, tokenize='porter ascii'
|
||||||
|
);
|
||||||
|
CREATE TABLE context_facts (
|
||||||
|
id TEXT PRIMARY KEY, category TEXT NOT NULL, key TEXT NOT NULL,
|
||||||
|
value TEXT NOT NULL, source TEXT, created_at TEXT NOT NULL
|
||||||
|
);
|
||||||
|
CREATE TABLE context_documents (
|
||||||
|
id TEXT PRIMARY KEY, filename TEXT NOT NULL, doc_type TEXT NOT NULL,
|
||||||
|
full_text TEXT NOT NULL, file_size INTEGER, uploaded_at TEXT NOT NULL
|
||||||
|
);
|
||||||
|
CREATE TABLE context_chunks (
|
||||||
|
id TEXT PRIMARY KEY, document_id TEXT NOT NULL
|
||||||
|
REFERENCES context_documents(id) ON DELETE CASCADE,
|
||||||
|
chunk_index INTEGER NOT NULL, text TEXT NOT NULL, embedding BLOB
|
||||||
|
);
|
||||||
|
INSERT INTO context_facts VALUES (
|
||||||
|
'f1','service','plex','port:32400','wizard','2026-05-13T00:00:00+00:00'
|
||||||
|
);
|
||||||
|
""")
|
||||||
|
conn.commit()
|
||||||
|
conn.close()
|
||||||
|
return db_path
|
||||||
|
|
||||||
|
|
||||||
|
def test_diagnose_stream_emits_context_event(db_with_facts):
|
||||||
|
from app.services.diagnose import diagnose_stream
|
||||||
|
events = []
|
||||||
|
|
||||||
|
async def collect():
|
||||||
|
async for evt in diagnose_stream(
|
||||||
|
db_path=db_with_facts,
|
||||||
|
query="plex stopped",
|
||||||
|
):
|
||||||
|
events.append(evt)
|
||||||
|
|
||||||
|
asyncio.run(collect())
|
||||||
|
types = [e["type"] for e in events]
|
||||||
|
assert "context" in types
|
||||||
|
ctx_event = next(e for e in events if e["type"] == "context")
|
||||||
|
assert "facts" in ctx_event
|
||||||
|
assert any(f["key"] == "plex" for f in ctx_event["facts"])
|
||||||
Loading…
Reference in a new issue