"""Anomaly scoring pipeline — batch-score log_entries with a HF classifier. Designed to run after each glean cycle (or standalone). When no model is configured the scorer is a no-op and returns immediately, so it is always safe to wire into the glean pipeline. Model: any HuggingFace text-classification model. The existing Hybrid-BERT label map (from diagnose/classifier.py) is reused when the model produces NORMAL/SECURITY_ANOMALY/… outputs; other models get a generic severity map. Scoring strategy ---------------- - Query unscored rows in batches (WHERE anomaly_scored_at IS NULL) - Run each entry text through the HF pipeline - Write anomaly_score + anomaly_label + anomaly_scored_at back - INSERT high-confidence hits (score >= threshold) into detections table, skipping duplicates so the scorer is safe to re-run """ from __future__ import annotations import logging import os import uuid from dataclasses import dataclass from datetime import datetime, timezone from pathlib import Path from typing import Any from app.db import get_conn, resolve_tenant_id from app.db.dialect import q logger = logging.getLogger(__name__) # --------------------------------------------------------------------------- # Label maps — reuse Hybrid-BERT vocabulary from diagnose/classifier.py # --------------------------------------------------------------------------- _HYBRID_BERT_SEVERITY: dict[str, str] = { "NORMAL": "INFO", "SECURITY_ANOMALY": "ERROR", "SYSTEM_FAILURE": "CRITICAL", "PERFORMANCE_ISSUE": "WARN", "NETWORK_ANOMALY": "WARN", "CONFIG_ERROR": "ERROR", "HARDWARE_ISSUE": "CRITICAL", } _GENERIC_SEVERITY: dict[str, str] = { "CRITICAL": "CRITICAL", "ERROR": "ERROR", "WARNING": "WARN", "WARN": "WARN", "INFO": "INFO", "DEBUG": "DEBUG", } _ANOMALOUS_LABELS: frozenset[str] = frozenset( { "SECURITY_ANOMALY", "SYSTEM_FAILURE", "PERFORMANCE_ISSUE", "NETWORK_ANOMALY", "CONFIG_ERROR", "HARDWARE_ISSUE", "CRITICAL", "ERROR", } ) _DEFAULT_THRESHOLD = float(os.environ.get("TURNSTONE_ANOMALY_THRESHOLD", "0.75")) _DEFAULT_MODEL = os.environ.get("TURNSTONE_ANOMALY_MODEL", "") _DEFAULT_DEVICE = os.environ.get("TURNSTONE_ANOMALY_DEVICE", "cpu") _DEFAULT_BATCH = int(os.environ.get("TURNSTONE_ANOMALY_BATCH", "256")) # --------------------------------------------------------------------------- # ML singleton # --------------------------------------------------------------------------- _pipeline: Any | None = None def _get_pipeline(model_id: str, device: str) -> Any: global _pipeline # noqa: PLW0603 if _pipeline is None: from transformers import pipeline as hf_pipeline # type: ignore[import-untyped] _pipeline = hf_pipeline("text-classification", model=model_id, device=device) return _pipeline def reset_pipeline() -> None: """Reset the cached pipeline singleton (test helper).""" global _pipeline # noqa: PLW0603 _pipeline = None # --------------------------------------------------------------------------- # Result types # --------------------------------------------------------------------------- @dataclass class ScoringResult: scored: int = 0 detections: int = 0 skipped: bool = False error: str | None = None # --------------------------------------------------------------------------- # Internal helpers # --------------------------------------------------------------------------- def _map_label(raw_label: str, score: float) -> tuple[str, str]: """Return (normalised_label, severity) for a raw model output label.""" upper = raw_label.upper() if upper in _HYBRID_BERT_SEVERITY: return upper, _HYBRID_BERT_SEVERITY[upper] sev = _GENERIC_SEVERITY.get(upper, "WARN") return upper, sev def _fetch_unscored(conn: Any, tenant_id: str, limit: int) -> list[dict]: rows = conn.execute( q(""" SELECT id, source_id, text, timestamp_iso, severity FROM log_entries WHERE anomaly_scored_at IS NULL AND (tenant_id = ? OR tenant_id = '') ORDER BY ingest_time DESC LIMIT ? """), (tenant_id, limit), ).fetchall() return [dict(r) for r in rows] def _write_scores( conn: Any, rows: list[dict], scored_at: str, ) -> None: conn.executemany( q("UPDATE log_entries SET anomaly_score = ?, anomaly_label = ?, anomaly_scored_at = ? WHERE id = ?"), [(r["anomaly_score"], r["anomaly_label"], scored_at, r["id"]) for r in rows], ) def _insert_detections(conn: Any, rows: list[dict], tenant_id: str, detected_at: str) -> int: inserted = 0 for r in rows: try: conn.execute( q(""" INSERT INTO detections (id, tenant_id, entry_id, source_id, anomaly_label, anomaly_score, severity, text, timestamp_iso, detected_at) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?) """), ( str(uuid.uuid4()), tenant_id, r["id"], r["source_id"], r["anomaly_label"], r["anomaly_score"], r["severity"], r["text"][:2000], r.get("timestamp_iso"), detected_at, ), ) inserted += 1 except Exception: # noqa: BLE001 pass # duplicate entry_id or constraint violation — skip return inserted # --------------------------------------------------------------------------- # Public API # --------------------------------------------------------------------------- def score_unscored( db_path: Path, model_id: str = _DEFAULT_MODEL, device: str = _DEFAULT_DEVICE, batch_size: int = _DEFAULT_BATCH, threshold: float = _DEFAULT_THRESHOLD, ) -> ScoringResult: """Score all unscored log_entries in batches. Returns immediately (skipped=True) when model_id is empty — allows unconditional wiring without requiring the model to be configured. """ if not model_id: return ScoringResult(skipped=True) try: pipe = _get_pipeline(model_id, device) except Exception as exc: logger.error("Failed to load anomaly model %r: %s", model_id, exc) return ScoringResult(error=str(exc)) tenant_id = resolve_tenant_id() total_scored = 0 total_detections = 0 while True: with get_conn(db_path) as conn: batch = _fetch_unscored(conn, tenant_id, batch_size) if not batch: break texts = [r["text"][:512] for r in batch] try: predictions = pipe(texts, truncation=True, max_length=512) except Exception as exc: logger.error("Inference error on batch of %d: %s", len(batch), exc) return ScoringResult(scored=total_scored, detections=total_detections, error=str(exc)) scored_at = datetime.now(tz=timezone.utc).isoformat() scored_rows: list[dict] = [] detection_rows: list[dict] = [] for row, pred in zip(batch, predictions): label, severity = _map_label(pred["label"], pred["score"]) enriched = {**row, "anomaly_score": pred["score"], "anomaly_label": label, "severity": severity} scored_rows.append(enriched) if label in _ANOMALOUS_LABELS and pred["score"] >= threshold: detection_rows.append(enriched) with get_conn(db_path) as conn: _write_scores(conn, scored_rows, scored_at) det_count = _insert_detections(conn, detection_rows, tenant_id, scored_at) conn.commit() total_scored += len(scored_rows) total_detections += det_count logger.info( "Scored %d entries, %d detections (threshold=%.2f)", len(scored_rows), det_count, threshold, ) if len(batch) < batch_size: break return ScoringResult(scored=total_scored, detections=total_detections) def list_detections( db_path: Path, limit: int = 100, unacked_only: bool = False, label: str | None = None, scorer: str | None = None, ) -> list[dict]: """Return detections ordered by detected_at DESC.""" tenant_id = resolve_tenant_id() conditions = ["(tenant_id = ? OR tenant_id = '')"] params: list[Any] = [tenant_id] if unacked_only: conditions.append("acknowledged = 0") if label: conditions.append(q("anomaly_label = ?")) params.append(label.upper()) if scorer: conditions.append(q("scorer = ?")) params.append(scorer.lower()) where = " AND ".join(conditions) with get_conn(db_path) as conn: rows = conn.execute( q(f"SELECT * FROM detections WHERE {where} ORDER BY detected_at DESC LIMIT ?"), # noqa: S608 (*params, limit), ).fetchall() return [dict(r) for r in rows] def acknowledge_detection(db_path: Path, detection_id: str, notes: str = "") -> bool: """Mark a detection as acknowledged. Returns True if a row was updated.""" tenant_id = resolve_tenant_id() acked_at = datetime.now(tz=timezone.utc).isoformat() with get_conn(db_path) as conn: cur = conn.execute( q(""" UPDATE detections SET acknowledged = 1, acknowledged_at = ?, notes = ? WHERE id = ? AND (tenant_id = ? OR tenant_id = '') """), (acked_at, notes, detection_id, tenant_id), ) conn.commit() return cur.rowcount > 0