turnstone/tests/test_diagnose_classifier.py
pyr0ball ae0ecac17d feat(classifier): add Hybrid-BERT label mapping shim (#41)
Adds _HYBRID_BERT_LABEL_MAP to translate the 7-class output vocabulary of
krishnas4415/log-anomaly-detection-models (Hybrid-BERT, MIT) to Turnstone
SeverityLabel. _map_label now checks the Hybrid-BERT map before the standard
map so either model family works via TURNSTONE_CLASSIFIER_MODEL without any
additional code path.

Mapping (confirmed from model config.json):
  normal            → INFO
  security_anomaly  → ERROR
  system_failure    → CRITICAL
  performance_issue → WARN
  network_anomaly   → WARN
  config_error      → ERROR
  hardware_issue    → CRITICAL

Keyword-based CRITICAL promotion and low-confidence DEBUG demotion apply on
top of the base mapping (same rules as the standard vocabulary).

11 new tests covering all 7 Hybrid-BERT labels, case-insensitivity, and
regression on standard-vocabulary labels. 372 tests passing total.

Note: custom loading code for the non-standard .pt checkpoint format is
explicitly out of scope — evaluate better-packaged HF alternatives first
(see #41 for candidate list).

Closes: #41
2026-06-01 16:20:31 -07:00

299 lines
11 KiB
Python

"""Tests for app/services/diagnose/classifier.py — SeverityClassifier.
All ML-path tests mock ``transformers.pipeline`` so no model weights are
downloaded during the test suite.
"""
from __future__ import annotations
from dataclasses import FrozenInstanceError
from pathlib import Path
from typing import Any
from unittest.mock import MagicMock, patch
import pytest
import app.services.diagnose.classifier as clf_module
from app.services.diagnose.classifier import SeverityClassifier
from app.services.diagnose.models import ClassifiedTimeline, EventCluster, TimelineResult
# ---------------------------------------------------------------------------
# Fixtures
# ---------------------------------------------------------------------------
@pytest.fixture(autouse=True)
def reset_ml_singleton():
"""Ensure the module-level ML singleton is cleared before and after each test."""
clf_module._ml_classifier = None
yield
clf_module._ml_classifier = None
# ---------------------------------------------------------------------------
# Test-object builders
# ---------------------------------------------------------------------------
def _make_cluster(
representative_text: str = "test log",
pattern_tags: tuple[str, ...] = (),
severity: str = "INFO",
) -> EventCluster:
return EventCluster(
cluster_id="abc123",
entries=("e1",),
start_iso=None,
end_iso=None,
duration_seconds=0.0,
source_ids=("src",),
pattern_tags=pattern_tags,
severity=severity, # type: ignore[arg-type]
burst=False,
gap_before_seconds=0.0,
representative_text=representative_text,
)
def _make_timeline(clusters: tuple[EventCluster, ...] = ()) -> TimelineResult:
return TimelineResult(
clusters=clusters,
total_entries=0,
window_start=None,
window_end=None,
gap_count=0,
burst_count=0,
dominant_sources=(),
)
def _mock_hf_pipeline(label: str, score: float) -> MagicMock:
"""Return a mock HF pipeline callable that always yields one result."""
pipe = MagicMock()
pipe.return_value = [{"label": label, "score": score}]
return pipe
# ---------------------------------------------------------------------------
# Path A — ML classification
# ---------------------------------------------------------------------------
class TestMLPath:
def test_ml_error_maps_to_error(self) -> None:
"""ML returning ERROR with score 0.98 → cluster severity ERROR."""
pipe = _mock_hf_pipeline("ERROR", 0.98)
with patch(
"app.services.diagnose.classifier._get_ml_classifier", return_value=pipe
):
clf = SeverityClassifier(model_id="fake/model")
result = clf.classify(_make_timeline(((_make_cluster("disk error detected")),)))
assert result.cluster_severities["abc123"] == "ERROR"
assert result.classifier_used == "ml"
assert result.model_id == "fake/model"
def test_ml_critical_promotion(self) -> None:
"""ERROR + score > 0.95 + 'kernel panic' in text → promoted to CRITICAL."""
pipe = _mock_hf_pipeline("ERROR", 0.97)
with patch(
"app.services.diagnose.classifier._get_ml_classifier", return_value=pipe
):
clf = SeverityClassifier(model_id="fake/model")
result = clf.classify(
_make_timeline((_make_cluster("kernel panic: not syncing VFS"),))
)
assert result.cluster_severities["abc123"] == "CRITICAL"
def test_ml_debug_demotion(self) -> None:
"""INFO + score < 0.4 → demoted to DEBUG."""
pipe = _mock_hf_pipeline("INFO", 0.3)
with patch(
"app.services.diagnose.classifier._get_ml_classifier", return_value=pipe
):
clf = SeverityClassifier(model_id="fake/model")
result = clf.classify(_make_timeline((_make_cluster("routine ping"),)))
assert result.cluster_severities["abc123"] == "DEBUG"
def test_ml_warning_maps_to_warn(self) -> None:
"""ML returning WARNING → mapped to WARN."""
pipe = _mock_hf_pipeline("WARNING", 0.85)
with patch(
"app.services.diagnose.classifier._get_ml_classifier", return_value=pipe
):
clf = SeverityClassifier(model_id="fake/model")
result = clf.classify(_make_timeline((_make_cluster("low disk space"),)))
assert result.cluster_severities["abc123"] == "WARN"
# ---------------------------------------------------------------------------
# Path B — pattern_tags fallback
# ---------------------------------------------------------------------------
class TestPatternTagsPath:
def test_pattern_tags_resolve_error_severity(self, tmp_path: Path) -> None:
"""Cluster with pattern_tag 'service_crash_loop' → ERROR from pattern file."""
pattern_yaml = tmp_path / "default.yaml"
pattern_yaml.write_text(
"patterns:\n"
" - name: service_crash_loop\n"
" pattern: crash\n"
" severity: ERROR\n"
" description: Service crashed in a loop\n"
)
clf = SeverityClassifier(model_id="", pattern_file=pattern_yaml)
cluster = _make_cluster(
representative_text="service crashed",
pattern_tags=("service_crash_loop",),
)
result = clf.classify(_make_timeline((cluster,)))
assert result.cluster_severities["abc123"] == "ERROR"
assert result.classifier_used == "pattern_tags"
assert result.model_id is None
# ---------------------------------------------------------------------------
# Path C — regex fallback
# ---------------------------------------------------------------------------
class TestRegexPath:
def test_regex_detects_error(self) -> None:
"""No ML, no pattern file: 'ERROR: disk full' → ERROR via regex."""
clf = SeverityClassifier(model_id="")
result = clf.classify(
_make_timeline((_make_cluster("ERROR: disk full"),))
)
assert result.cluster_severities["abc123"] == "ERROR"
assert result.classifier_used == "regex"
def test_regex_defaults_to_info_when_no_match(self) -> None:
"""No severity keyword in text → defaults to INFO."""
clf = SeverityClassifier(model_id="")
result = clf.classify(
_make_timeline((_make_cluster("mount: disk mounted successfully"),))
)
assert result.cluster_severities["abc123"] == "INFO"
# ---------------------------------------------------------------------------
# Fallback behaviour
# ---------------------------------------------------------------------------
class TestImportErrorFallback:
def test_transformers_import_error_falls_back_to_pattern_tags(
self, tmp_path: Path
) -> None:
"""ImportError from transformers → clean fallback to pattern_tags path."""
pattern_yaml = tmp_path / "default.yaml"
pattern_yaml.write_text(
"patterns:\n"
" - name: auth_failure\n"
" pattern: auth\n"
" severity: ERROR\n"
" description: Auth failure\n"
)
def _raising_get_ml(*_args: Any, **_kwargs: Any) -> None:
raise ImportError("No module named 'transformers'")
with patch(
"app.services.diagnose.classifier._get_ml_classifier",
side_effect=_raising_get_ml,
):
clf = SeverityClassifier(model_id="fake/model", pattern_file=pattern_yaml)
cluster = _make_cluster(
representative_text="auth failed",
pattern_tags=("auth_failure",),
)
result = clf.classify(_make_timeline((cluster,)))
# ML was attempted (classifier_used == "ml") but pattern_tags resolved it
assert result.classifier_used == "ml"
assert result.cluster_severities["abc123"] == "ERROR"
# ---------------------------------------------------------------------------
# Edge cases
# ---------------------------------------------------------------------------
class TestEdgeCases:
def test_empty_timeline_produces_empty_severities(self) -> None:
"""TimelineResult with no clusters → empty cluster_severities, no crash."""
clf = SeverityClassifier(model_id="")
result = clf.classify(_make_timeline())
assert isinstance(result, ClassifiedTimeline)
assert result.cluster_severities == {}
assert result.classifier_used == "regex"
def test_classified_timeline_is_frozen(self) -> None:
"""ClassifiedTimeline must be frozen (FrozenInstanceError on mutation)."""
clf = SeverityClassifier(model_id="")
result = clf.classify(_make_timeline((_make_cluster(),)))
with pytest.raises(FrozenInstanceError):
result.classifier_used = "ml" # type: ignore[misc]
# ---------------------------------------------------------------------------
# Hybrid-BERT label mapping shim (turnstone#41)
# ---------------------------------------------------------------------------
class TestHybridBertLabelMap:
"""_map_label must translate Hybrid-BERT vocabulary to SeverityLabel."""
def _run(self, label: str, score: float = 0.9, text: str = "log line") -> str:
from app.services.diagnose.classifier import _map_label
return _map_label(label, score, text)
def test_normal_maps_to_info(self) -> None:
assert self._run("normal") == "INFO"
def test_security_anomaly_maps_to_error(self) -> None:
assert self._run("security_anomaly") == "ERROR"
def test_system_failure_maps_to_critical(self) -> None:
assert self._run("system_failure") == "CRITICAL"
def test_performance_issue_maps_to_warn(self) -> None:
assert self._run("performance_issue") == "WARN"
def test_network_anomaly_maps_to_warn(self) -> None:
assert self._run("network_anomaly") == "WARN"
def test_config_error_maps_to_error(self) -> None:
assert self._run("config_error") == "ERROR"
def test_hardware_issue_maps_to_critical(self) -> None:
assert self._run("hardware_issue") == "CRITICAL"
def test_hybrid_bert_labels_are_case_insensitive(self) -> None:
from app.services.diagnose.classifier import _map_label
assert _map_label("SECURITY_ANOMALY", 0.9, "x") == "ERROR"
assert _map_label("Security_Anomaly", 0.9, "x") == "ERROR"
def test_system_failure_critical_promotion_not_doubled(self) -> None:
"""system_failure already maps to CRITICAL — keyword promotion is a no-op."""
assert self._run("system_failure", score=0.99, text="kernel panic") == "CRITICAL"
def test_normal_low_confidence_demotes_to_debug(self) -> None:
"""normal + low score → INFO base → DEBUG (same demotion rule as INFO)."""
assert self._run("normal", score=0.2) == "DEBUG"
def test_standard_labels_still_work(self) -> None:
"""Existing standard-vocabulary labels must not be broken by the shim."""
from app.services.diagnose.classifier import _map_label
assert _map_label("ERROR", 0.9, "x") == "ERROR"
assert _map_label("WARNING", 0.9, "x") == "WARN"
assert _map_label("CRITICAL", 0.9, "x") == "CRITICAL"