feat(diagnose): 5-stage multi-agent diagnose pipeline (#29) #39
5 changed files with 1120 additions and 6 deletions
|
|
@ -23,6 +23,7 @@ from typing import Any
|
||||||
from app.context.retriever import retrieve_context, format_context_block
|
from app.context.retriever import retrieve_context, format_context_block
|
||||||
from app.services.llm import summarize
|
from app.services.llm import summarize
|
||||||
from app.services.search import SearchResult, entries_in_window, search
|
from app.services.search import SearchResult, entries_in_window, search
|
||||||
|
from app.services.diagnose.pipeline import run_pipeline
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
@ -303,6 +304,21 @@ async def diagnose_stream(
|
||||||
}
|
}
|
||||||
yield {"type": "entries", "data": [dataclasses.asdict(r) for r in combined]}
|
yield {"type": "entries", "data": [dataclasses.asdict(r) for r in combined]}
|
||||||
|
|
||||||
|
if MULTI_AGENT_ENABLED:
|
||||||
|
async for event in run_pipeline(
|
||||||
|
db_path=db_path,
|
||||||
|
entries=combined,
|
||||||
|
ctx=ctx,
|
||||||
|
query=query,
|
||||||
|
since=since,
|
||||||
|
until=until,
|
||||||
|
llm_url=llm_url,
|
||||||
|
llm_model=llm_model,
|
||||||
|
llm_api_key=llm_api_key,
|
||||||
|
):
|
||||||
|
yield event
|
||||||
|
return # pipeline emits its own "done" event
|
||||||
|
|
||||||
if llm_url and llm_model and combined:
|
if llm_url and llm_model and combined:
|
||||||
yield {"type": "status", "message": "Analyzing with LLM…"}
|
yield {"type": "status", "message": "Analyzing with LLM…"}
|
||||||
reasoning = await asyncio.to_thread(
|
reasoning = await asyncio.to_thread(
|
||||||
|
|
|
||||||
|
|
@ -1,18 +1,132 @@
|
||||||
"""Multi-agent diagnose pipeline orchestrator — stub (Task 1)."""
|
"""Multi-agent diagnose pipeline orchestrator — Stage 1–5 wiring."""
|
||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import dataclasses
|
||||||
import logging
|
import logging
|
||||||
|
from collections.abc import AsyncGenerator
|
||||||
|
from pathlib import Path
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
|
from app.context.retriever import RetrievedContext
|
||||||
|
from app.services.diagnose.classifier import SeverityClassifier
|
||||||
|
from app.services.diagnose.hypothesizer import RootCauseHypothesizer
|
||||||
|
from app.services.diagnose.suppressor import FalsePositiveSuppressor
|
||||||
|
from app.services.diagnose.synthesizer import SummarySynthesizer
|
||||||
from app.services.diagnose.timeline import TimelineReconstructor
|
from app.services.diagnose.timeline import TimelineReconstructor
|
||||||
|
from app.services.search import SearchResult
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
# run_pipeline() will be implemented in Task 6
|
|
||||||
logger.debug("TimelineReconstructor available: %s", TimelineReconstructor)
|
|
||||||
|
|
||||||
|
async def run_pipeline(
|
||||||
|
db_path: Path,
|
||||||
|
entries: list[SearchResult],
|
||||||
|
ctx: RetrievedContext,
|
||||||
|
query: str,
|
||||||
|
since: str | None,
|
||||||
|
until: str | None,
|
||||||
|
llm_url: str | None,
|
||||||
|
llm_model: str | None,
|
||||||
|
llm_api_key: str | None,
|
||||||
|
) -> AsyncGenerator[dict[str, Any], None]:
|
||||||
|
"""Async generator that runs all 5 pipeline stages and yields SSE event dicts.
|
||||||
|
|
||||||
async def run_pipeline(*args: Any, **kwargs: Any) -> None:
|
Stages:
|
||||||
"""Placeholder — implemented in Task 6."""
|
1. TimelineReconstructor — cluster log entries by time
|
||||||
return None
|
2. SeverityClassifier — annotate clusters with severity
|
||||||
|
3. RootCauseHypothesizer — generate hypotheses via LLM
|
||||||
|
4. FalsePositiveSuppressor — rank and suppress known patterns
|
||||||
|
5. SummarySynthesizer — produce a narrative diagnosis
|
||||||
|
|
||||||
|
Yields events in order:
|
||||||
|
{"type": "status", "message": "Building timeline…"}
|
||||||
|
{"type": "pipeline_stage", "stage": 1, ...}
|
||||||
|
{"type": "pipeline_stage", "stage": 2, ...}
|
||||||
|
{"type": "pipeline_stage", "stage": 3, ...}
|
||||||
|
{"type": "pipeline_stage", "stage": 4, ...}
|
||||||
|
{"type": "hypotheses", "data": [...]}
|
||||||
|
{"type": "status", "message": "Synthesizing…"}
|
||||||
|
{"type": "reasoning", "text": "..."} — only when synthesis produces text
|
||||||
|
{"type": "done"}
|
||||||
|
"""
|
||||||
|
# Stage 1: Timeline reconstruction
|
||||||
|
yield {"type": "status", "message": "Building timeline…"}
|
||||||
|
timeline = await asyncio.to_thread(
|
||||||
|
TimelineReconstructor().reconstruct, entries
|
||||||
|
)
|
||||||
|
n_clusters = len(timeline.clusters)
|
||||||
|
burst = timeline.burst_count
|
||||||
|
yield {
|
||||||
|
"type": "pipeline_stage",
|
||||||
|
"stage": 1,
|
||||||
|
"name": "timeline",
|
||||||
|
"message": f"Built {n_clusters} clusters, {burst} bursts",
|
||||||
|
}
|
||||||
|
|
||||||
|
# Stage 2: Severity classification
|
||||||
|
classified = await asyncio.to_thread(
|
||||||
|
SeverityClassifier().classify, timeline
|
||||||
|
)
|
||||||
|
sev_counts: dict[str, int] = {}
|
||||||
|
for sev in classified.cluster_severities.values():
|
||||||
|
sev_counts[sev] = sev_counts.get(sev, 0) + 1
|
||||||
|
counts_str = ", ".join(f"{k}:{v}" for k, v in sorted(sev_counts.items()))
|
||||||
|
yield {
|
||||||
|
"type": "pipeline_stage",
|
||||||
|
"stage": 2,
|
||||||
|
"name": "classifier",
|
||||||
|
"message": f"{classified.classifier_used} classifier: {counts_str}",
|
||||||
|
}
|
||||||
|
|
||||||
|
# Stage 3: Root-cause hypotheses
|
||||||
|
hypotheses = await asyncio.to_thread(
|
||||||
|
RootCauseHypothesizer().hypothesize,
|
||||||
|
classified,
|
||||||
|
ctx,
|
||||||
|
query,
|
||||||
|
llm_url,
|
||||||
|
llm_model,
|
||||||
|
llm_api_key,
|
||||||
|
)
|
||||||
|
yield {
|
||||||
|
"type": "pipeline_stage",
|
||||||
|
"stage": 3,
|
||||||
|
"name": "hypotheses",
|
||||||
|
"message": f"{len(hypotheses)} hypotheses generated",
|
||||||
|
}
|
||||||
|
|
||||||
|
# Stage 4: False-positive suppression
|
||||||
|
ranked = await asyncio.to_thread(
|
||||||
|
FalsePositiveSuppressor().suppress, hypotheses, db_path
|
||||||
|
)
|
||||||
|
suppressed = sum(1 for rh in ranked if rh.suppress)
|
||||||
|
active = len(ranked) - suppressed
|
||||||
|
yield {
|
||||||
|
"type": "pipeline_stage",
|
||||||
|
"stage": 4,
|
||||||
|
"name": "suppressor",
|
||||||
|
"message": f"{suppressed} suppressed, {active} active",
|
||||||
|
}
|
||||||
|
yield {
|
||||||
|
"type": "hypotheses",
|
||||||
|
"data": [dataclasses.asdict(rh) for rh in ranked],
|
||||||
|
}
|
||||||
|
|
||||||
|
# Stage 5: Summary synthesis
|
||||||
|
yield {"type": "status", "message": "Synthesizing…"}
|
||||||
|
synthesis_text = await asyncio.to_thread(
|
||||||
|
SummarySynthesizer().synthesize,
|
||||||
|
ranked,
|
||||||
|
timeline,
|
||||||
|
ctx,
|
||||||
|
query,
|
||||||
|
llm_url,
|
||||||
|
llm_model,
|
||||||
|
llm_api_key,
|
||||||
|
)
|
||||||
|
if synthesis_text:
|
||||||
|
yield {"type": "reasoning", "text": synthesis_text}
|
||||||
|
|
||||||
|
yield {"type": "done"}
|
||||||
|
|
|
||||||
210
app/services/diagnose/synthesizer.py
Normal file
210
app/services/diagnose/synthesizer.py
Normal file
|
|
@ -0,0 +1,210 @@
|
||||||
|
"""Stage 5: Summary Synthesizer — deterministic narrative from ranked hypotheses.
|
||||||
|
|
||||||
|
Streaming upgrade (async SSE chunks) is tracked as a follow-up enhancement.
|
||||||
|
This implementation is synchronous to match the rest of the pipeline.
|
||||||
|
"""
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import logging
|
||||||
|
|
||||||
|
import httpx
|
||||||
|
|
||||||
|
from app.context.retriever import RetrievedContext
|
||||||
|
from app.services.diagnose.models import RankedHypothesis, TimelineResult
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
_SYSTEM_PROMPT = (
|
||||||
|
"You are a Linux sysadmin diagnosing a system incident. "
|
||||||
|
"Write a concise, actionable incident diagnosis.\n\n"
|
||||||
|
"Format your response exactly as:\n"
|
||||||
|
"1. VERDICT: [CRITICAL|ERROR|WARN|INFO] — <what happened> (<X>% confidence)\n"
|
||||||
|
"2. TIMELINE: <what the logs show in sequence, 2-3 sentences>\n"
|
||||||
|
"3. ROOT CAUSES:\n"
|
||||||
|
" - <hypothesis 1 title> (<confidence>%)\n"
|
||||||
|
" - <hypothesis 2 title> (<confidence>%)\n"
|
||||||
|
"4. RECOMMENDED ACTIONS:\n"
|
||||||
|
" - <action based on hypotheses>\n"
|
||||||
|
"5. INVESTIGATE FURTHER: <open questions, if any>"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _extract_content(resp_json: dict) -> str | None:
|
||||||
|
"""Pull text content from an OpenAI-compat chat completion response."""
|
||||||
|
choices = resp_json.get("choices") or []
|
||||||
|
if not choices:
|
||||||
|
return None
|
||||||
|
return (choices[0].get("message", {}).get("content") or "").strip() or None
|
||||||
|
|
||||||
|
|
||||||
|
def _build_hypothesis_block(ranked: list[RankedHypothesis]) -> str:
|
||||||
|
"""Build the hypothesis block for the prompt (non-suppressed only, top 3)."""
|
||||||
|
active = [rh for rh in ranked if not rh.suppress][:3]
|
||||||
|
if not active:
|
||||||
|
return "(none)"
|
||||||
|
lines: list[str] = []
|
||||||
|
for rh in active:
|
||||||
|
h = rh.hypothesis
|
||||||
|
conf_pct = int(h.confidence * 100)
|
||||||
|
similar = (
|
||||||
|
f"Yes — suppressed, {rh.suppression_reason}"
|
||||||
|
if rh.suppression_reason
|
||||||
|
else "No"
|
||||||
|
)
|
||||||
|
novelty = f"{rh.novelty_score:.2f}"
|
||||||
|
lines.append(
|
||||||
|
f"- [{h.severity}, {conf_pct}%] {h.title}\n"
|
||||||
|
f" Similar resolved incident? {similar} (novelty {novelty})"
|
||||||
|
)
|
||||||
|
return "\n".join(lines)
|
||||||
|
|
||||||
|
|
||||||
|
def _build_context_block(ctx: RetrievedContext) -> str:
|
||||||
|
"""Build the runbook context block for the prompt."""
|
||||||
|
parts: list[str] = []
|
||||||
|
for chunk in ctx.chunks[:5]:
|
||||||
|
filename = chunk.get("filename", "unknown")
|
||||||
|
text = chunk.get("text", "")[:300]
|
||||||
|
parts.append(f"[{filename}] {text}")
|
||||||
|
return "\n".join(parts) if parts else "(none)"
|
||||||
|
|
||||||
|
|
||||||
|
def _deterministic_fallback(
|
||||||
|
ranked: list[RankedHypothesis],
|
||||||
|
timeline: TimelineResult,
|
||||||
|
) -> str:
|
||||||
|
"""Build a deterministic fallback text when no LLM is available."""
|
||||||
|
active = [rh for rh in ranked if not rh.suppress][:3]
|
||||||
|
if active:
|
||||||
|
top = active[0]
|
||||||
|
verdict_severity = top.hypothesis.severity
|
||||||
|
verdict_title = top.hypothesis.title
|
||||||
|
verdict_conf = int(top.hypothesis.confidence * 100)
|
||||||
|
elif ranked:
|
||||||
|
top = ranked[0]
|
||||||
|
verdict_severity = top.hypothesis.severity
|
||||||
|
verdict_title = top.hypothesis.title
|
||||||
|
verdict_conf = int(top.hypothesis.confidence * 100)
|
||||||
|
else:
|
||||||
|
verdict_severity = "UNKNOWN"
|
||||||
|
verdict_title = "No hypotheses generated"
|
||||||
|
verdict_conf = 0
|
||||||
|
|
||||||
|
root_causes = ", ".join(
|
||||||
|
rh.hypothesis.title for rh in (active or ranked[:3])
|
||||||
|
) or "None"
|
||||||
|
|
||||||
|
return (
|
||||||
|
f"VERDICT: {verdict_severity} — {verdict_title} ({verdict_conf}% confidence)\n"
|
||||||
|
f"TIMELINE: {timeline.total_entries} entries across {len(timeline.clusters)} clusters.\n"
|
||||||
|
f"ROOT CAUSES: {root_causes}"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class SummarySynthesizer:
|
||||||
|
"""Stage 5 of the multi-agent diagnose pipeline.
|
||||||
|
|
||||||
|
Synthesizes a human-readable incident narrative from ranked hypotheses,
|
||||||
|
the reconstructed timeline, and RAG context. When no LLM is configured,
|
||||||
|
returns a deterministic fallback built from the hypothesis data.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def synthesize(
|
||||||
|
self,
|
||||||
|
ranked: list[RankedHypothesis],
|
||||||
|
timeline: TimelineResult,
|
||||||
|
ctx: RetrievedContext,
|
||||||
|
query: str,
|
||||||
|
llm_url: str | None = None,
|
||||||
|
llm_model: str | None = None,
|
||||||
|
llm_api_key: str | None = None,
|
||||||
|
) -> str:
|
||||||
|
"""Return synthesis text (single string, synchronous).
|
||||||
|
|
||||||
|
Falls back to a deterministic narrative when no LLM URL or model is
|
||||||
|
provided, or when the LLM call fails.
|
||||||
|
"""
|
||||||
|
fallback = _deterministic_fallback(ranked, timeline)
|
||||||
|
|
||||||
|
if not llm_url or not llm_model:
|
||||||
|
return fallback
|
||||||
|
|
||||||
|
hypothesis_block = _build_hypothesis_block(ranked)
|
||||||
|
context_block = _build_context_block(ctx)
|
||||||
|
dominant = ", ".join(timeline.dominant_sources[:5]) or "none"
|
||||||
|
|
||||||
|
user_message = (
|
||||||
|
f"Query: {query}\n\n"
|
||||||
|
f"Timeline summary:\n"
|
||||||
|
f"- {len(timeline.clusters)} clusters, "
|
||||||
|
f"{timeline.burst_count} bursts, "
|
||||||
|
f"{timeline.gap_count} silence gaps\n"
|
||||||
|
f"- Primary sources: {dominant}\n\n"
|
||||||
|
f"Top hypotheses:\n{hypothesis_block}\n\n"
|
||||||
|
f"Context from runbooks:\n{context_block}"
|
||||||
|
)
|
||||||
|
|
||||||
|
messages = [
|
||||||
|
{"role": "system", "content": _SYSTEM_PROMPT},
|
||||||
|
{"role": "user", "content": user_message},
|
||||||
|
]
|
||||||
|
|
||||||
|
result = self._call_llm(
|
||||||
|
llm_url=llm_url,
|
||||||
|
llm_model=llm_model,
|
||||||
|
llm_api_key=llm_api_key,
|
||||||
|
messages=messages,
|
||||||
|
)
|
||||||
|
return result if result else fallback
|
||||||
|
|
||||||
|
def _call_llm(
|
||||||
|
self,
|
||||||
|
llm_url: str,
|
||||||
|
llm_model: str,
|
||||||
|
llm_api_key: str | None,
|
||||||
|
messages: list[dict],
|
||||||
|
) -> str | None:
|
||||||
|
"""Send messages to the LLM and return raw text content.
|
||||||
|
|
||||||
|
Tries the cf-orch task endpoint first, falls back to direct OpenAI-compat.
|
||||||
|
"""
|
||||||
|
headers = {"Authorization": f"Bearer {llm_api_key}"} if llm_api_key else {}
|
||||||
|
|
||||||
|
task_url = f"{llm_url.rstrip('/')}/api/inference/task"
|
||||||
|
try:
|
||||||
|
resp = httpx.post(
|
||||||
|
task_url,
|
||||||
|
json={
|
||||||
|
"product": "turnstone",
|
||||||
|
"task": "log_analysis",
|
||||||
|
"payload": {"messages": messages, "stream": False},
|
||||||
|
},
|
||||||
|
headers=headers,
|
||||||
|
timeout=120.0,
|
||||||
|
)
|
||||||
|
if resp.status_code == 200:
|
||||||
|
return _extract_content(resp.json())
|
||||||
|
if resp.status_code != 404:
|
||||||
|
resp.raise_for_status()
|
||||||
|
logger.debug(
|
||||||
|
"No task assignment for turnstone.log_analysis — falling back to direct model"
|
||||||
|
)
|
||||||
|
except Exception as exc:
|
||||||
|
logger.debug(
|
||||||
|
"Task endpoint unavailable (%s) — falling back to direct model", exc
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
resp = httpx.post(
|
||||||
|
f"{llm_url.rstrip('/')}/v1/chat/completions",
|
||||||
|
json={"model": llm_model, "messages": messages, "stream": False},
|
||||||
|
headers=headers,
|
||||||
|
timeout=120.0,
|
||||||
|
)
|
||||||
|
resp.raise_for_status()
|
||||||
|
return _extract_content(resp.json())
|
||||||
|
except Exception as exc:
|
||||||
|
logger.warning(
|
||||||
|
"LLM synthesizer failed (%s): %s", type(exc).__name__, exc
|
||||||
|
)
|
||||||
|
return None
|
||||||
489
tests/test_diagnose_pipeline.py
Normal file
489
tests/test_diagnose_pipeline.py
Normal file
|
|
@ -0,0 +1,489 @@
|
||||||
|
"""Tests for app/services/diagnose/pipeline.py and __init__.py feature flag wiring.
|
||||||
|
|
||||||
|
All tests use mocking; no real LLM, ML, or DB calls are made.
|
||||||
|
"""
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Any
|
||||||
|
from unittest.mock import MagicMock, patch
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from app.context.retriever import RetrievedContext
|
||||||
|
from app.services.diagnose.models import (
|
||||||
|
ClassifiedTimeline,
|
||||||
|
Hypothesis,
|
||||||
|
RankedHypothesis,
|
||||||
|
TimelineResult,
|
||||||
|
)
|
||||||
|
from app.services.search import SearchResult
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Shared helpers
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
def _make_search_result(
|
||||||
|
entry_id: str = "e1",
|
||||||
|
source_id: str = "syslog",
|
||||||
|
timestamp_iso: str | None = "2026-01-01T00:00:00+00:00",
|
||||||
|
severity: str | None = "ERROR",
|
||||||
|
text: str = "ssh: invalid user",
|
||||||
|
) -> SearchResult:
|
||||||
|
return SearchResult(
|
||||||
|
entry_id=entry_id,
|
||||||
|
source_id=source_id,
|
||||||
|
sequence=1,
|
||||||
|
timestamp_iso=timestamp_iso,
|
||||||
|
severity=severity,
|
||||||
|
repeat_count=1,
|
||||||
|
out_of_order=False,
|
||||||
|
matched_patterns=["ssh_fail"],
|
||||||
|
text=text,
|
||||||
|
rank=1.0,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _make_ctx() -> RetrievedContext:
|
||||||
|
return RetrievedContext(facts=[], chunks=[])
|
||||||
|
|
||||||
|
|
||||||
|
def _make_timeline(n_clusters: int = 2) -> TimelineResult:
|
||||||
|
return TimelineResult(
|
||||||
|
clusters=tuple(),
|
||||||
|
total_entries=5,
|
||||||
|
window_start="2026-01-01T00:00:00+00:00",
|
||||||
|
window_end="2026-01-01T01:00:00+00:00",
|
||||||
|
gap_count=0,
|
||||||
|
burst_count=1,
|
||||||
|
dominant_sources=("syslog",),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _make_classified(timeline: TimelineResult | None = None) -> ClassifiedTimeline:
|
||||||
|
tl = timeline or _make_timeline()
|
||||||
|
return ClassifiedTimeline(
|
||||||
|
timeline=tl,
|
||||||
|
cluster_severities={},
|
||||||
|
classifier_used="regex",
|
||||||
|
model_id=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _make_hypothesis(
|
||||||
|
hypothesis_id: str = "h1",
|
||||||
|
title: str = "SSH flood",
|
||||||
|
confidence: float = 0.87,
|
||||||
|
severity: str = "CRITICAL",
|
||||||
|
) -> Hypothesis:
|
||||||
|
return Hypothesis(
|
||||||
|
hypothesis_id=hypothesis_id,
|
||||||
|
title=title,
|
||||||
|
description="Multiple failed SSH attempts.",
|
||||||
|
confidence=confidence,
|
||||||
|
supporting_cluster_ids=("c1",),
|
||||||
|
runbook_refs=(),
|
||||||
|
severity=severity, # type: ignore[arg-type]
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _make_ranked(hypothesis: Hypothesis | None = None, suppress: bool = False) -> RankedHypothesis:
|
||||||
|
h = hypothesis or _make_hypothesis()
|
||||||
|
return RankedHypothesis(
|
||||||
|
hypothesis=h,
|
||||||
|
novelty_score=0.95,
|
||||||
|
similarity_to_known=0.05,
|
||||||
|
suppress=suppress,
|
||||||
|
suppression_reason="similar to known" if suppress else None,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Helper: collect all events from run_pipeline
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
async def _collect_pipeline_events(**kwargs) -> list[dict[str, Any]]:
|
||||||
|
"""Run run_pipeline and collect all yielded events into a list."""
|
||||||
|
from app.services.diagnose.pipeline import run_pipeline
|
||||||
|
events = []
|
||||||
|
async for event in run_pipeline(**kwargs):
|
||||||
|
events.append(event)
|
||||||
|
return events
|
||||||
|
|
||||||
|
|
||||||
|
def _default_pipeline_kwargs(entries=None, db_path=None) -> dict:
|
||||||
|
return dict(
|
||||||
|
db_path=db_path or Path("/tmp/fake.db"),
|
||||||
|
entries=entries or [_make_search_result()],
|
||||||
|
ctx=_make_ctx(),
|
||||||
|
query="ssh brute force",
|
||||||
|
since="2026-01-01T00:00:00+00:00",
|
||||||
|
until="2026-01-01T01:00:00+00:00",
|
||||||
|
llm_url=None,
|
||||||
|
llm_model=None,
|
||||||
|
llm_api_key=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Mock factories for all 5 stage classes
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
def _mock_all_stages(
|
||||||
|
hypotheses=None,
|
||||||
|
ranked=None,
|
||||||
|
synthesis_text="VERDICT: CRITICAL — SSH flood (87% confidence)",
|
||||||
|
):
|
||||||
|
"""Return a dict of patch targets and their mock return values."""
|
||||||
|
timeline = _make_timeline()
|
||||||
|
classified = _make_classified(timeline)
|
||||||
|
hyps = hypotheses if hypotheses is not None else [_make_hypothesis()]
|
||||||
|
rnk = ranked if ranked is not None else [_make_ranked()]
|
||||||
|
|
||||||
|
mock_reconstructor = MagicMock()
|
||||||
|
mock_reconstructor.return_value.reconstruct.return_value = timeline
|
||||||
|
|
||||||
|
mock_classifier = MagicMock()
|
||||||
|
mock_classifier.return_value.classify.return_value = classified
|
||||||
|
|
||||||
|
mock_hypothesizer = MagicMock()
|
||||||
|
mock_hypothesizer.return_value.hypothesize.return_value = hyps
|
||||||
|
|
||||||
|
mock_suppressor = MagicMock()
|
||||||
|
mock_suppressor.return_value.suppress.return_value = rnk
|
||||||
|
|
||||||
|
mock_synthesizer = MagicMock()
|
||||||
|
mock_synthesizer.return_value.synthesize.return_value = synthesis_text
|
||||||
|
|
||||||
|
return {
|
||||||
|
"app.services.diagnose.pipeline.TimelineReconstructor": mock_reconstructor,
|
||||||
|
"app.services.diagnose.pipeline.SeverityClassifier": mock_classifier,
|
||||||
|
"app.services.diagnose.pipeline.RootCauseHypothesizer": mock_hypothesizer,
|
||||||
|
"app.services.diagnose.pipeline.FalsePositiveSuppressor": mock_suppressor,
|
||||||
|
"app.services.diagnose.pipeline.SummarySynthesizer": mock_synthesizer,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# 1. Feature flag off: legacy summarize() path runs, not run_pipeline
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
class TestFeatureFlagOff:
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_legacy_path_when_flag_off(self):
|
||||||
|
"""With MULTI_AGENT_ENABLED=False, run_pipeline is never called."""
|
||||||
|
from app.services import diagnose as diagnose_module
|
||||||
|
|
||||||
|
entries = [_make_search_result()]
|
||||||
|
|
||||||
|
with (
|
||||||
|
patch.object(diagnose_module, "MULTI_AGENT_ENABLED", False),
|
||||||
|
patch("app.services.diagnose.search", return_value=entries),
|
||||||
|
patch("app.services.diagnose.entries_in_window", return_value=[]),
|
||||||
|
patch("app.services.diagnose.retrieve_context", return_value=_make_ctx()),
|
||||||
|
patch("app.services.diagnose.format_context_block", return_value=None),
|
||||||
|
patch("app.services.diagnose.run_pipeline") as mock_pipeline,
|
||||||
|
patch("app.services.diagnose.summarize", return_value=None),
|
||||||
|
):
|
||||||
|
events = []
|
||||||
|
async for event in diagnose_module.diagnose_stream(
|
||||||
|
db_path=Path("/tmp/fake.db"),
|
||||||
|
query="ssh failures",
|
||||||
|
llm_url=None,
|
||||||
|
llm_model=None,
|
||||||
|
):
|
||||||
|
events.append(event)
|
||||||
|
|
||||||
|
# run_pipeline must NOT have been called
|
||||||
|
mock_pipeline.assert_not_called()
|
||||||
|
|
||||||
|
# SSE sequence must end with done
|
||||||
|
types = [e["type"] for e in events]
|
||||||
|
assert "done" in types
|
||||||
|
assert types[-1] == "done"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_legacy_done_event_is_last(self):
|
||||||
|
"""Legacy path: done is always the last event."""
|
||||||
|
from app.services import diagnose as diagnose_module
|
||||||
|
|
||||||
|
with (
|
||||||
|
patch.object(diagnose_module, "MULTI_AGENT_ENABLED", False),
|
||||||
|
patch("app.services.diagnose.search", return_value=[]),
|
||||||
|
patch("app.services.diagnose.entries_in_window", return_value=[]),
|
||||||
|
patch("app.services.diagnose.retrieve_context", return_value=_make_ctx()),
|
||||||
|
patch("app.services.diagnose.format_context_block", return_value=None),
|
||||||
|
):
|
||||||
|
events = []
|
||||||
|
async for event in diagnose_module.diagnose_stream(
|
||||||
|
db_path=Path("/tmp/fake.db"),
|
||||||
|
query="check logs",
|
||||||
|
):
|
||||||
|
events.append(event)
|
||||||
|
|
||||||
|
assert events[-1] == {"type": "done"}
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# 2. Feature flag on, all stages mocked: verify SSE event sequence
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
class TestFeatureFlagOn:
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_pipeline_stage_events_in_order(self):
|
||||||
|
"""pipeline_stage events must be emitted stages 1→2→3→4 in order."""
|
||||||
|
mocks = _mock_all_stages()
|
||||||
|
kwargs = _default_pipeline_kwargs()
|
||||||
|
with (
|
||||||
|
patch("app.services.diagnose.pipeline.TimelineReconstructor", mocks["app.services.diagnose.pipeline.TimelineReconstructor"]),
|
||||||
|
patch("app.services.diagnose.pipeline.SeverityClassifier", mocks["app.services.diagnose.pipeline.SeverityClassifier"]),
|
||||||
|
patch("app.services.diagnose.pipeline.RootCauseHypothesizer", mocks["app.services.diagnose.pipeline.RootCauseHypothesizer"]),
|
||||||
|
patch("app.services.diagnose.pipeline.FalsePositiveSuppressor", mocks["app.services.diagnose.pipeline.FalsePositiveSuppressor"]),
|
||||||
|
patch("app.services.diagnose.pipeline.SummarySynthesizer", mocks["app.services.diagnose.pipeline.SummarySynthesizer"]),
|
||||||
|
):
|
||||||
|
events = await _collect_pipeline_events(**kwargs)
|
||||||
|
|
||||||
|
stage_events = [e for e in events if e.get("type") == "pipeline_stage"]
|
||||||
|
stages = [e["stage"] for e in stage_events]
|
||||||
|
assert stages == [1, 2, 3, 4]
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_hypotheses_event_after_stage4(self):
|
||||||
|
"""hypotheses event must appear after pipeline_stage stage=4."""
|
||||||
|
mocks = _mock_all_stages()
|
||||||
|
kwargs = _default_pipeline_kwargs()
|
||||||
|
|
||||||
|
with (
|
||||||
|
patch("app.services.diagnose.pipeline.TimelineReconstructor", mocks["app.services.diagnose.pipeline.TimelineReconstructor"]),
|
||||||
|
patch("app.services.diagnose.pipeline.SeverityClassifier", mocks["app.services.diagnose.pipeline.SeverityClassifier"]),
|
||||||
|
patch("app.services.diagnose.pipeline.RootCauseHypothesizer", mocks["app.services.diagnose.pipeline.RootCauseHypothesizer"]),
|
||||||
|
patch("app.services.diagnose.pipeline.FalsePositiveSuppressor", mocks["app.services.diagnose.pipeline.FalsePositiveSuppressor"]),
|
||||||
|
patch("app.services.diagnose.pipeline.SummarySynthesizer", mocks["app.services.diagnose.pipeline.SummarySynthesizer"]),
|
||||||
|
):
|
||||||
|
events = await _collect_pipeline_events(**kwargs)
|
||||||
|
|
||||||
|
stage4_idx = next(
|
||||||
|
i for i, e in enumerate(events)
|
||||||
|
if e.get("type") == "pipeline_stage" and e.get("stage") == 4
|
||||||
|
)
|
||||||
|
hyp_idx = next(i for i, e in enumerate(events) if e.get("type") == "hypotheses")
|
||||||
|
assert hyp_idx > stage4_idx
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_reasoning_event_emitted(self):
|
||||||
|
"""reasoning event must be present when synthesizer returns text."""
|
||||||
|
mocks = _mock_all_stages(synthesis_text="VERDICT: CRITICAL — SSH flood")
|
||||||
|
kwargs = _default_pipeline_kwargs()
|
||||||
|
|
||||||
|
with (
|
||||||
|
patch("app.services.diagnose.pipeline.TimelineReconstructor", mocks["app.services.diagnose.pipeline.TimelineReconstructor"]),
|
||||||
|
patch("app.services.diagnose.pipeline.SeverityClassifier", mocks["app.services.diagnose.pipeline.SeverityClassifier"]),
|
||||||
|
patch("app.services.diagnose.pipeline.RootCauseHypothesizer", mocks["app.services.diagnose.pipeline.RootCauseHypothesizer"]),
|
||||||
|
patch("app.services.diagnose.pipeline.FalsePositiveSuppressor", mocks["app.services.diagnose.pipeline.FalsePositiveSuppressor"]),
|
||||||
|
patch("app.services.diagnose.pipeline.SummarySynthesizer", mocks["app.services.diagnose.pipeline.SummarySynthesizer"]),
|
||||||
|
):
|
||||||
|
events = await _collect_pipeline_events(**kwargs)
|
||||||
|
|
||||||
|
reasoning_events = [e for e in events if e.get("type") == "reasoning"]
|
||||||
|
assert len(reasoning_events) == 1
|
||||||
|
assert "VERDICT" in reasoning_events[0]["text"]
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_done_event_is_last(self):
|
||||||
|
"""done must always be the last event in the pipeline sequence."""
|
||||||
|
mocks = _mock_all_stages()
|
||||||
|
kwargs = _default_pipeline_kwargs()
|
||||||
|
|
||||||
|
with (
|
||||||
|
patch("app.services.diagnose.pipeline.TimelineReconstructor", mocks["app.services.diagnose.pipeline.TimelineReconstructor"]),
|
||||||
|
patch("app.services.diagnose.pipeline.SeverityClassifier", mocks["app.services.diagnose.pipeline.SeverityClassifier"]),
|
||||||
|
patch("app.services.diagnose.pipeline.RootCauseHypothesizer", mocks["app.services.diagnose.pipeline.RootCauseHypothesizer"]),
|
||||||
|
patch("app.services.diagnose.pipeline.FalsePositiveSuppressor", mocks["app.services.diagnose.pipeline.FalsePositiveSuppressor"]),
|
||||||
|
patch("app.services.diagnose.pipeline.SummarySynthesizer", mocks["app.services.diagnose.pipeline.SummarySynthesizer"]),
|
||||||
|
):
|
||||||
|
events = await _collect_pipeline_events(**kwargs)
|
||||||
|
|
||||||
|
assert events[-1] == {"type": "done"}
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_pipeline_wired_from_diagnose_stream(self):
|
||||||
|
"""diagnose_stream routes through run_pipeline when flag is on."""
|
||||||
|
from app.services import diagnose as diagnose_module
|
||||||
|
|
||||||
|
entries = [_make_search_result()]
|
||||||
|
|
||||||
|
async def fake_pipeline(**kwargs):
|
||||||
|
yield {"type": "status", "message": "Building timeline…"}
|
||||||
|
yield {"type": "pipeline_stage", "stage": 1, "name": "timeline", "message": "Built 1 clusters, 0 bursts"}
|
||||||
|
yield {"type": "done"}
|
||||||
|
|
||||||
|
with (
|
||||||
|
patch.object(diagnose_module, "MULTI_AGENT_ENABLED", True),
|
||||||
|
patch("app.services.diagnose.search", return_value=entries),
|
||||||
|
patch("app.services.diagnose.entries_in_window", return_value=[]),
|
||||||
|
patch("app.services.diagnose.retrieve_context", return_value=_make_ctx()),
|
||||||
|
patch("app.services.diagnose.format_context_block", return_value=None),
|
||||||
|
patch("app.services.diagnose.run_pipeline", side_effect=fake_pipeline),
|
||||||
|
):
|
||||||
|
events = []
|
||||||
|
async for event in diagnose_module.diagnose_stream(
|
||||||
|
db_path=Path("/tmp/fake.db"),
|
||||||
|
query="ssh failures",
|
||||||
|
):
|
||||||
|
events.append(event)
|
||||||
|
|
||||||
|
types = [e["type"] for e in events]
|
||||||
|
assert "pipeline_stage" in types
|
||||||
|
assert types[-1] == "done"
|
||||||
|
# Legacy summarize() must NOT have been called — done event came from pipeline
|
||||||
|
assert types.count("done") == 1
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# 3. Empty entries: pipeline completes with done
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
class TestEmptyEntries:
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_empty_entries_pipeline_completes(self):
|
||||||
|
"""Pipeline with entries=[] must still complete and emit done."""
|
||||||
|
mocks = _mock_all_stages(hypotheses=[], ranked=[])
|
||||||
|
kwargs = _default_pipeline_kwargs(entries=[])
|
||||||
|
|
||||||
|
with (
|
||||||
|
patch("app.services.diagnose.pipeline.TimelineReconstructor", mocks["app.services.diagnose.pipeline.TimelineReconstructor"]),
|
||||||
|
patch("app.services.diagnose.pipeline.SeverityClassifier", mocks["app.services.diagnose.pipeline.SeverityClassifier"]),
|
||||||
|
patch("app.services.diagnose.pipeline.RootCauseHypothesizer", mocks["app.services.diagnose.pipeline.RootCauseHypothesizer"]),
|
||||||
|
patch("app.services.diagnose.pipeline.FalsePositiveSuppressor", mocks["app.services.diagnose.pipeline.FalsePositiveSuppressor"]),
|
||||||
|
patch("app.services.diagnose.pipeline.SummarySynthesizer", mocks["app.services.diagnose.pipeline.SummarySynthesizer"]),
|
||||||
|
):
|
||||||
|
events = await _collect_pipeline_events(**kwargs)
|
||||||
|
|
||||||
|
types = [e["type"] for e in events]
|
||||||
|
assert "done" in types
|
||||||
|
assert types[-1] == "done"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_empty_entries_all_stage_events_present(self):
|
||||||
|
"""Even with empty entries, all 4 pipeline_stage events are emitted."""
|
||||||
|
mocks = _mock_all_stages(hypotheses=[], ranked=[])
|
||||||
|
kwargs = _default_pipeline_kwargs(entries=[])
|
||||||
|
|
||||||
|
with (
|
||||||
|
patch("app.services.diagnose.pipeline.TimelineReconstructor", mocks["app.services.diagnose.pipeline.TimelineReconstructor"]),
|
||||||
|
patch("app.services.diagnose.pipeline.SeverityClassifier", mocks["app.services.diagnose.pipeline.SeverityClassifier"]),
|
||||||
|
patch("app.services.diagnose.pipeline.RootCauseHypothesizer", mocks["app.services.diagnose.pipeline.RootCauseHypothesizer"]),
|
||||||
|
patch("app.services.diagnose.pipeline.FalsePositiveSuppressor", mocks["app.services.diagnose.pipeline.FalsePositiveSuppressor"]),
|
||||||
|
patch("app.services.diagnose.pipeline.SummarySynthesizer", mocks["app.services.diagnose.pipeline.SummarySynthesizer"]),
|
||||||
|
):
|
||||||
|
events = await _collect_pipeline_events(**kwargs)
|
||||||
|
|
||||||
|
stage_events = [e for e in events if e.get("type") == "pipeline_stage"]
|
||||||
|
assert len(stage_events) == 4
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# 4. No LLM: Stage 3 and Stage 5 return empty/fallback; done still emitted
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
class TestNoLLM:
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_no_llm_pipeline_completes_with_done(self):
|
||||||
|
"""No llm_url/llm_model: pipeline runs all stages and emits done."""
|
||||||
|
mocks = _mock_all_stages(hypotheses=[], ranked=[], synthesis_text="VERDICT: UNKNOWN — no hypotheses generated")
|
||||||
|
kwargs = _default_pipeline_kwargs()
|
||||||
|
# llm_url and llm_model already None in default kwargs
|
||||||
|
|
||||||
|
with (
|
||||||
|
patch("app.services.diagnose.pipeline.TimelineReconstructor", mocks["app.services.diagnose.pipeline.TimelineReconstructor"]),
|
||||||
|
patch("app.services.diagnose.pipeline.SeverityClassifier", mocks["app.services.diagnose.pipeline.SeverityClassifier"]),
|
||||||
|
patch("app.services.diagnose.pipeline.RootCauseHypothesizer", mocks["app.services.diagnose.pipeline.RootCauseHypothesizer"]),
|
||||||
|
patch("app.services.diagnose.pipeline.FalsePositiveSuppressor", mocks["app.services.diagnose.pipeline.FalsePositiveSuppressor"]),
|
||||||
|
patch("app.services.diagnose.pipeline.SummarySynthesizer", mocks["app.services.diagnose.pipeline.SummarySynthesizer"]),
|
||||||
|
):
|
||||||
|
events = await _collect_pipeline_events(**kwargs)
|
||||||
|
|
||||||
|
assert events[-1] == {"type": "done"}
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_no_llm_no_reasoning_event_when_synthesis_empty(self):
|
||||||
|
"""When synthesizer returns empty string, no reasoning event is emitted."""
|
||||||
|
mocks = _mock_all_stages(synthesis_text="")
|
||||||
|
kwargs = _default_pipeline_kwargs()
|
||||||
|
|
||||||
|
with (
|
||||||
|
patch("app.services.diagnose.pipeline.TimelineReconstructor", mocks["app.services.diagnose.pipeline.TimelineReconstructor"]),
|
||||||
|
patch("app.services.diagnose.pipeline.SeverityClassifier", mocks["app.services.diagnose.pipeline.SeverityClassifier"]),
|
||||||
|
patch("app.services.diagnose.pipeline.RootCauseHypothesizer", mocks["app.services.diagnose.pipeline.RootCauseHypothesizer"]),
|
||||||
|
patch("app.services.diagnose.pipeline.FalsePositiveSuppressor", mocks["app.services.diagnose.pipeline.FalsePositiveSuppressor"]),
|
||||||
|
patch("app.services.diagnose.pipeline.SummarySynthesizer", mocks["app.services.diagnose.pipeline.SummarySynthesizer"]),
|
||||||
|
):
|
||||||
|
events = await _collect_pipeline_events(**kwargs)
|
||||||
|
|
||||||
|
reasoning_events = [e for e in events if e.get("type") == "reasoning"]
|
||||||
|
assert len(reasoning_events) == 0
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# 5. Stage 1 cluster count in pipeline_stage message
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
class TestStage1Message:
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_stage1_message_contains_cluster_count(self):
|
||||||
|
"""pipeline_stage stage=1 message must report cluster count."""
|
||||||
|
timeline = TimelineResult(
|
||||||
|
clusters=tuple(),
|
||||||
|
total_entries=10,
|
||||||
|
window_start=None,
|
||||||
|
window_end=None,
|
||||||
|
gap_count=0,
|
||||||
|
burst_count=3,
|
||||||
|
dominant_sources=("syslog",),
|
||||||
|
)
|
||||||
|
classified = _make_classified(timeline)
|
||||||
|
|
||||||
|
mock_reconstructor = MagicMock()
|
||||||
|
mock_reconstructor.return_value.reconstruct.return_value = timeline
|
||||||
|
mock_classifier = MagicMock()
|
||||||
|
mock_classifier.return_value.classify.return_value = classified
|
||||||
|
mock_hypothesizer = MagicMock()
|
||||||
|
mock_hypothesizer.return_value.hypothesize.return_value = []
|
||||||
|
mock_suppressor = MagicMock()
|
||||||
|
mock_suppressor.return_value.suppress.return_value = []
|
||||||
|
mock_synthesizer = MagicMock()
|
||||||
|
mock_synthesizer.return_value.synthesize.return_value = "VERDICT: INFO — nothing found"
|
||||||
|
|
||||||
|
kwargs = _default_pipeline_kwargs()
|
||||||
|
with (
|
||||||
|
patch("app.services.diagnose.pipeline.TimelineReconstructor", mock_reconstructor),
|
||||||
|
patch("app.services.diagnose.pipeline.SeverityClassifier", mock_classifier),
|
||||||
|
patch("app.services.diagnose.pipeline.RootCauseHypothesizer", mock_hypothesizer),
|
||||||
|
patch("app.services.diagnose.pipeline.FalsePositiveSuppressor", mock_suppressor),
|
||||||
|
patch("app.services.diagnose.pipeline.SummarySynthesizer", mock_synthesizer),
|
||||||
|
):
|
||||||
|
events = await _collect_pipeline_events(**kwargs)
|
||||||
|
|
||||||
|
stage1 = next(e for e in events if e.get("type") == "pipeline_stage" and e.get("stage") == 1)
|
||||||
|
# 0 clusters (empty tuple), 3 bursts
|
||||||
|
assert "0" in stage1["message"] # cluster count
|
||||||
|
assert "3" in stage1["message"] # burst count
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_stage1_name_is_timeline(self):
|
||||||
|
"""pipeline_stage stage=1 must have name='timeline'."""
|
||||||
|
mocks = _mock_all_stages()
|
||||||
|
kwargs = _default_pipeline_kwargs()
|
||||||
|
|
||||||
|
with (
|
||||||
|
patch("app.services.diagnose.pipeline.TimelineReconstructor", mocks["app.services.diagnose.pipeline.TimelineReconstructor"]),
|
||||||
|
patch("app.services.diagnose.pipeline.SeverityClassifier", mocks["app.services.diagnose.pipeline.SeverityClassifier"]),
|
||||||
|
patch("app.services.diagnose.pipeline.RootCauseHypothesizer", mocks["app.services.diagnose.pipeline.RootCauseHypothesizer"]),
|
||||||
|
patch("app.services.diagnose.pipeline.FalsePositiveSuppressor", mocks["app.services.diagnose.pipeline.FalsePositiveSuppressor"]),
|
||||||
|
patch("app.services.diagnose.pipeline.SummarySynthesizer", mocks["app.services.diagnose.pipeline.SummarySynthesizer"]),
|
||||||
|
):
|
||||||
|
events = await _collect_pipeline_events(**kwargs)
|
||||||
|
|
||||||
|
stage1 = next(e for e in events if e.get("type") == "pipeline_stage" and e.get("stage") == 1)
|
||||||
|
assert stage1["name"] == "timeline"
|
||||||
285
tests/test_diagnose_synthesizer.py
Normal file
285
tests/test_diagnose_synthesizer.py
Normal file
|
|
@ -0,0 +1,285 @@
|
||||||
|
"""Tests for app/services/diagnose/synthesizer.py — SummarySynthesizer.
|
||||||
|
|
||||||
|
All tests use mocking; no real LLM calls are made.
|
||||||
|
"""
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from unittest.mock import MagicMock, patch
|
||||||
|
|
||||||
|
from app.context.retriever import RetrievedContext
|
||||||
|
from app.services.diagnose.models import Hypothesis, RankedHypothesis, TimelineResult
|
||||||
|
from app.services.diagnose.synthesizer import SummarySynthesizer
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Fixture helpers
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
def _make_hypothesis(
|
||||||
|
hypothesis_id: str = "h1",
|
||||||
|
title: str = "SSH flood from external IPs",
|
||||||
|
description: str = "Repeated failed login attempts from multiple IPs.",
|
||||||
|
confidence: float = 0.87,
|
||||||
|
severity: str = "CRITICAL",
|
||||||
|
) -> Hypothesis:
|
||||||
|
return Hypothesis(
|
||||||
|
hypothesis_id=hypothesis_id,
|
||||||
|
title=title,
|
||||||
|
description=description,
|
||||||
|
confidence=confidence,
|
||||||
|
supporting_cluster_ids=("c1",),
|
||||||
|
runbook_refs=(),
|
||||||
|
severity=severity, # type: ignore[arg-type]
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _make_ranked(
|
||||||
|
hypothesis: Hypothesis | None = None,
|
||||||
|
novelty_score: float = 0.95,
|
||||||
|
similarity_to_known: float = 0.05,
|
||||||
|
suppress: bool = False,
|
||||||
|
suppression_reason: str | None = None,
|
||||||
|
) -> RankedHypothesis:
|
||||||
|
h = hypothesis or _make_hypothesis()
|
||||||
|
return RankedHypothesis(
|
||||||
|
hypothesis=h,
|
||||||
|
novelty_score=novelty_score,
|
||||||
|
similarity_to_known=similarity_to_known,
|
||||||
|
suppress=suppress,
|
||||||
|
suppression_reason=suppression_reason,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _make_timeline(
|
||||||
|
total_entries: int = 42,
|
||||||
|
n_clusters: int = 3,
|
||||||
|
) -> TimelineResult:
|
||||||
|
return TimelineResult(
|
||||||
|
clusters=tuple(),
|
||||||
|
total_entries=total_entries,
|
||||||
|
window_start="2026-01-01T00:00:00+00:00",
|
||||||
|
window_end="2026-01-01T01:00:00+00:00",
|
||||||
|
gap_count=1,
|
||||||
|
burst_count=2,
|
||||||
|
dominant_sources=("syslog", "auth"),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _make_ctx(chunks: list[dict] | None = None) -> RetrievedContext:
|
||||||
|
return RetrievedContext(
|
||||||
|
facts=[{"category": "network", "key": "host", "value": "heimdall", "source": "facts"}],
|
||||||
|
chunks=chunks or [{"filename": "runbook.md", "text": "Restart sshd if flooded"}],
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Test cases
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
class TestSynthesizerWithHypotheses:
|
||||||
|
"""With hypotheses, result must contain VERDICT."""
|
||||||
|
|
||||||
|
def test_returns_verdict_string_with_llm(self):
|
||||||
|
synthesizer = SummarySynthesizer()
|
||||||
|
ranked = [_make_ranked()]
|
||||||
|
timeline = _make_timeline()
|
||||||
|
ctx = _make_ctx()
|
||||||
|
|
||||||
|
mock_resp = MagicMock()
|
||||||
|
mock_resp.status_code = 200
|
||||||
|
mock_resp.json.return_value = {
|
||||||
|
"choices": [{"message": {"content": "VERDICT: CRITICAL — SSH flood (87% confidence)\nTIMELINE: lots of hits."}}]
|
||||||
|
}
|
||||||
|
|
||||||
|
with patch("httpx.post", return_value=mock_resp):
|
||||||
|
result = synthesizer.synthesize(
|
||||||
|
ranked=ranked,
|
||||||
|
timeline=timeline,
|
||||||
|
ctx=ctx,
|
||||||
|
query="ssh brute force",
|
||||||
|
llm_url="http://localhost:11434",
|
||||||
|
llm_model="llama3",
|
||||||
|
)
|
||||||
|
|
||||||
|
assert "VERDICT" in result
|
||||||
|
|
||||||
|
def test_returns_nonempty_string(self):
|
||||||
|
synthesizer = SummarySynthesizer()
|
||||||
|
ranked = [_make_ranked()]
|
||||||
|
timeline = _make_timeline()
|
||||||
|
ctx = _make_ctx()
|
||||||
|
|
||||||
|
mock_resp = MagicMock()
|
||||||
|
mock_resp.status_code = 200
|
||||||
|
mock_resp.json.return_value = {
|
||||||
|
"choices": [{"message": {"content": "VERDICT: CRITICAL — SSH flood (87% confidence)"}}]
|
||||||
|
}
|
||||||
|
|
||||||
|
with patch("httpx.post", return_value=mock_resp):
|
||||||
|
result = synthesizer.synthesize(
|
||||||
|
ranked=ranked,
|
||||||
|
timeline=timeline,
|
||||||
|
ctx=ctx,
|
||||||
|
query="why is auth failing",
|
||||||
|
llm_url="http://localhost:11434",
|
||||||
|
llm_model="llama3",
|
||||||
|
)
|
||||||
|
|
||||||
|
assert isinstance(result, str)
|
||||||
|
assert len(result) > 0
|
||||||
|
|
||||||
|
|
||||||
|
class TestSynthesizerSuppressedHypotheses:
|
||||||
|
"""Suppressed hypotheses must be excluded from the LLM prompt."""
|
||||||
|
|
||||||
|
def test_suppressed_hypotheses_excluded_from_prompt(self):
|
||||||
|
suppressed = _make_ranked(
|
||||||
|
hypothesis=_make_hypothesis(
|
||||||
|
hypothesis_id="h2",
|
||||||
|
title="Wazuh alert processing backlog",
|
||||||
|
severity="ERROR",
|
||||||
|
confidence=0.72,
|
||||||
|
),
|
||||||
|
suppress=True,
|
||||||
|
suppression_reason="similar to 2025-04 SSH incident",
|
||||||
|
novelty_score=0.1,
|
||||||
|
)
|
||||||
|
active = _make_ranked(
|
||||||
|
hypothesis=_make_hypothesis(
|
||||||
|
hypothesis_id="h1",
|
||||||
|
title="SSH flood from external IPs",
|
||||||
|
severity="CRITICAL",
|
||||||
|
confidence=0.87,
|
||||||
|
),
|
||||||
|
suppress=False,
|
||||||
|
novelty_score=0.95,
|
||||||
|
)
|
||||||
|
|
||||||
|
captured_messages: list = []
|
||||||
|
|
||||||
|
def fake_post(url, json=None, headers=None, timeout=None):
|
||||||
|
if json and "payload" in json:
|
||||||
|
captured_messages.extend(json["payload"].get("messages", []))
|
||||||
|
elif json and "messages" in json:
|
||||||
|
captured_messages.extend(json.get("messages", []))
|
||||||
|
mock_resp = MagicMock()
|
||||||
|
mock_resp.status_code = 200
|
||||||
|
mock_resp.json.return_value = {
|
||||||
|
"choices": [{"message": {"content": "VERDICT: CRITICAL — SSH flood"}}]
|
||||||
|
}
|
||||||
|
return mock_resp
|
||||||
|
|
||||||
|
synthesizer = SummarySynthesizer()
|
||||||
|
with patch("httpx.post", side_effect=fake_post):
|
||||||
|
synthesizer.synthesize(
|
||||||
|
ranked=[active, suppressed],
|
||||||
|
timeline=_make_timeline(),
|
||||||
|
ctx=_make_ctx(),
|
||||||
|
query="auth failures",
|
||||||
|
llm_url="http://localhost:11434",
|
||||||
|
llm_model="llama3",
|
||||||
|
)
|
||||||
|
|
||||||
|
# The user message should contain the active hypothesis title
|
||||||
|
# and NOT contain the suppressed one (or mark it suppressed)
|
||||||
|
user_content = next(
|
||||||
|
(m["content"] for m in captured_messages if m.get("role") == "user"), ""
|
||||||
|
)
|
||||||
|
assert "SSH flood from external IPs" in user_content
|
||||||
|
# Wazuh should not appear as a standalone top-level hypothesis
|
||||||
|
# (suppressed items are excluded from the active list sent to the LLM)
|
||||||
|
assert "Wazuh alert processing backlog" not in user_content
|
||||||
|
|
||||||
|
|
||||||
|
class TestSynthesizerNoLLM:
|
||||||
|
"""No LLM configured: must return deterministic fallback (not empty)."""
|
||||||
|
|
||||||
|
def test_no_llm_url_returns_fallback(self):
|
||||||
|
synthesizer = SummarySynthesizer()
|
||||||
|
ranked = [_make_ranked()]
|
||||||
|
timeline = _make_timeline()
|
||||||
|
ctx = _make_ctx()
|
||||||
|
|
||||||
|
result = synthesizer.synthesize(
|
||||||
|
ranked=ranked,
|
||||||
|
timeline=timeline,
|
||||||
|
ctx=ctx,
|
||||||
|
query="disk errors",
|
||||||
|
)
|
||||||
|
|
||||||
|
assert isinstance(result, str)
|
||||||
|
assert len(result) > 0
|
||||||
|
assert "VERDICT" in result
|
||||||
|
|
||||||
|
def test_no_llm_model_returns_fallback(self):
|
||||||
|
synthesizer = SummarySynthesizer()
|
||||||
|
ranked = [_make_ranked()]
|
||||||
|
|
||||||
|
result = synthesizer.synthesize(
|
||||||
|
ranked=ranked,
|
||||||
|
timeline=_make_timeline(),
|
||||||
|
ctx=_make_ctx(),
|
||||||
|
query="oom killer",
|
||||||
|
llm_url="http://localhost:11434",
|
||||||
|
# llm_model omitted
|
||||||
|
)
|
||||||
|
|
||||||
|
assert "VERDICT" in result
|
||||||
|
assert "SSH flood from external IPs" in result
|
||||||
|
|
||||||
|
def test_llm_failure_returns_fallback(self):
|
||||||
|
synthesizer = SummarySynthesizer()
|
||||||
|
ranked = [_make_ranked()]
|
||||||
|
|
||||||
|
with patch("httpx.post", side_effect=ConnectionError("refused")):
|
||||||
|
result = synthesizer.synthesize(
|
||||||
|
ranked=ranked,
|
||||||
|
timeline=_make_timeline(),
|
||||||
|
ctx=_make_ctx(),
|
||||||
|
query="why is disk full",
|
||||||
|
llm_url="http://localhost:11434",
|
||||||
|
llm_model="llama3",
|
||||||
|
)
|
||||||
|
|
||||||
|
assert "VERDICT" in result
|
||||||
|
assert len(result) > 0
|
||||||
|
|
||||||
|
|
||||||
|
class TestSynthesizerEmptyRanked:
|
||||||
|
"""Empty ranked list: must return deterministic fallback text, not raise."""
|
||||||
|
|
||||||
|
def test_empty_ranked_no_llm_returns_fallback(self):
|
||||||
|
synthesizer = SummarySynthesizer()
|
||||||
|
result = synthesizer.synthesize(
|
||||||
|
ranked=[],
|
||||||
|
timeline=_make_timeline(),
|
||||||
|
ctx=_make_ctx(),
|
||||||
|
query="check everything",
|
||||||
|
)
|
||||||
|
|
||||||
|
assert isinstance(result, str)
|
||||||
|
assert len(result) > 0
|
||||||
|
assert "VERDICT" in result
|
||||||
|
|
||||||
|
def test_empty_ranked_with_llm_returns_fallback_or_llm_text(self):
|
||||||
|
"""Even with empty ranked, we attempt LLM and return something."""
|
||||||
|
synthesizer = SummarySynthesizer()
|
||||||
|
|
||||||
|
mock_resp = MagicMock()
|
||||||
|
mock_resp.status_code = 200
|
||||||
|
mock_resp.json.return_value = {
|
||||||
|
"choices": [{"message": {"content": "VERDICT: UNKNOWN — no hypotheses generated"}}]
|
||||||
|
}
|
||||||
|
|
||||||
|
with patch("httpx.post", return_value=mock_resp):
|
||||||
|
result = synthesizer.synthesize(
|
||||||
|
ranked=[],
|
||||||
|
timeline=_make_timeline(),
|
||||||
|
ctx=_make_ctx(),
|
||||||
|
query="nothing found",
|
||||||
|
llm_url="http://localhost:11434",
|
||||||
|
llm_model="llama3",
|
||||||
|
)
|
||||||
|
|
||||||
|
assert isinstance(result, str)
|
||||||
|
assert len(result) > 0
|
||||||
Loading…
Reference in a new issue