turnstone/tests/test_diagnose_pipeline.py
pyr0ball 8cbd981ec7 feat: Stage 5 synthesizer + pipeline orchestrator + feature flag wiring (issue #29)
- Add app/services/diagnose/synthesizer.py: SummarySynthesizer (Stage 5)
  - Builds structured LLM prompt from ranked hypotheses, timeline, RAG context
  - Excludes suppressed hypotheses from the narrative prompt
  - Deterministic fallback when no LLM configured or LLM call fails
  - Same cf-orch task endpoint + direct OpenAI-compat fallback pattern as other stages

- Replace pipeline.py stub with full run_pipeline() async generator
  - Orchestrates all 5 stages via asyncio.to_thread for each synchronous stage
  - Yields typed SSE event dicts: status, pipeline_stage (1-4), hypotheses, reasoning, done
  - Suppressor counts (active vs suppressed) reported in stage 4 event message

- Wire MULTI_AGENT_ENABLED feature flag into diagnose_stream()
  - TURNSTONE_MULTI_AGENT_DIAGNOSE=true routes through run_pipeline()
  - pipeline emits its own done event; legacy path unchanged when flag is false
  - Import of run_pipeline added to __init__.py

- Add 21 new tests (350 -> 371 passing):
  - tests/test_diagnose_synthesizer.py: 8 tests (with/without LLM, suppressed,
    empty ranked, LLM failure fallback)
  - tests/test_diagnose_pipeline.py: 13 tests (flag off, flag on event sequence,
    empty entries, no LLM, stage 1 cluster count message)

Closes: #29
2026-05-25 14:56:25 -07:00

489 lines
22 KiB
Python

"""Tests for app/services/diagnose/pipeline.py and __init__.py feature flag wiring.
All tests use mocking; no real LLM, ML, or DB calls are made.
"""
from __future__ import annotations
from pathlib import Path
from typing import Any
from unittest.mock import MagicMock, patch
import pytest
from app.context.retriever import RetrievedContext
from app.services.diagnose.models import (
ClassifiedTimeline,
Hypothesis,
RankedHypothesis,
TimelineResult,
)
from app.services.search import SearchResult
# ---------------------------------------------------------------------------
# Shared helpers
# ---------------------------------------------------------------------------
def _make_search_result(
entry_id: str = "e1",
source_id: str = "syslog",
timestamp_iso: str | None = "2026-01-01T00:00:00+00:00",
severity: str | None = "ERROR",
text: str = "ssh: invalid user",
) -> SearchResult:
return SearchResult(
entry_id=entry_id,
source_id=source_id,
sequence=1,
timestamp_iso=timestamp_iso,
severity=severity,
repeat_count=1,
out_of_order=False,
matched_patterns=["ssh_fail"],
text=text,
rank=1.0,
)
def _make_ctx() -> RetrievedContext:
return RetrievedContext(facts=[], chunks=[])
def _make_timeline(n_clusters: int = 2) -> TimelineResult:
return TimelineResult(
clusters=tuple(),
total_entries=5,
window_start="2026-01-01T00:00:00+00:00",
window_end="2026-01-01T01:00:00+00:00",
gap_count=0,
burst_count=1,
dominant_sources=("syslog",),
)
def _make_classified(timeline: TimelineResult | None = None) -> ClassifiedTimeline:
tl = timeline or _make_timeline()
return ClassifiedTimeline(
timeline=tl,
cluster_severities={},
classifier_used="regex",
model_id=None,
)
def _make_hypothesis(
hypothesis_id: str = "h1",
title: str = "SSH flood",
confidence: float = 0.87,
severity: str = "CRITICAL",
) -> Hypothesis:
return Hypothesis(
hypothesis_id=hypothesis_id,
title=title,
description="Multiple failed SSH attempts.",
confidence=confidence,
supporting_cluster_ids=("c1",),
runbook_refs=(),
severity=severity, # type: ignore[arg-type]
)
def _make_ranked(hypothesis: Hypothesis | None = None, suppress: bool = False) -> RankedHypothesis:
h = hypothesis or _make_hypothesis()
return RankedHypothesis(
hypothesis=h,
novelty_score=0.95,
similarity_to_known=0.05,
suppress=suppress,
suppression_reason="similar to known" if suppress else None,
)
# ---------------------------------------------------------------------------
# Helper: collect all events from run_pipeline
# ---------------------------------------------------------------------------
async def _collect_pipeline_events(**kwargs) -> list[dict[str, Any]]:
"""Run run_pipeline and collect all yielded events into a list."""
from app.services.diagnose.pipeline import run_pipeline
events = []
async for event in run_pipeline(**kwargs):
events.append(event)
return events
def _default_pipeline_kwargs(entries=None, db_path=None) -> dict:
return dict(
db_path=db_path or Path("/tmp/fake.db"),
entries=entries or [_make_search_result()],
ctx=_make_ctx(),
query="ssh brute force",
since="2026-01-01T00:00:00+00:00",
until="2026-01-01T01:00:00+00:00",
llm_url=None,
llm_model=None,
llm_api_key=None,
)
# ---------------------------------------------------------------------------
# Mock factories for all 5 stage classes
# ---------------------------------------------------------------------------
def _mock_all_stages(
hypotheses=None,
ranked=None,
synthesis_text="VERDICT: CRITICAL — SSH flood (87% confidence)",
):
"""Return a dict of patch targets and their mock return values."""
timeline = _make_timeline()
classified = _make_classified(timeline)
hyps = hypotheses if hypotheses is not None else [_make_hypothesis()]
rnk = ranked if ranked is not None else [_make_ranked()]
mock_reconstructor = MagicMock()
mock_reconstructor.return_value.reconstruct.return_value = timeline
mock_classifier = MagicMock()
mock_classifier.return_value.classify.return_value = classified
mock_hypothesizer = MagicMock()
mock_hypothesizer.return_value.hypothesize.return_value = hyps
mock_suppressor = MagicMock()
mock_suppressor.return_value.suppress.return_value = rnk
mock_synthesizer = MagicMock()
mock_synthesizer.return_value.synthesize.return_value = synthesis_text
return {
"app.services.diagnose.pipeline.TimelineReconstructor": mock_reconstructor,
"app.services.diagnose.pipeline.SeverityClassifier": mock_classifier,
"app.services.diagnose.pipeline.RootCauseHypothesizer": mock_hypothesizer,
"app.services.diagnose.pipeline.FalsePositiveSuppressor": mock_suppressor,
"app.services.diagnose.pipeline.SummarySynthesizer": mock_synthesizer,
}
# ---------------------------------------------------------------------------
# 1. Feature flag off: legacy summarize() path runs, not run_pipeline
# ---------------------------------------------------------------------------
class TestFeatureFlagOff:
@pytest.mark.asyncio
async def test_legacy_path_when_flag_off(self):
"""With MULTI_AGENT_ENABLED=False, run_pipeline is never called."""
from app.services import diagnose as diagnose_module
entries = [_make_search_result()]
with (
patch.object(diagnose_module, "MULTI_AGENT_ENABLED", False),
patch("app.services.diagnose.search", return_value=entries),
patch("app.services.diagnose.entries_in_window", return_value=[]),
patch("app.services.diagnose.retrieve_context", return_value=_make_ctx()),
patch("app.services.diagnose.format_context_block", return_value=None),
patch("app.services.diagnose.run_pipeline") as mock_pipeline,
patch("app.services.diagnose.summarize", return_value=None),
):
events = []
async for event in diagnose_module.diagnose_stream(
db_path=Path("/tmp/fake.db"),
query="ssh failures",
llm_url=None,
llm_model=None,
):
events.append(event)
# run_pipeline must NOT have been called
mock_pipeline.assert_not_called()
# SSE sequence must end with done
types = [e["type"] for e in events]
assert "done" in types
assert types[-1] == "done"
@pytest.mark.asyncio
async def test_legacy_done_event_is_last(self):
"""Legacy path: done is always the last event."""
from app.services import diagnose as diagnose_module
with (
patch.object(diagnose_module, "MULTI_AGENT_ENABLED", False),
patch("app.services.diagnose.search", return_value=[]),
patch("app.services.diagnose.entries_in_window", return_value=[]),
patch("app.services.diagnose.retrieve_context", return_value=_make_ctx()),
patch("app.services.diagnose.format_context_block", return_value=None),
):
events = []
async for event in diagnose_module.diagnose_stream(
db_path=Path("/tmp/fake.db"),
query="check logs",
):
events.append(event)
assert events[-1] == {"type": "done"}
# ---------------------------------------------------------------------------
# 2. Feature flag on, all stages mocked: verify SSE event sequence
# ---------------------------------------------------------------------------
class TestFeatureFlagOn:
@pytest.mark.asyncio
async def test_pipeline_stage_events_in_order(self):
"""pipeline_stage events must be emitted stages 1→2→3→4 in order."""
mocks = _mock_all_stages()
kwargs = _default_pipeline_kwargs()
with (
patch("app.services.diagnose.pipeline.TimelineReconstructor", mocks["app.services.diagnose.pipeline.TimelineReconstructor"]),
patch("app.services.diagnose.pipeline.SeverityClassifier", mocks["app.services.diagnose.pipeline.SeverityClassifier"]),
patch("app.services.diagnose.pipeline.RootCauseHypothesizer", mocks["app.services.diagnose.pipeline.RootCauseHypothesizer"]),
patch("app.services.diagnose.pipeline.FalsePositiveSuppressor", mocks["app.services.diagnose.pipeline.FalsePositiveSuppressor"]),
patch("app.services.diagnose.pipeline.SummarySynthesizer", mocks["app.services.diagnose.pipeline.SummarySynthesizer"]),
):
events = await _collect_pipeline_events(**kwargs)
stage_events = [e for e in events if e.get("type") == "pipeline_stage"]
stages = [e["stage"] for e in stage_events]
assert stages == [1, 2, 3, 4]
@pytest.mark.asyncio
async def test_hypotheses_event_after_stage4(self):
"""hypotheses event must appear after pipeline_stage stage=4."""
mocks = _mock_all_stages()
kwargs = _default_pipeline_kwargs()
with (
patch("app.services.diagnose.pipeline.TimelineReconstructor", mocks["app.services.diagnose.pipeline.TimelineReconstructor"]),
patch("app.services.diagnose.pipeline.SeverityClassifier", mocks["app.services.diagnose.pipeline.SeverityClassifier"]),
patch("app.services.diagnose.pipeline.RootCauseHypothesizer", mocks["app.services.diagnose.pipeline.RootCauseHypothesizer"]),
patch("app.services.diagnose.pipeline.FalsePositiveSuppressor", mocks["app.services.diagnose.pipeline.FalsePositiveSuppressor"]),
patch("app.services.diagnose.pipeline.SummarySynthesizer", mocks["app.services.diagnose.pipeline.SummarySynthesizer"]),
):
events = await _collect_pipeline_events(**kwargs)
stage4_idx = next(
i for i, e in enumerate(events)
if e.get("type") == "pipeline_stage" and e.get("stage") == 4
)
hyp_idx = next(i for i, e in enumerate(events) if e.get("type") == "hypotheses")
assert hyp_idx > stage4_idx
@pytest.mark.asyncio
async def test_reasoning_event_emitted(self):
"""reasoning event must be present when synthesizer returns text."""
mocks = _mock_all_stages(synthesis_text="VERDICT: CRITICAL — SSH flood")
kwargs = _default_pipeline_kwargs()
with (
patch("app.services.diagnose.pipeline.TimelineReconstructor", mocks["app.services.diagnose.pipeline.TimelineReconstructor"]),
patch("app.services.diagnose.pipeline.SeverityClassifier", mocks["app.services.diagnose.pipeline.SeverityClassifier"]),
patch("app.services.diagnose.pipeline.RootCauseHypothesizer", mocks["app.services.diagnose.pipeline.RootCauseHypothesizer"]),
patch("app.services.diagnose.pipeline.FalsePositiveSuppressor", mocks["app.services.diagnose.pipeline.FalsePositiveSuppressor"]),
patch("app.services.diagnose.pipeline.SummarySynthesizer", mocks["app.services.diagnose.pipeline.SummarySynthesizer"]),
):
events = await _collect_pipeline_events(**kwargs)
reasoning_events = [e for e in events if e.get("type") == "reasoning"]
assert len(reasoning_events) == 1
assert "VERDICT" in reasoning_events[0]["text"]
@pytest.mark.asyncio
async def test_done_event_is_last(self):
"""done must always be the last event in the pipeline sequence."""
mocks = _mock_all_stages()
kwargs = _default_pipeline_kwargs()
with (
patch("app.services.diagnose.pipeline.TimelineReconstructor", mocks["app.services.diagnose.pipeline.TimelineReconstructor"]),
patch("app.services.diagnose.pipeline.SeverityClassifier", mocks["app.services.diagnose.pipeline.SeverityClassifier"]),
patch("app.services.diagnose.pipeline.RootCauseHypothesizer", mocks["app.services.diagnose.pipeline.RootCauseHypothesizer"]),
patch("app.services.diagnose.pipeline.FalsePositiveSuppressor", mocks["app.services.diagnose.pipeline.FalsePositiveSuppressor"]),
patch("app.services.diagnose.pipeline.SummarySynthesizer", mocks["app.services.diagnose.pipeline.SummarySynthesizer"]),
):
events = await _collect_pipeline_events(**kwargs)
assert events[-1] == {"type": "done"}
@pytest.mark.asyncio
async def test_pipeline_wired_from_diagnose_stream(self):
"""diagnose_stream routes through run_pipeline when flag is on."""
from app.services import diagnose as diagnose_module
entries = [_make_search_result()]
async def fake_pipeline(**kwargs):
yield {"type": "status", "message": "Building timeline…"}
yield {"type": "pipeline_stage", "stage": 1, "name": "timeline", "message": "Built 1 clusters, 0 bursts"}
yield {"type": "done"}
with (
patch.object(diagnose_module, "MULTI_AGENT_ENABLED", True),
patch("app.services.diagnose.search", return_value=entries),
patch("app.services.diagnose.entries_in_window", return_value=[]),
patch("app.services.diagnose.retrieve_context", return_value=_make_ctx()),
patch("app.services.diagnose.format_context_block", return_value=None),
patch("app.services.diagnose.run_pipeline", side_effect=fake_pipeline),
):
events = []
async for event in diagnose_module.diagnose_stream(
db_path=Path("/tmp/fake.db"),
query="ssh failures",
):
events.append(event)
types = [e["type"] for e in events]
assert "pipeline_stage" in types
assert types[-1] == "done"
# Legacy summarize() must NOT have been called — done event came from pipeline
assert types.count("done") == 1
# ---------------------------------------------------------------------------
# 3. Empty entries: pipeline completes with done
# ---------------------------------------------------------------------------
class TestEmptyEntries:
@pytest.mark.asyncio
async def test_empty_entries_pipeline_completes(self):
"""Pipeline with entries=[] must still complete and emit done."""
mocks = _mock_all_stages(hypotheses=[], ranked=[])
kwargs = _default_pipeline_kwargs(entries=[])
with (
patch("app.services.diagnose.pipeline.TimelineReconstructor", mocks["app.services.diagnose.pipeline.TimelineReconstructor"]),
patch("app.services.diagnose.pipeline.SeverityClassifier", mocks["app.services.diagnose.pipeline.SeverityClassifier"]),
patch("app.services.diagnose.pipeline.RootCauseHypothesizer", mocks["app.services.diagnose.pipeline.RootCauseHypothesizer"]),
patch("app.services.diagnose.pipeline.FalsePositiveSuppressor", mocks["app.services.diagnose.pipeline.FalsePositiveSuppressor"]),
patch("app.services.diagnose.pipeline.SummarySynthesizer", mocks["app.services.diagnose.pipeline.SummarySynthesizer"]),
):
events = await _collect_pipeline_events(**kwargs)
types = [e["type"] for e in events]
assert "done" in types
assert types[-1] == "done"
@pytest.mark.asyncio
async def test_empty_entries_all_stage_events_present(self):
"""Even with empty entries, all 4 pipeline_stage events are emitted."""
mocks = _mock_all_stages(hypotheses=[], ranked=[])
kwargs = _default_pipeline_kwargs(entries=[])
with (
patch("app.services.diagnose.pipeline.TimelineReconstructor", mocks["app.services.diagnose.pipeline.TimelineReconstructor"]),
patch("app.services.diagnose.pipeline.SeverityClassifier", mocks["app.services.diagnose.pipeline.SeverityClassifier"]),
patch("app.services.diagnose.pipeline.RootCauseHypothesizer", mocks["app.services.diagnose.pipeline.RootCauseHypothesizer"]),
patch("app.services.diagnose.pipeline.FalsePositiveSuppressor", mocks["app.services.diagnose.pipeline.FalsePositiveSuppressor"]),
patch("app.services.diagnose.pipeline.SummarySynthesizer", mocks["app.services.diagnose.pipeline.SummarySynthesizer"]),
):
events = await _collect_pipeline_events(**kwargs)
stage_events = [e for e in events if e.get("type") == "pipeline_stage"]
assert len(stage_events) == 4
# ---------------------------------------------------------------------------
# 4. No LLM: Stage 3 and Stage 5 return empty/fallback; done still emitted
# ---------------------------------------------------------------------------
class TestNoLLM:
@pytest.mark.asyncio
async def test_no_llm_pipeline_completes_with_done(self):
"""No llm_url/llm_model: pipeline runs all stages and emits done."""
mocks = _mock_all_stages(hypotheses=[], ranked=[], synthesis_text="VERDICT: UNKNOWN — no hypotheses generated")
kwargs = _default_pipeline_kwargs()
# llm_url and llm_model already None in default kwargs
with (
patch("app.services.diagnose.pipeline.TimelineReconstructor", mocks["app.services.diagnose.pipeline.TimelineReconstructor"]),
patch("app.services.diagnose.pipeline.SeverityClassifier", mocks["app.services.diagnose.pipeline.SeverityClassifier"]),
patch("app.services.diagnose.pipeline.RootCauseHypothesizer", mocks["app.services.diagnose.pipeline.RootCauseHypothesizer"]),
patch("app.services.diagnose.pipeline.FalsePositiveSuppressor", mocks["app.services.diagnose.pipeline.FalsePositiveSuppressor"]),
patch("app.services.diagnose.pipeline.SummarySynthesizer", mocks["app.services.diagnose.pipeline.SummarySynthesizer"]),
):
events = await _collect_pipeline_events(**kwargs)
assert events[-1] == {"type": "done"}
@pytest.mark.asyncio
async def test_no_llm_no_reasoning_event_when_synthesis_empty(self):
"""When synthesizer returns empty string, no reasoning event is emitted."""
mocks = _mock_all_stages(synthesis_text="")
kwargs = _default_pipeline_kwargs()
with (
patch("app.services.diagnose.pipeline.TimelineReconstructor", mocks["app.services.diagnose.pipeline.TimelineReconstructor"]),
patch("app.services.diagnose.pipeline.SeverityClassifier", mocks["app.services.diagnose.pipeline.SeverityClassifier"]),
patch("app.services.diagnose.pipeline.RootCauseHypothesizer", mocks["app.services.diagnose.pipeline.RootCauseHypothesizer"]),
patch("app.services.diagnose.pipeline.FalsePositiveSuppressor", mocks["app.services.diagnose.pipeline.FalsePositiveSuppressor"]),
patch("app.services.diagnose.pipeline.SummarySynthesizer", mocks["app.services.diagnose.pipeline.SummarySynthesizer"]),
):
events = await _collect_pipeline_events(**kwargs)
reasoning_events = [e for e in events if e.get("type") == "reasoning"]
assert len(reasoning_events) == 0
# ---------------------------------------------------------------------------
# 5. Stage 1 cluster count in pipeline_stage message
# ---------------------------------------------------------------------------
class TestStage1Message:
@pytest.mark.asyncio
async def test_stage1_message_contains_cluster_count(self):
"""pipeline_stage stage=1 message must report cluster count."""
timeline = TimelineResult(
clusters=tuple(),
total_entries=10,
window_start=None,
window_end=None,
gap_count=0,
burst_count=3,
dominant_sources=("syslog",),
)
classified = _make_classified(timeline)
mock_reconstructor = MagicMock()
mock_reconstructor.return_value.reconstruct.return_value = timeline
mock_classifier = MagicMock()
mock_classifier.return_value.classify.return_value = classified
mock_hypothesizer = MagicMock()
mock_hypothesizer.return_value.hypothesize.return_value = []
mock_suppressor = MagicMock()
mock_suppressor.return_value.suppress.return_value = []
mock_synthesizer = MagicMock()
mock_synthesizer.return_value.synthesize.return_value = "VERDICT: INFO — nothing found"
kwargs = _default_pipeline_kwargs()
with (
patch("app.services.diagnose.pipeline.TimelineReconstructor", mock_reconstructor),
patch("app.services.diagnose.pipeline.SeverityClassifier", mock_classifier),
patch("app.services.diagnose.pipeline.RootCauseHypothesizer", mock_hypothesizer),
patch("app.services.diagnose.pipeline.FalsePositiveSuppressor", mock_suppressor),
patch("app.services.diagnose.pipeline.SummarySynthesizer", mock_synthesizer),
):
events = await _collect_pipeline_events(**kwargs)
stage1 = next(e for e in events if e.get("type") == "pipeline_stage" and e.get("stage") == 1)
# 0 clusters (empty tuple), 3 bursts
assert "0" in stage1["message"] # cluster count
assert "3" in stage1["message"] # burst count
@pytest.mark.asyncio
async def test_stage1_name_is_timeline(self):
"""pipeline_stage stage=1 must have name='timeline'."""
mocks = _mock_all_stages()
kwargs = _default_pipeline_kwargs()
with (
patch("app.services.diagnose.pipeline.TimelineReconstructor", mocks["app.services.diagnose.pipeline.TimelineReconstructor"]),
patch("app.services.diagnose.pipeline.SeverityClassifier", mocks["app.services.diagnose.pipeline.SeverityClassifier"]),
patch("app.services.diagnose.pipeline.RootCauseHypothesizer", mocks["app.services.diagnose.pipeline.RootCauseHypothesizer"]),
patch("app.services.diagnose.pipeline.FalsePositiveSuppressor", mocks["app.services.diagnose.pipeline.FalsePositiveSuppressor"]),
patch("app.services.diagnose.pipeline.SummarySynthesizer", mocks["app.services.diagnose.pipeline.SummarySynthesizer"]),
):
events = await _collect_pipeline_events(**kwargs)
stage1 = next(e for e in events if e.get("type") == "pipeline_stage" and e.get("stage") == 1)
assert stage1["name"] == "timeline"