"""Anomaly scoring pipeline — batch-score log_entries with a HF classifier. Designed to run after each glean cycle (or standalone). When no model is configured the scorer is a no-op and returns immediately, so it is always safe to wire into the glean pipeline. Model: any HuggingFace text-classification model. The existing Hybrid-BERT label map (from diagnose/classifier.py) is reused when the model produces NORMAL/SECURITY_ANOMALY/… outputs; other models get a generic severity map. Scoring strategy ---------------- - Query unscored rows in batches (WHERE anomaly_scored_at IS NULL) - Run each entry text through the HF pipeline - Write anomaly_score + anomaly_label + anomaly_scored_at back - INSERT high-confidence hits (score >= threshold) into detections table, skipping duplicates so the scorer is safe to re-run """ from __future__ import annotations import logging import os import time import uuid from dataclasses import dataclass from datetime import datetime, timezone from pathlib import Path from typing import Any from app.db import get_conn, resolve_tenant_id from app.db.dialect import q logger = logging.getLogger(__name__) # --------------------------------------------------------------------------- # Label maps — reuse Hybrid-BERT vocabulary from diagnose/classifier.py # --------------------------------------------------------------------------- _HYBRID_BERT_SEVERITY: dict[str, str] = { "NORMAL": "INFO", "SECURITY_ANOMALY": "ERROR", "SYSTEM_FAILURE": "CRITICAL", "PERFORMANCE_ISSUE": "WARN", "NETWORK_ANOMALY": "WARN", "CONFIG_ERROR": "ERROR", "HARDWARE_ISSUE": "CRITICAL", } _GENERIC_SEVERITY: dict[str, str] = { "CRITICAL": "CRITICAL", "ERROR": "ERROR", "WARNING": "WARN", "WARN": "WARN", "INFO": "INFO", "DEBUG": "DEBUG", } _ANOMALOUS_LABELS: frozenset[str] = frozenset( { "SECURITY_ANOMALY", "SYSTEM_FAILURE", "PERFORMANCE_ISSUE", "NETWORK_ANOMALY", "CONFIG_ERROR", "HARDWARE_ISSUE", "CRITICAL", "ERROR", } ) _DEFAULT_THRESHOLD = float(os.environ.get("TURNSTONE_ANOMALY_THRESHOLD", "0.75")) _DEFAULT_MODEL = os.environ.get("TURNSTONE_ANOMALY_MODEL", "") _DEFAULT_DEVICE = os.environ.get("TURNSTONE_ANOMALY_DEVICE", "cpu") _DEFAULT_BATCH = int(os.environ.get("TURNSTONE_ANOMALY_BATCH", "256")) # --------------------------------------------------------------------------- # ML singleton # --------------------------------------------------------------------------- _pipeline: Any | None = None def _get_pipeline(model_id: str, device: str) -> Any: global _pipeline # noqa: PLW0603 if _pipeline is None: from transformers import pipeline as hf_pipeline # type: ignore[import-untyped] _pipeline = hf_pipeline("text-classification", model=model_id, device=device) return _pipeline def reset_pipeline() -> None: """Reset the cached pipeline singleton (test helper).""" global _pipeline # noqa: PLW0603 _pipeline = None # --------------------------------------------------------------------------- # Result types # --------------------------------------------------------------------------- @dataclass class ScoringResult: scored: int = 0 detections: int = 0 skipped: bool = False error: str | None = None # --------------------------------------------------------------------------- # Internal helpers # --------------------------------------------------------------------------- def _map_label(raw_label: str, score: float) -> tuple[str, str]: """Return (normalised_label, severity) for a raw model output label.""" upper = raw_label.upper() if upper in _HYBRID_BERT_SEVERITY: return upper, _HYBRID_BERT_SEVERITY[upper] sev = _GENERIC_SEVERITY.get(upper, "WARN") return upper, sev def _fetch_unscored(conn: Any, tenant_id: str, limit: int) -> list[dict]: rows = conn.execute( q(""" SELECT id, source_id, text, timestamp_iso, severity FROM log_entries WHERE anomaly_scored_at IS NULL AND (tenant_id = ? OR tenant_id = '') ORDER BY ingest_time DESC LIMIT ? """), (tenant_id, limit), ).fetchall() return [dict(r) for r in rows] def _write_scores( conn: Any, rows: list[dict], scored_at: str, ) -> None: conn.executemany( q("UPDATE log_entries SET anomaly_score = ?, anomaly_label = ?, anomaly_scored_at = ? WHERE id = ?"), [(r["anomaly_score"], r["anomaly_label"], scored_at, r["id"]) for r in rows], ) def _insert_detections(conn: Any, rows: list[dict], tenant_id: str, detected_at: str) -> int: inserted = 0 for r in rows: try: conn.execute( q(""" INSERT INTO detections (id, tenant_id, entry_id, source_id, anomaly_label, anomaly_score, severity, text, timestamp_iso, detected_at) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?) """), ( str(uuid.uuid4()), tenant_id, r["id"], r["source_id"], r["anomaly_label"], r["anomaly_score"], r["severity"], r["text"][:2000], r.get("timestamp_iso"), detected_at, ), ) inserted += 1 except Exception: # noqa: BLE001 pass # duplicate entry_id or constraint violation — skip return inserted # --------------------------------------------------------------------------- # Public API # --------------------------------------------------------------------------- def score_unscored( db_path: Path, model_id: str = _DEFAULT_MODEL, device: str = _DEFAULT_DEVICE, batch_size: int = _DEFAULT_BATCH, threshold: float = _DEFAULT_THRESHOLD, ) -> ScoringResult: """Score all unscored log_entries in batches. Returns immediately (skipped=True) when model_id is empty — allows unconditional wiring without requiring the model to be configured. """ if not model_id: return ScoringResult(skipped=True) try: pipe = _get_pipeline(model_id, device) except Exception as exc: logger.error("Failed to load anomaly model %r: %s", model_id, exc) return ScoringResult(error=str(exc)) tenant_id = resolve_tenant_id() total_scored = 0 total_detections = 0 while True: with get_conn(db_path) as conn: batch = _fetch_unscored(conn, tenant_id, batch_size) if not batch: break texts = [r["text"][:512] for r in batch] try: predictions = pipe(texts, truncation=True, max_length=512) except Exception as exc: logger.error("Inference error on batch of %d: %s", len(batch), exc) return ScoringResult(scored=total_scored, detections=total_detections, error=str(exc)) scored_at = datetime.now(tz=timezone.utc).isoformat() scored_rows: list[dict] = [] detection_rows: list[dict] = [] for row, pred in zip(batch, predictions): label, severity = _map_label(pred["label"], pred["score"]) enriched = {**row, "anomaly_score": pred["score"], "anomaly_label": label, "severity": severity} scored_rows.append(enriched) if label in _ANOMALOUS_LABELS and pred["score"] >= threshold: detection_rows.append(enriched) for _attempt in range(4): try: with get_conn(db_path) as conn: _write_scores(conn, scored_rows, scored_at) det_count = _insert_detections(conn, detection_rows, tenant_id, scored_at) conn.commit() break except Exception as exc: if "database is locked" in str(exc).lower() and _attempt < 3: logger.warning("DB locked, retrying write in 10s (attempt %d/4)", _attempt + 1) time.sleep(10) else: raise total_scored += len(scored_rows) total_detections += det_count logger.info( "Scored %d entries, %d detections (threshold=%.2f)", len(scored_rows), det_count, threshold, ) if len(batch) < batch_size: break return ScoringResult(scored=total_scored, detections=total_detections) def list_detections( db_path: Path, limit: int = 100, unacked_only: bool = False, label: str | None = None, scorer: str | None = None, ) -> list[dict]: """Return detections ordered by detected_at DESC.""" tenant_id = resolve_tenant_id() conditions = ["(tenant_id = ? OR tenant_id = '')"] params: list[Any] = [tenant_id] if unacked_only: conditions.append("acknowledged = 0") if label: conditions.append(q("anomaly_label = ?")) params.append(label.upper()) if scorer: conditions.append(q("scorer = ?")) params.append(scorer.lower()) where = " AND ".join(conditions) with get_conn(db_path) as conn: rows = conn.execute( q(f"SELECT * FROM detections WHERE {where} ORDER BY detected_at DESC LIMIT ?"), # noqa: S608 (*params, limit), ).fetchall() return [dict(r) for r in rows] def acknowledge_detection(db_path: Path, detection_id: str, notes: str = "") -> bool: """Mark a detection as acknowledged. Returns True if a row was updated.""" tenant_id = resolve_tenant_id() acked_at = datetime.now(tz=timezone.utc).isoformat() with get_conn(db_path) as conn: cur = conn.execute( q(""" UPDATE detections SET acknowledged = 1, acknowledged_at = ?, notes = ? WHERE id = ? AND (tenant_id = ? OR tenant_id = '') """), (acked_at, notes, detection_id, tenant_id), ) conn.commit() return cur.rowcount > 0