turnstone/tests/test_diagnose_synthesizer.py
pyr0ball 8cbd981ec7 feat: Stage 5 synthesizer + pipeline orchestrator + feature flag wiring (issue #29)
- Add app/services/diagnose/synthesizer.py: SummarySynthesizer (Stage 5)
  - Builds structured LLM prompt from ranked hypotheses, timeline, RAG context
  - Excludes suppressed hypotheses from the narrative prompt
  - Deterministic fallback when no LLM configured or LLM call fails
  - Same cf-orch task endpoint + direct OpenAI-compat fallback pattern as other stages

- Replace pipeline.py stub with full run_pipeline() async generator
  - Orchestrates all 5 stages via asyncio.to_thread for each synchronous stage
  - Yields typed SSE event dicts: status, pipeline_stage (1-4), hypotheses, reasoning, done
  - Suppressor counts (active vs suppressed) reported in stage 4 event message

- Wire MULTI_AGENT_ENABLED feature flag into diagnose_stream()
  - TURNSTONE_MULTI_AGENT_DIAGNOSE=true routes through run_pipeline()
  - pipeline emits its own done event; legacy path unchanged when flag is false
  - Import of run_pipeline added to __init__.py

- Add 21 new tests (350 -> 371 passing):
  - tests/test_diagnose_synthesizer.py: 8 tests (with/without LLM, suppressed,
    empty ranked, LLM failure fallback)
  - tests/test_diagnose_pipeline.py: 13 tests (flag off, flag on event sequence,
    empty entries, no LLM, stage 1 cluster count message)

Closes: #29
2026-05-25 14:56:25 -07:00

285 lines
9.3 KiB
Python

"""Tests for app/services/diagnose/synthesizer.py — SummarySynthesizer.
All tests use mocking; no real LLM calls are made.
"""
from __future__ import annotations
from unittest.mock import MagicMock, patch
from app.context.retriever import RetrievedContext
from app.services.diagnose.models import Hypothesis, RankedHypothesis, TimelineResult
from app.services.diagnose.synthesizer import SummarySynthesizer
# ---------------------------------------------------------------------------
# Fixture helpers
# ---------------------------------------------------------------------------
def _make_hypothesis(
hypothesis_id: str = "h1",
title: str = "SSH flood from external IPs",
description: str = "Repeated failed login attempts from multiple IPs.",
confidence: float = 0.87,
severity: str = "CRITICAL",
) -> Hypothesis:
return Hypothesis(
hypothesis_id=hypothesis_id,
title=title,
description=description,
confidence=confidence,
supporting_cluster_ids=("c1",),
runbook_refs=(),
severity=severity, # type: ignore[arg-type]
)
def _make_ranked(
hypothesis: Hypothesis | None = None,
novelty_score: float = 0.95,
similarity_to_known: float = 0.05,
suppress: bool = False,
suppression_reason: str | None = None,
) -> RankedHypothesis:
h = hypothesis or _make_hypothesis()
return RankedHypothesis(
hypothesis=h,
novelty_score=novelty_score,
similarity_to_known=similarity_to_known,
suppress=suppress,
suppression_reason=suppression_reason,
)
def _make_timeline(
total_entries: int = 42,
n_clusters: int = 3,
) -> TimelineResult:
return TimelineResult(
clusters=tuple(),
total_entries=total_entries,
window_start="2026-01-01T00:00:00+00:00",
window_end="2026-01-01T01:00:00+00:00",
gap_count=1,
burst_count=2,
dominant_sources=("syslog", "auth"),
)
def _make_ctx(chunks: list[dict] | None = None) -> RetrievedContext:
return RetrievedContext(
facts=[{"category": "network", "key": "host", "value": "heimdall", "source": "facts"}],
chunks=chunks or [{"filename": "runbook.md", "text": "Restart sshd if flooded"}],
)
# ---------------------------------------------------------------------------
# Test cases
# ---------------------------------------------------------------------------
class TestSynthesizerWithHypotheses:
"""With hypotheses, result must contain VERDICT."""
def test_returns_verdict_string_with_llm(self):
synthesizer = SummarySynthesizer()
ranked = [_make_ranked()]
timeline = _make_timeline()
ctx = _make_ctx()
mock_resp = MagicMock()
mock_resp.status_code = 200
mock_resp.json.return_value = {
"choices": [{"message": {"content": "VERDICT: CRITICAL — SSH flood (87% confidence)\nTIMELINE: lots of hits."}}]
}
with patch("httpx.post", return_value=mock_resp):
result = synthesizer.synthesize(
ranked=ranked,
timeline=timeline,
ctx=ctx,
query="ssh brute force",
llm_url="http://localhost:11434",
llm_model="llama3",
)
assert "VERDICT" in result
def test_returns_nonempty_string(self):
synthesizer = SummarySynthesizer()
ranked = [_make_ranked()]
timeline = _make_timeline()
ctx = _make_ctx()
mock_resp = MagicMock()
mock_resp.status_code = 200
mock_resp.json.return_value = {
"choices": [{"message": {"content": "VERDICT: CRITICAL — SSH flood (87% confidence)"}}]
}
with patch("httpx.post", return_value=mock_resp):
result = synthesizer.synthesize(
ranked=ranked,
timeline=timeline,
ctx=ctx,
query="why is auth failing",
llm_url="http://localhost:11434",
llm_model="llama3",
)
assert isinstance(result, str)
assert len(result) > 0
class TestSynthesizerSuppressedHypotheses:
"""Suppressed hypotheses must be excluded from the LLM prompt."""
def test_suppressed_hypotheses_excluded_from_prompt(self):
suppressed = _make_ranked(
hypothesis=_make_hypothesis(
hypothesis_id="h2",
title="Wazuh alert processing backlog",
severity="ERROR",
confidence=0.72,
),
suppress=True,
suppression_reason="similar to 2025-04 SSH incident",
novelty_score=0.1,
)
active = _make_ranked(
hypothesis=_make_hypothesis(
hypothesis_id="h1",
title="SSH flood from external IPs",
severity="CRITICAL",
confidence=0.87,
),
suppress=False,
novelty_score=0.95,
)
captured_messages: list = []
def fake_post(url, json=None, headers=None, timeout=None):
if json and "payload" in json:
captured_messages.extend(json["payload"].get("messages", []))
elif json and "messages" in json:
captured_messages.extend(json.get("messages", []))
mock_resp = MagicMock()
mock_resp.status_code = 200
mock_resp.json.return_value = {
"choices": [{"message": {"content": "VERDICT: CRITICAL — SSH flood"}}]
}
return mock_resp
synthesizer = SummarySynthesizer()
with patch("httpx.post", side_effect=fake_post):
synthesizer.synthesize(
ranked=[active, suppressed],
timeline=_make_timeline(),
ctx=_make_ctx(),
query="auth failures",
llm_url="http://localhost:11434",
llm_model="llama3",
)
# The user message should contain the active hypothesis title
# and NOT contain the suppressed one (or mark it suppressed)
user_content = next(
(m["content"] for m in captured_messages if m.get("role") == "user"), ""
)
assert "SSH flood from external IPs" in user_content
# Wazuh should not appear as a standalone top-level hypothesis
# (suppressed items are excluded from the active list sent to the LLM)
assert "Wazuh alert processing backlog" not in user_content
class TestSynthesizerNoLLM:
"""No LLM configured: must return deterministic fallback (not empty)."""
def test_no_llm_url_returns_fallback(self):
synthesizer = SummarySynthesizer()
ranked = [_make_ranked()]
timeline = _make_timeline()
ctx = _make_ctx()
result = synthesizer.synthesize(
ranked=ranked,
timeline=timeline,
ctx=ctx,
query="disk errors",
)
assert isinstance(result, str)
assert len(result) > 0
assert "VERDICT" in result
def test_no_llm_model_returns_fallback(self):
synthesizer = SummarySynthesizer()
ranked = [_make_ranked()]
result = synthesizer.synthesize(
ranked=ranked,
timeline=_make_timeline(),
ctx=_make_ctx(),
query="oom killer",
llm_url="http://localhost:11434",
# llm_model omitted
)
assert "VERDICT" in result
assert "SSH flood from external IPs" in result
def test_llm_failure_returns_fallback(self):
synthesizer = SummarySynthesizer()
ranked = [_make_ranked()]
with patch("httpx.post", side_effect=ConnectionError("refused")):
result = synthesizer.synthesize(
ranked=ranked,
timeline=_make_timeline(),
ctx=_make_ctx(),
query="why is disk full",
llm_url="http://localhost:11434",
llm_model="llama3",
)
assert "VERDICT" in result
assert len(result) > 0
class TestSynthesizerEmptyRanked:
"""Empty ranked list: must return deterministic fallback text, not raise."""
def test_empty_ranked_no_llm_returns_fallback(self):
synthesizer = SummarySynthesizer()
result = synthesizer.synthesize(
ranked=[],
timeline=_make_timeline(),
ctx=_make_ctx(),
query="check everything",
)
assert isinstance(result, str)
assert len(result) > 0
assert "VERDICT" in result
def test_empty_ranked_with_llm_returns_fallback_or_llm_text(self):
"""Even with empty ranked, we attempt LLM and return something."""
synthesizer = SummarySynthesizer()
mock_resp = MagicMock()
mock_resp.status_code = 200
mock_resp.json.return_value = {
"choices": [{"message": {"content": "VERDICT: UNKNOWN — no hypotheses generated"}}]
}
with patch("httpx.post", return_value=mock_resp):
result = synthesizer.synthesize(
ranked=[],
timeline=_make_timeline(),
ctx=_make_ctx(),
query="nothing found",
llm_url="http://localhost:11434",
llm_model="llama3",
)
assert isinstance(result, str)
assert len(result) > 0