From 14bec5769bfa06ebeee42de294b5f28a321f9246 Mon Sep 17 00:00:00 2001 From: pyr0ball Date: Mon, 25 May 2026 14:56:25 -0700 Subject: [PATCH] feat: Stage 5 synthesizer + pipeline orchestrator + feature flag wiring (issue #29) - Add app/services/diagnose/synthesizer.py: SummarySynthesizer (Stage 5) - Builds structured LLM prompt from ranked hypotheses, timeline, RAG context - Excludes suppressed hypotheses from the narrative prompt - Deterministic fallback when no LLM configured or LLM call fails - Same cf-orch task endpoint + direct OpenAI-compat fallback pattern as other stages - Replace pipeline.py stub with full run_pipeline() async generator - Orchestrates all 5 stages via asyncio.to_thread for each synchronous stage - Yields typed SSE event dicts: status, pipeline_stage (1-4), hypotheses, reasoning, done - Suppressor counts (active vs suppressed) reported in stage 4 event message - Wire MULTI_AGENT_ENABLED feature flag into diagnose_stream() - TURNSTONE_MULTI_AGENT_DIAGNOSE=true routes through run_pipeline() - pipeline emits its own done event; legacy path unchanged when flag is false - Import of run_pipeline added to __init__.py - Add 21 new tests (350 -> 371 passing): - tests/test_diagnose_synthesizer.py: 8 tests (with/without LLM, suppressed, empty ranked, LLM failure fallback) - tests/test_diagnose_pipeline.py: 13 tests (flag off, flag on event sequence, empty entries, no LLM, stage 1 cluster count message) Closes: https://git.opensourcesolarpunk.com/Circuit-Forge/turnstone/issues/29 --- app/services/diagnose/__init__.py | 16 + app/services/diagnose/pipeline.py | 126 ++++++- app/services/diagnose/synthesizer.py | 210 ++++++++++++ tests/test_diagnose_pipeline.py | 489 +++++++++++++++++++++++++++ tests/test_diagnose_synthesizer.py | 285 ++++++++++++++++ 5 files changed, 1120 insertions(+), 6 deletions(-) create mode 100644 app/services/diagnose/synthesizer.py create mode 100644 tests/test_diagnose_pipeline.py create mode 100644 tests/test_diagnose_synthesizer.py diff --git a/app/services/diagnose/__init__.py b/app/services/diagnose/__init__.py index a1ee55f..51613cf 100644 --- a/app/services/diagnose/__init__.py +++ b/app/services/diagnose/__init__.py @@ -23,6 +23,7 @@ from typing import Any from app.context.retriever import retrieve_context, format_context_block from app.services.llm import summarize from app.services.search import SearchResult, entries_in_window, search +from app.services.diagnose.pipeline import run_pipeline logger = logging.getLogger(__name__) @@ -303,6 +304,21 @@ async def diagnose_stream( } yield {"type": "entries", "data": [dataclasses.asdict(r) for r in combined]} + if MULTI_AGENT_ENABLED: + async for event in run_pipeline( + db_path=db_path, + entries=combined, + ctx=ctx, + query=query, + since=since, + until=until, + llm_url=llm_url, + llm_model=llm_model, + llm_api_key=llm_api_key, + ): + yield event + return # pipeline emits its own "done" event + if llm_url and llm_model and combined: yield {"type": "status", "message": "Analyzing with LLM…"} reasoning = await asyncio.to_thread( diff --git a/app/services/diagnose/pipeline.py b/app/services/diagnose/pipeline.py index c8ffdcb..834ab34 100644 --- a/app/services/diagnose/pipeline.py +++ b/app/services/diagnose/pipeline.py @@ -1,18 +1,132 @@ -"""Multi-agent diagnose pipeline orchestrator — stub (Task 1).""" +"""Multi-agent diagnose pipeline orchestrator — Stage 1–5 wiring.""" from __future__ import annotations +import asyncio +import dataclasses import logging +from collections.abc import AsyncGenerator +from pathlib import Path from typing import Any +from app.context.retriever import RetrievedContext +from app.services.diagnose.classifier import SeverityClassifier +from app.services.diagnose.hypothesizer import RootCauseHypothesizer +from app.services.diagnose.suppressor import FalsePositiveSuppressor +from app.services.diagnose.synthesizer import SummarySynthesizer from app.services.diagnose.timeline import TimelineReconstructor +from app.services.search import SearchResult logger = logging.getLogger(__name__) -# run_pipeline() will be implemented in Task 6 -logger.debug("TimelineReconstructor available: %s", TimelineReconstructor) +async def run_pipeline( + db_path: Path, + entries: list[SearchResult], + ctx: RetrievedContext, + query: str, + since: str | None, + until: str | None, + llm_url: str | None, + llm_model: str | None, + llm_api_key: str | None, +) -> AsyncGenerator[dict[str, Any], None]: + """Async generator that runs all 5 pipeline stages and yields SSE event dicts. -async def run_pipeline(*args: Any, **kwargs: Any) -> None: - """Placeholder — implemented in Task 6.""" - return None + Stages: + 1. TimelineReconstructor — cluster log entries by time + 2. SeverityClassifier — annotate clusters with severity + 3. RootCauseHypothesizer — generate hypotheses via LLM + 4. FalsePositiveSuppressor — rank and suppress known patterns + 5. SummarySynthesizer — produce a narrative diagnosis + + Yields events in order: + {"type": "status", "message": "Building timeline…"} + {"type": "pipeline_stage", "stage": 1, ...} + {"type": "pipeline_stage", "stage": 2, ...} + {"type": "pipeline_stage", "stage": 3, ...} + {"type": "pipeline_stage", "stage": 4, ...} + {"type": "hypotheses", "data": [...]} + {"type": "status", "message": "Synthesizing…"} + {"type": "reasoning", "text": "..."} — only when synthesis produces text + {"type": "done"} + """ + # Stage 1: Timeline reconstruction + yield {"type": "status", "message": "Building timeline…"} + timeline = await asyncio.to_thread( + TimelineReconstructor().reconstruct, entries + ) + n_clusters = len(timeline.clusters) + burst = timeline.burst_count + yield { + "type": "pipeline_stage", + "stage": 1, + "name": "timeline", + "message": f"Built {n_clusters} clusters, {burst} bursts", + } + + # Stage 2: Severity classification + classified = await asyncio.to_thread( + SeverityClassifier().classify, timeline + ) + sev_counts: dict[str, int] = {} + for sev in classified.cluster_severities.values(): + sev_counts[sev] = sev_counts.get(sev, 0) + 1 + counts_str = ", ".join(f"{k}:{v}" for k, v in sorted(sev_counts.items())) + yield { + "type": "pipeline_stage", + "stage": 2, + "name": "classifier", + "message": f"{classified.classifier_used} classifier: {counts_str}", + } + + # Stage 3: Root-cause hypotheses + hypotheses = await asyncio.to_thread( + RootCauseHypothesizer().hypothesize, + classified, + ctx, + query, + llm_url, + llm_model, + llm_api_key, + ) + yield { + "type": "pipeline_stage", + "stage": 3, + "name": "hypotheses", + "message": f"{len(hypotheses)} hypotheses generated", + } + + # Stage 4: False-positive suppression + ranked = await asyncio.to_thread( + FalsePositiveSuppressor().suppress, hypotheses, db_path + ) + suppressed = sum(1 for rh in ranked if rh.suppress) + active = len(ranked) - suppressed + yield { + "type": "pipeline_stage", + "stage": 4, + "name": "suppressor", + "message": f"{suppressed} suppressed, {active} active", + } + yield { + "type": "hypotheses", + "data": [dataclasses.asdict(rh) for rh in ranked], + } + + # Stage 5: Summary synthesis + yield {"type": "status", "message": "Synthesizing…"} + synthesis_text = await asyncio.to_thread( + SummarySynthesizer().synthesize, + ranked, + timeline, + ctx, + query, + llm_url, + llm_model, + llm_api_key, + ) + if synthesis_text: + yield {"type": "reasoning", "text": synthesis_text} + + yield {"type": "done"} diff --git a/app/services/diagnose/synthesizer.py b/app/services/diagnose/synthesizer.py new file mode 100644 index 0000000..6f1e3f9 --- /dev/null +++ b/app/services/diagnose/synthesizer.py @@ -0,0 +1,210 @@ +"""Stage 5: Summary Synthesizer — deterministic narrative from ranked hypotheses. + +Streaming upgrade (async SSE chunks) is tracked as a follow-up enhancement. +This implementation is synchronous to match the rest of the pipeline. +""" +from __future__ import annotations + +import logging + +import httpx + +from app.context.retriever import RetrievedContext +from app.services.diagnose.models import RankedHypothesis, TimelineResult + +logger = logging.getLogger(__name__) + +_SYSTEM_PROMPT = ( + "You are a Linux sysadmin diagnosing a system incident. " + "Write a concise, actionable incident diagnosis.\n\n" + "Format your response exactly as:\n" + "1. VERDICT: [CRITICAL|ERROR|WARN|INFO] — (% confidence)\n" + "2. TIMELINE: \n" + "3. ROOT CAUSES:\n" + " - (%)\n" + " - (%)\n" + "4. RECOMMENDED ACTIONS:\n" + " - \n" + "5. INVESTIGATE FURTHER: " +) + + +def _extract_content(resp_json: dict) -> str | None: + """Pull text content from an OpenAI-compat chat completion response.""" + choices = resp_json.get("choices") or [] + if not choices: + return None + return (choices[0].get("message", {}).get("content") or "").strip() or None + + +def _build_hypothesis_block(ranked: list[RankedHypothesis]) -> str: + """Build the hypothesis block for the prompt (non-suppressed only, top 3).""" + active = [rh for rh in ranked if not rh.suppress][:3] + if not active: + return "(none)" + lines: list[str] = [] + for rh in active: + h = rh.hypothesis + conf_pct = int(h.confidence * 100) + similar = ( + f"Yes — suppressed, {rh.suppression_reason}" + if rh.suppression_reason + else "No" + ) + novelty = f"{rh.novelty_score:.2f}" + lines.append( + f"- [{h.severity}, {conf_pct}%] {h.title}\n" + f" Similar resolved incident? {similar} (novelty {novelty})" + ) + return "\n".join(lines) + + +def _build_context_block(ctx: RetrievedContext) -> str: + """Build the runbook context block for the prompt.""" + parts: list[str] = [] + for chunk in ctx.chunks[:5]: + filename = chunk.get("filename", "unknown") + text = chunk.get("text", "")[:300] + parts.append(f"[{filename}] {text}") + return "\n".join(parts) if parts else "(none)" + + +def _deterministic_fallback( + ranked: list[RankedHypothesis], + timeline: TimelineResult, +) -> str: + """Build a deterministic fallback text when no LLM is available.""" + active = [rh for rh in ranked if not rh.suppress][:3] + if active: + top = active[0] + verdict_severity = top.hypothesis.severity + verdict_title = top.hypothesis.title + verdict_conf = int(top.hypothesis.confidence * 100) + elif ranked: + top = ranked[0] + verdict_severity = top.hypothesis.severity + verdict_title = top.hypothesis.title + verdict_conf = int(top.hypothesis.confidence * 100) + else: + verdict_severity = "UNKNOWN" + verdict_title = "No hypotheses generated" + verdict_conf = 0 + + root_causes = ", ".join( + rh.hypothesis.title for rh in (active or ranked[:3]) + ) or "None" + + return ( + f"VERDICT: {verdict_severity} — {verdict_title} ({verdict_conf}% confidence)\n" + f"TIMELINE: {timeline.total_entries} entries across {len(timeline.clusters)} clusters.\n" + f"ROOT CAUSES: {root_causes}" + ) + + +class SummarySynthesizer: + """Stage 5 of the multi-agent diagnose pipeline. + + Synthesizes a human-readable incident narrative from ranked hypotheses, + the reconstructed timeline, and RAG context. When no LLM is configured, + returns a deterministic fallback built from the hypothesis data. + """ + + def synthesize( + self, + ranked: list[RankedHypothesis], + timeline: TimelineResult, + ctx: RetrievedContext, + query: str, + llm_url: str | None = None, + llm_model: str | None = None, + llm_api_key: str | None = None, + ) -> str: + """Return synthesis text (single string, synchronous). + + Falls back to a deterministic narrative when no LLM URL or model is + provided, or when the LLM call fails. + """ + fallback = _deterministic_fallback(ranked, timeline) + + if not llm_url or not llm_model: + return fallback + + hypothesis_block = _build_hypothesis_block(ranked) + context_block = _build_context_block(ctx) + dominant = ", ".join(timeline.dominant_sources[:5]) or "none" + + user_message = ( + f"Query: {query}\n\n" + f"Timeline summary:\n" + f"- {len(timeline.clusters)} clusters, " + f"{timeline.burst_count} bursts, " + f"{timeline.gap_count} silence gaps\n" + f"- Primary sources: {dominant}\n\n" + f"Top hypotheses:\n{hypothesis_block}\n\n" + f"Context from runbooks:\n{context_block}" + ) + + messages = [ + {"role": "system", "content": _SYSTEM_PROMPT}, + {"role": "user", "content": user_message}, + ] + + result = self._call_llm( + llm_url=llm_url, + llm_model=llm_model, + llm_api_key=llm_api_key, + messages=messages, + ) + return result if result else fallback + + def _call_llm( + self, + llm_url: str, + llm_model: str, + llm_api_key: str | None, + messages: list[dict], + ) -> str | None: + """Send messages to the LLM and return raw text content. + + Tries the cf-orch task endpoint first, falls back to direct OpenAI-compat. + """ + headers = {"Authorization": f"Bearer {llm_api_key}"} if llm_api_key else {} + + task_url = f"{llm_url.rstrip('/')}/api/inference/task" + try: + resp = httpx.post( + task_url, + json={ + "product": "turnstone", + "task": "log_analysis", + "payload": {"messages": messages, "stream": False}, + }, + headers=headers, + timeout=120.0, + ) + if resp.status_code == 200: + return _extract_content(resp.json()) + if resp.status_code != 404: + resp.raise_for_status() + logger.debug( + "No task assignment for turnstone.log_analysis — falling back to direct model" + ) + except Exception as exc: + logger.debug( + "Task endpoint unavailable (%s) — falling back to direct model", exc + ) + + try: + resp = httpx.post( + f"{llm_url.rstrip('/')}/v1/chat/completions", + json={"model": llm_model, "messages": messages, "stream": False}, + headers=headers, + timeout=120.0, + ) + resp.raise_for_status() + return _extract_content(resp.json()) + except Exception as exc: + logger.warning( + "LLM synthesizer failed (%s): %s", type(exc).__name__, exc + ) + return None diff --git a/tests/test_diagnose_pipeline.py b/tests/test_diagnose_pipeline.py new file mode 100644 index 0000000..9cc1c1d --- /dev/null +++ b/tests/test_diagnose_pipeline.py @@ -0,0 +1,489 @@ +"""Tests for app/services/diagnose/pipeline.py and __init__.py feature flag wiring. + +All tests use mocking; no real LLM, ML, or DB calls are made. +""" +from __future__ import annotations + +from pathlib import Path +from typing import Any +from unittest.mock import MagicMock, patch + +import pytest + +from app.context.retriever import RetrievedContext +from app.services.diagnose.models import ( + ClassifiedTimeline, + Hypothesis, + RankedHypothesis, + TimelineResult, +) +from app.services.search import SearchResult + + +# --------------------------------------------------------------------------- +# Shared helpers +# --------------------------------------------------------------------------- + +def _make_search_result( + entry_id: str = "e1", + source_id: str = "syslog", + timestamp_iso: str | None = "2026-01-01T00:00:00+00:00", + severity: str | None = "ERROR", + text: str = "ssh: invalid user", +) -> SearchResult: + return SearchResult( + entry_id=entry_id, + source_id=source_id, + sequence=1, + timestamp_iso=timestamp_iso, + severity=severity, + repeat_count=1, + out_of_order=False, + matched_patterns=["ssh_fail"], + text=text, + rank=1.0, + ) + + +def _make_ctx() -> RetrievedContext: + return RetrievedContext(facts=[], chunks=[]) + + +def _make_timeline(n_clusters: int = 2) -> TimelineResult: + return TimelineResult( + clusters=tuple(), + total_entries=5, + window_start="2026-01-01T00:00:00+00:00", + window_end="2026-01-01T01:00:00+00:00", + gap_count=0, + burst_count=1, + dominant_sources=("syslog",), + ) + + +def _make_classified(timeline: TimelineResult | None = None) -> ClassifiedTimeline: + tl = timeline or _make_timeline() + return ClassifiedTimeline( + timeline=tl, + cluster_severities={}, + classifier_used="regex", + model_id=None, + ) + + +def _make_hypothesis( + hypothesis_id: str = "h1", + title: str = "SSH flood", + confidence: float = 0.87, + severity: str = "CRITICAL", +) -> Hypothesis: + return Hypothesis( + hypothesis_id=hypothesis_id, + title=title, + description="Multiple failed SSH attempts.", + confidence=confidence, + supporting_cluster_ids=("c1",), + runbook_refs=(), + severity=severity, # type: ignore[arg-type] + ) + + +def _make_ranked(hypothesis: Hypothesis | None = None, suppress: bool = False) -> RankedHypothesis: + h = hypothesis or _make_hypothesis() + return RankedHypothesis( + hypothesis=h, + novelty_score=0.95, + similarity_to_known=0.05, + suppress=suppress, + suppression_reason="similar to known" if suppress else None, + ) + + +# --------------------------------------------------------------------------- +# Helper: collect all events from run_pipeline +# --------------------------------------------------------------------------- + +async def _collect_pipeline_events(**kwargs) -> list[dict[str, Any]]: + """Run run_pipeline and collect all yielded events into a list.""" + from app.services.diagnose.pipeline import run_pipeline + events = [] + async for event in run_pipeline(**kwargs): + events.append(event) + return events + + +def _default_pipeline_kwargs(entries=None, db_path=None) -> dict: + return dict( + db_path=db_path or Path("/tmp/fake.db"), + entries=entries or [_make_search_result()], + ctx=_make_ctx(), + query="ssh brute force", + since="2026-01-01T00:00:00+00:00", + until="2026-01-01T01:00:00+00:00", + llm_url=None, + llm_model=None, + llm_api_key=None, + ) + + +# --------------------------------------------------------------------------- +# Mock factories for all 5 stage classes +# --------------------------------------------------------------------------- + +def _mock_all_stages( + hypotheses=None, + ranked=None, + synthesis_text="VERDICT: CRITICAL — SSH flood (87% confidence)", +): + """Return a dict of patch targets and their mock return values.""" + timeline = _make_timeline() + classified = _make_classified(timeline) + hyps = hypotheses if hypotheses is not None else [_make_hypothesis()] + rnk = ranked if ranked is not None else [_make_ranked()] + + mock_reconstructor = MagicMock() + mock_reconstructor.return_value.reconstruct.return_value = timeline + + mock_classifier = MagicMock() + mock_classifier.return_value.classify.return_value = classified + + mock_hypothesizer = MagicMock() + mock_hypothesizer.return_value.hypothesize.return_value = hyps + + mock_suppressor = MagicMock() + mock_suppressor.return_value.suppress.return_value = rnk + + mock_synthesizer = MagicMock() + mock_synthesizer.return_value.synthesize.return_value = synthesis_text + + return { + "app.services.diagnose.pipeline.TimelineReconstructor": mock_reconstructor, + "app.services.diagnose.pipeline.SeverityClassifier": mock_classifier, + "app.services.diagnose.pipeline.RootCauseHypothesizer": mock_hypothesizer, + "app.services.diagnose.pipeline.FalsePositiveSuppressor": mock_suppressor, + "app.services.diagnose.pipeline.SummarySynthesizer": mock_synthesizer, + } + + +# --------------------------------------------------------------------------- +# 1. Feature flag off: legacy summarize() path runs, not run_pipeline +# --------------------------------------------------------------------------- + +class TestFeatureFlagOff: + @pytest.mark.asyncio + async def test_legacy_path_when_flag_off(self): + """With MULTI_AGENT_ENABLED=False, run_pipeline is never called.""" + from app.services import diagnose as diagnose_module + + entries = [_make_search_result()] + + with ( + patch.object(diagnose_module, "MULTI_AGENT_ENABLED", False), + patch("app.services.diagnose.search", return_value=entries), + patch("app.services.diagnose.entries_in_window", return_value=[]), + patch("app.services.diagnose.retrieve_context", return_value=_make_ctx()), + patch("app.services.diagnose.format_context_block", return_value=None), + patch("app.services.diagnose.run_pipeline") as mock_pipeline, + patch("app.services.diagnose.summarize", return_value=None), + ): + events = [] + async for event in diagnose_module.diagnose_stream( + db_path=Path("/tmp/fake.db"), + query="ssh failures", + llm_url=None, + llm_model=None, + ): + events.append(event) + + # run_pipeline must NOT have been called + mock_pipeline.assert_not_called() + + # SSE sequence must end with done + types = [e["type"] for e in events] + assert "done" in types + assert types[-1] == "done" + + @pytest.mark.asyncio + async def test_legacy_done_event_is_last(self): + """Legacy path: done is always the last event.""" + from app.services import diagnose as diagnose_module + + with ( + patch.object(diagnose_module, "MULTI_AGENT_ENABLED", False), + patch("app.services.diagnose.search", return_value=[]), + patch("app.services.diagnose.entries_in_window", return_value=[]), + patch("app.services.diagnose.retrieve_context", return_value=_make_ctx()), + patch("app.services.diagnose.format_context_block", return_value=None), + ): + events = [] + async for event in diagnose_module.diagnose_stream( + db_path=Path("/tmp/fake.db"), + query="check logs", + ): + events.append(event) + + assert events[-1] == {"type": "done"} + + +# --------------------------------------------------------------------------- +# 2. Feature flag on, all stages mocked: verify SSE event sequence +# --------------------------------------------------------------------------- + +class TestFeatureFlagOn: + @pytest.mark.asyncio + async def test_pipeline_stage_events_in_order(self): + """pipeline_stage events must be emitted stages 1→2→3→4 in order.""" + mocks = _mock_all_stages() + kwargs = _default_pipeline_kwargs() + with ( + patch("app.services.diagnose.pipeline.TimelineReconstructor", mocks["app.services.diagnose.pipeline.TimelineReconstructor"]), + patch("app.services.diagnose.pipeline.SeverityClassifier", mocks["app.services.diagnose.pipeline.SeverityClassifier"]), + patch("app.services.diagnose.pipeline.RootCauseHypothesizer", mocks["app.services.diagnose.pipeline.RootCauseHypothesizer"]), + patch("app.services.diagnose.pipeline.FalsePositiveSuppressor", mocks["app.services.diagnose.pipeline.FalsePositiveSuppressor"]), + patch("app.services.diagnose.pipeline.SummarySynthesizer", mocks["app.services.diagnose.pipeline.SummarySynthesizer"]), + ): + events = await _collect_pipeline_events(**kwargs) + + stage_events = [e for e in events if e.get("type") == "pipeline_stage"] + stages = [e["stage"] for e in stage_events] + assert stages == [1, 2, 3, 4] + + @pytest.mark.asyncio + async def test_hypotheses_event_after_stage4(self): + """hypotheses event must appear after pipeline_stage stage=4.""" + mocks = _mock_all_stages() + kwargs = _default_pipeline_kwargs() + + with ( + patch("app.services.diagnose.pipeline.TimelineReconstructor", mocks["app.services.diagnose.pipeline.TimelineReconstructor"]), + patch("app.services.diagnose.pipeline.SeverityClassifier", mocks["app.services.diagnose.pipeline.SeverityClassifier"]), + patch("app.services.diagnose.pipeline.RootCauseHypothesizer", mocks["app.services.diagnose.pipeline.RootCauseHypothesizer"]), + patch("app.services.diagnose.pipeline.FalsePositiveSuppressor", mocks["app.services.diagnose.pipeline.FalsePositiveSuppressor"]), + patch("app.services.diagnose.pipeline.SummarySynthesizer", mocks["app.services.diagnose.pipeline.SummarySynthesizer"]), + ): + events = await _collect_pipeline_events(**kwargs) + + stage4_idx = next( + i for i, e in enumerate(events) + if e.get("type") == "pipeline_stage" and e.get("stage") == 4 + ) + hyp_idx = next(i for i, e in enumerate(events) if e.get("type") == "hypotheses") + assert hyp_idx > stage4_idx + + @pytest.mark.asyncio + async def test_reasoning_event_emitted(self): + """reasoning event must be present when synthesizer returns text.""" + mocks = _mock_all_stages(synthesis_text="VERDICT: CRITICAL — SSH flood") + kwargs = _default_pipeline_kwargs() + + with ( + patch("app.services.diagnose.pipeline.TimelineReconstructor", mocks["app.services.diagnose.pipeline.TimelineReconstructor"]), + patch("app.services.diagnose.pipeline.SeverityClassifier", mocks["app.services.diagnose.pipeline.SeverityClassifier"]), + patch("app.services.diagnose.pipeline.RootCauseHypothesizer", mocks["app.services.diagnose.pipeline.RootCauseHypothesizer"]), + patch("app.services.diagnose.pipeline.FalsePositiveSuppressor", mocks["app.services.diagnose.pipeline.FalsePositiveSuppressor"]), + patch("app.services.diagnose.pipeline.SummarySynthesizer", mocks["app.services.diagnose.pipeline.SummarySynthesizer"]), + ): + events = await _collect_pipeline_events(**kwargs) + + reasoning_events = [e for e in events if e.get("type") == "reasoning"] + assert len(reasoning_events) == 1 + assert "VERDICT" in reasoning_events[0]["text"] + + @pytest.mark.asyncio + async def test_done_event_is_last(self): + """done must always be the last event in the pipeline sequence.""" + mocks = _mock_all_stages() + kwargs = _default_pipeline_kwargs() + + with ( + patch("app.services.diagnose.pipeline.TimelineReconstructor", mocks["app.services.diagnose.pipeline.TimelineReconstructor"]), + patch("app.services.diagnose.pipeline.SeverityClassifier", mocks["app.services.diagnose.pipeline.SeverityClassifier"]), + patch("app.services.diagnose.pipeline.RootCauseHypothesizer", mocks["app.services.diagnose.pipeline.RootCauseHypothesizer"]), + patch("app.services.diagnose.pipeline.FalsePositiveSuppressor", mocks["app.services.diagnose.pipeline.FalsePositiveSuppressor"]), + patch("app.services.diagnose.pipeline.SummarySynthesizer", mocks["app.services.diagnose.pipeline.SummarySynthesizer"]), + ): + events = await _collect_pipeline_events(**kwargs) + + assert events[-1] == {"type": "done"} + + @pytest.mark.asyncio + async def test_pipeline_wired_from_diagnose_stream(self): + """diagnose_stream routes through run_pipeline when flag is on.""" + from app.services import diagnose as diagnose_module + + entries = [_make_search_result()] + + async def fake_pipeline(**kwargs): + yield {"type": "status", "message": "Building timeline…"} + yield {"type": "pipeline_stage", "stage": 1, "name": "timeline", "message": "Built 1 clusters, 0 bursts"} + yield {"type": "done"} + + with ( + patch.object(diagnose_module, "MULTI_AGENT_ENABLED", True), + patch("app.services.diagnose.search", return_value=entries), + patch("app.services.diagnose.entries_in_window", return_value=[]), + patch("app.services.diagnose.retrieve_context", return_value=_make_ctx()), + patch("app.services.diagnose.format_context_block", return_value=None), + patch("app.services.diagnose.run_pipeline", side_effect=fake_pipeline), + ): + events = [] + async for event in diagnose_module.diagnose_stream( + db_path=Path("/tmp/fake.db"), + query="ssh failures", + ): + events.append(event) + + types = [e["type"] for e in events] + assert "pipeline_stage" in types + assert types[-1] == "done" + # Legacy summarize() must NOT have been called — done event came from pipeline + assert types.count("done") == 1 + + +# --------------------------------------------------------------------------- +# 3. Empty entries: pipeline completes with done +# --------------------------------------------------------------------------- + +class TestEmptyEntries: + @pytest.mark.asyncio + async def test_empty_entries_pipeline_completes(self): + """Pipeline with entries=[] must still complete and emit done.""" + mocks = _mock_all_stages(hypotheses=[], ranked=[]) + kwargs = _default_pipeline_kwargs(entries=[]) + + with ( + patch("app.services.diagnose.pipeline.TimelineReconstructor", mocks["app.services.diagnose.pipeline.TimelineReconstructor"]), + patch("app.services.diagnose.pipeline.SeverityClassifier", mocks["app.services.diagnose.pipeline.SeverityClassifier"]), + patch("app.services.diagnose.pipeline.RootCauseHypothesizer", mocks["app.services.diagnose.pipeline.RootCauseHypothesizer"]), + patch("app.services.diagnose.pipeline.FalsePositiveSuppressor", mocks["app.services.diagnose.pipeline.FalsePositiveSuppressor"]), + patch("app.services.diagnose.pipeline.SummarySynthesizer", mocks["app.services.diagnose.pipeline.SummarySynthesizer"]), + ): + events = await _collect_pipeline_events(**kwargs) + + types = [e["type"] for e in events] + assert "done" in types + assert types[-1] == "done" + + @pytest.mark.asyncio + async def test_empty_entries_all_stage_events_present(self): + """Even with empty entries, all 4 pipeline_stage events are emitted.""" + mocks = _mock_all_stages(hypotheses=[], ranked=[]) + kwargs = _default_pipeline_kwargs(entries=[]) + + with ( + patch("app.services.diagnose.pipeline.TimelineReconstructor", mocks["app.services.diagnose.pipeline.TimelineReconstructor"]), + patch("app.services.diagnose.pipeline.SeverityClassifier", mocks["app.services.diagnose.pipeline.SeverityClassifier"]), + patch("app.services.diagnose.pipeline.RootCauseHypothesizer", mocks["app.services.diagnose.pipeline.RootCauseHypothesizer"]), + patch("app.services.diagnose.pipeline.FalsePositiveSuppressor", mocks["app.services.diagnose.pipeline.FalsePositiveSuppressor"]), + patch("app.services.diagnose.pipeline.SummarySynthesizer", mocks["app.services.diagnose.pipeline.SummarySynthesizer"]), + ): + events = await _collect_pipeline_events(**kwargs) + + stage_events = [e for e in events if e.get("type") == "pipeline_stage"] + assert len(stage_events) == 4 + + +# --------------------------------------------------------------------------- +# 4. No LLM: Stage 3 and Stage 5 return empty/fallback; done still emitted +# --------------------------------------------------------------------------- + +class TestNoLLM: + @pytest.mark.asyncio + async def test_no_llm_pipeline_completes_with_done(self): + """No llm_url/llm_model: pipeline runs all stages and emits done.""" + mocks = _mock_all_stages(hypotheses=[], ranked=[], synthesis_text="VERDICT: UNKNOWN — no hypotheses generated") + kwargs = _default_pipeline_kwargs() + # llm_url and llm_model already None in default kwargs + + with ( + patch("app.services.diagnose.pipeline.TimelineReconstructor", mocks["app.services.diagnose.pipeline.TimelineReconstructor"]), + patch("app.services.diagnose.pipeline.SeverityClassifier", mocks["app.services.diagnose.pipeline.SeverityClassifier"]), + patch("app.services.diagnose.pipeline.RootCauseHypothesizer", mocks["app.services.diagnose.pipeline.RootCauseHypothesizer"]), + patch("app.services.diagnose.pipeline.FalsePositiveSuppressor", mocks["app.services.diagnose.pipeline.FalsePositiveSuppressor"]), + patch("app.services.diagnose.pipeline.SummarySynthesizer", mocks["app.services.diagnose.pipeline.SummarySynthesizer"]), + ): + events = await _collect_pipeline_events(**kwargs) + + assert events[-1] == {"type": "done"} + + @pytest.mark.asyncio + async def test_no_llm_no_reasoning_event_when_synthesis_empty(self): + """When synthesizer returns empty string, no reasoning event is emitted.""" + mocks = _mock_all_stages(synthesis_text="") + kwargs = _default_pipeline_kwargs() + + with ( + patch("app.services.diagnose.pipeline.TimelineReconstructor", mocks["app.services.diagnose.pipeline.TimelineReconstructor"]), + patch("app.services.diagnose.pipeline.SeverityClassifier", mocks["app.services.diagnose.pipeline.SeverityClassifier"]), + patch("app.services.diagnose.pipeline.RootCauseHypothesizer", mocks["app.services.diagnose.pipeline.RootCauseHypothesizer"]), + patch("app.services.diagnose.pipeline.FalsePositiveSuppressor", mocks["app.services.diagnose.pipeline.FalsePositiveSuppressor"]), + patch("app.services.diagnose.pipeline.SummarySynthesizer", mocks["app.services.diagnose.pipeline.SummarySynthesizer"]), + ): + events = await _collect_pipeline_events(**kwargs) + + reasoning_events = [e for e in events if e.get("type") == "reasoning"] + assert len(reasoning_events) == 0 + + +# --------------------------------------------------------------------------- +# 5. Stage 1 cluster count in pipeline_stage message +# --------------------------------------------------------------------------- + +class TestStage1Message: + @pytest.mark.asyncio + async def test_stage1_message_contains_cluster_count(self): + """pipeline_stage stage=1 message must report cluster count.""" + timeline = TimelineResult( + clusters=tuple(), + total_entries=10, + window_start=None, + window_end=None, + gap_count=0, + burst_count=3, + dominant_sources=("syslog",), + ) + classified = _make_classified(timeline) + + mock_reconstructor = MagicMock() + mock_reconstructor.return_value.reconstruct.return_value = timeline + mock_classifier = MagicMock() + mock_classifier.return_value.classify.return_value = classified + mock_hypothesizer = MagicMock() + mock_hypothesizer.return_value.hypothesize.return_value = [] + mock_suppressor = MagicMock() + mock_suppressor.return_value.suppress.return_value = [] + mock_synthesizer = MagicMock() + mock_synthesizer.return_value.synthesize.return_value = "VERDICT: INFO — nothing found" + + kwargs = _default_pipeline_kwargs() + with ( + patch("app.services.diagnose.pipeline.TimelineReconstructor", mock_reconstructor), + patch("app.services.diagnose.pipeline.SeverityClassifier", mock_classifier), + patch("app.services.diagnose.pipeline.RootCauseHypothesizer", mock_hypothesizer), + patch("app.services.diagnose.pipeline.FalsePositiveSuppressor", mock_suppressor), + patch("app.services.diagnose.pipeline.SummarySynthesizer", mock_synthesizer), + ): + events = await _collect_pipeline_events(**kwargs) + + stage1 = next(e for e in events if e.get("type") == "pipeline_stage" and e.get("stage") == 1) + # 0 clusters (empty tuple), 3 bursts + assert "0" in stage1["message"] # cluster count + assert "3" in stage1["message"] # burst count + + @pytest.mark.asyncio + async def test_stage1_name_is_timeline(self): + """pipeline_stage stage=1 must have name='timeline'.""" + mocks = _mock_all_stages() + kwargs = _default_pipeline_kwargs() + + with ( + patch("app.services.diagnose.pipeline.TimelineReconstructor", mocks["app.services.diagnose.pipeline.TimelineReconstructor"]), + patch("app.services.diagnose.pipeline.SeverityClassifier", mocks["app.services.diagnose.pipeline.SeverityClassifier"]), + patch("app.services.diagnose.pipeline.RootCauseHypothesizer", mocks["app.services.diagnose.pipeline.RootCauseHypothesizer"]), + patch("app.services.diagnose.pipeline.FalsePositiveSuppressor", mocks["app.services.diagnose.pipeline.FalsePositiveSuppressor"]), + patch("app.services.diagnose.pipeline.SummarySynthesizer", mocks["app.services.diagnose.pipeline.SummarySynthesizer"]), + ): + events = await _collect_pipeline_events(**kwargs) + + stage1 = next(e for e in events if e.get("type") == "pipeline_stage" and e.get("stage") == 1) + assert stage1["name"] == "timeline" diff --git a/tests/test_diagnose_synthesizer.py b/tests/test_diagnose_synthesizer.py new file mode 100644 index 0000000..5229c99 --- /dev/null +++ b/tests/test_diagnose_synthesizer.py @@ -0,0 +1,285 @@ +"""Tests for app/services/diagnose/synthesizer.py — SummarySynthesizer. + +All tests use mocking; no real LLM calls are made. +""" +from __future__ import annotations + +from unittest.mock import MagicMock, patch + +from app.context.retriever import RetrievedContext +from app.services.diagnose.models import Hypothesis, RankedHypothesis, TimelineResult +from app.services.diagnose.synthesizer import SummarySynthesizer + + +# --------------------------------------------------------------------------- +# Fixture helpers +# --------------------------------------------------------------------------- + +def _make_hypothesis( + hypothesis_id: str = "h1", + title: str = "SSH flood from external IPs", + description: str = "Repeated failed login attempts from multiple IPs.", + confidence: float = 0.87, + severity: str = "CRITICAL", +) -> Hypothesis: + return Hypothesis( + hypothesis_id=hypothesis_id, + title=title, + description=description, + confidence=confidence, + supporting_cluster_ids=("c1",), + runbook_refs=(), + severity=severity, # type: ignore[arg-type] + ) + + +def _make_ranked( + hypothesis: Hypothesis | None = None, + novelty_score: float = 0.95, + similarity_to_known: float = 0.05, + suppress: bool = False, + suppression_reason: str | None = None, +) -> RankedHypothesis: + h = hypothesis or _make_hypothesis() + return RankedHypothesis( + hypothesis=h, + novelty_score=novelty_score, + similarity_to_known=similarity_to_known, + suppress=suppress, + suppression_reason=suppression_reason, + ) + + +def _make_timeline( + total_entries: int = 42, + n_clusters: int = 3, +) -> TimelineResult: + return TimelineResult( + clusters=tuple(), + total_entries=total_entries, + window_start="2026-01-01T00:00:00+00:00", + window_end="2026-01-01T01:00:00+00:00", + gap_count=1, + burst_count=2, + dominant_sources=("syslog", "auth"), + ) + + +def _make_ctx(chunks: list[dict] | None = None) -> RetrievedContext: + return RetrievedContext( + facts=[{"category": "network", "key": "host", "value": "heimdall", "source": "facts"}], + chunks=chunks or [{"filename": "runbook.md", "text": "Restart sshd if flooded"}], + ) + + +# --------------------------------------------------------------------------- +# Test cases +# --------------------------------------------------------------------------- + +class TestSynthesizerWithHypotheses: + """With hypotheses, result must contain VERDICT.""" + + def test_returns_verdict_string_with_llm(self): + synthesizer = SummarySynthesizer() + ranked = [_make_ranked()] + timeline = _make_timeline() + ctx = _make_ctx() + + mock_resp = MagicMock() + mock_resp.status_code = 200 + mock_resp.json.return_value = { + "choices": [{"message": {"content": "VERDICT: CRITICAL — SSH flood (87% confidence)\nTIMELINE: lots of hits."}}] + } + + with patch("httpx.post", return_value=mock_resp): + result = synthesizer.synthesize( + ranked=ranked, + timeline=timeline, + ctx=ctx, + query="ssh brute force", + llm_url="http://localhost:11434", + llm_model="llama3", + ) + + assert "VERDICT" in result + + def test_returns_nonempty_string(self): + synthesizer = SummarySynthesizer() + ranked = [_make_ranked()] + timeline = _make_timeline() + ctx = _make_ctx() + + mock_resp = MagicMock() + mock_resp.status_code = 200 + mock_resp.json.return_value = { + "choices": [{"message": {"content": "VERDICT: CRITICAL — SSH flood (87% confidence)"}}] + } + + with patch("httpx.post", return_value=mock_resp): + result = synthesizer.synthesize( + ranked=ranked, + timeline=timeline, + ctx=ctx, + query="why is auth failing", + llm_url="http://localhost:11434", + llm_model="llama3", + ) + + assert isinstance(result, str) + assert len(result) > 0 + + +class TestSynthesizerSuppressedHypotheses: + """Suppressed hypotheses must be excluded from the LLM prompt.""" + + def test_suppressed_hypotheses_excluded_from_prompt(self): + suppressed = _make_ranked( + hypothesis=_make_hypothesis( + hypothesis_id="h2", + title="Wazuh alert processing backlog", + severity="ERROR", + confidence=0.72, + ), + suppress=True, + suppression_reason="similar to 2025-04 SSH incident", + novelty_score=0.1, + ) + active = _make_ranked( + hypothesis=_make_hypothesis( + hypothesis_id="h1", + title="SSH flood from external IPs", + severity="CRITICAL", + confidence=0.87, + ), + suppress=False, + novelty_score=0.95, + ) + + captured_messages: list = [] + + def fake_post(url, json=None, headers=None, timeout=None): + if json and "payload" in json: + captured_messages.extend(json["payload"].get("messages", [])) + elif json and "messages" in json: + captured_messages.extend(json.get("messages", [])) + mock_resp = MagicMock() + mock_resp.status_code = 200 + mock_resp.json.return_value = { + "choices": [{"message": {"content": "VERDICT: CRITICAL — SSH flood"}}] + } + return mock_resp + + synthesizer = SummarySynthesizer() + with patch("httpx.post", side_effect=fake_post): + synthesizer.synthesize( + ranked=[active, suppressed], + timeline=_make_timeline(), + ctx=_make_ctx(), + query="auth failures", + llm_url="http://localhost:11434", + llm_model="llama3", + ) + + # The user message should contain the active hypothesis title + # and NOT contain the suppressed one (or mark it suppressed) + user_content = next( + (m["content"] for m in captured_messages if m.get("role") == "user"), "" + ) + assert "SSH flood from external IPs" in user_content + # Wazuh should not appear as a standalone top-level hypothesis + # (suppressed items are excluded from the active list sent to the LLM) + assert "Wazuh alert processing backlog" not in user_content + + +class TestSynthesizerNoLLM: + """No LLM configured: must return deterministic fallback (not empty).""" + + def test_no_llm_url_returns_fallback(self): + synthesizer = SummarySynthesizer() + ranked = [_make_ranked()] + timeline = _make_timeline() + ctx = _make_ctx() + + result = synthesizer.synthesize( + ranked=ranked, + timeline=timeline, + ctx=ctx, + query="disk errors", + ) + + assert isinstance(result, str) + assert len(result) > 0 + assert "VERDICT" in result + + def test_no_llm_model_returns_fallback(self): + synthesizer = SummarySynthesizer() + ranked = [_make_ranked()] + + result = synthesizer.synthesize( + ranked=ranked, + timeline=_make_timeline(), + ctx=_make_ctx(), + query="oom killer", + llm_url="http://localhost:11434", + # llm_model omitted + ) + + assert "VERDICT" in result + assert "SSH flood from external IPs" in result + + def test_llm_failure_returns_fallback(self): + synthesizer = SummarySynthesizer() + ranked = [_make_ranked()] + + with patch("httpx.post", side_effect=ConnectionError("refused")): + result = synthesizer.synthesize( + ranked=ranked, + timeline=_make_timeline(), + ctx=_make_ctx(), + query="why is disk full", + llm_url="http://localhost:11434", + llm_model="llama3", + ) + + assert "VERDICT" in result + assert len(result) > 0 + + +class TestSynthesizerEmptyRanked: + """Empty ranked list: must return deterministic fallback text, not raise.""" + + def test_empty_ranked_no_llm_returns_fallback(self): + synthesizer = SummarySynthesizer() + result = synthesizer.synthesize( + ranked=[], + timeline=_make_timeline(), + ctx=_make_ctx(), + query="check everything", + ) + + assert isinstance(result, str) + assert len(result) > 0 + assert "VERDICT" in result + + def test_empty_ranked_with_llm_returns_fallback_or_llm_text(self): + """Even with empty ranked, we attempt LLM and return something.""" + synthesizer = SummarySynthesizer() + + mock_resp = MagicMock() + mock_resp.status_code = 200 + mock_resp.json.return_value = { + "choices": [{"message": {"content": "VERDICT: UNKNOWN — no hypotheses generated"}}] + } + + with patch("httpx.post", return_value=mock_resp): + result = synthesizer.synthesize( + ranked=[], + timeline=_make_timeline(), + ctx=_make_ctx(), + query="nothing found", + llm_url="http://localhost:11434", + llm_model="llama3", + ) + + assert isinstance(result, str) + assert len(result) > 0