turnstone/app/services/anomaly.py
pyr0ball 0693e1fd54 feat: anomaly scoring pipeline (#10)
- Add app/services/anomaly.py: batch scorer using HF text-classification
  pipeline; rewrites anomaly_score/anomaly_label/anomaly_scored_at on
  log_entries; inserts high-confidence hits into detections table
- Add app/tasks/anomaly_scorer.py: background task (same shape as
  glean_scheduler); triggered after each glean cycle when
  TURNSTONE_ANOMALY_MODEL is set
- DB schema: add anomaly_score/anomaly_label/anomaly_scored_at columns to
  log_entries (idempotent ALTER TABLE migration); add detections table
- Wire scorer into scheduler_loop and glean_scheduler.run_once; no-op when
  model env var is empty (safe to leave unconfigured)
- REST endpoints: GET/POST /api/anomaly/status, /api/anomaly/run,
  GET /api/anomaly/detections, POST /api/anomaly/detections/{id}/acknowledge
- Reuses Hybrid-BERT label map from diagnose/classifier.py; works with any
  HF text-classification model
- 12 new tests; 406/406 passing

Closes: #10
2026-06-09 11:15:13 -07:00

291 lines
9.4 KiB
Python

"""Anomaly scoring pipeline — batch-score log_entries with a HF classifier.
Designed to run after each glean cycle (or standalone). When no model is
configured the scorer is a no-op and returns immediately, so it is always
safe to wire into the glean pipeline.
Model: any HuggingFace text-classification model. The existing Hybrid-BERT
label map (from diagnose/classifier.py) is reused when the model produces
NORMAL/SECURITY_ANOMALY/… outputs; other models get a generic severity map.
Scoring strategy
----------------
- Query unscored rows in batches (WHERE anomaly_scored_at IS NULL)
- Run each entry text through the HF pipeline
- Write anomaly_score + anomaly_label + anomaly_scored_at back
- INSERT high-confidence hits (score >= threshold) into detections table,
skipping duplicates so the scorer is safe to re-run
"""
from __future__ import annotations
import logging
import os
import uuid
from dataclasses import dataclass
from datetime import datetime, timezone
from pathlib import Path
from typing import Any
from app.db import get_conn, resolve_tenant_id
from app.db.dialect import q
logger = logging.getLogger(__name__)
# ---------------------------------------------------------------------------
# Label maps — reuse Hybrid-BERT vocabulary from diagnose/classifier.py
# ---------------------------------------------------------------------------
_HYBRID_BERT_SEVERITY: dict[str, str] = {
"NORMAL": "INFO",
"SECURITY_ANOMALY": "ERROR",
"SYSTEM_FAILURE": "CRITICAL",
"PERFORMANCE_ISSUE": "WARN",
"NETWORK_ANOMALY": "WARN",
"CONFIG_ERROR": "ERROR",
"HARDWARE_ISSUE": "CRITICAL",
}
_GENERIC_SEVERITY: dict[str, str] = {
"CRITICAL": "CRITICAL",
"ERROR": "ERROR",
"WARNING": "WARN",
"WARN": "WARN",
"INFO": "INFO",
"DEBUG": "DEBUG",
}
_ANOMALOUS_LABELS: frozenset[str] = frozenset(
{
"SECURITY_ANOMALY",
"SYSTEM_FAILURE",
"PERFORMANCE_ISSUE",
"NETWORK_ANOMALY",
"CONFIG_ERROR",
"HARDWARE_ISSUE",
"CRITICAL",
"ERROR",
}
)
_DEFAULT_THRESHOLD = float(os.environ.get("TURNSTONE_ANOMALY_THRESHOLD", "0.75"))
_DEFAULT_MODEL = os.environ.get("TURNSTONE_ANOMALY_MODEL", "")
_DEFAULT_DEVICE = os.environ.get("TURNSTONE_ANOMALY_DEVICE", "cpu")
_DEFAULT_BATCH = int(os.environ.get("TURNSTONE_ANOMALY_BATCH", "256"))
# ---------------------------------------------------------------------------
# ML singleton
# ---------------------------------------------------------------------------
_pipeline: Any | None = None
def _get_pipeline(model_id: str, device: str) -> Any:
global _pipeline # noqa: PLW0603
if _pipeline is None:
from transformers import pipeline as hf_pipeline # type: ignore[import-untyped]
_pipeline = hf_pipeline("text-classification", model=model_id, device=device)
return _pipeline
def reset_pipeline() -> None:
"""Reset the cached pipeline singleton (test helper)."""
global _pipeline # noqa: PLW0603
_pipeline = None
# ---------------------------------------------------------------------------
# Result types
# ---------------------------------------------------------------------------
@dataclass
class ScoringResult:
scored: int = 0
detections: int = 0
skipped: bool = False
error: str | None = None
# ---------------------------------------------------------------------------
# Internal helpers
# ---------------------------------------------------------------------------
def _map_label(raw_label: str, score: float) -> tuple[str, str]:
"""Return (normalised_label, severity) for a raw model output label."""
upper = raw_label.upper()
if upper in _HYBRID_BERT_SEVERITY:
return upper, _HYBRID_BERT_SEVERITY[upper]
sev = _GENERIC_SEVERITY.get(upper, "WARN")
return upper, sev
def _fetch_unscored(conn: Any, tenant_id: str, limit: int) -> list[dict]:
rows = conn.execute(
q("""
SELECT id, source_id, text, timestamp_iso, severity
FROM log_entries
WHERE anomaly_scored_at IS NULL
AND (tenant_id = ? OR tenant_id = '')
ORDER BY ingest_time DESC
LIMIT ?
"""),
(tenant_id, limit),
).fetchall()
return [dict(r) for r in rows]
def _write_scores(
conn: Any,
rows: list[dict],
scored_at: str,
) -> None:
conn.executemany(
q("UPDATE log_entries SET anomaly_score = ?, anomaly_label = ?, anomaly_scored_at = ? WHERE id = ?"),
[(r["anomaly_score"], r["anomaly_label"], scored_at, r["id"]) for r in rows],
)
def _insert_detections(conn: Any, rows: list[dict], tenant_id: str, detected_at: str) -> int:
inserted = 0
for r in rows:
try:
conn.execute(
q("""
INSERT INTO detections
(id, tenant_id, entry_id, source_id, anomaly_label, anomaly_score,
severity, text, timestamp_iso, detected_at)
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
"""),
(
str(uuid.uuid4()),
tenant_id,
r["id"],
r["source_id"],
r["anomaly_label"],
r["anomaly_score"],
r["severity"],
r["text"][:2000],
r.get("timestamp_iso"),
detected_at,
),
)
inserted += 1
except Exception: # noqa: BLE001
pass # duplicate entry_id or constraint violation — skip
return inserted
# ---------------------------------------------------------------------------
# Public API
# ---------------------------------------------------------------------------
def score_unscored(
db_path: Path,
model_id: str = _DEFAULT_MODEL,
device: str = _DEFAULT_DEVICE,
batch_size: int = _DEFAULT_BATCH,
threshold: float = _DEFAULT_THRESHOLD,
) -> ScoringResult:
"""Score all unscored log_entries in batches.
Returns immediately (skipped=True) when model_id is empty — allows
unconditional wiring without requiring the model to be configured.
"""
if not model_id:
return ScoringResult(skipped=True)
try:
pipe = _get_pipeline(model_id, device)
except Exception as exc:
logger.error("Failed to load anomaly model %r: %s", model_id, exc)
return ScoringResult(error=str(exc))
tenant_id = resolve_tenant_id()
total_scored = 0
total_detections = 0
while True:
with get_conn(db_path) as conn:
batch = _fetch_unscored(conn, tenant_id, batch_size)
if not batch:
break
texts = [r["text"][:512] for r in batch]
try:
predictions = pipe(texts, truncation=True, max_length=512)
except Exception as exc:
logger.error("Inference error on batch of %d: %s", len(batch), exc)
return ScoringResult(scored=total_scored, detections=total_detections, error=str(exc))
scored_at = datetime.now(tz=timezone.utc).isoformat()
scored_rows: list[dict] = []
detection_rows: list[dict] = []
for row, pred in zip(batch, predictions):
label, severity = _map_label(pred["label"], pred["score"])
enriched = {**row, "anomaly_score": pred["score"], "anomaly_label": label, "severity": severity}
scored_rows.append(enriched)
if label in _ANOMALOUS_LABELS and pred["score"] >= threshold:
detection_rows.append(enriched)
with get_conn(db_path) as conn:
_write_scores(conn, scored_rows, scored_at)
det_count = _insert_detections(conn, detection_rows, tenant_id, scored_at)
conn.commit()
total_scored += len(scored_rows)
total_detections += det_count
logger.info(
"Scored %d entries, %d detections (threshold=%.2f)",
len(scored_rows), det_count, threshold,
)
if len(batch) < batch_size:
break
return ScoringResult(scored=total_scored, detections=total_detections)
def list_detections(
db_path: Path,
limit: int = 100,
unacked_only: bool = False,
label: str | None = None,
) -> list[dict]:
"""Return detections ordered by detected_at DESC."""
tenant_id = resolve_tenant_id()
conditions = ["(tenant_id = ? OR tenant_id = '')"]
params: list[Any] = [tenant_id]
if unacked_only:
conditions.append("acknowledged = 0")
if label:
conditions.append(q("anomaly_label = ?"))
params.append(label.upper())
where = " AND ".join(conditions)
with get_conn(db_path) as conn:
rows = conn.execute(
q(f"SELECT * FROM detections WHERE {where} ORDER BY detected_at DESC LIMIT ?"), # noqa: S608
(*params, limit),
).fetchall()
return [dict(r) for r in rows]
def acknowledge_detection(db_path: Path, detection_id: str, notes: str = "") -> bool:
"""Mark a detection as acknowledged. Returns True if a row was updated."""
tenant_id = resolve_tenant_id()
acked_at = datetime.now(tz=timezone.utc).isoformat()
with get_conn(db_path) as conn:
cur = conn.execute(
q("""
UPDATE detections
SET acknowledged = 1, acknowledged_at = ?, notes = ?
WHERE id = ? AND (tenant_id = ? OR tenant_id = '')
"""),
(acked_at, notes, detection_id, tenant_id),
)
conn.commit()
return cur.rowcount > 0