- Add context_block param to summarize() and thread it into _PROMPT_TEMPLATE - Wire retrieve_context/format_context_block into diagnose_stream() before log search; emit context SSE event (facts + chunks) to the client - 3 new tests covering prompt injection and SSE event emission (155 total, all pass)
118 lines
4.2 KiB
Python
118 lines
4.2 KiB
Python
"""Verify context SSE event and LLM prompt injection."""
|
|
import asyncio
|
|
import sqlite3
|
|
from pathlib import Path
|
|
from unittest.mock import patch
|
|
import pytest
|
|
from app.services.llm import summarize
|
|
from app.services.search import SearchResult
|
|
|
|
|
|
def _entry(text: str) -> SearchResult:
|
|
return SearchResult(
|
|
entry_id="x", source_id="svc", sequence=0,
|
|
timestamp_iso="2026-05-13T00:00:00+00:00",
|
|
severity="ERROR", text=text,
|
|
matched_patterns=[], repeat_count=1, out_of_order=False, rank=0.0,
|
|
)
|
|
|
|
|
|
def test_summarize_includes_context_block():
|
|
captured = {}
|
|
|
|
def fake_post(url, json=None, headers=None, timeout=None):
|
|
captured["json"] = json
|
|
raise ConnectionError("offline")
|
|
|
|
with patch("app.services.llm.httpx.post", side_effect=fake_post):
|
|
summarize(
|
|
"plex stopped",
|
|
[_entry("plex error")],
|
|
llm_url="http://localhost:11434",
|
|
llm_model="llama3",
|
|
context_block="Known environment facts:\n [service] plex: port:32400",
|
|
)
|
|
|
|
messages = captured.get("json", {}).get("messages", [])
|
|
content = " ".join(m.get("content", "") for m in messages)
|
|
assert "Known environment facts" in content
|
|
assert "plex: port:32400" in content
|
|
|
|
|
|
def test_summarize_without_context_block_unchanged():
|
|
"""When context_block is None the prompt must not contain the context header."""
|
|
captured = {}
|
|
|
|
def fake_post(url, json=None, headers=None, timeout=None):
|
|
captured["json"] = json
|
|
raise ConnectionError("offline")
|
|
|
|
with patch("app.services.llm.httpx.post", side_effect=fake_post):
|
|
summarize(
|
|
"plex stopped",
|
|
[_entry("plex error")],
|
|
llm_url="http://localhost:11434",
|
|
llm_model="llama3",
|
|
context_block=None,
|
|
)
|
|
|
|
messages = captured.get("json", {}).get("messages", [])
|
|
content = " ".join(m.get("content", "") for m in messages)
|
|
assert "Known environment facts" not in content
|
|
|
|
|
|
@pytest.fixture
|
|
def db_with_facts(tmp_path):
|
|
db_path = tmp_path / "t.db"
|
|
conn = sqlite3.connect(str(db_path))
|
|
conn.executescript("""
|
|
CREATE TABLE log_entries (
|
|
id TEXT PRIMARY KEY, source_id TEXT NOT NULL, sequence INTEGER NOT NULL,
|
|
timestamp_raw TEXT, timestamp_iso TEXT, ingest_time TEXT NOT NULL,
|
|
severity TEXT, repeat_count INTEGER DEFAULT 1, out_of_order INTEGER DEFAULT 0,
|
|
matched_patterns TEXT DEFAULT '[]', text TEXT NOT NULL
|
|
);
|
|
CREATE VIRTUAL TABLE IF NOT EXISTS log_fts USING fts5(
|
|
text, entry_id UNINDEXED, source_id UNINDEXED, sequence UNINDEXED,
|
|
severity UNINDEXED, timestamp_iso UNINDEXED, matched_patterns UNINDEXED,
|
|
repeat_count UNINDEXED, out_of_order UNINDEXED, tokenize='porter ascii'
|
|
);
|
|
CREATE TABLE context_facts (
|
|
id TEXT PRIMARY KEY, category TEXT NOT NULL, key TEXT NOT NULL,
|
|
value TEXT NOT NULL, source TEXT, created_at TEXT NOT NULL
|
|
);
|
|
CREATE TABLE context_documents (
|
|
id TEXT PRIMARY KEY, filename TEXT NOT NULL, doc_type TEXT NOT NULL,
|
|
full_text TEXT NOT NULL, file_size INTEGER, uploaded_at TEXT NOT NULL
|
|
);
|
|
CREATE TABLE context_chunks (
|
|
id TEXT PRIMARY KEY, document_id TEXT NOT NULL
|
|
REFERENCES context_documents(id) ON DELETE CASCADE,
|
|
chunk_index INTEGER NOT NULL, text TEXT NOT NULL, embedding BLOB
|
|
);
|
|
INSERT INTO context_facts VALUES (
|
|
'f1','service','plex','port:32400','wizard','2026-05-13T00:00:00+00:00'
|
|
);
|
|
""")
|
|
conn.commit()
|
|
conn.close()
|
|
return db_path
|
|
|
|
|
|
def test_diagnose_stream_emits_context_event(db_with_facts):
|
|
from app.services.diagnose import diagnose_stream
|
|
events = []
|
|
|
|
async def collect():
|
|
async for evt in diagnose_stream(
|
|
db_path=db_with_facts,
|
|
query="plex stopped",
|
|
):
|
|
events.append(evt)
|
|
|
|
asyncio.run(collect())
|
|
types = [e["type"] for e in events]
|
|
assert "context" in types
|
|
ctx_event = next(e for e in events if e["type"] == "context")
|
|
assert "facts" in ctx_event
|
|
assert any(f["key"] == "plex" for f in ctx_event["facts"])
|