turnstone/app/services/anomaly.py
pyr0ball 502ff54fd0 feat(ui): security alert dedup, clickable criticals, loading shimmer
Security Alerts:
- Client-side duplicate collapsing via anomaly_label + text fingerprint
- ×N count badge chip on collapsed rows; toggle to expand
- Skeleton shimmer rows replace "Loading..." text

Dashboard:
- Clickable Recent Criticals — inline LLM explanation via SSE stream
- ±5 min time window scoped to source_id for useful context
- Explanation cache keyed by entry_id (no re-fetch on re-expand)
- Default diagnose query injected on Diagnose button navigation to
  prevent local models hallucinating from bare log data
- Stat card and source-health skeleton shimmer loading states

Backend:
- anomaly.py: 4-attempt retry on "database is locked" with 10s backoff
- search.py: migrate build_fts_index to get_conn() (WAL race fix);
  add timeline_events to stats_summary for clickable criticals feature
- theme.css: @keyframes shimmer + .loading-shimmer utility;
  prefers-reduced-motion degrades gracefully to static muted block
2026-06-13 09:32:26 -07:00

305 lines
9.9 KiB
Python

"""Anomaly scoring pipeline — batch-score log_entries with a HF classifier.
Designed to run after each glean cycle (or standalone). When no model is
configured the scorer is a no-op and returns immediately, so it is always
safe to wire into the glean pipeline.
Model: any HuggingFace text-classification model. The existing Hybrid-BERT
label map (from diagnose/classifier.py) is reused when the model produces
NORMAL/SECURITY_ANOMALY/… outputs; other models get a generic severity map.
Scoring strategy
----------------
- Query unscored rows in batches (WHERE anomaly_scored_at IS NULL)
- Run each entry text through the HF pipeline
- Write anomaly_score + anomaly_label + anomaly_scored_at back
- INSERT high-confidence hits (score >= threshold) into detections table,
skipping duplicates so the scorer is safe to re-run
"""
from __future__ import annotations
import logging
import os
import time
import uuid
from dataclasses import dataclass
from datetime import datetime, timezone
from pathlib import Path
from typing import Any
from app.db import get_conn, resolve_tenant_id
from app.db.dialect import q
logger = logging.getLogger(__name__)
# ---------------------------------------------------------------------------
# Label maps — reuse Hybrid-BERT vocabulary from diagnose/classifier.py
# ---------------------------------------------------------------------------
_HYBRID_BERT_SEVERITY: dict[str, str] = {
"NORMAL": "INFO",
"SECURITY_ANOMALY": "ERROR",
"SYSTEM_FAILURE": "CRITICAL",
"PERFORMANCE_ISSUE": "WARN",
"NETWORK_ANOMALY": "WARN",
"CONFIG_ERROR": "ERROR",
"HARDWARE_ISSUE": "CRITICAL",
}
_GENERIC_SEVERITY: dict[str, str] = {
"CRITICAL": "CRITICAL",
"ERROR": "ERROR",
"WARNING": "WARN",
"WARN": "WARN",
"INFO": "INFO",
"DEBUG": "DEBUG",
}
_ANOMALOUS_LABELS: frozenset[str] = frozenset(
{
"SECURITY_ANOMALY",
"SYSTEM_FAILURE",
"PERFORMANCE_ISSUE",
"NETWORK_ANOMALY",
"CONFIG_ERROR",
"HARDWARE_ISSUE",
"CRITICAL",
"ERROR",
}
)
_DEFAULT_THRESHOLD = float(os.environ.get("TURNSTONE_ANOMALY_THRESHOLD", "0.75"))
_DEFAULT_MODEL = os.environ.get("TURNSTONE_ANOMALY_MODEL", "")
_DEFAULT_DEVICE = os.environ.get("TURNSTONE_ANOMALY_DEVICE", "cpu")
_DEFAULT_BATCH = int(os.environ.get("TURNSTONE_ANOMALY_BATCH", "256"))
# ---------------------------------------------------------------------------
# ML singleton
# ---------------------------------------------------------------------------
_pipeline: Any | None = None
def _get_pipeline(model_id: str, device: str) -> Any:
global _pipeline # noqa: PLW0603
if _pipeline is None:
from transformers import pipeline as hf_pipeline # type: ignore[import-untyped]
_pipeline = hf_pipeline("text-classification", model=model_id, device=device)
return _pipeline
def reset_pipeline() -> None:
"""Reset the cached pipeline singleton (test helper)."""
global _pipeline # noqa: PLW0603
_pipeline = None
# ---------------------------------------------------------------------------
# Result types
# ---------------------------------------------------------------------------
@dataclass
class ScoringResult:
scored: int = 0
detections: int = 0
skipped: bool = False
error: str | None = None
# ---------------------------------------------------------------------------
# Internal helpers
# ---------------------------------------------------------------------------
def _map_label(raw_label: str, score: float) -> tuple[str, str]:
"""Return (normalised_label, severity) for a raw model output label."""
upper = raw_label.upper()
if upper in _HYBRID_BERT_SEVERITY:
return upper, _HYBRID_BERT_SEVERITY[upper]
sev = _GENERIC_SEVERITY.get(upper, "WARN")
return upper, sev
def _fetch_unscored(conn: Any, tenant_id: str, limit: int) -> list[dict]:
rows = conn.execute(
q("""
SELECT id, source_id, text, timestamp_iso, severity
FROM log_entries
WHERE anomaly_scored_at IS NULL
AND (tenant_id = ? OR tenant_id = '')
ORDER BY ingest_time DESC
LIMIT ?
"""),
(tenant_id, limit),
).fetchall()
return [dict(r) for r in rows]
def _write_scores(
conn: Any,
rows: list[dict],
scored_at: str,
) -> None:
conn.executemany(
q("UPDATE log_entries SET anomaly_score = ?, anomaly_label = ?, anomaly_scored_at = ? WHERE id = ?"),
[(r["anomaly_score"], r["anomaly_label"], scored_at, r["id"]) for r in rows],
)
def _insert_detections(conn: Any, rows: list[dict], tenant_id: str, detected_at: str) -> int:
inserted = 0
for r in rows:
try:
conn.execute(
q("""
INSERT INTO detections
(id, tenant_id, entry_id, source_id, anomaly_label, anomaly_score,
severity, text, timestamp_iso, detected_at)
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
"""),
(
str(uuid.uuid4()),
tenant_id,
r["id"],
r["source_id"],
r["anomaly_label"],
r["anomaly_score"],
r["severity"],
r["text"][:2000],
r.get("timestamp_iso"),
detected_at,
),
)
inserted += 1
except Exception: # noqa: BLE001
pass # duplicate entry_id or constraint violation — skip
return inserted
# ---------------------------------------------------------------------------
# Public API
# ---------------------------------------------------------------------------
def score_unscored(
db_path: Path,
model_id: str = _DEFAULT_MODEL,
device: str = _DEFAULT_DEVICE,
batch_size: int = _DEFAULT_BATCH,
threshold: float = _DEFAULT_THRESHOLD,
) -> ScoringResult:
"""Score all unscored log_entries in batches.
Returns immediately (skipped=True) when model_id is empty — allows
unconditional wiring without requiring the model to be configured.
"""
if not model_id:
return ScoringResult(skipped=True)
try:
pipe = _get_pipeline(model_id, device)
except Exception as exc:
logger.error("Failed to load anomaly model %r: %s", model_id, exc)
return ScoringResult(error=str(exc))
tenant_id = resolve_tenant_id()
total_scored = 0
total_detections = 0
while True:
with get_conn(db_path) as conn:
batch = _fetch_unscored(conn, tenant_id, batch_size)
if not batch:
break
texts = [r["text"][:512] for r in batch]
try:
predictions = pipe(texts, truncation=True, max_length=512)
except Exception as exc:
logger.error("Inference error on batch of %d: %s", len(batch), exc)
return ScoringResult(scored=total_scored, detections=total_detections, error=str(exc))
scored_at = datetime.now(tz=timezone.utc).isoformat()
scored_rows: list[dict] = []
detection_rows: list[dict] = []
for row, pred in zip(batch, predictions):
label, severity = _map_label(pred["label"], pred["score"])
enriched = {**row, "anomaly_score": pred["score"], "anomaly_label": label, "severity": severity}
scored_rows.append(enriched)
if label in _ANOMALOUS_LABELS and pred["score"] >= threshold:
detection_rows.append(enriched)
for _attempt in range(4):
try:
with get_conn(db_path) as conn:
_write_scores(conn, scored_rows, scored_at)
det_count = _insert_detections(conn, detection_rows, tenant_id, scored_at)
conn.commit()
break
except Exception as exc:
if "database is locked" in str(exc).lower() and _attempt < 3:
logger.warning("DB locked, retrying write in 10s (attempt %d/4)", _attempt + 1)
time.sleep(10)
else:
raise
total_scored += len(scored_rows)
total_detections += det_count
logger.info(
"Scored %d entries, %d detections (threshold=%.2f)",
len(scored_rows), det_count, threshold,
)
if len(batch) < batch_size:
break
return ScoringResult(scored=total_scored, detections=total_detections)
def list_detections(
db_path: Path,
limit: int = 100,
unacked_only: bool = False,
label: str | None = None,
scorer: str | None = None,
) -> list[dict]:
"""Return detections ordered by detected_at DESC."""
tenant_id = resolve_tenant_id()
conditions = ["(tenant_id = ? OR tenant_id = '')"]
params: list[Any] = [tenant_id]
if unacked_only:
conditions.append("acknowledged = 0")
if label:
conditions.append(q("anomaly_label = ?"))
params.append(label.upper())
if scorer:
conditions.append(q("scorer = ?"))
params.append(scorer.lower())
where = " AND ".join(conditions)
with get_conn(db_path) as conn:
rows = conn.execute(
q(f"SELECT * FROM detections WHERE {where} ORDER BY detected_at DESC LIMIT ?"), # noqa: S608
(*params, limit),
).fetchall()
return [dict(r) for r in rows]
def acknowledge_detection(db_path: Path, detection_id: str, notes: str = "") -> bool:
"""Mark a detection as acknowledged. Returns True if a row was updated."""
tenant_id = resolve_tenant_id()
acked_at = datetime.now(tz=timezone.utc).isoformat()
with get_conn(db_path) as conn:
cur = conn.execute(
q("""
UPDATE detections
SET acknowledged = 1, acknowledged_at = ?, notes = ?
WHERE id = ? AND (tenant_id = ? OR tenant_id = '')
"""),
(acked_at, notes, detection_id, tenant_id),
)
conn.commit()
return cur.rowcount > 0