turnstone/tests/test_cybersec.py
pyr0ball b2e2f15d55 feat: cybersec zero-shot scoring pipeline (#9)
Second-pass cybersec classifier using DeBERTa-v3-base-mnli (already
cached — no download required). Runs after each anomaly scoring pass on
entries flagged by the anomaly scorer or with pattern matches.

Architecture:
- app/services/cybersec.py: zero-shot-classification pipeline with 5
  cybersec candidate labels (auth failure, privilege escalation, network
  intrusion, malware, data exfiltration). Writes ml_score/ml_label/
  ml_scored_at to log_entries; inserts high-confidence hits into
  detections with scorer='cybersec'.
- app/tasks/cybersec_scorer.py: async background task (same shape as
  anomaly_scorer.py).
- REST: GET/POST /turnstone/api/cybersec/status|run|detections.
  GET /turnstone/api/anomaly/detections now accepts scorer= filter.

Schema: ml_score, ml_label, ml_scored_at added to log_entries; scorer
column added to detections (idempotent migrations + DDL for both SQLite
and Postgres).

UI: Security Alerts view gains Source dropdown (All / Anomaly / Cybersec)
and cybersec scorer status badge. Label dropdown split into optgroups.

Deployment: TURNSTONE_CYBERSEC_MODEL/DEVICE/THRESHOLD vars added to
.env.example, docker-compose.yml, docker-standalone.sh.

Tests: 10 new tests — no model, no eligible entries, scoring, detection
creation, normal label suppression, threshold filtering, pattern-tag
filtering, idempotency, list filtering, scorer column filter.
416/416 passing.

Closes: #9
2026-06-10 01:03:25 -07:00

233 lines
8.6 KiB
Python

"""Tests for the cybersec zero-shot scoring pipeline."""
from __future__ import annotations
import sqlite3
import tempfile
from pathlib import Path
from unittest.mock import MagicMock
import pytest
from app.db.schema import ensure_schema
from app.services.cybersec import (
CybersecResult,
CYBERSEC_LABELS,
_NORMAL_LABEL,
reset_pipeline,
score_security_entries,
list_cybersec_detections,
)
import app.services.cybersec as cybersec_mod
@pytest.fixture(autouse=True)
def _reset(tmp_path):
reset_pipeline()
yield
reset_pipeline()
@pytest.fixture
def db(tmp_path) -> Path:
path = tmp_path / "test.db"
ensure_schema(path)
return path
def _insert_entry(db: Path, entry_id: str, text: str,
anomaly_label: str | None = None,
matched_patterns: str = "[]") -> None:
with sqlite3.connect(db) as conn:
conn.execute(
"""INSERT OR IGNORE INTO log_entries
(id, tenant_id, source_id, sequence, ingest_time, text,
anomaly_label, matched_patterns)
VALUES (?, '', 'test-src', 1, '2026-01-01T00:00:00Z', ?, ?, ?)""",
(entry_id, text, anomaly_label, matched_patterns),
)
conn.commit()
# ---------------------------------------------------------------------------
# No model configured → skipped
# ---------------------------------------------------------------------------
def test_no_model_returns_skipped(db):
result = score_security_entries(db, model_id="")
assert result.skipped is True
assert result.scored == 0
# ---------------------------------------------------------------------------
# No eligible entries → skipped
# ---------------------------------------------------------------------------
def test_no_eligible_entries_skipped(db):
_insert_entry(db, "e1", "Started nginx.service", anomaly_label=None, matched_patterns="[]")
mock_pipe = MagicMock(return_value=[{"labels": [_NORMAL_LABEL], "scores": [0.99]}])
monkeypatch = pytest.MonkeyPatch()
monkeypatch.setattr(cybersec_mod, "_pipeline", mock_pipe)
result = score_security_entries(db, model_id="fake-model")
assert result.skipped is True
monkeypatch.undo()
# ---------------------------------------------------------------------------
# Security entry gets scored
# ---------------------------------------------------------------------------
def test_security_entry_scored(db, monkeypatch):
_insert_entry(db, "e1",
"Failed password for root from 192.168.1.1 port 22 ssh2",
anomaly_label="SECURITY_ANOMALY")
mock_pipe = MagicMock(return_value=[{
"labels": ["authentication failure or brute force attack", _NORMAL_LABEL],
"scores": [0.85, 0.15],
}])
monkeypatch.setattr(cybersec_mod, "_pipeline", mock_pipe)
result = score_security_entries(db, model_id="fake-model", threshold=0.70)
assert result.scored == 1
assert result.detections == 1
assert result.error is None
with sqlite3.connect(db) as conn:
conn.row_factory = sqlite3.Row
row = conn.execute("SELECT ml_score, ml_label, ml_scored_at FROM log_entries WHERE id='e1'").fetchone()
assert row["ml_score"] == pytest.approx(0.85)
assert row["ml_label"] == "authentication failure or brute force attack"
assert row["ml_scored_at"] is not None
# ---------------------------------------------------------------------------
# Detection created above threshold
# ---------------------------------------------------------------------------
def test_detection_inserted_above_threshold(db, monkeypatch):
_insert_entry(db, "e1", "sudo: authentication failure", anomaly_label="ERROR")
monkeypatch.setattr(cybersec_mod, "_pipeline", MagicMock(return_value=[{
"labels": ["privilege escalation or unauthorized access", _NORMAL_LABEL],
"scores": [0.75, 0.25],
}]))
score_security_entries(db, model_id="fake-model", threshold=0.60)
with sqlite3.connect(db) as conn:
conn.row_factory = sqlite3.Row
dets = conn.execute("SELECT * FROM detections WHERE scorer='cybersec'").fetchall()
assert len(dets) == 1
assert dets[0]["anomaly_label"] == "privilege escalation or unauthorized access"
assert dets[0]["severity"] == "CRITICAL"
# ---------------------------------------------------------------------------
# Normal label → no detection even above score threshold
# ---------------------------------------------------------------------------
def test_normal_label_no_detection(db, monkeypatch):
_insert_entry(db, "e1", "Started nginx.service", anomaly_label="INFO",
matched_patterns='["service_start"]')
monkeypatch.setattr(cybersec_mod, "_pipeline", MagicMock(return_value=[{
"labels": [_NORMAL_LABEL, "network intrusion or port scan"],
"scores": [0.95, 0.05],
}]))
result = score_security_entries(db, model_id="fake-model", threshold=0.60)
assert result.detections == 0
# ---------------------------------------------------------------------------
# Below threshold → scored but no detection
# ---------------------------------------------------------------------------
def test_below_threshold_no_detection(db, monkeypatch):
_insert_entry(db, "e1", "Some suspicious text", anomaly_label="WARN")
monkeypatch.setattr(cybersec_mod, "_pipeline", MagicMock(return_value=[{
"labels": ["network intrusion or port scan", _NORMAL_LABEL],
"scores": [0.45, 0.55],
}]))
result = score_security_entries(db, model_id="fake-model", threshold=0.60)
assert result.scored == 1
assert result.detections == 0
# ---------------------------------------------------------------------------
# Pattern-matched entry (not anomaly-flagged) still gets scored
# ---------------------------------------------------------------------------
def test_pattern_matched_entry_scored(db, monkeypatch):
_insert_entry(db, "e1", "SSH port forwarding conflict detected",
anomaly_label=None,
matched_patterns='["ssh_forward_conflict"]')
monkeypatch.setattr(cybersec_mod, "_pipeline", MagicMock(return_value=[{
"labels": ["network intrusion or port scan", _NORMAL_LABEL],
"scores": [0.70, 0.30],
}]))
result = score_security_entries(db, model_id="fake-model", threshold=0.60)
assert result.scored == 1
assert result.detections == 1
# ---------------------------------------------------------------------------
# Idempotency — re-run finds nothing unscored
# ---------------------------------------------------------------------------
def test_idempotent_rerun(db, monkeypatch):
_insert_entry(db, "e1", "Failed login", anomaly_label="ERROR")
monkeypatch.setattr(cybersec_mod, "_pipeline", MagicMock(return_value=[{
"labels": ["authentication failure or brute force attack"],
"scores": [0.80],
}]))
score_security_entries(db, model_id="fake-model", threshold=0.60)
result2 = score_security_entries(db, model_id="fake-model", threshold=0.60)
assert result2.skipped is True
# ---------------------------------------------------------------------------
# list_cybersec_detections filters to scorer='cybersec'
# ---------------------------------------------------------------------------
def test_list_cybersec_detections(db, monkeypatch):
_insert_entry(db, "e1", "Failed login", anomaly_label="ERROR")
monkeypatch.setattr(cybersec_mod, "_pipeline", MagicMock(return_value=[{
"labels": ["authentication failure or brute force attack"],
"scores": [0.90],
}]))
score_security_entries(db, model_id="fake-model", threshold=0.60)
rows = list_cybersec_detections(db)
assert len(rows) == 1
assert rows[0]["scorer"] == "cybersec"
# ---------------------------------------------------------------------------
# list_detections scorer filter (anomaly service)
# ---------------------------------------------------------------------------
def test_list_detections_scorer_filter(db, monkeypatch):
from app.services.anomaly import list_detections
_insert_entry(db, "e1", "Failed login", anomaly_label="ERROR")
monkeypatch.setattr(cybersec_mod, "_pipeline", MagicMock(return_value=[{
"labels": ["authentication failure or brute force attack"],
"scores": [0.90],
}]))
score_security_entries(db, model_id="fake-model", threshold=0.60)
all_dets = list_detections(db)
cybersec_dets = list_detections(db, scorer="cybersec")
anomaly_dets = list_detections(db, scorer="anomaly")
assert len(cybersec_dets) == 1
assert len(anomaly_dets) == 0
assert len(all_dets) >= 1