feat: Stage 5 synthesizer + pipeline orchestrator + feature flag wiring (issue #29)
- Add app/services/diagnose/synthesizer.py: SummarySynthesizer (Stage 5)
- Builds structured LLM prompt from ranked hypotheses, timeline, RAG context
- Excludes suppressed hypotheses from the narrative prompt
- Deterministic fallback when no LLM configured or LLM call fails
- Same cf-orch task endpoint + direct OpenAI-compat fallback pattern as other stages
- Replace pipeline.py stub with full run_pipeline() async generator
- Orchestrates all 5 stages via asyncio.to_thread for each synchronous stage
- Yields typed SSE event dicts: status, pipeline_stage (1-4), hypotheses, reasoning, done
- Suppressor counts (active vs suppressed) reported in stage 4 event message
- Wire MULTI_AGENT_ENABLED feature flag into diagnose_stream()
- TURNSTONE_MULTI_AGENT_DIAGNOSE=true routes through run_pipeline()
- pipeline emits its own done event; legacy path unchanged when flag is false
- Import of run_pipeline added to __init__.py
- Add 21 new tests (350 -> 371 passing):
- tests/test_diagnose_synthesizer.py: 8 tests (with/without LLM, suppressed,
empty ranked, LLM failure fallback)
- tests/test_diagnose_pipeline.py: 13 tests (flag off, flag on event sequence,
empty entries, no LLM, stage 1 cluster count message)
Closes: #29
This commit is contained in:
parent
a4f97e5a79
commit
14bec5769b
5 changed files with 1120 additions and 6 deletions
|
|
@ -23,6 +23,7 @@ from typing import Any
|
|||
from app.context.retriever import retrieve_context, format_context_block
|
||||
from app.services.llm import summarize
|
||||
from app.services.search import SearchResult, entries_in_window, search
|
||||
from app.services.diagnose.pipeline import run_pipeline
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
|
@ -303,6 +304,21 @@ async def diagnose_stream(
|
|||
}
|
||||
yield {"type": "entries", "data": [dataclasses.asdict(r) for r in combined]}
|
||||
|
||||
if MULTI_AGENT_ENABLED:
|
||||
async for event in run_pipeline(
|
||||
db_path=db_path,
|
||||
entries=combined,
|
||||
ctx=ctx,
|
||||
query=query,
|
||||
since=since,
|
||||
until=until,
|
||||
llm_url=llm_url,
|
||||
llm_model=llm_model,
|
||||
llm_api_key=llm_api_key,
|
||||
):
|
||||
yield event
|
||||
return # pipeline emits its own "done" event
|
||||
|
||||
if llm_url and llm_model and combined:
|
||||
yield {"type": "status", "message": "Analyzing with LLM…"}
|
||||
reasoning = await asyncio.to_thread(
|
||||
|
|
|
|||
|
|
@ -1,18 +1,132 @@
|
|||
"""Multi-agent diagnose pipeline orchestrator — stub (Task 1)."""
|
||||
"""Multi-agent diagnose pipeline orchestrator — Stage 1–5 wiring."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import dataclasses
|
||||
import logging
|
||||
from collections.abc import AsyncGenerator
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
from app.context.retriever import RetrievedContext
|
||||
from app.services.diagnose.classifier import SeverityClassifier
|
||||
from app.services.diagnose.hypothesizer import RootCauseHypothesizer
|
||||
from app.services.diagnose.suppressor import FalsePositiveSuppressor
|
||||
from app.services.diagnose.synthesizer import SummarySynthesizer
|
||||
from app.services.diagnose.timeline import TimelineReconstructor
|
||||
from app.services.search import SearchResult
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# run_pipeline() will be implemented in Task 6
|
||||
logger.debug("TimelineReconstructor available: %s", TimelineReconstructor)
|
||||
|
||||
async def run_pipeline(
|
||||
db_path: Path,
|
||||
entries: list[SearchResult],
|
||||
ctx: RetrievedContext,
|
||||
query: str,
|
||||
since: str | None,
|
||||
until: str | None,
|
||||
llm_url: str | None,
|
||||
llm_model: str | None,
|
||||
llm_api_key: str | None,
|
||||
) -> AsyncGenerator[dict[str, Any], None]:
|
||||
"""Async generator that runs all 5 pipeline stages and yields SSE event dicts.
|
||||
|
||||
async def run_pipeline(*args: Any, **kwargs: Any) -> None:
|
||||
"""Placeholder — implemented in Task 6."""
|
||||
return None
|
||||
Stages:
|
||||
1. TimelineReconstructor — cluster log entries by time
|
||||
2. SeverityClassifier — annotate clusters with severity
|
||||
3. RootCauseHypothesizer — generate hypotheses via LLM
|
||||
4. FalsePositiveSuppressor — rank and suppress known patterns
|
||||
5. SummarySynthesizer — produce a narrative diagnosis
|
||||
|
||||
Yields events in order:
|
||||
{"type": "status", "message": "Building timeline…"}
|
||||
{"type": "pipeline_stage", "stage": 1, ...}
|
||||
{"type": "pipeline_stage", "stage": 2, ...}
|
||||
{"type": "pipeline_stage", "stage": 3, ...}
|
||||
{"type": "pipeline_stage", "stage": 4, ...}
|
||||
{"type": "hypotheses", "data": [...]}
|
||||
{"type": "status", "message": "Synthesizing…"}
|
||||
{"type": "reasoning", "text": "..."} — only when synthesis produces text
|
||||
{"type": "done"}
|
||||
"""
|
||||
# Stage 1: Timeline reconstruction
|
||||
yield {"type": "status", "message": "Building timeline…"}
|
||||
timeline = await asyncio.to_thread(
|
||||
TimelineReconstructor().reconstruct, entries
|
||||
)
|
||||
n_clusters = len(timeline.clusters)
|
||||
burst = timeline.burst_count
|
||||
yield {
|
||||
"type": "pipeline_stage",
|
||||
"stage": 1,
|
||||
"name": "timeline",
|
||||
"message": f"Built {n_clusters} clusters, {burst} bursts",
|
||||
}
|
||||
|
||||
# Stage 2: Severity classification
|
||||
classified = await asyncio.to_thread(
|
||||
SeverityClassifier().classify, timeline
|
||||
)
|
||||
sev_counts: dict[str, int] = {}
|
||||
for sev in classified.cluster_severities.values():
|
||||
sev_counts[sev] = sev_counts.get(sev, 0) + 1
|
||||
counts_str = ", ".join(f"{k}:{v}" for k, v in sorted(sev_counts.items()))
|
||||
yield {
|
||||
"type": "pipeline_stage",
|
||||
"stage": 2,
|
||||
"name": "classifier",
|
||||
"message": f"{classified.classifier_used} classifier: {counts_str}",
|
||||
}
|
||||
|
||||
# Stage 3: Root-cause hypotheses
|
||||
hypotheses = await asyncio.to_thread(
|
||||
RootCauseHypothesizer().hypothesize,
|
||||
classified,
|
||||
ctx,
|
||||
query,
|
||||
llm_url,
|
||||
llm_model,
|
||||
llm_api_key,
|
||||
)
|
||||
yield {
|
||||
"type": "pipeline_stage",
|
||||
"stage": 3,
|
||||
"name": "hypotheses",
|
||||
"message": f"{len(hypotheses)} hypotheses generated",
|
||||
}
|
||||
|
||||
# Stage 4: False-positive suppression
|
||||
ranked = await asyncio.to_thread(
|
||||
FalsePositiveSuppressor().suppress, hypotheses, db_path
|
||||
)
|
||||
suppressed = sum(1 for rh in ranked if rh.suppress)
|
||||
active = len(ranked) - suppressed
|
||||
yield {
|
||||
"type": "pipeline_stage",
|
||||
"stage": 4,
|
||||
"name": "suppressor",
|
||||
"message": f"{suppressed} suppressed, {active} active",
|
||||
}
|
||||
yield {
|
||||
"type": "hypotheses",
|
||||
"data": [dataclasses.asdict(rh) for rh in ranked],
|
||||
}
|
||||
|
||||
# Stage 5: Summary synthesis
|
||||
yield {"type": "status", "message": "Synthesizing…"}
|
||||
synthesis_text = await asyncio.to_thread(
|
||||
SummarySynthesizer().synthesize,
|
||||
ranked,
|
||||
timeline,
|
||||
ctx,
|
||||
query,
|
||||
llm_url,
|
||||
llm_model,
|
||||
llm_api_key,
|
||||
)
|
||||
if synthesis_text:
|
||||
yield {"type": "reasoning", "text": synthesis_text}
|
||||
|
||||
yield {"type": "done"}
|
||||
|
|
|
|||
210
app/services/diagnose/synthesizer.py
Normal file
210
app/services/diagnose/synthesizer.py
Normal file
|
|
@ -0,0 +1,210 @@
|
|||
"""Stage 5: Summary Synthesizer — deterministic narrative from ranked hypotheses.
|
||||
|
||||
Streaming upgrade (async SSE chunks) is tracked as a follow-up enhancement.
|
||||
This implementation is synchronous to match the rest of the pipeline.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
|
||||
import httpx
|
||||
|
||||
from app.context.retriever import RetrievedContext
|
||||
from app.services.diagnose.models import RankedHypothesis, TimelineResult
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_SYSTEM_PROMPT = (
|
||||
"You are a Linux sysadmin diagnosing a system incident. "
|
||||
"Write a concise, actionable incident diagnosis.\n\n"
|
||||
"Format your response exactly as:\n"
|
||||
"1. VERDICT: [CRITICAL|ERROR|WARN|INFO] — <what happened> (<X>% confidence)\n"
|
||||
"2. TIMELINE: <what the logs show in sequence, 2-3 sentences>\n"
|
||||
"3. ROOT CAUSES:\n"
|
||||
" - <hypothesis 1 title> (<confidence>%)\n"
|
||||
" - <hypothesis 2 title> (<confidence>%)\n"
|
||||
"4. RECOMMENDED ACTIONS:\n"
|
||||
" - <action based on hypotheses>\n"
|
||||
"5. INVESTIGATE FURTHER: <open questions, if any>"
|
||||
)
|
||||
|
||||
|
||||
def _extract_content(resp_json: dict) -> str | None:
|
||||
"""Pull text content from an OpenAI-compat chat completion response."""
|
||||
choices = resp_json.get("choices") or []
|
||||
if not choices:
|
||||
return None
|
||||
return (choices[0].get("message", {}).get("content") or "").strip() or None
|
||||
|
||||
|
||||
def _build_hypothesis_block(ranked: list[RankedHypothesis]) -> str:
|
||||
"""Build the hypothesis block for the prompt (non-suppressed only, top 3)."""
|
||||
active = [rh for rh in ranked if not rh.suppress][:3]
|
||||
if not active:
|
||||
return "(none)"
|
||||
lines: list[str] = []
|
||||
for rh in active:
|
||||
h = rh.hypothesis
|
||||
conf_pct = int(h.confidence * 100)
|
||||
similar = (
|
||||
f"Yes — suppressed, {rh.suppression_reason}"
|
||||
if rh.suppression_reason
|
||||
else "No"
|
||||
)
|
||||
novelty = f"{rh.novelty_score:.2f}"
|
||||
lines.append(
|
||||
f"- [{h.severity}, {conf_pct}%] {h.title}\n"
|
||||
f" Similar resolved incident? {similar} (novelty {novelty})"
|
||||
)
|
||||
return "\n".join(lines)
|
||||
|
||||
|
||||
def _build_context_block(ctx: RetrievedContext) -> str:
|
||||
"""Build the runbook context block for the prompt."""
|
||||
parts: list[str] = []
|
||||
for chunk in ctx.chunks[:5]:
|
||||
filename = chunk.get("filename", "unknown")
|
||||
text = chunk.get("text", "")[:300]
|
||||
parts.append(f"[{filename}] {text}")
|
||||
return "\n".join(parts) if parts else "(none)"
|
||||
|
||||
|
||||
def _deterministic_fallback(
|
||||
ranked: list[RankedHypothesis],
|
||||
timeline: TimelineResult,
|
||||
) -> str:
|
||||
"""Build a deterministic fallback text when no LLM is available."""
|
||||
active = [rh for rh in ranked if not rh.suppress][:3]
|
||||
if active:
|
||||
top = active[0]
|
||||
verdict_severity = top.hypothesis.severity
|
||||
verdict_title = top.hypothesis.title
|
||||
verdict_conf = int(top.hypothesis.confidence * 100)
|
||||
elif ranked:
|
||||
top = ranked[0]
|
||||
verdict_severity = top.hypothesis.severity
|
||||
verdict_title = top.hypothesis.title
|
||||
verdict_conf = int(top.hypothesis.confidence * 100)
|
||||
else:
|
||||
verdict_severity = "UNKNOWN"
|
||||
verdict_title = "No hypotheses generated"
|
||||
verdict_conf = 0
|
||||
|
||||
root_causes = ", ".join(
|
||||
rh.hypothesis.title for rh in (active or ranked[:3])
|
||||
) or "None"
|
||||
|
||||
return (
|
||||
f"VERDICT: {verdict_severity} — {verdict_title} ({verdict_conf}% confidence)\n"
|
||||
f"TIMELINE: {timeline.total_entries} entries across {len(timeline.clusters)} clusters.\n"
|
||||
f"ROOT CAUSES: {root_causes}"
|
||||
)
|
||||
|
||||
|
||||
class SummarySynthesizer:
|
||||
"""Stage 5 of the multi-agent diagnose pipeline.
|
||||
|
||||
Synthesizes a human-readable incident narrative from ranked hypotheses,
|
||||
the reconstructed timeline, and RAG context. When no LLM is configured,
|
||||
returns a deterministic fallback built from the hypothesis data.
|
||||
"""
|
||||
|
||||
def synthesize(
|
||||
self,
|
||||
ranked: list[RankedHypothesis],
|
||||
timeline: TimelineResult,
|
||||
ctx: RetrievedContext,
|
||||
query: str,
|
||||
llm_url: str | None = None,
|
||||
llm_model: str | None = None,
|
||||
llm_api_key: str | None = None,
|
||||
) -> str:
|
||||
"""Return synthesis text (single string, synchronous).
|
||||
|
||||
Falls back to a deterministic narrative when no LLM URL or model is
|
||||
provided, or when the LLM call fails.
|
||||
"""
|
||||
fallback = _deterministic_fallback(ranked, timeline)
|
||||
|
||||
if not llm_url or not llm_model:
|
||||
return fallback
|
||||
|
||||
hypothesis_block = _build_hypothesis_block(ranked)
|
||||
context_block = _build_context_block(ctx)
|
||||
dominant = ", ".join(timeline.dominant_sources[:5]) or "none"
|
||||
|
||||
user_message = (
|
||||
f"Query: {query}\n\n"
|
||||
f"Timeline summary:\n"
|
||||
f"- {len(timeline.clusters)} clusters, "
|
||||
f"{timeline.burst_count} bursts, "
|
||||
f"{timeline.gap_count} silence gaps\n"
|
||||
f"- Primary sources: {dominant}\n\n"
|
||||
f"Top hypotheses:\n{hypothesis_block}\n\n"
|
||||
f"Context from runbooks:\n{context_block}"
|
||||
)
|
||||
|
||||
messages = [
|
||||
{"role": "system", "content": _SYSTEM_PROMPT},
|
||||
{"role": "user", "content": user_message},
|
||||
]
|
||||
|
||||
result = self._call_llm(
|
||||
llm_url=llm_url,
|
||||
llm_model=llm_model,
|
||||
llm_api_key=llm_api_key,
|
||||
messages=messages,
|
||||
)
|
||||
return result if result else fallback
|
||||
|
||||
def _call_llm(
|
||||
self,
|
||||
llm_url: str,
|
||||
llm_model: str,
|
||||
llm_api_key: str | None,
|
||||
messages: list[dict],
|
||||
) -> str | None:
|
||||
"""Send messages to the LLM and return raw text content.
|
||||
|
||||
Tries the cf-orch task endpoint first, falls back to direct OpenAI-compat.
|
||||
"""
|
||||
headers = {"Authorization": f"Bearer {llm_api_key}"} if llm_api_key else {}
|
||||
|
||||
task_url = f"{llm_url.rstrip('/')}/api/inference/task"
|
||||
try:
|
||||
resp = httpx.post(
|
||||
task_url,
|
||||
json={
|
||||
"product": "turnstone",
|
||||
"task": "log_analysis",
|
||||
"payload": {"messages": messages, "stream": False},
|
||||
},
|
||||
headers=headers,
|
||||
timeout=120.0,
|
||||
)
|
||||
if resp.status_code == 200:
|
||||
return _extract_content(resp.json())
|
||||
if resp.status_code != 404:
|
||||
resp.raise_for_status()
|
||||
logger.debug(
|
||||
"No task assignment for turnstone.log_analysis — falling back to direct model"
|
||||
)
|
||||
except Exception as exc:
|
||||
logger.debug(
|
||||
"Task endpoint unavailable (%s) — falling back to direct model", exc
|
||||
)
|
||||
|
||||
try:
|
||||
resp = httpx.post(
|
||||
f"{llm_url.rstrip('/')}/v1/chat/completions",
|
||||
json={"model": llm_model, "messages": messages, "stream": False},
|
||||
headers=headers,
|
||||
timeout=120.0,
|
||||
)
|
||||
resp.raise_for_status()
|
||||
return _extract_content(resp.json())
|
||||
except Exception as exc:
|
||||
logger.warning(
|
||||
"LLM synthesizer failed (%s): %s", type(exc).__name__, exc
|
||||
)
|
||||
return None
|
||||
489
tests/test_diagnose_pipeline.py
Normal file
489
tests/test_diagnose_pipeline.py
Normal file
|
|
@ -0,0 +1,489 @@
|
|||
"""Tests for app/services/diagnose/pipeline.py and __init__.py feature flag wiring.
|
||||
|
||||
All tests use mocking; no real LLM, ML, or DB calls are made.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from app.context.retriever import RetrievedContext
|
||||
from app.services.diagnose.models import (
|
||||
ClassifiedTimeline,
|
||||
Hypothesis,
|
||||
RankedHypothesis,
|
||||
TimelineResult,
|
||||
)
|
||||
from app.services.search import SearchResult
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Shared helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def _make_search_result(
|
||||
entry_id: str = "e1",
|
||||
source_id: str = "syslog",
|
||||
timestamp_iso: str | None = "2026-01-01T00:00:00+00:00",
|
||||
severity: str | None = "ERROR",
|
||||
text: str = "ssh: invalid user",
|
||||
) -> SearchResult:
|
||||
return SearchResult(
|
||||
entry_id=entry_id,
|
||||
source_id=source_id,
|
||||
sequence=1,
|
||||
timestamp_iso=timestamp_iso,
|
||||
severity=severity,
|
||||
repeat_count=1,
|
||||
out_of_order=False,
|
||||
matched_patterns=["ssh_fail"],
|
||||
text=text,
|
||||
rank=1.0,
|
||||
)
|
||||
|
||||
|
||||
def _make_ctx() -> RetrievedContext:
|
||||
return RetrievedContext(facts=[], chunks=[])
|
||||
|
||||
|
||||
def _make_timeline(n_clusters: int = 2) -> TimelineResult:
|
||||
return TimelineResult(
|
||||
clusters=tuple(),
|
||||
total_entries=5,
|
||||
window_start="2026-01-01T00:00:00+00:00",
|
||||
window_end="2026-01-01T01:00:00+00:00",
|
||||
gap_count=0,
|
||||
burst_count=1,
|
||||
dominant_sources=("syslog",),
|
||||
)
|
||||
|
||||
|
||||
def _make_classified(timeline: TimelineResult | None = None) -> ClassifiedTimeline:
|
||||
tl = timeline or _make_timeline()
|
||||
return ClassifiedTimeline(
|
||||
timeline=tl,
|
||||
cluster_severities={},
|
||||
classifier_used="regex",
|
||||
model_id=None,
|
||||
)
|
||||
|
||||
|
||||
def _make_hypothesis(
|
||||
hypothesis_id: str = "h1",
|
||||
title: str = "SSH flood",
|
||||
confidence: float = 0.87,
|
||||
severity: str = "CRITICAL",
|
||||
) -> Hypothesis:
|
||||
return Hypothesis(
|
||||
hypothesis_id=hypothesis_id,
|
||||
title=title,
|
||||
description="Multiple failed SSH attempts.",
|
||||
confidence=confidence,
|
||||
supporting_cluster_ids=("c1",),
|
||||
runbook_refs=(),
|
||||
severity=severity, # type: ignore[arg-type]
|
||||
)
|
||||
|
||||
|
||||
def _make_ranked(hypothesis: Hypothesis | None = None, suppress: bool = False) -> RankedHypothesis:
|
||||
h = hypothesis or _make_hypothesis()
|
||||
return RankedHypothesis(
|
||||
hypothesis=h,
|
||||
novelty_score=0.95,
|
||||
similarity_to_known=0.05,
|
||||
suppress=suppress,
|
||||
suppression_reason="similar to known" if suppress else None,
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helper: collect all events from run_pipeline
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
async def _collect_pipeline_events(**kwargs) -> list[dict[str, Any]]:
|
||||
"""Run run_pipeline and collect all yielded events into a list."""
|
||||
from app.services.diagnose.pipeline import run_pipeline
|
||||
events = []
|
||||
async for event in run_pipeline(**kwargs):
|
||||
events.append(event)
|
||||
return events
|
||||
|
||||
|
||||
def _default_pipeline_kwargs(entries=None, db_path=None) -> dict:
|
||||
return dict(
|
||||
db_path=db_path or Path("/tmp/fake.db"),
|
||||
entries=entries or [_make_search_result()],
|
||||
ctx=_make_ctx(),
|
||||
query="ssh brute force",
|
||||
since="2026-01-01T00:00:00+00:00",
|
||||
until="2026-01-01T01:00:00+00:00",
|
||||
llm_url=None,
|
||||
llm_model=None,
|
||||
llm_api_key=None,
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Mock factories for all 5 stage classes
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def _mock_all_stages(
|
||||
hypotheses=None,
|
||||
ranked=None,
|
||||
synthesis_text="VERDICT: CRITICAL — SSH flood (87% confidence)",
|
||||
):
|
||||
"""Return a dict of patch targets and their mock return values."""
|
||||
timeline = _make_timeline()
|
||||
classified = _make_classified(timeline)
|
||||
hyps = hypotheses if hypotheses is not None else [_make_hypothesis()]
|
||||
rnk = ranked if ranked is not None else [_make_ranked()]
|
||||
|
||||
mock_reconstructor = MagicMock()
|
||||
mock_reconstructor.return_value.reconstruct.return_value = timeline
|
||||
|
||||
mock_classifier = MagicMock()
|
||||
mock_classifier.return_value.classify.return_value = classified
|
||||
|
||||
mock_hypothesizer = MagicMock()
|
||||
mock_hypothesizer.return_value.hypothesize.return_value = hyps
|
||||
|
||||
mock_suppressor = MagicMock()
|
||||
mock_suppressor.return_value.suppress.return_value = rnk
|
||||
|
||||
mock_synthesizer = MagicMock()
|
||||
mock_synthesizer.return_value.synthesize.return_value = synthesis_text
|
||||
|
||||
return {
|
||||
"app.services.diagnose.pipeline.TimelineReconstructor": mock_reconstructor,
|
||||
"app.services.diagnose.pipeline.SeverityClassifier": mock_classifier,
|
||||
"app.services.diagnose.pipeline.RootCauseHypothesizer": mock_hypothesizer,
|
||||
"app.services.diagnose.pipeline.FalsePositiveSuppressor": mock_suppressor,
|
||||
"app.services.diagnose.pipeline.SummarySynthesizer": mock_synthesizer,
|
||||
}
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 1. Feature flag off: legacy summarize() path runs, not run_pipeline
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestFeatureFlagOff:
|
||||
@pytest.mark.asyncio
|
||||
async def test_legacy_path_when_flag_off(self):
|
||||
"""With MULTI_AGENT_ENABLED=False, run_pipeline is never called."""
|
||||
from app.services import diagnose as diagnose_module
|
||||
|
||||
entries = [_make_search_result()]
|
||||
|
||||
with (
|
||||
patch.object(diagnose_module, "MULTI_AGENT_ENABLED", False),
|
||||
patch("app.services.diagnose.search", return_value=entries),
|
||||
patch("app.services.diagnose.entries_in_window", return_value=[]),
|
||||
patch("app.services.diagnose.retrieve_context", return_value=_make_ctx()),
|
||||
patch("app.services.diagnose.format_context_block", return_value=None),
|
||||
patch("app.services.diagnose.run_pipeline") as mock_pipeline,
|
||||
patch("app.services.diagnose.summarize", return_value=None),
|
||||
):
|
||||
events = []
|
||||
async for event in diagnose_module.diagnose_stream(
|
||||
db_path=Path("/tmp/fake.db"),
|
||||
query="ssh failures",
|
||||
llm_url=None,
|
||||
llm_model=None,
|
||||
):
|
||||
events.append(event)
|
||||
|
||||
# run_pipeline must NOT have been called
|
||||
mock_pipeline.assert_not_called()
|
||||
|
||||
# SSE sequence must end with done
|
||||
types = [e["type"] for e in events]
|
||||
assert "done" in types
|
||||
assert types[-1] == "done"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_legacy_done_event_is_last(self):
|
||||
"""Legacy path: done is always the last event."""
|
||||
from app.services import diagnose as diagnose_module
|
||||
|
||||
with (
|
||||
patch.object(diagnose_module, "MULTI_AGENT_ENABLED", False),
|
||||
patch("app.services.diagnose.search", return_value=[]),
|
||||
patch("app.services.diagnose.entries_in_window", return_value=[]),
|
||||
patch("app.services.diagnose.retrieve_context", return_value=_make_ctx()),
|
||||
patch("app.services.diagnose.format_context_block", return_value=None),
|
||||
):
|
||||
events = []
|
||||
async for event in diagnose_module.diagnose_stream(
|
||||
db_path=Path("/tmp/fake.db"),
|
||||
query="check logs",
|
||||
):
|
||||
events.append(event)
|
||||
|
||||
assert events[-1] == {"type": "done"}
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 2. Feature flag on, all stages mocked: verify SSE event sequence
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestFeatureFlagOn:
|
||||
@pytest.mark.asyncio
|
||||
async def test_pipeline_stage_events_in_order(self):
|
||||
"""pipeline_stage events must be emitted stages 1→2→3→4 in order."""
|
||||
mocks = _mock_all_stages()
|
||||
kwargs = _default_pipeline_kwargs()
|
||||
with (
|
||||
patch("app.services.diagnose.pipeline.TimelineReconstructor", mocks["app.services.diagnose.pipeline.TimelineReconstructor"]),
|
||||
patch("app.services.diagnose.pipeline.SeverityClassifier", mocks["app.services.diagnose.pipeline.SeverityClassifier"]),
|
||||
patch("app.services.diagnose.pipeline.RootCauseHypothesizer", mocks["app.services.diagnose.pipeline.RootCauseHypothesizer"]),
|
||||
patch("app.services.diagnose.pipeline.FalsePositiveSuppressor", mocks["app.services.diagnose.pipeline.FalsePositiveSuppressor"]),
|
||||
patch("app.services.diagnose.pipeline.SummarySynthesizer", mocks["app.services.diagnose.pipeline.SummarySynthesizer"]),
|
||||
):
|
||||
events = await _collect_pipeline_events(**kwargs)
|
||||
|
||||
stage_events = [e for e in events if e.get("type") == "pipeline_stage"]
|
||||
stages = [e["stage"] for e in stage_events]
|
||||
assert stages == [1, 2, 3, 4]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_hypotheses_event_after_stage4(self):
|
||||
"""hypotheses event must appear after pipeline_stage stage=4."""
|
||||
mocks = _mock_all_stages()
|
||||
kwargs = _default_pipeline_kwargs()
|
||||
|
||||
with (
|
||||
patch("app.services.diagnose.pipeline.TimelineReconstructor", mocks["app.services.diagnose.pipeline.TimelineReconstructor"]),
|
||||
patch("app.services.diagnose.pipeline.SeverityClassifier", mocks["app.services.diagnose.pipeline.SeverityClassifier"]),
|
||||
patch("app.services.diagnose.pipeline.RootCauseHypothesizer", mocks["app.services.diagnose.pipeline.RootCauseHypothesizer"]),
|
||||
patch("app.services.diagnose.pipeline.FalsePositiveSuppressor", mocks["app.services.diagnose.pipeline.FalsePositiveSuppressor"]),
|
||||
patch("app.services.diagnose.pipeline.SummarySynthesizer", mocks["app.services.diagnose.pipeline.SummarySynthesizer"]),
|
||||
):
|
||||
events = await _collect_pipeline_events(**kwargs)
|
||||
|
||||
stage4_idx = next(
|
||||
i for i, e in enumerate(events)
|
||||
if e.get("type") == "pipeline_stage" and e.get("stage") == 4
|
||||
)
|
||||
hyp_idx = next(i for i, e in enumerate(events) if e.get("type") == "hypotheses")
|
||||
assert hyp_idx > stage4_idx
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_reasoning_event_emitted(self):
|
||||
"""reasoning event must be present when synthesizer returns text."""
|
||||
mocks = _mock_all_stages(synthesis_text="VERDICT: CRITICAL — SSH flood")
|
||||
kwargs = _default_pipeline_kwargs()
|
||||
|
||||
with (
|
||||
patch("app.services.diagnose.pipeline.TimelineReconstructor", mocks["app.services.diagnose.pipeline.TimelineReconstructor"]),
|
||||
patch("app.services.diagnose.pipeline.SeverityClassifier", mocks["app.services.diagnose.pipeline.SeverityClassifier"]),
|
||||
patch("app.services.diagnose.pipeline.RootCauseHypothesizer", mocks["app.services.diagnose.pipeline.RootCauseHypothesizer"]),
|
||||
patch("app.services.diagnose.pipeline.FalsePositiveSuppressor", mocks["app.services.diagnose.pipeline.FalsePositiveSuppressor"]),
|
||||
patch("app.services.diagnose.pipeline.SummarySynthesizer", mocks["app.services.diagnose.pipeline.SummarySynthesizer"]),
|
||||
):
|
||||
events = await _collect_pipeline_events(**kwargs)
|
||||
|
||||
reasoning_events = [e for e in events if e.get("type") == "reasoning"]
|
||||
assert len(reasoning_events) == 1
|
||||
assert "VERDICT" in reasoning_events[0]["text"]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_done_event_is_last(self):
|
||||
"""done must always be the last event in the pipeline sequence."""
|
||||
mocks = _mock_all_stages()
|
||||
kwargs = _default_pipeline_kwargs()
|
||||
|
||||
with (
|
||||
patch("app.services.diagnose.pipeline.TimelineReconstructor", mocks["app.services.diagnose.pipeline.TimelineReconstructor"]),
|
||||
patch("app.services.diagnose.pipeline.SeverityClassifier", mocks["app.services.diagnose.pipeline.SeverityClassifier"]),
|
||||
patch("app.services.diagnose.pipeline.RootCauseHypothesizer", mocks["app.services.diagnose.pipeline.RootCauseHypothesizer"]),
|
||||
patch("app.services.diagnose.pipeline.FalsePositiveSuppressor", mocks["app.services.diagnose.pipeline.FalsePositiveSuppressor"]),
|
||||
patch("app.services.diagnose.pipeline.SummarySynthesizer", mocks["app.services.diagnose.pipeline.SummarySynthesizer"]),
|
||||
):
|
||||
events = await _collect_pipeline_events(**kwargs)
|
||||
|
||||
assert events[-1] == {"type": "done"}
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_pipeline_wired_from_diagnose_stream(self):
|
||||
"""diagnose_stream routes through run_pipeline when flag is on."""
|
||||
from app.services import diagnose as diagnose_module
|
||||
|
||||
entries = [_make_search_result()]
|
||||
|
||||
async def fake_pipeline(**kwargs):
|
||||
yield {"type": "status", "message": "Building timeline…"}
|
||||
yield {"type": "pipeline_stage", "stage": 1, "name": "timeline", "message": "Built 1 clusters, 0 bursts"}
|
||||
yield {"type": "done"}
|
||||
|
||||
with (
|
||||
patch.object(diagnose_module, "MULTI_AGENT_ENABLED", True),
|
||||
patch("app.services.diagnose.search", return_value=entries),
|
||||
patch("app.services.diagnose.entries_in_window", return_value=[]),
|
||||
patch("app.services.diagnose.retrieve_context", return_value=_make_ctx()),
|
||||
patch("app.services.diagnose.format_context_block", return_value=None),
|
||||
patch("app.services.diagnose.run_pipeline", side_effect=fake_pipeline),
|
||||
):
|
||||
events = []
|
||||
async for event in diagnose_module.diagnose_stream(
|
||||
db_path=Path("/tmp/fake.db"),
|
||||
query="ssh failures",
|
||||
):
|
||||
events.append(event)
|
||||
|
||||
types = [e["type"] for e in events]
|
||||
assert "pipeline_stage" in types
|
||||
assert types[-1] == "done"
|
||||
# Legacy summarize() must NOT have been called — done event came from pipeline
|
||||
assert types.count("done") == 1
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 3. Empty entries: pipeline completes with done
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestEmptyEntries:
|
||||
@pytest.mark.asyncio
|
||||
async def test_empty_entries_pipeline_completes(self):
|
||||
"""Pipeline with entries=[] must still complete and emit done."""
|
||||
mocks = _mock_all_stages(hypotheses=[], ranked=[])
|
||||
kwargs = _default_pipeline_kwargs(entries=[])
|
||||
|
||||
with (
|
||||
patch("app.services.diagnose.pipeline.TimelineReconstructor", mocks["app.services.diagnose.pipeline.TimelineReconstructor"]),
|
||||
patch("app.services.diagnose.pipeline.SeverityClassifier", mocks["app.services.diagnose.pipeline.SeverityClassifier"]),
|
||||
patch("app.services.diagnose.pipeline.RootCauseHypothesizer", mocks["app.services.diagnose.pipeline.RootCauseHypothesizer"]),
|
||||
patch("app.services.diagnose.pipeline.FalsePositiveSuppressor", mocks["app.services.diagnose.pipeline.FalsePositiveSuppressor"]),
|
||||
patch("app.services.diagnose.pipeline.SummarySynthesizer", mocks["app.services.diagnose.pipeline.SummarySynthesizer"]),
|
||||
):
|
||||
events = await _collect_pipeline_events(**kwargs)
|
||||
|
||||
types = [e["type"] for e in events]
|
||||
assert "done" in types
|
||||
assert types[-1] == "done"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_empty_entries_all_stage_events_present(self):
|
||||
"""Even with empty entries, all 4 pipeline_stage events are emitted."""
|
||||
mocks = _mock_all_stages(hypotheses=[], ranked=[])
|
||||
kwargs = _default_pipeline_kwargs(entries=[])
|
||||
|
||||
with (
|
||||
patch("app.services.diagnose.pipeline.TimelineReconstructor", mocks["app.services.diagnose.pipeline.TimelineReconstructor"]),
|
||||
patch("app.services.diagnose.pipeline.SeverityClassifier", mocks["app.services.diagnose.pipeline.SeverityClassifier"]),
|
||||
patch("app.services.diagnose.pipeline.RootCauseHypothesizer", mocks["app.services.diagnose.pipeline.RootCauseHypothesizer"]),
|
||||
patch("app.services.diagnose.pipeline.FalsePositiveSuppressor", mocks["app.services.diagnose.pipeline.FalsePositiveSuppressor"]),
|
||||
patch("app.services.diagnose.pipeline.SummarySynthesizer", mocks["app.services.diagnose.pipeline.SummarySynthesizer"]),
|
||||
):
|
||||
events = await _collect_pipeline_events(**kwargs)
|
||||
|
||||
stage_events = [e for e in events if e.get("type") == "pipeline_stage"]
|
||||
assert len(stage_events) == 4
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 4. No LLM: Stage 3 and Stage 5 return empty/fallback; done still emitted
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestNoLLM:
|
||||
@pytest.mark.asyncio
|
||||
async def test_no_llm_pipeline_completes_with_done(self):
|
||||
"""No llm_url/llm_model: pipeline runs all stages and emits done."""
|
||||
mocks = _mock_all_stages(hypotheses=[], ranked=[], synthesis_text="VERDICT: UNKNOWN — no hypotheses generated")
|
||||
kwargs = _default_pipeline_kwargs()
|
||||
# llm_url and llm_model already None in default kwargs
|
||||
|
||||
with (
|
||||
patch("app.services.diagnose.pipeline.TimelineReconstructor", mocks["app.services.diagnose.pipeline.TimelineReconstructor"]),
|
||||
patch("app.services.diagnose.pipeline.SeverityClassifier", mocks["app.services.diagnose.pipeline.SeverityClassifier"]),
|
||||
patch("app.services.diagnose.pipeline.RootCauseHypothesizer", mocks["app.services.diagnose.pipeline.RootCauseHypothesizer"]),
|
||||
patch("app.services.diagnose.pipeline.FalsePositiveSuppressor", mocks["app.services.diagnose.pipeline.FalsePositiveSuppressor"]),
|
||||
patch("app.services.diagnose.pipeline.SummarySynthesizer", mocks["app.services.diagnose.pipeline.SummarySynthesizer"]),
|
||||
):
|
||||
events = await _collect_pipeline_events(**kwargs)
|
||||
|
||||
assert events[-1] == {"type": "done"}
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_no_llm_no_reasoning_event_when_synthesis_empty(self):
|
||||
"""When synthesizer returns empty string, no reasoning event is emitted."""
|
||||
mocks = _mock_all_stages(synthesis_text="")
|
||||
kwargs = _default_pipeline_kwargs()
|
||||
|
||||
with (
|
||||
patch("app.services.diagnose.pipeline.TimelineReconstructor", mocks["app.services.diagnose.pipeline.TimelineReconstructor"]),
|
||||
patch("app.services.diagnose.pipeline.SeverityClassifier", mocks["app.services.diagnose.pipeline.SeverityClassifier"]),
|
||||
patch("app.services.diagnose.pipeline.RootCauseHypothesizer", mocks["app.services.diagnose.pipeline.RootCauseHypothesizer"]),
|
||||
patch("app.services.diagnose.pipeline.FalsePositiveSuppressor", mocks["app.services.diagnose.pipeline.FalsePositiveSuppressor"]),
|
||||
patch("app.services.diagnose.pipeline.SummarySynthesizer", mocks["app.services.diagnose.pipeline.SummarySynthesizer"]),
|
||||
):
|
||||
events = await _collect_pipeline_events(**kwargs)
|
||||
|
||||
reasoning_events = [e for e in events if e.get("type") == "reasoning"]
|
||||
assert len(reasoning_events) == 0
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 5. Stage 1 cluster count in pipeline_stage message
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestStage1Message:
|
||||
@pytest.mark.asyncio
|
||||
async def test_stage1_message_contains_cluster_count(self):
|
||||
"""pipeline_stage stage=1 message must report cluster count."""
|
||||
timeline = TimelineResult(
|
||||
clusters=tuple(),
|
||||
total_entries=10,
|
||||
window_start=None,
|
||||
window_end=None,
|
||||
gap_count=0,
|
||||
burst_count=3,
|
||||
dominant_sources=("syslog",),
|
||||
)
|
||||
classified = _make_classified(timeline)
|
||||
|
||||
mock_reconstructor = MagicMock()
|
||||
mock_reconstructor.return_value.reconstruct.return_value = timeline
|
||||
mock_classifier = MagicMock()
|
||||
mock_classifier.return_value.classify.return_value = classified
|
||||
mock_hypothesizer = MagicMock()
|
||||
mock_hypothesizer.return_value.hypothesize.return_value = []
|
||||
mock_suppressor = MagicMock()
|
||||
mock_suppressor.return_value.suppress.return_value = []
|
||||
mock_synthesizer = MagicMock()
|
||||
mock_synthesizer.return_value.synthesize.return_value = "VERDICT: INFO — nothing found"
|
||||
|
||||
kwargs = _default_pipeline_kwargs()
|
||||
with (
|
||||
patch("app.services.diagnose.pipeline.TimelineReconstructor", mock_reconstructor),
|
||||
patch("app.services.diagnose.pipeline.SeverityClassifier", mock_classifier),
|
||||
patch("app.services.diagnose.pipeline.RootCauseHypothesizer", mock_hypothesizer),
|
||||
patch("app.services.diagnose.pipeline.FalsePositiveSuppressor", mock_suppressor),
|
||||
patch("app.services.diagnose.pipeline.SummarySynthesizer", mock_synthesizer),
|
||||
):
|
||||
events = await _collect_pipeline_events(**kwargs)
|
||||
|
||||
stage1 = next(e for e in events if e.get("type") == "pipeline_stage" and e.get("stage") == 1)
|
||||
# 0 clusters (empty tuple), 3 bursts
|
||||
assert "0" in stage1["message"] # cluster count
|
||||
assert "3" in stage1["message"] # burst count
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_stage1_name_is_timeline(self):
|
||||
"""pipeline_stage stage=1 must have name='timeline'."""
|
||||
mocks = _mock_all_stages()
|
||||
kwargs = _default_pipeline_kwargs()
|
||||
|
||||
with (
|
||||
patch("app.services.diagnose.pipeline.TimelineReconstructor", mocks["app.services.diagnose.pipeline.TimelineReconstructor"]),
|
||||
patch("app.services.diagnose.pipeline.SeverityClassifier", mocks["app.services.diagnose.pipeline.SeverityClassifier"]),
|
||||
patch("app.services.diagnose.pipeline.RootCauseHypothesizer", mocks["app.services.diagnose.pipeline.RootCauseHypothesizer"]),
|
||||
patch("app.services.diagnose.pipeline.FalsePositiveSuppressor", mocks["app.services.diagnose.pipeline.FalsePositiveSuppressor"]),
|
||||
patch("app.services.diagnose.pipeline.SummarySynthesizer", mocks["app.services.diagnose.pipeline.SummarySynthesizer"]),
|
||||
):
|
||||
events = await _collect_pipeline_events(**kwargs)
|
||||
|
||||
stage1 = next(e for e in events if e.get("type") == "pipeline_stage" and e.get("stage") == 1)
|
||||
assert stage1["name"] == "timeline"
|
||||
285
tests/test_diagnose_synthesizer.py
Normal file
285
tests/test_diagnose_synthesizer.py
Normal file
|
|
@ -0,0 +1,285 @@
|
|||
"""Tests for app/services/diagnose/synthesizer.py — SummarySynthesizer.
|
||||
|
||||
All tests use mocking; no real LLM calls are made.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
from app.context.retriever import RetrievedContext
|
||||
from app.services.diagnose.models import Hypothesis, RankedHypothesis, TimelineResult
|
||||
from app.services.diagnose.synthesizer import SummarySynthesizer
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Fixture helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def _make_hypothesis(
|
||||
hypothesis_id: str = "h1",
|
||||
title: str = "SSH flood from external IPs",
|
||||
description: str = "Repeated failed login attempts from multiple IPs.",
|
||||
confidence: float = 0.87,
|
||||
severity: str = "CRITICAL",
|
||||
) -> Hypothesis:
|
||||
return Hypothesis(
|
||||
hypothesis_id=hypothesis_id,
|
||||
title=title,
|
||||
description=description,
|
||||
confidence=confidence,
|
||||
supporting_cluster_ids=("c1",),
|
||||
runbook_refs=(),
|
||||
severity=severity, # type: ignore[arg-type]
|
||||
)
|
||||
|
||||
|
||||
def _make_ranked(
|
||||
hypothesis: Hypothesis | None = None,
|
||||
novelty_score: float = 0.95,
|
||||
similarity_to_known: float = 0.05,
|
||||
suppress: bool = False,
|
||||
suppression_reason: str | None = None,
|
||||
) -> RankedHypothesis:
|
||||
h = hypothesis or _make_hypothesis()
|
||||
return RankedHypothesis(
|
||||
hypothesis=h,
|
||||
novelty_score=novelty_score,
|
||||
similarity_to_known=similarity_to_known,
|
||||
suppress=suppress,
|
||||
suppression_reason=suppression_reason,
|
||||
)
|
||||
|
||||
|
||||
def _make_timeline(
|
||||
total_entries: int = 42,
|
||||
n_clusters: int = 3,
|
||||
) -> TimelineResult:
|
||||
return TimelineResult(
|
||||
clusters=tuple(),
|
||||
total_entries=total_entries,
|
||||
window_start="2026-01-01T00:00:00+00:00",
|
||||
window_end="2026-01-01T01:00:00+00:00",
|
||||
gap_count=1,
|
||||
burst_count=2,
|
||||
dominant_sources=("syslog", "auth"),
|
||||
)
|
||||
|
||||
|
||||
def _make_ctx(chunks: list[dict] | None = None) -> RetrievedContext:
|
||||
return RetrievedContext(
|
||||
facts=[{"category": "network", "key": "host", "value": "heimdall", "source": "facts"}],
|
||||
chunks=chunks or [{"filename": "runbook.md", "text": "Restart sshd if flooded"}],
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Test cases
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestSynthesizerWithHypotheses:
|
||||
"""With hypotheses, result must contain VERDICT."""
|
||||
|
||||
def test_returns_verdict_string_with_llm(self):
|
||||
synthesizer = SummarySynthesizer()
|
||||
ranked = [_make_ranked()]
|
||||
timeline = _make_timeline()
|
||||
ctx = _make_ctx()
|
||||
|
||||
mock_resp = MagicMock()
|
||||
mock_resp.status_code = 200
|
||||
mock_resp.json.return_value = {
|
||||
"choices": [{"message": {"content": "VERDICT: CRITICAL — SSH flood (87% confidence)\nTIMELINE: lots of hits."}}]
|
||||
}
|
||||
|
||||
with patch("httpx.post", return_value=mock_resp):
|
||||
result = synthesizer.synthesize(
|
||||
ranked=ranked,
|
||||
timeline=timeline,
|
||||
ctx=ctx,
|
||||
query="ssh brute force",
|
||||
llm_url="http://localhost:11434",
|
||||
llm_model="llama3",
|
||||
)
|
||||
|
||||
assert "VERDICT" in result
|
||||
|
||||
def test_returns_nonempty_string(self):
|
||||
synthesizer = SummarySynthesizer()
|
||||
ranked = [_make_ranked()]
|
||||
timeline = _make_timeline()
|
||||
ctx = _make_ctx()
|
||||
|
||||
mock_resp = MagicMock()
|
||||
mock_resp.status_code = 200
|
||||
mock_resp.json.return_value = {
|
||||
"choices": [{"message": {"content": "VERDICT: CRITICAL — SSH flood (87% confidence)"}}]
|
||||
}
|
||||
|
||||
with patch("httpx.post", return_value=mock_resp):
|
||||
result = synthesizer.synthesize(
|
||||
ranked=ranked,
|
||||
timeline=timeline,
|
||||
ctx=ctx,
|
||||
query="why is auth failing",
|
||||
llm_url="http://localhost:11434",
|
||||
llm_model="llama3",
|
||||
)
|
||||
|
||||
assert isinstance(result, str)
|
||||
assert len(result) > 0
|
||||
|
||||
|
||||
class TestSynthesizerSuppressedHypotheses:
|
||||
"""Suppressed hypotheses must be excluded from the LLM prompt."""
|
||||
|
||||
def test_suppressed_hypotheses_excluded_from_prompt(self):
|
||||
suppressed = _make_ranked(
|
||||
hypothesis=_make_hypothesis(
|
||||
hypothesis_id="h2",
|
||||
title="Wazuh alert processing backlog",
|
||||
severity="ERROR",
|
||||
confidence=0.72,
|
||||
),
|
||||
suppress=True,
|
||||
suppression_reason="similar to 2025-04 SSH incident",
|
||||
novelty_score=0.1,
|
||||
)
|
||||
active = _make_ranked(
|
||||
hypothesis=_make_hypothesis(
|
||||
hypothesis_id="h1",
|
||||
title="SSH flood from external IPs",
|
||||
severity="CRITICAL",
|
||||
confidence=0.87,
|
||||
),
|
||||
suppress=False,
|
||||
novelty_score=0.95,
|
||||
)
|
||||
|
||||
captured_messages: list = []
|
||||
|
||||
def fake_post(url, json=None, headers=None, timeout=None):
|
||||
if json and "payload" in json:
|
||||
captured_messages.extend(json["payload"].get("messages", []))
|
||||
elif json and "messages" in json:
|
||||
captured_messages.extend(json.get("messages", []))
|
||||
mock_resp = MagicMock()
|
||||
mock_resp.status_code = 200
|
||||
mock_resp.json.return_value = {
|
||||
"choices": [{"message": {"content": "VERDICT: CRITICAL — SSH flood"}}]
|
||||
}
|
||||
return mock_resp
|
||||
|
||||
synthesizer = SummarySynthesizer()
|
||||
with patch("httpx.post", side_effect=fake_post):
|
||||
synthesizer.synthesize(
|
||||
ranked=[active, suppressed],
|
||||
timeline=_make_timeline(),
|
||||
ctx=_make_ctx(),
|
||||
query="auth failures",
|
||||
llm_url="http://localhost:11434",
|
||||
llm_model="llama3",
|
||||
)
|
||||
|
||||
# The user message should contain the active hypothesis title
|
||||
# and NOT contain the suppressed one (or mark it suppressed)
|
||||
user_content = next(
|
||||
(m["content"] for m in captured_messages if m.get("role") == "user"), ""
|
||||
)
|
||||
assert "SSH flood from external IPs" in user_content
|
||||
# Wazuh should not appear as a standalone top-level hypothesis
|
||||
# (suppressed items are excluded from the active list sent to the LLM)
|
||||
assert "Wazuh alert processing backlog" not in user_content
|
||||
|
||||
|
||||
class TestSynthesizerNoLLM:
|
||||
"""No LLM configured: must return deterministic fallback (not empty)."""
|
||||
|
||||
def test_no_llm_url_returns_fallback(self):
|
||||
synthesizer = SummarySynthesizer()
|
||||
ranked = [_make_ranked()]
|
||||
timeline = _make_timeline()
|
||||
ctx = _make_ctx()
|
||||
|
||||
result = synthesizer.synthesize(
|
||||
ranked=ranked,
|
||||
timeline=timeline,
|
||||
ctx=ctx,
|
||||
query="disk errors",
|
||||
)
|
||||
|
||||
assert isinstance(result, str)
|
||||
assert len(result) > 0
|
||||
assert "VERDICT" in result
|
||||
|
||||
def test_no_llm_model_returns_fallback(self):
|
||||
synthesizer = SummarySynthesizer()
|
||||
ranked = [_make_ranked()]
|
||||
|
||||
result = synthesizer.synthesize(
|
||||
ranked=ranked,
|
||||
timeline=_make_timeline(),
|
||||
ctx=_make_ctx(),
|
||||
query="oom killer",
|
||||
llm_url="http://localhost:11434",
|
||||
# llm_model omitted
|
||||
)
|
||||
|
||||
assert "VERDICT" in result
|
||||
assert "SSH flood from external IPs" in result
|
||||
|
||||
def test_llm_failure_returns_fallback(self):
|
||||
synthesizer = SummarySynthesizer()
|
||||
ranked = [_make_ranked()]
|
||||
|
||||
with patch("httpx.post", side_effect=ConnectionError("refused")):
|
||||
result = synthesizer.synthesize(
|
||||
ranked=ranked,
|
||||
timeline=_make_timeline(),
|
||||
ctx=_make_ctx(),
|
||||
query="why is disk full",
|
||||
llm_url="http://localhost:11434",
|
||||
llm_model="llama3",
|
||||
)
|
||||
|
||||
assert "VERDICT" in result
|
||||
assert len(result) > 0
|
||||
|
||||
|
||||
class TestSynthesizerEmptyRanked:
|
||||
"""Empty ranked list: must return deterministic fallback text, not raise."""
|
||||
|
||||
def test_empty_ranked_no_llm_returns_fallback(self):
|
||||
synthesizer = SummarySynthesizer()
|
||||
result = synthesizer.synthesize(
|
||||
ranked=[],
|
||||
timeline=_make_timeline(),
|
||||
ctx=_make_ctx(),
|
||||
query="check everything",
|
||||
)
|
||||
|
||||
assert isinstance(result, str)
|
||||
assert len(result) > 0
|
||||
assert "VERDICT" in result
|
||||
|
||||
def test_empty_ranked_with_llm_returns_fallback_or_llm_text(self):
|
||||
"""Even with empty ranked, we attempt LLM and return something."""
|
||||
synthesizer = SummarySynthesizer()
|
||||
|
||||
mock_resp = MagicMock()
|
||||
mock_resp.status_code = 200
|
||||
mock_resp.json.return_value = {
|
||||
"choices": [{"message": {"content": "VERDICT: UNKNOWN — no hypotheses generated"}}]
|
||||
}
|
||||
|
||||
with patch("httpx.post", return_value=mock_resp):
|
||||
result = synthesizer.synthesize(
|
||||
ranked=[],
|
||||
timeline=_make_timeline(),
|
||||
ctx=_make_ctx(),
|
||||
query="nothing found",
|
||||
llm_url="http://localhost:11434",
|
||||
llm_model="llama3",
|
||||
)
|
||||
|
||||
assert isinstance(result, str)
|
||||
assert len(result) > 0
|
||||
Loading…
Reference in a new issue