Second-pass cybersec classifier using DeBERTa-v3-base-mnli (already cached — no download required). Runs after each anomaly scoring pass on entries flagged by the anomaly scorer or with pattern matches. Architecture: - app/services/cybersec.py: zero-shot-classification pipeline with 5 cybersec candidate labels (auth failure, privilege escalation, network intrusion, malware, data exfiltration). Writes ml_score/ml_label/ ml_scored_at to log_entries; inserts high-confidence hits into detections with scorer='cybersec'. - app/tasks/cybersec_scorer.py: async background task (same shape as anomaly_scorer.py). - REST: GET/POST /turnstone/api/cybersec/status|run|detections. GET /turnstone/api/anomaly/detections now accepts scorer= filter. Schema: ml_score, ml_label, ml_scored_at added to log_entries; scorer column added to detections (idempotent migrations + DDL for both SQLite and Postgres). UI: Security Alerts view gains Source dropdown (All / Anomaly / Cybersec) and cybersec scorer status badge. Label dropdown split into optgroups. Deployment: TURNSTONE_CYBERSEC_MODEL/DEVICE/THRESHOLD vars added to .env.example, docker-compose.yml, docker-standalone.sh. Tests: 10 new tests — no model, no eligible entries, scoring, detection creation, normal label suppression, threshold filtering, pattern-tag filtering, idempotency, list filtering, scorer column filter. 416/416 passing. Closes: #9
241 lines
8.9 KiB
Python
241 lines
8.9 KiB
Python
"""Cybersecurity-focused scoring pipeline using zero-shot classification.
|
|
|
|
Runs a second-pass analysis on entries that were already flagged by the
|
|
anomaly scorer or that have pattern matches. Uses a zero-shot classification
|
|
model (DeBERTa-v3-base-mnli is cached locally) so no fine-tuning is needed.
|
|
|
|
The scorer writes ml_score / ml_label / ml_scored_at to log_entries and
|
|
inserts high-confidence non-normal hits into the detections table tagged
|
|
with scorer='cybersec'.
|
|
|
|
Env vars
|
|
--------
|
|
TURNSTONE_CYBERSEC_MODEL — HF model id for zero-shot classification.
|
|
Recommended: MoritzLaurer/DeBERTa-v3-base-mnli-fever-anli
|
|
(already cached from the diagnose pipeline).
|
|
Set to empty string to disable (safe default).
|
|
TURNSTONE_CYBERSEC_DEVICE — 'cpu' (default) or 'cuda'
|
|
TURNSTONE_CYBERSEC_THRESHOLD — float confidence floor for detection insertion (default 0.60)
|
|
"""
|
|
from __future__ import annotations
|
|
|
|
import logging
|
|
import uuid
|
|
from dataclasses import dataclass, field
|
|
from datetime import datetime, timezone
|
|
from pathlib import Path
|
|
from typing import Any
|
|
|
|
from app.db import get_conn, resolve_tenant_id
|
|
from app.db.dialect import q
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Candidate labels — cybersec vocabulary for zero-shot inference
|
|
# ---------------------------------------------------------------------------
|
|
|
|
CYBERSEC_LABELS: list[str] = [
|
|
"authentication failure or brute force attack",
|
|
"privilege escalation or unauthorized access",
|
|
"network intrusion or port scan",
|
|
"malware or suspicious process activity",
|
|
"data exfiltration or unusual outbound traffic",
|
|
"normal system operation",
|
|
]
|
|
|
|
_NORMAL_LABEL = "normal system operation"
|
|
|
|
_LABEL_SEVERITY: dict[str, str] = {
|
|
"authentication failure or brute force attack": "ERROR",
|
|
"privilege escalation or unauthorized access": "CRITICAL",
|
|
"network intrusion or port scan": "ERROR",
|
|
"malware or suspicious process activity": "CRITICAL",
|
|
"data exfiltration or unusual outbound traffic":"CRITICAL",
|
|
"normal system operation": "INFO",
|
|
}
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Pipeline singleton
|
|
# ---------------------------------------------------------------------------
|
|
|
|
_pipeline: Any = None
|
|
|
|
|
|
def _get_pipeline(model_id: str, device: str) -> Any:
|
|
global _pipeline # noqa: PLW0603
|
|
if _pipeline is None:
|
|
from transformers import pipeline # type: ignore[import-untyped]
|
|
logger.info("loading cybersec zero-shot pipeline: %s on %s", model_id, device)
|
|
_pipeline = pipeline(
|
|
"zero-shot-classification",
|
|
model=model_id,
|
|
device=0 if device == "cuda" else -1,
|
|
)
|
|
logger.info("cybersec pipeline ready")
|
|
return _pipeline
|
|
|
|
|
|
def reset_pipeline() -> None:
|
|
"""Clear the cached pipeline — for testing only."""
|
|
global _pipeline # noqa: PLW0603
|
|
_pipeline = None
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Result type
|
|
# ---------------------------------------------------------------------------
|
|
|
|
@dataclass
|
|
class CybersecResult:
|
|
scored: int = 0
|
|
detections: int = 0
|
|
skipped: bool = False
|
|
error: str | None = None
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Core scoring function
|
|
# ---------------------------------------------------------------------------
|
|
|
|
def score_security_entries(
|
|
db_path: Path,
|
|
model_id: str,
|
|
device: str = "cpu",
|
|
batch_size: int = 32,
|
|
threshold: float = 0.60,
|
|
) -> CybersecResult:
|
|
"""Score entries that were anomaly-flagged or pattern-matched.
|
|
|
|
Only entries with ml_scored_at IS NULL are processed (idempotent).
|
|
Writes ml_score / ml_label / ml_scored_at and inserts high-confidence
|
|
hits into detections with scorer='cybersec'.
|
|
"""
|
|
if not model_id:
|
|
return CybersecResult(skipped=True)
|
|
|
|
tenant_id = resolve_tenant_id()
|
|
try:
|
|
pipe = _get_pipeline(model_id, device)
|
|
except Exception as exc:
|
|
logger.error("failed to load cybersec pipeline: %s", exc)
|
|
return CybersecResult(error=str(exc))
|
|
|
|
total_scored = 0
|
|
total_detections = 0
|
|
|
|
try:
|
|
with get_conn(db_path) as conn:
|
|
# Only score entries that are worth a second look:
|
|
# anomaly-flagged (non-normal) OR have at least one pattern match.
|
|
rows = conn.execute(
|
|
q("""
|
|
SELECT id, source_id, text, timestamp_iso
|
|
FROM log_entries
|
|
WHERE (tenant_id = ? OR tenant_id = '')
|
|
AND ml_scored_at IS NULL
|
|
AND (
|
|
(anomaly_label IS NOT NULL AND anomaly_label != 'NORMAL')
|
|
OR (matched_patterns IS NOT NULL AND matched_patterns != '[]' AND matched_patterns != '')
|
|
)
|
|
LIMIT ?
|
|
"""),
|
|
(tenant_id, batch_size * 10),
|
|
).fetchall()
|
|
|
|
if not rows:
|
|
return CybersecResult(skipped=True)
|
|
|
|
# Process in chunks to avoid OOM on large backlogs
|
|
for i in range(0, len(rows), batch_size):
|
|
chunk = rows[i : i + batch_size]
|
|
texts = [r["text"] for r in chunk]
|
|
|
|
try:
|
|
results = pipe(texts, candidate_labels=CYBERSEC_LABELS, multi_label=False)
|
|
except Exception as exc:
|
|
logger.warning("zero-shot inference error on chunk %d: %s", i, exc)
|
|
continue
|
|
|
|
now = datetime.now(tz=timezone.utc).isoformat()
|
|
|
|
with get_conn(db_path) as conn:
|
|
for row, result in zip(chunk, results):
|
|
top_label: str = result["labels"][0]
|
|
top_score: float = result["scores"][0]
|
|
|
|
conn.execute(
|
|
q("""
|
|
UPDATE log_entries
|
|
SET ml_score = ?, ml_label = ?, ml_scored_at = ?
|
|
WHERE id = ? AND (tenant_id = ? OR tenant_id = '')
|
|
"""),
|
|
(top_score, top_label, now, row["id"], tenant_id),
|
|
)
|
|
total_scored += 1
|
|
|
|
if top_score >= threshold and top_label != _NORMAL_LABEL:
|
|
severity = _LABEL_SEVERITY.get(top_label, "WARN")
|
|
try:
|
|
conn.execute(
|
|
q("""
|
|
INSERT INTO detections
|
|
(id, tenant_id, entry_id, source_id, anomaly_label,
|
|
anomaly_score, severity, text, timestamp_iso,
|
|
detected_at, scorer)
|
|
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, 'cybersec')
|
|
"""),
|
|
(
|
|
str(uuid.uuid4()),
|
|
tenant_id,
|
|
row["id"],
|
|
row["source_id"],
|
|
top_label,
|
|
top_score,
|
|
severity,
|
|
row["text"],
|
|
row["timestamp_iso"],
|
|
now,
|
|
),
|
|
)
|
|
total_detections += 1
|
|
except Exception:
|
|
pass # entry may already have a detection — skip
|
|
|
|
conn.commit()
|
|
|
|
except Exception as exc:
|
|
logger.error("cybersec scoring failed: %s", exc)
|
|
return CybersecResult(scored=total_scored, detections=total_detections, error=str(exc))
|
|
|
|
return CybersecResult(scored=total_scored, detections=total_detections)
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Query helpers (used by REST layer)
|
|
# ---------------------------------------------------------------------------
|
|
|
|
def list_cybersec_detections(
|
|
db_path: Path,
|
|
limit: int = 100,
|
|
unacked_only: bool = False,
|
|
label: str | None = None,
|
|
) -> list[dict]:
|
|
"""Return cybersec detections ordered by detected_at DESC."""
|
|
tenant_id = resolve_tenant_id()
|
|
conditions = ["(tenant_id = ? OR tenant_id = '')", "scorer = 'cybersec'"]
|
|
params: list[Any] = [tenant_id]
|
|
|
|
if unacked_only:
|
|
conditions.append("acknowledged = 0")
|
|
if label:
|
|
conditions.append(q("anomaly_label = ?"))
|
|
params.append(label)
|
|
|
|
where = " AND ".join(conditions)
|
|
with get_conn(db_path) as conn:
|
|
rows = conn.execute(
|
|
q(f"SELECT * FROM detections WHERE {where} ORDER BY detected_at DESC LIMIT ?"), # noqa: S608
|
|
(*params, limit),
|
|
).fetchall()
|
|
return [dict(r) for r in rows]
|