"""Cybersecurity-focused scoring pipeline using zero-shot classification. Runs a second-pass analysis on entries that were already flagged by the anomaly scorer or that have pattern matches. Uses a zero-shot classification model (DeBERTa-v3-base-mnli is cached locally) so no fine-tuning is needed. The scorer writes ml_score / ml_label / ml_scored_at to log_entries and inserts high-confidence non-normal hits into the detections table tagged with scorer='cybersec'. Env vars -------- TURNSTONE_CYBERSEC_MODEL — HF model id for zero-shot classification. Recommended: MoritzLaurer/DeBERTa-v3-base-mnli-fever-anli (already cached from the diagnose pipeline). Set to empty string to disable (safe default). TURNSTONE_CYBERSEC_DEVICE — 'cpu' (default) or 'cuda' TURNSTONE_CYBERSEC_THRESHOLD — float confidence floor for detection insertion (default 0.60) """ from __future__ import annotations import logging import uuid from dataclasses import dataclass, field from datetime import datetime, timezone from pathlib import Path from typing import Any from app.db import get_conn, resolve_tenant_id from app.db.dialect import q logger = logging.getLogger(__name__) # --------------------------------------------------------------------------- # Candidate labels — cybersec vocabulary for zero-shot inference # --------------------------------------------------------------------------- CYBERSEC_LABELS: list[str] = [ "authentication failure or brute force attack", "privilege escalation or unauthorized access", "network intrusion or port scan", "malware or suspicious process activity", "data exfiltration or unusual outbound traffic", "normal system operation", ] _NORMAL_LABEL = "normal system operation" _LABEL_SEVERITY: dict[str, str] = { "authentication failure or brute force attack": "ERROR", "privilege escalation or unauthorized access": "CRITICAL", "network intrusion or port scan": "ERROR", "malware or suspicious process activity": "CRITICAL", "data exfiltration or unusual outbound traffic":"CRITICAL", "normal system operation": "INFO", } # --------------------------------------------------------------------------- # Pipeline singleton # --------------------------------------------------------------------------- _pipeline: Any = None def _get_pipeline(model_id: str, device: str) -> Any: global _pipeline # noqa: PLW0603 if _pipeline is None: from transformers import pipeline # type: ignore[import-untyped] logger.info("loading cybersec zero-shot pipeline: %s on %s", model_id, device) _pipeline = pipeline( "zero-shot-classification", model=model_id, device=0 if device == "cuda" else -1, ) logger.info("cybersec pipeline ready") return _pipeline def reset_pipeline() -> None: """Clear the cached pipeline — for testing only.""" global _pipeline # noqa: PLW0603 _pipeline = None # --------------------------------------------------------------------------- # Result type # --------------------------------------------------------------------------- @dataclass class CybersecResult: scored: int = 0 detections: int = 0 skipped: bool = False error: str | None = None # --------------------------------------------------------------------------- # Core scoring function # --------------------------------------------------------------------------- def score_security_entries( db_path: Path, model_id: str, device: str = "cpu", batch_size: int = 32, threshold: float = 0.60, ) -> CybersecResult: """Score entries that were anomaly-flagged or pattern-matched. Only entries with ml_scored_at IS NULL are processed (idempotent). Writes ml_score / ml_label / ml_scored_at and inserts high-confidence hits into detections with scorer='cybersec'. """ if not model_id: return CybersecResult(skipped=True) tenant_id = resolve_tenant_id() try: pipe = _get_pipeline(model_id, device) except Exception as exc: logger.error("failed to load cybersec pipeline: %s", exc) return CybersecResult(error=str(exc)) total_scored = 0 total_detections = 0 try: with get_conn(db_path) as conn: # Only score entries that are worth a second look: # anomaly-flagged (non-normal) OR have at least one pattern match. rows = conn.execute( q(""" SELECT id, source_id, text, timestamp_iso FROM log_entries WHERE (tenant_id = ? OR tenant_id = '') AND ml_scored_at IS NULL AND ( (anomaly_label IS NOT NULL AND anomaly_label != 'NORMAL') OR (matched_patterns IS NOT NULL AND matched_patterns != '[]' AND matched_patterns != '') ) LIMIT ? """), (tenant_id, batch_size * 10), ).fetchall() if not rows: return CybersecResult(skipped=True) # Process in chunks to avoid OOM on large backlogs for i in range(0, len(rows), batch_size): chunk = rows[i : i + batch_size] texts = [r["text"] for r in chunk] try: results = pipe(texts, candidate_labels=CYBERSEC_LABELS, multi_label=False) except Exception as exc: logger.warning("zero-shot inference error on chunk %d: %s", i, exc) continue now = datetime.now(tz=timezone.utc).isoformat() with get_conn(db_path) as conn: for row, result in zip(chunk, results): top_label: str = result["labels"][0] top_score: float = result["scores"][0] conn.execute( q(""" UPDATE log_entries SET ml_score = ?, ml_label = ?, ml_scored_at = ? WHERE id = ? AND (tenant_id = ? OR tenant_id = '') """), (top_score, top_label, now, row["id"], tenant_id), ) total_scored += 1 if top_score >= threshold and top_label != _NORMAL_LABEL: severity = _LABEL_SEVERITY.get(top_label, "WARN") try: conn.execute( q(""" INSERT INTO detections (id, tenant_id, entry_id, source_id, anomaly_label, anomaly_score, severity, text, timestamp_iso, detected_at, scorer) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, 'cybersec') """), ( str(uuid.uuid4()), tenant_id, row["id"], row["source_id"], top_label, top_score, severity, row["text"], row["timestamp_iso"], now, ), ) total_detections += 1 except Exception: pass # entry may already have a detection — skip conn.commit() except Exception as exc: logger.error("cybersec scoring failed: %s", exc, exc_info=True) return CybersecResult(scored=total_scored, detections=total_detections, error=str(exc)) return CybersecResult(scored=total_scored, detections=total_detections) # --------------------------------------------------------------------------- # Query helpers (used by REST layer) # --------------------------------------------------------------------------- def list_cybersec_detections( db_path: Path, limit: int = 100, unacked_only: bool = False, label: str | None = None, ) -> list[dict]: """Return cybersec detections ordered by detected_at DESC.""" tenant_id = resolve_tenant_id() conditions = ["(tenant_id = ? OR tenant_id = '')", "scorer = 'cybersec'"] params: list[Any] = [tenant_id] if unacked_only: conditions.append("acknowledged = 0") if label: conditions.append(q("anomaly_label = ?")) params.append(label) where = " AND ".join(conditions) with get_conn(db_path) as conn: rows = conn.execute( q(f"SELECT * FROM detections WHERE {where} ORDER BY detected_at DESC LIMIT ?"), # noqa: S608 (*params, limit), ).fetchall() return [dict(r) for r in rows]