Adds _HYBRID_BERT_LABEL_MAP to translate the 7-class output vocabulary of krishnas4415/log-anomaly-detection-models (Hybrid-BERT, MIT) to Turnstone SeverityLabel. _map_label now checks the Hybrid-BERT map before the standard map so either model family works via TURNSTONE_CLASSIFIER_MODEL without any additional code path. Mapping (confirmed from model config.json): normal → INFO security_anomaly → ERROR system_failure → CRITICAL performance_issue → WARN network_anomaly → WARN config_error → ERROR hardware_issue → CRITICAL Keyword-based CRITICAL promotion and low-confidence DEBUG demotion apply on top of the base mapping (same rules as the standard vocabulary). 11 new tests covering all 7 Hybrid-BERT labels, case-insensitivity, and regression on standard-vocabulary labels. 372 tests passing total. Note: custom loading code for the non-standard .pt checkpoint format is explicitly out of scope — evaluate better-packaged HF alternatives first (see #41 for candidate list). Closes: #41
299 lines
11 KiB
Python
299 lines
11 KiB
Python
"""Tests for app/services/diagnose/classifier.py — SeverityClassifier.
|
|
|
|
All ML-path tests mock ``transformers.pipeline`` so no model weights are
|
|
downloaded during the test suite.
|
|
"""
|
|
from __future__ import annotations
|
|
|
|
from dataclasses import FrozenInstanceError
|
|
from pathlib import Path
|
|
from typing import Any
|
|
from unittest.mock import MagicMock, patch
|
|
|
|
import pytest
|
|
|
|
import app.services.diagnose.classifier as clf_module
|
|
from app.services.diagnose.classifier import SeverityClassifier
|
|
from app.services.diagnose.models import ClassifiedTimeline, EventCluster, TimelineResult
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Fixtures
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
@pytest.fixture(autouse=True)
|
|
def reset_ml_singleton():
|
|
"""Ensure the module-level ML singleton is cleared before and after each test."""
|
|
clf_module._ml_classifier = None
|
|
yield
|
|
clf_module._ml_classifier = None
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Test-object builders
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
def _make_cluster(
|
|
representative_text: str = "test log",
|
|
pattern_tags: tuple[str, ...] = (),
|
|
severity: str = "INFO",
|
|
) -> EventCluster:
|
|
return EventCluster(
|
|
cluster_id="abc123",
|
|
entries=("e1",),
|
|
start_iso=None,
|
|
end_iso=None,
|
|
duration_seconds=0.0,
|
|
source_ids=("src",),
|
|
pattern_tags=pattern_tags,
|
|
severity=severity, # type: ignore[arg-type]
|
|
burst=False,
|
|
gap_before_seconds=0.0,
|
|
representative_text=representative_text,
|
|
)
|
|
|
|
|
|
def _make_timeline(clusters: tuple[EventCluster, ...] = ()) -> TimelineResult:
|
|
return TimelineResult(
|
|
clusters=clusters,
|
|
total_entries=0,
|
|
window_start=None,
|
|
window_end=None,
|
|
gap_count=0,
|
|
burst_count=0,
|
|
dominant_sources=(),
|
|
)
|
|
|
|
|
|
def _mock_hf_pipeline(label: str, score: float) -> MagicMock:
|
|
"""Return a mock HF pipeline callable that always yields one result."""
|
|
pipe = MagicMock()
|
|
pipe.return_value = [{"label": label, "score": score}]
|
|
return pipe
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Path A — ML classification
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
class TestMLPath:
|
|
def test_ml_error_maps_to_error(self) -> None:
|
|
"""ML returning ERROR with score 0.98 → cluster severity ERROR."""
|
|
pipe = _mock_hf_pipeline("ERROR", 0.98)
|
|
with patch(
|
|
"app.services.diagnose.classifier._get_ml_classifier", return_value=pipe
|
|
):
|
|
clf = SeverityClassifier(model_id="fake/model")
|
|
result = clf.classify(_make_timeline(((_make_cluster("disk error detected")),)))
|
|
|
|
assert result.cluster_severities["abc123"] == "ERROR"
|
|
assert result.classifier_used == "ml"
|
|
assert result.model_id == "fake/model"
|
|
|
|
def test_ml_critical_promotion(self) -> None:
|
|
"""ERROR + score > 0.95 + 'kernel panic' in text → promoted to CRITICAL."""
|
|
pipe = _mock_hf_pipeline("ERROR", 0.97)
|
|
with patch(
|
|
"app.services.diagnose.classifier._get_ml_classifier", return_value=pipe
|
|
):
|
|
clf = SeverityClassifier(model_id="fake/model")
|
|
result = clf.classify(
|
|
_make_timeline((_make_cluster("kernel panic: not syncing VFS"),))
|
|
)
|
|
|
|
assert result.cluster_severities["abc123"] == "CRITICAL"
|
|
|
|
def test_ml_debug_demotion(self) -> None:
|
|
"""INFO + score < 0.4 → demoted to DEBUG."""
|
|
pipe = _mock_hf_pipeline("INFO", 0.3)
|
|
with patch(
|
|
"app.services.diagnose.classifier._get_ml_classifier", return_value=pipe
|
|
):
|
|
clf = SeverityClassifier(model_id="fake/model")
|
|
result = clf.classify(_make_timeline((_make_cluster("routine ping"),)))
|
|
|
|
assert result.cluster_severities["abc123"] == "DEBUG"
|
|
|
|
def test_ml_warning_maps_to_warn(self) -> None:
|
|
"""ML returning WARNING → mapped to WARN."""
|
|
pipe = _mock_hf_pipeline("WARNING", 0.85)
|
|
with patch(
|
|
"app.services.diagnose.classifier._get_ml_classifier", return_value=pipe
|
|
):
|
|
clf = SeverityClassifier(model_id="fake/model")
|
|
result = clf.classify(_make_timeline((_make_cluster("low disk space"),)))
|
|
|
|
assert result.cluster_severities["abc123"] == "WARN"
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Path B — pattern_tags fallback
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
class TestPatternTagsPath:
|
|
def test_pattern_tags_resolve_error_severity(self, tmp_path: Path) -> None:
|
|
"""Cluster with pattern_tag 'service_crash_loop' → ERROR from pattern file."""
|
|
pattern_yaml = tmp_path / "default.yaml"
|
|
pattern_yaml.write_text(
|
|
"patterns:\n"
|
|
" - name: service_crash_loop\n"
|
|
" pattern: crash\n"
|
|
" severity: ERROR\n"
|
|
" description: Service crashed in a loop\n"
|
|
)
|
|
clf = SeverityClassifier(model_id="", pattern_file=pattern_yaml)
|
|
cluster = _make_cluster(
|
|
representative_text="service crashed",
|
|
pattern_tags=("service_crash_loop",),
|
|
)
|
|
result = clf.classify(_make_timeline((cluster,)))
|
|
|
|
assert result.cluster_severities["abc123"] == "ERROR"
|
|
assert result.classifier_used == "pattern_tags"
|
|
assert result.model_id is None
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Path C — regex fallback
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
class TestRegexPath:
|
|
def test_regex_detects_error(self) -> None:
|
|
"""No ML, no pattern file: 'ERROR: disk full' → ERROR via regex."""
|
|
clf = SeverityClassifier(model_id="")
|
|
result = clf.classify(
|
|
_make_timeline((_make_cluster("ERROR: disk full"),))
|
|
)
|
|
|
|
assert result.cluster_severities["abc123"] == "ERROR"
|
|
assert result.classifier_used == "regex"
|
|
|
|
def test_regex_defaults_to_info_when_no_match(self) -> None:
|
|
"""No severity keyword in text → defaults to INFO."""
|
|
clf = SeverityClassifier(model_id="")
|
|
result = clf.classify(
|
|
_make_timeline((_make_cluster("mount: disk mounted successfully"),))
|
|
)
|
|
|
|
assert result.cluster_severities["abc123"] == "INFO"
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Fallback behaviour
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
class TestImportErrorFallback:
|
|
def test_transformers_import_error_falls_back_to_pattern_tags(
|
|
self, tmp_path: Path
|
|
) -> None:
|
|
"""ImportError from transformers → clean fallback to pattern_tags path."""
|
|
pattern_yaml = tmp_path / "default.yaml"
|
|
pattern_yaml.write_text(
|
|
"patterns:\n"
|
|
" - name: auth_failure\n"
|
|
" pattern: auth\n"
|
|
" severity: ERROR\n"
|
|
" description: Auth failure\n"
|
|
)
|
|
|
|
def _raising_get_ml(*_args: Any, **_kwargs: Any) -> None:
|
|
raise ImportError("No module named 'transformers'")
|
|
|
|
with patch(
|
|
"app.services.diagnose.classifier._get_ml_classifier",
|
|
side_effect=_raising_get_ml,
|
|
):
|
|
clf = SeverityClassifier(model_id="fake/model", pattern_file=pattern_yaml)
|
|
cluster = _make_cluster(
|
|
representative_text="auth failed",
|
|
pattern_tags=("auth_failure",),
|
|
)
|
|
result = clf.classify(_make_timeline((cluster,)))
|
|
|
|
# ML was attempted (classifier_used == "ml") but pattern_tags resolved it
|
|
assert result.classifier_used == "ml"
|
|
assert result.cluster_severities["abc123"] == "ERROR"
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Edge cases
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
class TestEdgeCases:
|
|
def test_empty_timeline_produces_empty_severities(self) -> None:
|
|
"""TimelineResult with no clusters → empty cluster_severities, no crash."""
|
|
clf = SeverityClassifier(model_id="")
|
|
result = clf.classify(_make_timeline())
|
|
|
|
assert isinstance(result, ClassifiedTimeline)
|
|
assert result.cluster_severities == {}
|
|
assert result.classifier_used == "regex"
|
|
|
|
def test_classified_timeline_is_frozen(self) -> None:
|
|
"""ClassifiedTimeline must be frozen (FrozenInstanceError on mutation)."""
|
|
clf = SeverityClassifier(model_id="")
|
|
result = clf.classify(_make_timeline((_make_cluster(),)))
|
|
|
|
with pytest.raises(FrozenInstanceError):
|
|
result.classifier_used = "ml" # type: ignore[misc]
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Hybrid-BERT label mapping shim (turnstone#41)
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
class TestHybridBertLabelMap:
|
|
"""_map_label must translate Hybrid-BERT vocabulary to SeverityLabel."""
|
|
|
|
def _run(self, label: str, score: float = 0.9, text: str = "log line") -> str:
|
|
from app.services.diagnose.classifier import _map_label
|
|
return _map_label(label, score, text)
|
|
|
|
def test_normal_maps_to_info(self) -> None:
|
|
assert self._run("normal") == "INFO"
|
|
|
|
def test_security_anomaly_maps_to_error(self) -> None:
|
|
assert self._run("security_anomaly") == "ERROR"
|
|
|
|
def test_system_failure_maps_to_critical(self) -> None:
|
|
assert self._run("system_failure") == "CRITICAL"
|
|
|
|
def test_performance_issue_maps_to_warn(self) -> None:
|
|
assert self._run("performance_issue") == "WARN"
|
|
|
|
def test_network_anomaly_maps_to_warn(self) -> None:
|
|
assert self._run("network_anomaly") == "WARN"
|
|
|
|
def test_config_error_maps_to_error(self) -> None:
|
|
assert self._run("config_error") == "ERROR"
|
|
|
|
def test_hardware_issue_maps_to_critical(self) -> None:
|
|
assert self._run("hardware_issue") == "CRITICAL"
|
|
|
|
def test_hybrid_bert_labels_are_case_insensitive(self) -> None:
|
|
from app.services.diagnose.classifier import _map_label
|
|
assert _map_label("SECURITY_ANOMALY", 0.9, "x") == "ERROR"
|
|
assert _map_label("Security_Anomaly", 0.9, "x") == "ERROR"
|
|
|
|
def test_system_failure_critical_promotion_not_doubled(self) -> None:
|
|
"""system_failure already maps to CRITICAL — keyword promotion is a no-op."""
|
|
assert self._run("system_failure", score=0.99, text="kernel panic") == "CRITICAL"
|
|
|
|
def test_normal_low_confidence_demotes_to_debug(self) -> None:
|
|
"""normal + low score → INFO base → DEBUG (same demotion rule as INFO)."""
|
|
assert self._run("normal", score=0.2) == "DEBUG"
|
|
|
|
def test_standard_labels_still_work(self) -> None:
|
|
"""Existing standard-vocabulary labels must not be broken by the shim."""
|
|
from app.services.diagnose.classifier import _map_label
|
|
assert _map_label("ERROR", 0.9, "x") == "ERROR"
|
|
assert _map_label("WARNING", 0.9, "x") == "WARN"
|
|
assert _map_label("CRITICAL", 0.9, "x") == "CRITICAL"
|