turnstone/app/services/cybersec.py
pyr0ball db2e4f85e7 fix(cybersec): clean up debug traceback logging
Replaced manual traceback import with exc_info=True, which is the
idiomatic logging pattern and produces the same output.
2026-06-10 13:20:56 -07:00

241 lines
8.9 KiB
Python

"""Cybersecurity-focused scoring pipeline using zero-shot classification.
Runs a second-pass analysis on entries that were already flagged by the
anomaly scorer or that have pattern matches. Uses a zero-shot classification
model (DeBERTa-v3-base-mnli is cached locally) so no fine-tuning is needed.
The scorer writes ml_score / ml_label / ml_scored_at to log_entries and
inserts high-confidence non-normal hits into the detections table tagged
with scorer='cybersec'.
Env vars
--------
TURNSTONE_CYBERSEC_MODEL — HF model id for zero-shot classification.
Recommended: MoritzLaurer/DeBERTa-v3-base-mnli-fever-anli
(already cached from the diagnose pipeline).
Set to empty string to disable (safe default).
TURNSTONE_CYBERSEC_DEVICE — 'cpu' (default) or 'cuda'
TURNSTONE_CYBERSEC_THRESHOLD — float confidence floor for detection insertion (default 0.60)
"""
from __future__ import annotations
import logging
import uuid
from dataclasses import dataclass, field
from datetime import datetime, timezone
from pathlib import Path
from typing import Any
from app.db import get_conn, resolve_tenant_id
from app.db.dialect import q
logger = logging.getLogger(__name__)
# ---------------------------------------------------------------------------
# Candidate labels — cybersec vocabulary for zero-shot inference
# ---------------------------------------------------------------------------
CYBERSEC_LABELS: list[str] = [
"authentication failure or brute force attack",
"privilege escalation or unauthorized access",
"network intrusion or port scan",
"malware or suspicious process activity",
"data exfiltration or unusual outbound traffic",
"normal system operation",
]
_NORMAL_LABEL = "normal system operation"
_LABEL_SEVERITY: dict[str, str] = {
"authentication failure or brute force attack": "ERROR",
"privilege escalation or unauthorized access": "CRITICAL",
"network intrusion or port scan": "ERROR",
"malware or suspicious process activity": "CRITICAL",
"data exfiltration or unusual outbound traffic":"CRITICAL",
"normal system operation": "INFO",
}
# ---------------------------------------------------------------------------
# Pipeline singleton
# ---------------------------------------------------------------------------
_pipeline: Any = None
def _get_pipeline(model_id: str, device: str) -> Any:
global _pipeline # noqa: PLW0603
if _pipeline is None:
from transformers import pipeline # type: ignore[import-untyped]
logger.info("loading cybersec zero-shot pipeline: %s on %s", model_id, device)
_pipeline = pipeline(
"zero-shot-classification",
model=model_id,
device=0 if device == "cuda" else -1,
)
logger.info("cybersec pipeline ready")
return _pipeline
def reset_pipeline() -> None:
"""Clear the cached pipeline — for testing only."""
global _pipeline # noqa: PLW0603
_pipeline = None
# ---------------------------------------------------------------------------
# Result type
# ---------------------------------------------------------------------------
@dataclass
class CybersecResult:
scored: int = 0
detections: int = 0
skipped: bool = False
error: str | None = None
# ---------------------------------------------------------------------------
# Core scoring function
# ---------------------------------------------------------------------------
def score_security_entries(
db_path: Path,
model_id: str,
device: str = "cpu",
batch_size: int = 32,
threshold: float = 0.60,
) -> CybersecResult:
"""Score entries that were anomaly-flagged or pattern-matched.
Only entries with ml_scored_at IS NULL are processed (idempotent).
Writes ml_score / ml_label / ml_scored_at and inserts high-confidence
hits into detections with scorer='cybersec'.
"""
if not model_id:
return CybersecResult(skipped=True)
tenant_id = resolve_tenant_id()
try:
pipe = _get_pipeline(model_id, device)
except Exception as exc:
logger.error("failed to load cybersec pipeline: %s", exc)
return CybersecResult(error=str(exc))
total_scored = 0
total_detections = 0
try:
with get_conn(db_path) as conn:
# Only score entries that are worth a second look:
# anomaly-flagged (non-normal) OR have at least one pattern match.
rows = conn.execute(
q("""
SELECT id, source_id, text, timestamp_iso
FROM log_entries
WHERE (tenant_id = ? OR tenant_id = '')
AND ml_scored_at IS NULL
AND (
(anomaly_label IS NOT NULL AND anomaly_label != 'NORMAL')
OR (matched_patterns IS NOT NULL AND matched_patterns != '[]' AND matched_patterns != '')
)
LIMIT ?
"""),
(tenant_id, batch_size * 10),
).fetchall()
if not rows:
return CybersecResult(skipped=True)
# Process in chunks to avoid OOM on large backlogs
for i in range(0, len(rows), batch_size):
chunk = rows[i : i + batch_size]
texts = [r["text"] for r in chunk]
try:
results = pipe(texts, candidate_labels=CYBERSEC_LABELS, multi_label=False)
except Exception as exc:
logger.warning("zero-shot inference error on chunk %d: %s", i, exc)
continue
now = datetime.now(tz=timezone.utc).isoformat()
with get_conn(db_path) as conn:
for row, result in zip(chunk, results):
top_label: str = result["labels"][0]
top_score: float = result["scores"][0]
conn.execute(
q("""
UPDATE log_entries
SET ml_score = ?, ml_label = ?, ml_scored_at = ?
WHERE id = ? AND (tenant_id = ? OR tenant_id = '')
"""),
(top_score, top_label, now, row["id"], tenant_id),
)
total_scored += 1
if top_score >= threshold and top_label != _NORMAL_LABEL:
severity = _LABEL_SEVERITY.get(top_label, "WARN")
try:
conn.execute(
q("""
INSERT INTO detections
(id, tenant_id, entry_id, source_id, anomaly_label,
anomaly_score, severity, text, timestamp_iso,
detected_at, scorer)
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, 'cybersec')
"""),
(
str(uuid.uuid4()),
tenant_id,
row["id"],
row["source_id"],
top_label,
top_score,
severity,
row["text"],
row["timestamp_iso"],
now,
),
)
total_detections += 1
except Exception:
pass # entry may already have a detection — skip
conn.commit()
except Exception as exc:
logger.error("cybersec scoring failed: %s", exc, exc_info=True)
return CybersecResult(scored=total_scored, detections=total_detections, error=str(exc))
return CybersecResult(scored=total_scored, detections=total_detections)
# ---------------------------------------------------------------------------
# Query helpers (used by REST layer)
# ---------------------------------------------------------------------------
def list_cybersec_detections(
db_path: Path,
limit: int = 100,
unacked_only: bool = False,
label: str | None = None,
) -> list[dict]:
"""Return cybersec detections ordered by detected_at DESC."""
tenant_id = resolve_tenant_id()
conditions = ["(tenant_id = ? OR tenant_id = '')", "scorer = 'cybersec'"]
params: list[Any] = [tenant_id]
if unacked_only:
conditions.append("acknowledged = 0")
if label:
conditions.append(q("anomaly_label = ?"))
params.append(label)
where = " AND ".join(conditions)
with get_conn(db_path) as conn:
rows = conn.execute(
q(f"SELECT * FROM detections WHERE {where} ORDER BY detected_at DESC LIMIT ?"), # noqa: S608
(*params, limit),
).fetchall()
return [dict(r) for r in rows]