feat: Stage 5 synthesizer + pipeline orchestrator + feature flag wiring (issue #29)

- Add app/services/diagnose/synthesizer.py: SummarySynthesizer (Stage 5)
  - Builds structured LLM prompt from ranked hypotheses, timeline, RAG context
  - Excludes suppressed hypotheses from the narrative prompt
  - Deterministic fallback when no LLM configured or LLM call fails
  - Same cf-orch task endpoint + direct OpenAI-compat fallback pattern as other stages

- Replace pipeline.py stub with full run_pipeline() async generator
  - Orchestrates all 5 stages via asyncio.to_thread for each synchronous stage
  - Yields typed SSE event dicts: status, pipeline_stage (1-4), hypotheses, reasoning, done
  - Suppressor counts (active vs suppressed) reported in stage 4 event message

- Wire MULTI_AGENT_ENABLED feature flag into diagnose_stream()
  - TURNSTONE_MULTI_AGENT_DIAGNOSE=true routes through run_pipeline()
  - pipeline emits its own done event; legacy path unchanged when flag is false
  - Import of run_pipeline added to __init__.py

- Add 21 new tests (350 -> 371 passing):
  - tests/test_diagnose_synthesizer.py: 8 tests (with/without LLM, suppressed,
    empty ranked, LLM failure fallback)
  - tests/test_diagnose_pipeline.py: 13 tests (flag off, flag on event sequence,
    empty entries, no LLM, stage 1 cluster count message)

Closes: #29
This commit is contained in:
pyr0ball 2026-05-25 14:56:25 -07:00
parent a4f97e5a79
commit 14bec5769b
5 changed files with 1120 additions and 6 deletions

View file

@ -23,6 +23,7 @@ from typing import Any
from app.context.retriever import retrieve_context, format_context_block
from app.services.llm import summarize
from app.services.search import SearchResult, entries_in_window, search
from app.services.diagnose.pipeline import run_pipeline
logger = logging.getLogger(__name__)
@ -303,6 +304,21 @@ async def diagnose_stream(
}
yield {"type": "entries", "data": [dataclasses.asdict(r) for r in combined]}
if MULTI_AGENT_ENABLED:
async for event in run_pipeline(
db_path=db_path,
entries=combined,
ctx=ctx,
query=query,
since=since,
until=until,
llm_url=llm_url,
llm_model=llm_model,
llm_api_key=llm_api_key,
):
yield event
return # pipeline emits its own "done" event
if llm_url and llm_model and combined:
yield {"type": "status", "message": "Analyzing with LLM…"}
reasoning = await asyncio.to_thread(

View file

@ -1,18 +1,132 @@
"""Multi-agent diagnose pipeline orchestrator — stub (Task 1)."""
"""Multi-agent diagnose pipeline orchestrator — Stage 15 wiring."""
from __future__ import annotations
import asyncio
import dataclasses
import logging
from collections.abc import AsyncGenerator
from pathlib import Path
from typing import Any
from app.context.retriever import RetrievedContext
from app.services.diagnose.classifier import SeverityClassifier
from app.services.diagnose.hypothesizer import RootCauseHypothesizer
from app.services.diagnose.suppressor import FalsePositiveSuppressor
from app.services.diagnose.synthesizer import SummarySynthesizer
from app.services.diagnose.timeline import TimelineReconstructor
from app.services.search import SearchResult
logger = logging.getLogger(__name__)
# run_pipeline() will be implemented in Task 6
logger.debug("TimelineReconstructor available: %s", TimelineReconstructor)
async def run_pipeline(
db_path: Path,
entries: list[SearchResult],
ctx: RetrievedContext,
query: str,
since: str | None,
until: str | None,
llm_url: str | None,
llm_model: str | None,
llm_api_key: str | None,
) -> AsyncGenerator[dict[str, Any], None]:
"""Async generator that runs all 5 pipeline stages and yields SSE event dicts.
async def run_pipeline(*args: Any, **kwargs: Any) -> None:
"""Placeholder — implemented in Task 6."""
return None
Stages:
1. TimelineReconstructor cluster log entries by time
2. SeverityClassifier annotate clusters with severity
3. RootCauseHypothesizer generate hypotheses via LLM
4. FalsePositiveSuppressor rank and suppress known patterns
5. SummarySynthesizer produce a narrative diagnosis
Yields events in order:
{"type": "status", "message": "Building timeline…"}
{"type": "pipeline_stage", "stage": 1, ...}
{"type": "pipeline_stage", "stage": 2, ...}
{"type": "pipeline_stage", "stage": 3, ...}
{"type": "pipeline_stage", "stage": 4, ...}
{"type": "hypotheses", "data": [...]}
{"type": "status", "message": "Synthesizing…"}
{"type": "reasoning", "text": "..."} only when synthesis produces text
{"type": "done"}
"""
# Stage 1: Timeline reconstruction
yield {"type": "status", "message": "Building timeline…"}
timeline = await asyncio.to_thread(
TimelineReconstructor().reconstruct, entries
)
n_clusters = len(timeline.clusters)
burst = timeline.burst_count
yield {
"type": "pipeline_stage",
"stage": 1,
"name": "timeline",
"message": f"Built {n_clusters} clusters, {burst} bursts",
}
# Stage 2: Severity classification
classified = await asyncio.to_thread(
SeverityClassifier().classify, timeline
)
sev_counts: dict[str, int] = {}
for sev in classified.cluster_severities.values():
sev_counts[sev] = sev_counts.get(sev, 0) + 1
counts_str = ", ".join(f"{k}:{v}" for k, v in sorted(sev_counts.items()))
yield {
"type": "pipeline_stage",
"stage": 2,
"name": "classifier",
"message": f"{classified.classifier_used} classifier: {counts_str}",
}
# Stage 3: Root-cause hypotheses
hypotheses = await asyncio.to_thread(
RootCauseHypothesizer().hypothesize,
classified,
ctx,
query,
llm_url,
llm_model,
llm_api_key,
)
yield {
"type": "pipeline_stage",
"stage": 3,
"name": "hypotheses",
"message": f"{len(hypotheses)} hypotheses generated",
}
# Stage 4: False-positive suppression
ranked = await asyncio.to_thread(
FalsePositiveSuppressor().suppress, hypotheses, db_path
)
suppressed = sum(1 for rh in ranked if rh.suppress)
active = len(ranked) - suppressed
yield {
"type": "pipeline_stage",
"stage": 4,
"name": "suppressor",
"message": f"{suppressed} suppressed, {active} active",
}
yield {
"type": "hypotheses",
"data": [dataclasses.asdict(rh) for rh in ranked],
}
# Stage 5: Summary synthesis
yield {"type": "status", "message": "Synthesizing…"}
synthesis_text = await asyncio.to_thread(
SummarySynthesizer().synthesize,
ranked,
timeline,
ctx,
query,
llm_url,
llm_model,
llm_api_key,
)
if synthesis_text:
yield {"type": "reasoning", "text": synthesis_text}
yield {"type": "done"}

View file

@ -0,0 +1,210 @@
"""Stage 5: Summary Synthesizer — deterministic narrative from ranked hypotheses.
Streaming upgrade (async SSE chunks) is tracked as a follow-up enhancement.
This implementation is synchronous to match the rest of the pipeline.
"""
from __future__ import annotations
import logging
import httpx
from app.context.retriever import RetrievedContext
from app.services.diagnose.models import RankedHypothesis, TimelineResult
logger = logging.getLogger(__name__)
_SYSTEM_PROMPT = (
"You are a Linux sysadmin diagnosing a system incident. "
"Write a concise, actionable incident diagnosis.\n\n"
"Format your response exactly as:\n"
"1. VERDICT: [CRITICAL|ERROR|WARN|INFO] — <what happened> (<X>% confidence)\n"
"2. TIMELINE: <what the logs show in sequence, 2-3 sentences>\n"
"3. ROOT CAUSES:\n"
" - <hypothesis 1 title> (<confidence>%)\n"
" - <hypothesis 2 title> (<confidence>%)\n"
"4. RECOMMENDED ACTIONS:\n"
" - <action based on hypotheses>\n"
"5. INVESTIGATE FURTHER: <open questions, if any>"
)
def _extract_content(resp_json: dict) -> str | None:
"""Pull text content from an OpenAI-compat chat completion response."""
choices = resp_json.get("choices") or []
if not choices:
return None
return (choices[0].get("message", {}).get("content") or "").strip() or None
def _build_hypothesis_block(ranked: list[RankedHypothesis]) -> str:
"""Build the hypothesis block for the prompt (non-suppressed only, top 3)."""
active = [rh for rh in ranked if not rh.suppress][:3]
if not active:
return "(none)"
lines: list[str] = []
for rh in active:
h = rh.hypothesis
conf_pct = int(h.confidence * 100)
similar = (
f"Yes — suppressed, {rh.suppression_reason}"
if rh.suppression_reason
else "No"
)
novelty = f"{rh.novelty_score:.2f}"
lines.append(
f"- [{h.severity}, {conf_pct}%] {h.title}\n"
f" Similar resolved incident? {similar} (novelty {novelty})"
)
return "\n".join(lines)
def _build_context_block(ctx: RetrievedContext) -> str:
"""Build the runbook context block for the prompt."""
parts: list[str] = []
for chunk in ctx.chunks[:5]:
filename = chunk.get("filename", "unknown")
text = chunk.get("text", "")[:300]
parts.append(f"[{filename}] {text}")
return "\n".join(parts) if parts else "(none)"
def _deterministic_fallback(
ranked: list[RankedHypothesis],
timeline: TimelineResult,
) -> str:
"""Build a deterministic fallback text when no LLM is available."""
active = [rh for rh in ranked if not rh.suppress][:3]
if active:
top = active[0]
verdict_severity = top.hypothesis.severity
verdict_title = top.hypothesis.title
verdict_conf = int(top.hypothesis.confidence * 100)
elif ranked:
top = ranked[0]
verdict_severity = top.hypothesis.severity
verdict_title = top.hypothesis.title
verdict_conf = int(top.hypothesis.confidence * 100)
else:
verdict_severity = "UNKNOWN"
verdict_title = "No hypotheses generated"
verdict_conf = 0
root_causes = ", ".join(
rh.hypothesis.title for rh in (active or ranked[:3])
) or "None"
return (
f"VERDICT: {verdict_severity}{verdict_title} ({verdict_conf}% confidence)\n"
f"TIMELINE: {timeline.total_entries} entries across {len(timeline.clusters)} clusters.\n"
f"ROOT CAUSES: {root_causes}"
)
class SummarySynthesizer:
"""Stage 5 of the multi-agent diagnose pipeline.
Synthesizes a human-readable incident narrative from ranked hypotheses,
the reconstructed timeline, and RAG context. When no LLM is configured,
returns a deterministic fallback built from the hypothesis data.
"""
def synthesize(
self,
ranked: list[RankedHypothesis],
timeline: TimelineResult,
ctx: RetrievedContext,
query: str,
llm_url: str | None = None,
llm_model: str | None = None,
llm_api_key: str | None = None,
) -> str:
"""Return synthesis text (single string, synchronous).
Falls back to a deterministic narrative when no LLM URL or model is
provided, or when the LLM call fails.
"""
fallback = _deterministic_fallback(ranked, timeline)
if not llm_url or not llm_model:
return fallback
hypothesis_block = _build_hypothesis_block(ranked)
context_block = _build_context_block(ctx)
dominant = ", ".join(timeline.dominant_sources[:5]) or "none"
user_message = (
f"Query: {query}\n\n"
f"Timeline summary:\n"
f"- {len(timeline.clusters)} clusters, "
f"{timeline.burst_count} bursts, "
f"{timeline.gap_count} silence gaps\n"
f"- Primary sources: {dominant}\n\n"
f"Top hypotheses:\n{hypothesis_block}\n\n"
f"Context from runbooks:\n{context_block}"
)
messages = [
{"role": "system", "content": _SYSTEM_PROMPT},
{"role": "user", "content": user_message},
]
result = self._call_llm(
llm_url=llm_url,
llm_model=llm_model,
llm_api_key=llm_api_key,
messages=messages,
)
return result if result else fallback
def _call_llm(
self,
llm_url: str,
llm_model: str,
llm_api_key: str | None,
messages: list[dict],
) -> str | None:
"""Send messages to the LLM and return raw text content.
Tries the cf-orch task endpoint first, falls back to direct OpenAI-compat.
"""
headers = {"Authorization": f"Bearer {llm_api_key}"} if llm_api_key else {}
task_url = f"{llm_url.rstrip('/')}/api/inference/task"
try:
resp = httpx.post(
task_url,
json={
"product": "turnstone",
"task": "log_analysis",
"payload": {"messages": messages, "stream": False},
},
headers=headers,
timeout=120.0,
)
if resp.status_code == 200:
return _extract_content(resp.json())
if resp.status_code != 404:
resp.raise_for_status()
logger.debug(
"No task assignment for turnstone.log_analysis — falling back to direct model"
)
except Exception as exc:
logger.debug(
"Task endpoint unavailable (%s) — falling back to direct model", exc
)
try:
resp = httpx.post(
f"{llm_url.rstrip('/')}/v1/chat/completions",
json={"model": llm_model, "messages": messages, "stream": False},
headers=headers,
timeout=120.0,
)
resp.raise_for_status()
return _extract_content(resp.json())
except Exception as exc:
logger.warning(
"LLM synthesizer failed (%s): %s", type(exc).__name__, exc
)
return None

View file

@ -0,0 +1,489 @@
"""Tests for app/services/diagnose/pipeline.py and __init__.py feature flag wiring.
All tests use mocking; no real LLM, ML, or DB calls are made.
"""
from __future__ import annotations
from pathlib import Path
from typing import Any
from unittest.mock import MagicMock, patch
import pytest
from app.context.retriever import RetrievedContext
from app.services.diagnose.models import (
ClassifiedTimeline,
Hypothesis,
RankedHypothesis,
TimelineResult,
)
from app.services.search import SearchResult
# ---------------------------------------------------------------------------
# Shared helpers
# ---------------------------------------------------------------------------
def _make_search_result(
entry_id: str = "e1",
source_id: str = "syslog",
timestamp_iso: str | None = "2026-01-01T00:00:00+00:00",
severity: str | None = "ERROR",
text: str = "ssh: invalid user",
) -> SearchResult:
return SearchResult(
entry_id=entry_id,
source_id=source_id,
sequence=1,
timestamp_iso=timestamp_iso,
severity=severity,
repeat_count=1,
out_of_order=False,
matched_patterns=["ssh_fail"],
text=text,
rank=1.0,
)
def _make_ctx() -> RetrievedContext:
return RetrievedContext(facts=[], chunks=[])
def _make_timeline(n_clusters: int = 2) -> TimelineResult:
return TimelineResult(
clusters=tuple(),
total_entries=5,
window_start="2026-01-01T00:00:00+00:00",
window_end="2026-01-01T01:00:00+00:00",
gap_count=0,
burst_count=1,
dominant_sources=("syslog",),
)
def _make_classified(timeline: TimelineResult | None = None) -> ClassifiedTimeline:
tl = timeline or _make_timeline()
return ClassifiedTimeline(
timeline=tl,
cluster_severities={},
classifier_used="regex",
model_id=None,
)
def _make_hypothesis(
hypothesis_id: str = "h1",
title: str = "SSH flood",
confidence: float = 0.87,
severity: str = "CRITICAL",
) -> Hypothesis:
return Hypothesis(
hypothesis_id=hypothesis_id,
title=title,
description="Multiple failed SSH attempts.",
confidence=confidence,
supporting_cluster_ids=("c1",),
runbook_refs=(),
severity=severity, # type: ignore[arg-type]
)
def _make_ranked(hypothesis: Hypothesis | None = None, suppress: bool = False) -> RankedHypothesis:
h = hypothesis or _make_hypothesis()
return RankedHypothesis(
hypothesis=h,
novelty_score=0.95,
similarity_to_known=0.05,
suppress=suppress,
suppression_reason="similar to known" if suppress else None,
)
# ---------------------------------------------------------------------------
# Helper: collect all events from run_pipeline
# ---------------------------------------------------------------------------
async def _collect_pipeline_events(**kwargs) -> list[dict[str, Any]]:
"""Run run_pipeline and collect all yielded events into a list."""
from app.services.diagnose.pipeline import run_pipeline
events = []
async for event in run_pipeline(**kwargs):
events.append(event)
return events
def _default_pipeline_kwargs(entries=None, db_path=None) -> dict:
return dict(
db_path=db_path or Path("/tmp/fake.db"),
entries=entries or [_make_search_result()],
ctx=_make_ctx(),
query="ssh brute force",
since="2026-01-01T00:00:00+00:00",
until="2026-01-01T01:00:00+00:00",
llm_url=None,
llm_model=None,
llm_api_key=None,
)
# ---------------------------------------------------------------------------
# Mock factories for all 5 stage classes
# ---------------------------------------------------------------------------
def _mock_all_stages(
hypotheses=None,
ranked=None,
synthesis_text="VERDICT: CRITICAL — SSH flood (87% confidence)",
):
"""Return a dict of patch targets and their mock return values."""
timeline = _make_timeline()
classified = _make_classified(timeline)
hyps = hypotheses if hypotheses is not None else [_make_hypothesis()]
rnk = ranked if ranked is not None else [_make_ranked()]
mock_reconstructor = MagicMock()
mock_reconstructor.return_value.reconstruct.return_value = timeline
mock_classifier = MagicMock()
mock_classifier.return_value.classify.return_value = classified
mock_hypothesizer = MagicMock()
mock_hypothesizer.return_value.hypothesize.return_value = hyps
mock_suppressor = MagicMock()
mock_suppressor.return_value.suppress.return_value = rnk
mock_synthesizer = MagicMock()
mock_synthesizer.return_value.synthesize.return_value = synthesis_text
return {
"app.services.diagnose.pipeline.TimelineReconstructor": mock_reconstructor,
"app.services.diagnose.pipeline.SeverityClassifier": mock_classifier,
"app.services.diagnose.pipeline.RootCauseHypothesizer": mock_hypothesizer,
"app.services.diagnose.pipeline.FalsePositiveSuppressor": mock_suppressor,
"app.services.diagnose.pipeline.SummarySynthesizer": mock_synthesizer,
}
# ---------------------------------------------------------------------------
# 1. Feature flag off: legacy summarize() path runs, not run_pipeline
# ---------------------------------------------------------------------------
class TestFeatureFlagOff:
@pytest.mark.asyncio
async def test_legacy_path_when_flag_off(self):
"""With MULTI_AGENT_ENABLED=False, run_pipeline is never called."""
from app.services import diagnose as diagnose_module
entries = [_make_search_result()]
with (
patch.object(diagnose_module, "MULTI_AGENT_ENABLED", False),
patch("app.services.diagnose.search", return_value=entries),
patch("app.services.diagnose.entries_in_window", return_value=[]),
patch("app.services.diagnose.retrieve_context", return_value=_make_ctx()),
patch("app.services.diagnose.format_context_block", return_value=None),
patch("app.services.diagnose.run_pipeline") as mock_pipeline,
patch("app.services.diagnose.summarize", return_value=None),
):
events = []
async for event in diagnose_module.diagnose_stream(
db_path=Path("/tmp/fake.db"),
query="ssh failures",
llm_url=None,
llm_model=None,
):
events.append(event)
# run_pipeline must NOT have been called
mock_pipeline.assert_not_called()
# SSE sequence must end with done
types = [e["type"] for e in events]
assert "done" in types
assert types[-1] == "done"
@pytest.mark.asyncio
async def test_legacy_done_event_is_last(self):
"""Legacy path: done is always the last event."""
from app.services import diagnose as diagnose_module
with (
patch.object(diagnose_module, "MULTI_AGENT_ENABLED", False),
patch("app.services.diagnose.search", return_value=[]),
patch("app.services.diagnose.entries_in_window", return_value=[]),
patch("app.services.diagnose.retrieve_context", return_value=_make_ctx()),
patch("app.services.diagnose.format_context_block", return_value=None),
):
events = []
async for event in diagnose_module.diagnose_stream(
db_path=Path("/tmp/fake.db"),
query="check logs",
):
events.append(event)
assert events[-1] == {"type": "done"}
# ---------------------------------------------------------------------------
# 2. Feature flag on, all stages mocked: verify SSE event sequence
# ---------------------------------------------------------------------------
class TestFeatureFlagOn:
@pytest.mark.asyncio
async def test_pipeline_stage_events_in_order(self):
"""pipeline_stage events must be emitted stages 1→2→3→4 in order."""
mocks = _mock_all_stages()
kwargs = _default_pipeline_kwargs()
with (
patch("app.services.diagnose.pipeline.TimelineReconstructor", mocks["app.services.diagnose.pipeline.TimelineReconstructor"]),
patch("app.services.diagnose.pipeline.SeverityClassifier", mocks["app.services.diagnose.pipeline.SeverityClassifier"]),
patch("app.services.diagnose.pipeline.RootCauseHypothesizer", mocks["app.services.diagnose.pipeline.RootCauseHypothesizer"]),
patch("app.services.diagnose.pipeline.FalsePositiveSuppressor", mocks["app.services.diagnose.pipeline.FalsePositiveSuppressor"]),
patch("app.services.diagnose.pipeline.SummarySynthesizer", mocks["app.services.diagnose.pipeline.SummarySynthesizer"]),
):
events = await _collect_pipeline_events(**kwargs)
stage_events = [e for e in events if e.get("type") == "pipeline_stage"]
stages = [e["stage"] for e in stage_events]
assert stages == [1, 2, 3, 4]
@pytest.mark.asyncio
async def test_hypotheses_event_after_stage4(self):
"""hypotheses event must appear after pipeline_stage stage=4."""
mocks = _mock_all_stages()
kwargs = _default_pipeline_kwargs()
with (
patch("app.services.diagnose.pipeline.TimelineReconstructor", mocks["app.services.diagnose.pipeline.TimelineReconstructor"]),
patch("app.services.diagnose.pipeline.SeverityClassifier", mocks["app.services.diagnose.pipeline.SeverityClassifier"]),
patch("app.services.diagnose.pipeline.RootCauseHypothesizer", mocks["app.services.diagnose.pipeline.RootCauseHypothesizer"]),
patch("app.services.diagnose.pipeline.FalsePositiveSuppressor", mocks["app.services.diagnose.pipeline.FalsePositiveSuppressor"]),
patch("app.services.diagnose.pipeline.SummarySynthesizer", mocks["app.services.diagnose.pipeline.SummarySynthesizer"]),
):
events = await _collect_pipeline_events(**kwargs)
stage4_idx = next(
i for i, e in enumerate(events)
if e.get("type") == "pipeline_stage" and e.get("stage") == 4
)
hyp_idx = next(i for i, e in enumerate(events) if e.get("type") == "hypotheses")
assert hyp_idx > stage4_idx
@pytest.mark.asyncio
async def test_reasoning_event_emitted(self):
"""reasoning event must be present when synthesizer returns text."""
mocks = _mock_all_stages(synthesis_text="VERDICT: CRITICAL — SSH flood")
kwargs = _default_pipeline_kwargs()
with (
patch("app.services.diagnose.pipeline.TimelineReconstructor", mocks["app.services.diagnose.pipeline.TimelineReconstructor"]),
patch("app.services.diagnose.pipeline.SeverityClassifier", mocks["app.services.diagnose.pipeline.SeverityClassifier"]),
patch("app.services.diagnose.pipeline.RootCauseHypothesizer", mocks["app.services.diagnose.pipeline.RootCauseHypothesizer"]),
patch("app.services.diagnose.pipeline.FalsePositiveSuppressor", mocks["app.services.diagnose.pipeline.FalsePositiveSuppressor"]),
patch("app.services.diagnose.pipeline.SummarySynthesizer", mocks["app.services.diagnose.pipeline.SummarySynthesizer"]),
):
events = await _collect_pipeline_events(**kwargs)
reasoning_events = [e for e in events if e.get("type") == "reasoning"]
assert len(reasoning_events) == 1
assert "VERDICT" in reasoning_events[0]["text"]
@pytest.mark.asyncio
async def test_done_event_is_last(self):
"""done must always be the last event in the pipeline sequence."""
mocks = _mock_all_stages()
kwargs = _default_pipeline_kwargs()
with (
patch("app.services.diagnose.pipeline.TimelineReconstructor", mocks["app.services.diagnose.pipeline.TimelineReconstructor"]),
patch("app.services.diagnose.pipeline.SeverityClassifier", mocks["app.services.diagnose.pipeline.SeverityClassifier"]),
patch("app.services.diagnose.pipeline.RootCauseHypothesizer", mocks["app.services.diagnose.pipeline.RootCauseHypothesizer"]),
patch("app.services.diagnose.pipeline.FalsePositiveSuppressor", mocks["app.services.diagnose.pipeline.FalsePositiveSuppressor"]),
patch("app.services.diagnose.pipeline.SummarySynthesizer", mocks["app.services.diagnose.pipeline.SummarySynthesizer"]),
):
events = await _collect_pipeline_events(**kwargs)
assert events[-1] == {"type": "done"}
@pytest.mark.asyncio
async def test_pipeline_wired_from_diagnose_stream(self):
"""diagnose_stream routes through run_pipeline when flag is on."""
from app.services import diagnose as diagnose_module
entries = [_make_search_result()]
async def fake_pipeline(**kwargs):
yield {"type": "status", "message": "Building timeline…"}
yield {"type": "pipeline_stage", "stage": 1, "name": "timeline", "message": "Built 1 clusters, 0 bursts"}
yield {"type": "done"}
with (
patch.object(diagnose_module, "MULTI_AGENT_ENABLED", True),
patch("app.services.diagnose.search", return_value=entries),
patch("app.services.diagnose.entries_in_window", return_value=[]),
patch("app.services.diagnose.retrieve_context", return_value=_make_ctx()),
patch("app.services.diagnose.format_context_block", return_value=None),
patch("app.services.diagnose.run_pipeline", side_effect=fake_pipeline),
):
events = []
async for event in diagnose_module.diagnose_stream(
db_path=Path("/tmp/fake.db"),
query="ssh failures",
):
events.append(event)
types = [e["type"] for e in events]
assert "pipeline_stage" in types
assert types[-1] == "done"
# Legacy summarize() must NOT have been called — done event came from pipeline
assert types.count("done") == 1
# ---------------------------------------------------------------------------
# 3. Empty entries: pipeline completes with done
# ---------------------------------------------------------------------------
class TestEmptyEntries:
@pytest.mark.asyncio
async def test_empty_entries_pipeline_completes(self):
"""Pipeline with entries=[] must still complete and emit done."""
mocks = _mock_all_stages(hypotheses=[], ranked=[])
kwargs = _default_pipeline_kwargs(entries=[])
with (
patch("app.services.diagnose.pipeline.TimelineReconstructor", mocks["app.services.diagnose.pipeline.TimelineReconstructor"]),
patch("app.services.diagnose.pipeline.SeverityClassifier", mocks["app.services.diagnose.pipeline.SeverityClassifier"]),
patch("app.services.diagnose.pipeline.RootCauseHypothesizer", mocks["app.services.diagnose.pipeline.RootCauseHypothesizer"]),
patch("app.services.diagnose.pipeline.FalsePositiveSuppressor", mocks["app.services.diagnose.pipeline.FalsePositiveSuppressor"]),
patch("app.services.diagnose.pipeline.SummarySynthesizer", mocks["app.services.diagnose.pipeline.SummarySynthesizer"]),
):
events = await _collect_pipeline_events(**kwargs)
types = [e["type"] for e in events]
assert "done" in types
assert types[-1] == "done"
@pytest.mark.asyncio
async def test_empty_entries_all_stage_events_present(self):
"""Even with empty entries, all 4 pipeline_stage events are emitted."""
mocks = _mock_all_stages(hypotheses=[], ranked=[])
kwargs = _default_pipeline_kwargs(entries=[])
with (
patch("app.services.diagnose.pipeline.TimelineReconstructor", mocks["app.services.diagnose.pipeline.TimelineReconstructor"]),
patch("app.services.diagnose.pipeline.SeverityClassifier", mocks["app.services.diagnose.pipeline.SeverityClassifier"]),
patch("app.services.diagnose.pipeline.RootCauseHypothesizer", mocks["app.services.diagnose.pipeline.RootCauseHypothesizer"]),
patch("app.services.diagnose.pipeline.FalsePositiveSuppressor", mocks["app.services.diagnose.pipeline.FalsePositiveSuppressor"]),
patch("app.services.diagnose.pipeline.SummarySynthesizer", mocks["app.services.diagnose.pipeline.SummarySynthesizer"]),
):
events = await _collect_pipeline_events(**kwargs)
stage_events = [e for e in events if e.get("type") == "pipeline_stage"]
assert len(stage_events) == 4
# ---------------------------------------------------------------------------
# 4. No LLM: Stage 3 and Stage 5 return empty/fallback; done still emitted
# ---------------------------------------------------------------------------
class TestNoLLM:
@pytest.mark.asyncio
async def test_no_llm_pipeline_completes_with_done(self):
"""No llm_url/llm_model: pipeline runs all stages and emits done."""
mocks = _mock_all_stages(hypotheses=[], ranked=[], synthesis_text="VERDICT: UNKNOWN — no hypotheses generated")
kwargs = _default_pipeline_kwargs()
# llm_url and llm_model already None in default kwargs
with (
patch("app.services.diagnose.pipeline.TimelineReconstructor", mocks["app.services.diagnose.pipeline.TimelineReconstructor"]),
patch("app.services.diagnose.pipeline.SeverityClassifier", mocks["app.services.diagnose.pipeline.SeverityClassifier"]),
patch("app.services.diagnose.pipeline.RootCauseHypothesizer", mocks["app.services.diagnose.pipeline.RootCauseHypothesizer"]),
patch("app.services.diagnose.pipeline.FalsePositiveSuppressor", mocks["app.services.diagnose.pipeline.FalsePositiveSuppressor"]),
patch("app.services.diagnose.pipeline.SummarySynthesizer", mocks["app.services.diagnose.pipeline.SummarySynthesizer"]),
):
events = await _collect_pipeline_events(**kwargs)
assert events[-1] == {"type": "done"}
@pytest.mark.asyncio
async def test_no_llm_no_reasoning_event_when_synthesis_empty(self):
"""When synthesizer returns empty string, no reasoning event is emitted."""
mocks = _mock_all_stages(synthesis_text="")
kwargs = _default_pipeline_kwargs()
with (
patch("app.services.diagnose.pipeline.TimelineReconstructor", mocks["app.services.diagnose.pipeline.TimelineReconstructor"]),
patch("app.services.diagnose.pipeline.SeverityClassifier", mocks["app.services.diagnose.pipeline.SeverityClassifier"]),
patch("app.services.diagnose.pipeline.RootCauseHypothesizer", mocks["app.services.diagnose.pipeline.RootCauseHypothesizer"]),
patch("app.services.diagnose.pipeline.FalsePositiveSuppressor", mocks["app.services.diagnose.pipeline.FalsePositiveSuppressor"]),
patch("app.services.diagnose.pipeline.SummarySynthesizer", mocks["app.services.diagnose.pipeline.SummarySynthesizer"]),
):
events = await _collect_pipeline_events(**kwargs)
reasoning_events = [e for e in events if e.get("type") == "reasoning"]
assert len(reasoning_events) == 0
# ---------------------------------------------------------------------------
# 5. Stage 1 cluster count in pipeline_stage message
# ---------------------------------------------------------------------------
class TestStage1Message:
@pytest.mark.asyncio
async def test_stage1_message_contains_cluster_count(self):
"""pipeline_stage stage=1 message must report cluster count."""
timeline = TimelineResult(
clusters=tuple(),
total_entries=10,
window_start=None,
window_end=None,
gap_count=0,
burst_count=3,
dominant_sources=("syslog",),
)
classified = _make_classified(timeline)
mock_reconstructor = MagicMock()
mock_reconstructor.return_value.reconstruct.return_value = timeline
mock_classifier = MagicMock()
mock_classifier.return_value.classify.return_value = classified
mock_hypothesizer = MagicMock()
mock_hypothesizer.return_value.hypothesize.return_value = []
mock_suppressor = MagicMock()
mock_suppressor.return_value.suppress.return_value = []
mock_synthesizer = MagicMock()
mock_synthesizer.return_value.synthesize.return_value = "VERDICT: INFO — nothing found"
kwargs = _default_pipeline_kwargs()
with (
patch("app.services.diagnose.pipeline.TimelineReconstructor", mock_reconstructor),
patch("app.services.diagnose.pipeline.SeverityClassifier", mock_classifier),
patch("app.services.diagnose.pipeline.RootCauseHypothesizer", mock_hypothesizer),
patch("app.services.diagnose.pipeline.FalsePositiveSuppressor", mock_suppressor),
patch("app.services.diagnose.pipeline.SummarySynthesizer", mock_synthesizer),
):
events = await _collect_pipeline_events(**kwargs)
stage1 = next(e for e in events if e.get("type") == "pipeline_stage" and e.get("stage") == 1)
# 0 clusters (empty tuple), 3 bursts
assert "0" in stage1["message"] # cluster count
assert "3" in stage1["message"] # burst count
@pytest.mark.asyncio
async def test_stage1_name_is_timeline(self):
"""pipeline_stage stage=1 must have name='timeline'."""
mocks = _mock_all_stages()
kwargs = _default_pipeline_kwargs()
with (
patch("app.services.diagnose.pipeline.TimelineReconstructor", mocks["app.services.diagnose.pipeline.TimelineReconstructor"]),
patch("app.services.diagnose.pipeline.SeverityClassifier", mocks["app.services.diagnose.pipeline.SeverityClassifier"]),
patch("app.services.diagnose.pipeline.RootCauseHypothesizer", mocks["app.services.diagnose.pipeline.RootCauseHypothesizer"]),
patch("app.services.diagnose.pipeline.FalsePositiveSuppressor", mocks["app.services.diagnose.pipeline.FalsePositiveSuppressor"]),
patch("app.services.diagnose.pipeline.SummarySynthesizer", mocks["app.services.diagnose.pipeline.SummarySynthesizer"]),
):
events = await _collect_pipeline_events(**kwargs)
stage1 = next(e for e in events if e.get("type") == "pipeline_stage" and e.get("stage") == 1)
assert stage1["name"] == "timeline"

View file

@ -0,0 +1,285 @@
"""Tests for app/services/diagnose/synthesizer.py — SummarySynthesizer.
All tests use mocking; no real LLM calls are made.
"""
from __future__ import annotations
from unittest.mock import MagicMock, patch
from app.context.retriever import RetrievedContext
from app.services.diagnose.models import Hypothesis, RankedHypothesis, TimelineResult
from app.services.diagnose.synthesizer import SummarySynthesizer
# ---------------------------------------------------------------------------
# Fixture helpers
# ---------------------------------------------------------------------------
def _make_hypothesis(
hypothesis_id: str = "h1",
title: str = "SSH flood from external IPs",
description: str = "Repeated failed login attempts from multiple IPs.",
confidence: float = 0.87,
severity: str = "CRITICAL",
) -> Hypothesis:
return Hypothesis(
hypothesis_id=hypothesis_id,
title=title,
description=description,
confidence=confidence,
supporting_cluster_ids=("c1",),
runbook_refs=(),
severity=severity, # type: ignore[arg-type]
)
def _make_ranked(
hypothesis: Hypothesis | None = None,
novelty_score: float = 0.95,
similarity_to_known: float = 0.05,
suppress: bool = False,
suppression_reason: str | None = None,
) -> RankedHypothesis:
h = hypothesis or _make_hypothesis()
return RankedHypothesis(
hypothesis=h,
novelty_score=novelty_score,
similarity_to_known=similarity_to_known,
suppress=suppress,
suppression_reason=suppression_reason,
)
def _make_timeline(
total_entries: int = 42,
n_clusters: int = 3,
) -> TimelineResult:
return TimelineResult(
clusters=tuple(),
total_entries=total_entries,
window_start="2026-01-01T00:00:00+00:00",
window_end="2026-01-01T01:00:00+00:00",
gap_count=1,
burst_count=2,
dominant_sources=("syslog", "auth"),
)
def _make_ctx(chunks: list[dict] | None = None) -> RetrievedContext:
return RetrievedContext(
facts=[{"category": "network", "key": "host", "value": "heimdall", "source": "facts"}],
chunks=chunks or [{"filename": "runbook.md", "text": "Restart sshd if flooded"}],
)
# ---------------------------------------------------------------------------
# Test cases
# ---------------------------------------------------------------------------
class TestSynthesizerWithHypotheses:
"""With hypotheses, result must contain VERDICT."""
def test_returns_verdict_string_with_llm(self):
synthesizer = SummarySynthesizer()
ranked = [_make_ranked()]
timeline = _make_timeline()
ctx = _make_ctx()
mock_resp = MagicMock()
mock_resp.status_code = 200
mock_resp.json.return_value = {
"choices": [{"message": {"content": "VERDICT: CRITICAL — SSH flood (87% confidence)\nTIMELINE: lots of hits."}}]
}
with patch("httpx.post", return_value=mock_resp):
result = synthesizer.synthesize(
ranked=ranked,
timeline=timeline,
ctx=ctx,
query="ssh brute force",
llm_url="http://localhost:11434",
llm_model="llama3",
)
assert "VERDICT" in result
def test_returns_nonempty_string(self):
synthesizer = SummarySynthesizer()
ranked = [_make_ranked()]
timeline = _make_timeline()
ctx = _make_ctx()
mock_resp = MagicMock()
mock_resp.status_code = 200
mock_resp.json.return_value = {
"choices": [{"message": {"content": "VERDICT: CRITICAL — SSH flood (87% confidence)"}}]
}
with patch("httpx.post", return_value=mock_resp):
result = synthesizer.synthesize(
ranked=ranked,
timeline=timeline,
ctx=ctx,
query="why is auth failing",
llm_url="http://localhost:11434",
llm_model="llama3",
)
assert isinstance(result, str)
assert len(result) > 0
class TestSynthesizerSuppressedHypotheses:
"""Suppressed hypotheses must be excluded from the LLM prompt."""
def test_suppressed_hypotheses_excluded_from_prompt(self):
suppressed = _make_ranked(
hypothesis=_make_hypothesis(
hypothesis_id="h2",
title="Wazuh alert processing backlog",
severity="ERROR",
confidence=0.72,
),
suppress=True,
suppression_reason="similar to 2025-04 SSH incident",
novelty_score=0.1,
)
active = _make_ranked(
hypothesis=_make_hypothesis(
hypothesis_id="h1",
title="SSH flood from external IPs",
severity="CRITICAL",
confidence=0.87,
),
suppress=False,
novelty_score=0.95,
)
captured_messages: list = []
def fake_post(url, json=None, headers=None, timeout=None):
if json and "payload" in json:
captured_messages.extend(json["payload"].get("messages", []))
elif json and "messages" in json:
captured_messages.extend(json.get("messages", []))
mock_resp = MagicMock()
mock_resp.status_code = 200
mock_resp.json.return_value = {
"choices": [{"message": {"content": "VERDICT: CRITICAL — SSH flood"}}]
}
return mock_resp
synthesizer = SummarySynthesizer()
with patch("httpx.post", side_effect=fake_post):
synthesizer.synthesize(
ranked=[active, suppressed],
timeline=_make_timeline(),
ctx=_make_ctx(),
query="auth failures",
llm_url="http://localhost:11434",
llm_model="llama3",
)
# The user message should contain the active hypothesis title
# and NOT contain the suppressed one (or mark it suppressed)
user_content = next(
(m["content"] for m in captured_messages if m.get("role") == "user"), ""
)
assert "SSH flood from external IPs" in user_content
# Wazuh should not appear as a standalone top-level hypothesis
# (suppressed items are excluded from the active list sent to the LLM)
assert "Wazuh alert processing backlog" not in user_content
class TestSynthesizerNoLLM:
"""No LLM configured: must return deterministic fallback (not empty)."""
def test_no_llm_url_returns_fallback(self):
synthesizer = SummarySynthesizer()
ranked = [_make_ranked()]
timeline = _make_timeline()
ctx = _make_ctx()
result = synthesizer.synthesize(
ranked=ranked,
timeline=timeline,
ctx=ctx,
query="disk errors",
)
assert isinstance(result, str)
assert len(result) > 0
assert "VERDICT" in result
def test_no_llm_model_returns_fallback(self):
synthesizer = SummarySynthesizer()
ranked = [_make_ranked()]
result = synthesizer.synthesize(
ranked=ranked,
timeline=_make_timeline(),
ctx=_make_ctx(),
query="oom killer",
llm_url="http://localhost:11434",
# llm_model omitted
)
assert "VERDICT" in result
assert "SSH flood from external IPs" in result
def test_llm_failure_returns_fallback(self):
synthesizer = SummarySynthesizer()
ranked = [_make_ranked()]
with patch("httpx.post", side_effect=ConnectionError("refused")):
result = synthesizer.synthesize(
ranked=ranked,
timeline=_make_timeline(),
ctx=_make_ctx(),
query="why is disk full",
llm_url="http://localhost:11434",
llm_model="llama3",
)
assert "VERDICT" in result
assert len(result) > 0
class TestSynthesizerEmptyRanked:
"""Empty ranked list: must return deterministic fallback text, not raise."""
def test_empty_ranked_no_llm_returns_fallback(self):
synthesizer = SummarySynthesizer()
result = synthesizer.synthesize(
ranked=[],
timeline=_make_timeline(),
ctx=_make_ctx(),
query="check everything",
)
assert isinstance(result, str)
assert len(result) > 0
assert "VERDICT" in result
def test_empty_ranked_with_llm_returns_fallback_or_llm_text(self):
"""Even with empty ranked, we attempt LLM and return something."""
synthesizer = SummarySynthesizer()
mock_resp = MagicMock()
mock_resp.status_code = 200
mock_resp.json.return_value = {
"choices": [{"message": {"content": "VERDICT: UNKNOWN — no hypotheses generated"}}]
}
with patch("httpx.post", return_value=mock_resp):
result = synthesizer.synthesize(
ranked=[],
timeline=_make_timeline(),
ctx=_make_ctx(),
query="nothing found",
llm_url="http://localhost:11434",
llm_model="llama3",
)
assert isinstance(result, str)
assert len(result) > 0