"""Tests for app/services/diagnose/pipeline.py and __init__.py feature flag wiring. All tests use mocking; no real LLM, ML, or DB calls are made. """ from __future__ import annotations from pathlib import Path from typing import Any from unittest.mock import MagicMock, patch import pytest from app.context.retriever import RetrievedContext from app.services.diagnose.models import ( ClassifiedTimeline, Hypothesis, RankedHypothesis, TimelineResult, ) from app.services.search import SearchResult # --------------------------------------------------------------------------- # Shared helpers # --------------------------------------------------------------------------- def _make_search_result( entry_id: str = "e1", source_id: str = "syslog", timestamp_iso: str | None = "2026-01-01T00:00:00+00:00", severity: str | None = "ERROR", text: str = "ssh: invalid user", ) -> SearchResult: return SearchResult( entry_id=entry_id, source_id=source_id, sequence=1, timestamp_iso=timestamp_iso, severity=severity, repeat_count=1, out_of_order=False, matched_patterns=["ssh_fail"], text=text, rank=1.0, ) def _make_ctx() -> RetrievedContext: return RetrievedContext(facts=[], chunks=[]) def _make_timeline(n_clusters: int = 2) -> TimelineResult: return TimelineResult( clusters=tuple(), total_entries=5, window_start="2026-01-01T00:00:00+00:00", window_end="2026-01-01T01:00:00+00:00", gap_count=0, burst_count=1, dominant_sources=("syslog",), ) def _make_classified(timeline: TimelineResult | None = None) -> ClassifiedTimeline: tl = timeline or _make_timeline() return ClassifiedTimeline( timeline=tl, cluster_severities={}, classifier_used="regex", model_id=None, ) def _make_hypothesis( hypothesis_id: str = "h1", title: str = "SSH flood", confidence: float = 0.87, severity: str = "CRITICAL", ) -> Hypothesis: return Hypothesis( hypothesis_id=hypothesis_id, title=title, description="Multiple failed SSH attempts.", confidence=confidence, supporting_cluster_ids=("c1",), runbook_refs=(), severity=severity, # type: ignore[arg-type] ) def _make_ranked(hypothesis: Hypothesis | None = None, suppress: bool = False) -> RankedHypothesis: h = hypothesis or _make_hypothesis() return RankedHypothesis( hypothesis=h, novelty_score=0.95, similarity_to_known=0.05, suppress=suppress, suppression_reason="similar to known" if suppress else None, ) # --------------------------------------------------------------------------- # Helper: collect all events from run_pipeline # --------------------------------------------------------------------------- async def _collect_pipeline_events(**kwargs) -> list[dict[str, Any]]: """Run run_pipeline and collect all yielded events into a list.""" from app.services.diagnose.pipeline import run_pipeline events = [] async for event in run_pipeline(**kwargs): events.append(event) return events def _default_pipeline_kwargs(entries=None, db_path=None) -> dict: return dict( db_path=db_path or Path("/tmp/fake.db"), entries=entries or [_make_search_result()], ctx=_make_ctx(), query="ssh brute force", since="2026-01-01T00:00:00+00:00", until="2026-01-01T01:00:00+00:00", llm_url=None, llm_model=None, llm_api_key=None, ) # --------------------------------------------------------------------------- # Mock factories for all 5 stage classes # --------------------------------------------------------------------------- def _mock_all_stages( hypotheses=None, ranked=None, synthesis_text="VERDICT: CRITICAL — SSH flood (87% confidence)", ): """Return a dict of patch targets and their mock return values.""" timeline = _make_timeline() classified = _make_classified(timeline) hyps = hypotheses if hypotheses is not None else [_make_hypothesis()] rnk = ranked if ranked is not None else [_make_ranked()] mock_reconstructor = MagicMock() mock_reconstructor.return_value.reconstruct.return_value = timeline mock_classifier = MagicMock() mock_classifier.return_value.classify.return_value = classified mock_hypothesizer = MagicMock() mock_hypothesizer.return_value.hypothesize.return_value = hyps mock_suppressor = MagicMock() mock_suppressor.return_value.suppress.return_value = rnk mock_synthesizer = MagicMock() mock_synthesizer.return_value.synthesize.return_value = synthesis_text return { "app.services.diagnose.pipeline.TimelineReconstructor": mock_reconstructor, "app.services.diagnose.pipeline.SeverityClassifier": mock_classifier, "app.services.diagnose.pipeline.RootCauseHypothesizer": mock_hypothesizer, "app.services.diagnose.pipeline.FalsePositiveSuppressor": mock_suppressor, "app.services.diagnose.pipeline.SummarySynthesizer": mock_synthesizer, } # --------------------------------------------------------------------------- # 1. Feature flag off: legacy summarize() path runs, not run_pipeline # --------------------------------------------------------------------------- class TestFeatureFlagOff: @pytest.mark.asyncio async def test_legacy_path_when_flag_off(self): """With MULTI_AGENT_ENABLED=False, run_pipeline is never called.""" from app.services import diagnose as diagnose_module entries = [_make_search_result()] with ( patch.object(diagnose_module, "MULTI_AGENT_ENABLED", False), patch("app.services.diagnose.search", return_value=entries), patch("app.services.diagnose.entries_in_window", return_value=[]), patch("app.services.diagnose.retrieve_context", return_value=_make_ctx()), patch("app.services.diagnose.format_context_block", return_value=None), patch("app.services.diagnose.run_pipeline") as mock_pipeline, patch("app.services.diagnose.summarize", return_value=None), ): events = [] async for event in diagnose_module.diagnose_stream( db_path=Path("/tmp/fake.db"), query="ssh failures", llm_url=None, llm_model=None, ): events.append(event) # run_pipeline must NOT have been called mock_pipeline.assert_not_called() # SSE sequence must end with done types = [e["type"] for e in events] assert "done" in types assert types[-1] == "done" @pytest.mark.asyncio async def test_legacy_done_event_is_last(self): """Legacy path: done is always the last event.""" from app.services import diagnose as diagnose_module with ( patch.object(diagnose_module, "MULTI_AGENT_ENABLED", False), patch("app.services.diagnose.search", return_value=[]), patch("app.services.diagnose.entries_in_window", return_value=[]), patch("app.services.diagnose.retrieve_context", return_value=_make_ctx()), patch("app.services.diagnose.format_context_block", return_value=None), ): events = [] async for event in diagnose_module.diagnose_stream( db_path=Path("/tmp/fake.db"), query="check logs", ): events.append(event) assert events[-1] == {"type": "done"} # --------------------------------------------------------------------------- # 2. Feature flag on, all stages mocked: verify SSE event sequence # --------------------------------------------------------------------------- class TestFeatureFlagOn: @pytest.mark.asyncio async def test_pipeline_stage_events_in_order(self): """pipeline_stage events must be emitted stages 1→2→3→4 in order.""" mocks = _mock_all_stages() kwargs = _default_pipeline_kwargs() with ( patch("app.services.diagnose.pipeline.TimelineReconstructor", mocks["app.services.diagnose.pipeline.TimelineReconstructor"]), patch("app.services.diagnose.pipeline.SeverityClassifier", mocks["app.services.diagnose.pipeline.SeverityClassifier"]), patch("app.services.diagnose.pipeline.RootCauseHypothesizer", mocks["app.services.diagnose.pipeline.RootCauseHypothesizer"]), patch("app.services.diagnose.pipeline.FalsePositiveSuppressor", mocks["app.services.diagnose.pipeline.FalsePositiveSuppressor"]), patch("app.services.diagnose.pipeline.SummarySynthesizer", mocks["app.services.diagnose.pipeline.SummarySynthesizer"]), ): events = await _collect_pipeline_events(**kwargs) stage_events = [e for e in events if e.get("type") == "pipeline_stage"] stages = [e["stage"] for e in stage_events] assert stages == [1, 2, 3, 4] @pytest.mark.asyncio async def test_hypotheses_event_after_stage4(self): """hypotheses event must appear after pipeline_stage stage=4.""" mocks = _mock_all_stages() kwargs = _default_pipeline_kwargs() with ( patch("app.services.diagnose.pipeline.TimelineReconstructor", mocks["app.services.diagnose.pipeline.TimelineReconstructor"]), patch("app.services.diagnose.pipeline.SeverityClassifier", mocks["app.services.diagnose.pipeline.SeverityClassifier"]), patch("app.services.diagnose.pipeline.RootCauseHypothesizer", mocks["app.services.diagnose.pipeline.RootCauseHypothesizer"]), patch("app.services.diagnose.pipeline.FalsePositiveSuppressor", mocks["app.services.diagnose.pipeline.FalsePositiveSuppressor"]), patch("app.services.diagnose.pipeline.SummarySynthesizer", mocks["app.services.diagnose.pipeline.SummarySynthesizer"]), ): events = await _collect_pipeline_events(**kwargs) stage4_idx = next( i for i, e in enumerate(events) if e.get("type") == "pipeline_stage" and e.get("stage") == 4 ) hyp_idx = next(i for i, e in enumerate(events) if e.get("type") == "hypotheses") assert hyp_idx > stage4_idx @pytest.mark.asyncio async def test_reasoning_event_emitted(self): """reasoning event must be present when synthesizer returns text.""" mocks = _mock_all_stages(synthesis_text="VERDICT: CRITICAL — SSH flood") kwargs = _default_pipeline_kwargs() with ( patch("app.services.diagnose.pipeline.TimelineReconstructor", mocks["app.services.diagnose.pipeline.TimelineReconstructor"]), patch("app.services.diagnose.pipeline.SeverityClassifier", mocks["app.services.diagnose.pipeline.SeverityClassifier"]), patch("app.services.diagnose.pipeline.RootCauseHypothesizer", mocks["app.services.diagnose.pipeline.RootCauseHypothesizer"]), patch("app.services.diagnose.pipeline.FalsePositiveSuppressor", mocks["app.services.diagnose.pipeline.FalsePositiveSuppressor"]), patch("app.services.diagnose.pipeline.SummarySynthesizer", mocks["app.services.diagnose.pipeline.SummarySynthesizer"]), ): events = await _collect_pipeline_events(**kwargs) reasoning_events = [e for e in events if e.get("type") == "reasoning"] assert len(reasoning_events) == 1 assert "VERDICT" in reasoning_events[0]["text"] @pytest.mark.asyncio async def test_done_event_is_last(self): """done must always be the last event in the pipeline sequence.""" mocks = _mock_all_stages() kwargs = _default_pipeline_kwargs() with ( patch("app.services.diagnose.pipeline.TimelineReconstructor", mocks["app.services.diagnose.pipeline.TimelineReconstructor"]), patch("app.services.diagnose.pipeline.SeverityClassifier", mocks["app.services.diagnose.pipeline.SeverityClassifier"]), patch("app.services.diagnose.pipeline.RootCauseHypothesizer", mocks["app.services.diagnose.pipeline.RootCauseHypothesizer"]), patch("app.services.diagnose.pipeline.FalsePositiveSuppressor", mocks["app.services.diagnose.pipeline.FalsePositiveSuppressor"]), patch("app.services.diagnose.pipeline.SummarySynthesizer", mocks["app.services.diagnose.pipeline.SummarySynthesizer"]), ): events = await _collect_pipeline_events(**kwargs) assert events[-1] == {"type": "done"} @pytest.mark.asyncio async def test_pipeline_wired_from_diagnose_stream(self): """diagnose_stream routes through run_pipeline when flag is on.""" from app.services import diagnose as diagnose_module entries = [_make_search_result()] async def fake_pipeline(**kwargs): yield {"type": "status", "message": "Building timeline…"} yield {"type": "pipeline_stage", "stage": 1, "name": "timeline", "message": "Built 1 clusters, 0 bursts"} yield {"type": "done"} with ( patch.object(diagnose_module, "MULTI_AGENT_ENABLED", True), patch("app.services.diagnose.search", return_value=entries), patch("app.services.diagnose.entries_in_window", return_value=[]), patch("app.services.diagnose.retrieve_context", return_value=_make_ctx()), patch("app.services.diagnose.format_context_block", return_value=None), patch("app.services.diagnose.run_pipeline", side_effect=fake_pipeline), ): events = [] async for event in diagnose_module.diagnose_stream( db_path=Path("/tmp/fake.db"), query="ssh failures", ): events.append(event) types = [e["type"] for e in events] assert "pipeline_stage" in types assert types[-1] == "done" # Legacy summarize() must NOT have been called — done event came from pipeline assert types.count("done") == 1 # --------------------------------------------------------------------------- # 3. Empty entries: pipeline completes with done # --------------------------------------------------------------------------- class TestEmptyEntries: @pytest.mark.asyncio async def test_empty_entries_pipeline_completes(self): """Pipeline with entries=[] must still complete and emit done.""" mocks = _mock_all_stages(hypotheses=[], ranked=[]) kwargs = _default_pipeline_kwargs(entries=[]) with ( patch("app.services.diagnose.pipeline.TimelineReconstructor", mocks["app.services.diagnose.pipeline.TimelineReconstructor"]), patch("app.services.diagnose.pipeline.SeverityClassifier", mocks["app.services.diagnose.pipeline.SeverityClassifier"]), patch("app.services.diagnose.pipeline.RootCauseHypothesizer", mocks["app.services.diagnose.pipeline.RootCauseHypothesizer"]), patch("app.services.diagnose.pipeline.FalsePositiveSuppressor", mocks["app.services.diagnose.pipeline.FalsePositiveSuppressor"]), patch("app.services.diagnose.pipeline.SummarySynthesizer", mocks["app.services.diagnose.pipeline.SummarySynthesizer"]), ): events = await _collect_pipeline_events(**kwargs) types = [e["type"] for e in events] assert "done" in types assert types[-1] == "done" @pytest.mark.asyncio async def test_empty_entries_all_stage_events_present(self): """Even with empty entries, all 4 pipeline_stage events are emitted.""" mocks = _mock_all_stages(hypotheses=[], ranked=[]) kwargs = _default_pipeline_kwargs(entries=[]) with ( patch("app.services.diagnose.pipeline.TimelineReconstructor", mocks["app.services.diagnose.pipeline.TimelineReconstructor"]), patch("app.services.diagnose.pipeline.SeverityClassifier", mocks["app.services.diagnose.pipeline.SeverityClassifier"]), patch("app.services.diagnose.pipeline.RootCauseHypothesizer", mocks["app.services.diagnose.pipeline.RootCauseHypothesizer"]), patch("app.services.diagnose.pipeline.FalsePositiveSuppressor", mocks["app.services.diagnose.pipeline.FalsePositiveSuppressor"]), patch("app.services.diagnose.pipeline.SummarySynthesizer", mocks["app.services.diagnose.pipeline.SummarySynthesizer"]), ): events = await _collect_pipeline_events(**kwargs) stage_events = [e for e in events if e.get("type") == "pipeline_stage"] assert len(stage_events) == 4 # --------------------------------------------------------------------------- # 4. No LLM: Stage 3 and Stage 5 return empty/fallback; done still emitted # --------------------------------------------------------------------------- class TestNoLLM: @pytest.mark.asyncio async def test_no_llm_pipeline_completes_with_done(self): """No llm_url/llm_model: pipeline runs all stages and emits done.""" mocks = _mock_all_stages(hypotheses=[], ranked=[], synthesis_text="VERDICT: UNKNOWN — no hypotheses generated") kwargs = _default_pipeline_kwargs() # llm_url and llm_model already None in default kwargs with ( patch("app.services.diagnose.pipeline.TimelineReconstructor", mocks["app.services.diagnose.pipeline.TimelineReconstructor"]), patch("app.services.diagnose.pipeline.SeverityClassifier", mocks["app.services.diagnose.pipeline.SeverityClassifier"]), patch("app.services.diagnose.pipeline.RootCauseHypothesizer", mocks["app.services.diagnose.pipeline.RootCauseHypothesizer"]), patch("app.services.diagnose.pipeline.FalsePositiveSuppressor", mocks["app.services.diagnose.pipeline.FalsePositiveSuppressor"]), patch("app.services.diagnose.pipeline.SummarySynthesizer", mocks["app.services.diagnose.pipeline.SummarySynthesizer"]), ): events = await _collect_pipeline_events(**kwargs) assert events[-1] == {"type": "done"} @pytest.mark.asyncio async def test_no_llm_no_reasoning_event_when_synthesis_empty(self): """When synthesizer returns empty string, no reasoning event is emitted.""" mocks = _mock_all_stages(synthesis_text="") kwargs = _default_pipeline_kwargs() with ( patch("app.services.diagnose.pipeline.TimelineReconstructor", mocks["app.services.diagnose.pipeline.TimelineReconstructor"]), patch("app.services.diagnose.pipeline.SeverityClassifier", mocks["app.services.diagnose.pipeline.SeverityClassifier"]), patch("app.services.diagnose.pipeline.RootCauseHypothesizer", mocks["app.services.diagnose.pipeline.RootCauseHypothesizer"]), patch("app.services.diagnose.pipeline.FalsePositiveSuppressor", mocks["app.services.diagnose.pipeline.FalsePositiveSuppressor"]), patch("app.services.diagnose.pipeline.SummarySynthesizer", mocks["app.services.diagnose.pipeline.SummarySynthesizer"]), ): events = await _collect_pipeline_events(**kwargs) reasoning_events = [e for e in events if e.get("type") == "reasoning"] assert len(reasoning_events) == 0 # --------------------------------------------------------------------------- # 5. Stage 1 cluster count in pipeline_stage message # --------------------------------------------------------------------------- class TestStage1Message: @pytest.mark.asyncio async def test_stage1_message_contains_cluster_count(self): """pipeline_stage stage=1 message must report cluster count.""" timeline = TimelineResult( clusters=tuple(), total_entries=10, window_start=None, window_end=None, gap_count=0, burst_count=3, dominant_sources=("syslog",), ) classified = _make_classified(timeline) mock_reconstructor = MagicMock() mock_reconstructor.return_value.reconstruct.return_value = timeline mock_classifier = MagicMock() mock_classifier.return_value.classify.return_value = classified mock_hypothesizer = MagicMock() mock_hypothesizer.return_value.hypothesize.return_value = [] mock_suppressor = MagicMock() mock_suppressor.return_value.suppress.return_value = [] mock_synthesizer = MagicMock() mock_synthesizer.return_value.synthesize.return_value = "VERDICT: INFO — nothing found" kwargs = _default_pipeline_kwargs() with ( patch("app.services.diagnose.pipeline.TimelineReconstructor", mock_reconstructor), patch("app.services.diagnose.pipeline.SeverityClassifier", mock_classifier), patch("app.services.diagnose.pipeline.RootCauseHypothesizer", mock_hypothesizer), patch("app.services.diagnose.pipeline.FalsePositiveSuppressor", mock_suppressor), patch("app.services.diagnose.pipeline.SummarySynthesizer", mock_synthesizer), ): events = await _collect_pipeline_events(**kwargs) stage1 = next(e for e in events if e.get("type") == "pipeline_stage" and e.get("stage") == 1) # 0 clusters (empty tuple), 3 bursts assert "0" in stage1["message"] # cluster count assert "3" in stage1["message"] # burst count @pytest.mark.asyncio async def test_stage1_name_is_timeline(self): """pipeline_stage stage=1 must have name='timeline'.""" mocks = _mock_all_stages() kwargs = _default_pipeline_kwargs() with ( patch("app.services.diagnose.pipeline.TimelineReconstructor", mocks["app.services.diagnose.pipeline.TimelineReconstructor"]), patch("app.services.diagnose.pipeline.SeverityClassifier", mocks["app.services.diagnose.pipeline.SeverityClassifier"]), patch("app.services.diagnose.pipeline.RootCauseHypothesizer", mocks["app.services.diagnose.pipeline.RootCauseHypothesizer"]), patch("app.services.diagnose.pipeline.FalsePositiveSuppressor", mocks["app.services.diagnose.pipeline.FalsePositiveSuppressor"]), patch("app.services.diagnose.pipeline.SummarySynthesizer", mocks["app.services.diagnose.pipeline.SummarySynthesizer"]), ): events = await _collect_pipeline_events(**kwargs) stage1 = next(e for e in events if e.get("type") == "pipeline_stage" and e.get("stage") == 1) assert stage1["name"] == "timeline"