turnstone/app/services/diagnose/classifier.py
pyr0ball 2c83247f1e feat(classifier): add Hybrid-BERT label mapping shim (#41)
Adds _HYBRID_BERT_LABEL_MAP to translate the 7-class output vocabulary of
krishnas4415/log-anomaly-detection-models (Hybrid-BERT, MIT) to Turnstone
SeverityLabel. _map_label now checks the Hybrid-BERT map before the standard
map so either model family works via TURNSTONE_CLASSIFIER_MODEL without any
additional code path.

Mapping (confirmed from model config.json):
  normal            → INFO
  security_anomaly  → ERROR
  system_failure    → CRITICAL
  performance_issue → WARN
  network_anomaly   → WARN
  config_error      → ERROR
  hardware_issue    → CRITICAL

Keyword-based CRITICAL promotion and low-confidence DEBUG demotion apply on
top of the base mapping (same rules as the standard vocabulary).

11 new tests covering all 7 Hybrid-BERT labels, case-insensitivity, and
regression on standard-vocabulary labels. 372 tests passing total.

Note: custom loading code for the non-standard .pt checkpoint format is
explicitly out of scope — evaluate better-packaged HF alternatives first
(see #41 for candidate list).

Closes: #41
2026-06-01 16:20:31 -07:00

274 lines
9.3 KiB
Python

"""Stage 2: Severity Classifier — ML with two fallback levels.
Classification strategy (in priority order):
Path A — ML: Hugging Face text-classification pipeline, loaded lazily.
Path B — pattern_tags: Map cluster.pattern_tags through the loaded pattern
severity dict; pick the highest severity across matching tags.
Path C — regex: Call detect_severity() from app.glean.base on the cluster's
representative_text.
Each cluster is classified independently. The ``classifier_used`` field on the
returned ``ClassifiedTimeline`` reflects the primary path (the one that governed
the overall classification session, not individual cluster fallbacks).
"""
from __future__ import annotations
import logging
import os
from pathlib import Path
from typing import Any
from types import MappingProxyType
from app.services.diagnose.models import (
ClassifiedTimeline,
EventCluster,
SeverityLabel,
TimelineResult,
)
logger = logging.getLogger(__name__)
# ---------------------------------------------------------------------------
# Module-level ML singleton — reset to None between tests via the fixture
# ---------------------------------------------------------------------------
_ml_classifier: Any | None = None
def _get_ml_classifier(model_id: str, device: str) -> Any:
"""Return the cached HF pipeline, loading it on first call."""
global _ml_classifier # noqa: PLW0603
if _ml_classifier is None:
from transformers import pipeline as hf_pipeline # type: ignore[import-untyped]
_ml_classifier = hf_pipeline(
"text-classification", model=model_id, device=device
)
return _ml_classifier
# ---------------------------------------------------------------------------
# Label mapping
# ---------------------------------------------------------------------------
_LABEL_MAP: dict[str, SeverityLabel] = {
"ERROR": "ERROR",
"WARNING": "WARN",
"WARN": "WARN",
"INFO": "INFO",
"DEBUG": "DEBUG",
"CRITICAL": "CRITICAL",
}
# Label shim for krishnas4415/log-anomaly-detection-models (Hybrid-BERT, MIT).
# Maps the model's 7-class output vocabulary to Turnstone SeverityLabel.
# Checked against the model config.json — labels confirmed in turnstone#41.
_HYBRID_BERT_LABEL_MAP: dict[str, SeverityLabel] = {
"NORMAL": "INFO",
"SECURITY_ANOMALY": "ERROR",
"SYSTEM_FAILURE": "CRITICAL",
"PERFORMANCE_ISSUE": "WARN",
"NETWORK_ANOMALY": "WARN",
"CONFIG_ERROR": "ERROR",
"HARDWARE_ISSUE": "CRITICAL",
}
_CRITICAL_KEYWORDS: frozenset[str] = frozenset(
{
"panic",
"oom",
"fatal",
"critical",
"kernel panic",
"out of memory",
"segfault",
"segmentation fault",
}
)
_SEVERITY_ORDER: dict[str | None, int] = {
"CRITICAL": 5,
"ERROR": 4,
"WARN": 3,
"WARNING": 3,
"INFO": 2,
"DEBUG": 1,
None: 0,
}
def _map_label(label: str, score: float, text: str) -> SeverityLabel:
"""Translate a raw model output label to a Turnstone SeverityLabel.
Handles two model vocabularies:
- Standard (ERROR/WARN/INFO/CRITICAL/DEBUG) — byviz/bylastic_classification_logs
- Hybrid-BERT (normal/security_anomaly/…) — krishnas4415/log-anomaly-detection-models
Applies keyword-based CRITICAL promotion and low-confidence DEBUG demotion
on top of the base mapping.
"""
upper = label.upper()
# Resolve via Hybrid-BERT map first, then standard map, then UNKNOWN.
base: SeverityLabel = _HYBRID_BERT_LABEL_MAP.get(upper) or _LABEL_MAP.get(upper, "UNKNOWN") # type: ignore[assignment]
if base == "ERROR" and score > 0.95 and any(
k in text.lower() for k in _CRITICAL_KEYWORDS
):
return "CRITICAL"
if base == "INFO" and score < 0.4:
return "DEBUG"
return base
def _highest_from_tags(
tags: tuple[str, ...], severity_map: dict[str, str]
) -> SeverityLabel | None:
"""Return the highest severity from the pattern_tags that appear in severity_map."""
best: str | None = None
best_rank = -1
for tag in tags:
sev = severity_map.get(tag)
rank = _SEVERITY_ORDER.get(sev, 0)
if rank > best_rank:
best_rank = rank
best = sev
if best is None:
return None
normalised = "WARN" if best.upper() == "WARNING" else best.upper()
return normalised # type: ignore[return-value]
# ---------------------------------------------------------------------------
# SeverityClassifier
# ---------------------------------------------------------------------------
class SeverityClassifier:
"""Classify each EventCluster's severity using ML, patterns, or regex fallback.
Parameters
----------
model_id:
Hugging Face model identifier. When empty (default), ML is skipped.
device:
Torch device string passed to the HF pipeline (e.g. ``"cpu"`` or ``"cuda:0"``).
pattern_file:
Path to the YAML pattern file. When ``None`` the classifier reads
``TURNSTONE_PATTERNS`` env var (same logic as ``app/rest.py``).
"""
def __init__(
self,
model_id: str = "",
device: str = "cpu",
pattern_file: Path | None = None,
) -> None:
self._model_id = model_id
self._device = device
self._pattern_file: Path | None = pattern_file
self._pattern_severity: dict[str, str] = {}
self._patterns_loaded = False
# ------------------------------------------------------------------
# Lazy loaders
# ------------------------------------------------------------------
def _resolve_pattern_file(self) -> Path | None:
"""Resolve pattern file from constructor arg or env var."""
if self._pattern_file is not None:
return self._pattern_file
env_dir = os.environ.get("TURNSTONE_PATTERNS")
if env_dir:
return Path(env_dir) / "default.yaml"
return None
def _ensure_patterns_loaded(self) -> None:
"""Populate _pattern_severity from the pattern YAML file (once)."""
if self._patterns_loaded:
return
self._patterns_loaded = True
path = self._resolve_pattern_file()
if path is None:
return
from app.glean.base import load_patterns
patterns = load_patterns(path)
self._pattern_severity = {p.name: p.severity for p in patterns}
# ------------------------------------------------------------------
# Per-cluster classification helpers
# ------------------------------------------------------------------
def _classify_cluster_ml(self, cluster: EventCluster) -> SeverityLabel | None:
"""Attempt ML classification. Returns None on any inference failure."""
try:
pipe = _get_ml_classifier(self._model_id, self._device)
results = pipe(cluster.representative_text)
if not results:
return None
hit = results[0]
return _map_label(hit["label"], hit["score"], cluster.representative_text)
except Exception: # noqa: BLE001
logger.warning(
"ML inference failed for cluster %s — falling back",
cluster.cluster_id,
)
return None
def _classify_cluster_pattern_tags(
self, cluster: EventCluster
) -> SeverityLabel | None:
"""Derive severity from the cluster's pattern_tags. Returns None if no match."""
return _highest_from_tags(cluster.pattern_tags, self._pattern_severity)
def _classify_cluster_regex(self, cluster: EventCluster) -> SeverityLabel:
"""Classify by scanning representative_text with the severity regex."""
from app.glean.base import detect_severity
raw = detect_severity(cluster.representative_text)
if raw is None:
return "INFO"
return _LABEL_MAP.get(raw.upper(), "INFO") # type: ignore[return-value]
# ------------------------------------------------------------------
# Public API
# ------------------------------------------------------------------
def classify(self, timeline: TimelineResult) -> ClassifiedTimeline:
"""Classify every cluster in *timeline* and return a ``ClassifiedTimeline``."""
self._ensure_patterns_loaded()
# Determine which primary path governs this session
ml_available = bool(self._model_id)
patterns_available = bool(self._pattern_severity)
if ml_available:
classifier_used: str = "ml"
elif patterns_available:
classifier_used = "pattern_tags"
else:
classifier_used = "regex"
cluster_severities: dict[str, SeverityLabel] = {}
for cluster in timeline.clusters:
severity: SeverityLabel | None = None
if ml_available:
severity = self._classify_cluster_ml(cluster)
if severity is None and patterns_available:
severity = self._classify_cluster_pattern_tags(cluster)
if severity is None:
severity = self._classify_cluster_regex(cluster)
cluster_severities[cluster.cluster_id] = severity
return ClassifiedTimeline(
timeline=timeline,
cluster_severities=MappingProxyType(cluster_severities),
classifier_used=classifier_used, # type: ignore[arg-type]
model_id=self._model_id if ml_available else None,
)