Second-pass cybersec classifier using DeBERTa-v3-base-mnli (already cached — no download required). Runs after each anomaly scoring pass on entries flagged by the anomaly scorer or with pattern matches. Architecture: - app/services/cybersec.py: zero-shot-classification pipeline with 5 cybersec candidate labels (auth failure, privilege escalation, network intrusion, malware, data exfiltration). Writes ml_score/ml_label/ ml_scored_at to log_entries; inserts high-confidence hits into detections with scorer='cybersec'. - app/tasks/cybersec_scorer.py: async background task (same shape as anomaly_scorer.py). - REST: GET/POST /turnstone/api/cybersec/status|run|detections. GET /turnstone/api/anomaly/detections now accepts scorer= filter. Schema: ml_score, ml_label, ml_scored_at added to log_entries; scorer column added to detections (idempotent migrations + DDL for both SQLite and Postgres). UI: Security Alerts view gains Source dropdown (All / Anomaly / Cybersec) and cybersec scorer status badge. Label dropdown split into optgroups. Deployment: TURNSTONE_CYBERSEC_MODEL/DEVICE/THRESHOLD vars added to .env.example, docker-compose.yml, docker-standalone.sh. Tests: 10 new tests — no model, no eligible entries, scoring, detection creation, normal label suppression, threshold filtering, pattern-tag filtering, idempotency, list filtering, scorer column filter. 416/416 passing. Closes: #9
233 lines
8.6 KiB
Python
233 lines
8.6 KiB
Python
"""Tests for the cybersec zero-shot scoring pipeline."""
|
|
from __future__ import annotations
|
|
|
|
import sqlite3
|
|
import tempfile
|
|
from pathlib import Path
|
|
from unittest.mock import MagicMock
|
|
|
|
import pytest
|
|
|
|
from app.db.schema import ensure_schema
|
|
from app.services.cybersec import (
|
|
CybersecResult,
|
|
CYBERSEC_LABELS,
|
|
_NORMAL_LABEL,
|
|
reset_pipeline,
|
|
score_security_entries,
|
|
list_cybersec_detections,
|
|
)
|
|
import app.services.cybersec as cybersec_mod
|
|
|
|
|
|
@pytest.fixture(autouse=True)
|
|
def _reset(tmp_path):
|
|
reset_pipeline()
|
|
yield
|
|
reset_pipeline()
|
|
|
|
|
|
@pytest.fixture
|
|
def db(tmp_path) -> Path:
|
|
path = tmp_path / "test.db"
|
|
ensure_schema(path)
|
|
return path
|
|
|
|
|
|
def _insert_entry(db: Path, entry_id: str, text: str,
|
|
anomaly_label: str | None = None,
|
|
matched_patterns: str = "[]") -> None:
|
|
with sqlite3.connect(db) as conn:
|
|
conn.execute(
|
|
"""INSERT OR IGNORE INTO log_entries
|
|
(id, tenant_id, source_id, sequence, ingest_time, text,
|
|
anomaly_label, matched_patterns)
|
|
VALUES (?, '', 'test-src', 1, '2026-01-01T00:00:00Z', ?, ?, ?)""",
|
|
(entry_id, text, anomaly_label, matched_patterns),
|
|
)
|
|
conn.commit()
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# No model configured → skipped
|
|
# ---------------------------------------------------------------------------
|
|
|
|
def test_no_model_returns_skipped(db):
|
|
result = score_security_entries(db, model_id="")
|
|
assert result.skipped is True
|
|
assert result.scored == 0
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# No eligible entries → skipped
|
|
# ---------------------------------------------------------------------------
|
|
|
|
def test_no_eligible_entries_skipped(db):
|
|
_insert_entry(db, "e1", "Started nginx.service", anomaly_label=None, matched_patterns="[]")
|
|
mock_pipe = MagicMock(return_value=[{"labels": [_NORMAL_LABEL], "scores": [0.99]}])
|
|
monkeypatch = pytest.MonkeyPatch()
|
|
monkeypatch.setattr(cybersec_mod, "_pipeline", mock_pipe)
|
|
result = score_security_entries(db, model_id="fake-model")
|
|
assert result.skipped is True
|
|
monkeypatch.undo()
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Security entry gets scored
|
|
# ---------------------------------------------------------------------------
|
|
|
|
def test_security_entry_scored(db, monkeypatch):
|
|
_insert_entry(db, "e1",
|
|
"Failed password for root from 192.168.1.1 port 22 ssh2",
|
|
anomaly_label="SECURITY_ANOMALY")
|
|
|
|
mock_pipe = MagicMock(return_value=[{
|
|
"labels": ["authentication failure or brute force attack", _NORMAL_LABEL],
|
|
"scores": [0.85, 0.15],
|
|
}])
|
|
monkeypatch.setattr(cybersec_mod, "_pipeline", mock_pipe)
|
|
|
|
result = score_security_entries(db, model_id="fake-model", threshold=0.70)
|
|
assert result.scored == 1
|
|
assert result.detections == 1
|
|
assert result.error is None
|
|
|
|
with sqlite3.connect(db) as conn:
|
|
conn.row_factory = sqlite3.Row
|
|
row = conn.execute("SELECT ml_score, ml_label, ml_scored_at FROM log_entries WHERE id='e1'").fetchone()
|
|
assert row["ml_score"] == pytest.approx(0.85)
|
|
assert row["ml_label"] == "authentication failure or brute force attack"
|
|
assert row["ml_scored_at"] is not None
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Detection created above threshold
|
|
# ---------------------------------------------------------------------------
|
|
|
|
def test_detection_inserted_above_threshold(db, monkeypatch):
|
|
_insert_entry(db, "e1", "sudo: authentication failure", anomaly_label="ERROR")
|
|
|
|
monkeypatch.setattr(cybersec_mod, "_pipeline", MagicMock(return_value=[{
|
|
"labels": ["privilege escalation or unauthorized access", _NORMAL_LABEL],
|
|
"scores": [0.75, 0.25],
|
|
}]))
|
|
|
|
score_security_entries(db, model_id="fake-model", threshold=0.60)
|
|
|
|
with sqlite3.connect(db) as conn:
|
|
conn.row_factory = sqlite3.Row
|
|
dets = conn.execute("SELECT * FROM detections WHERE scorer='cybersec'").fetchall()
|
|
assert len(dets) == 1
|
|
assert dets[0]["anomaly_label"] == "privilege escalation or unauthorized access"
|
|
assert dets[0]["severity"] == "CRITICAL"
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Normal label → no detection even above score threshold
|
|
# ---------------------------------------------------------------------------
|
|
|
|
def test_normal_label_no_detection(db, monkeypatch):
|
|
_insert_entry(db, "e1", "Started nginx.service", anomaly_label="INFO",
|
|
matched_patterns='["service_start"]')
|
|
|
|
monkeypatch.setattr(cybersec_mod, "_pipeline", MagicMock(return_value=[{
|
|
"labels": [_NORMAL_LABEL, "network intrusion or port scan"],
|
|
"scores": [0.95, 0.05],
|
|
}]))
|
|
|
|
result = score_security_entries(db, model_id="fake-model", threshold=0.60)
|
|
assert result.detections == 0
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Below threshold → scored but no detection
|
|
# ---------------------------------------------------------------------------
|
|
|
|
def test_below_threshold_no_detection(db, monkeypatch):
|
|
_insert_entry(db, "e1", "Some suspicious text", anomaly_label="WARN")
|
|
|
|
monkeypatch.setattr(cybersec_mod, "_pipeline", MagicMock(return_value=[{
|
|
"labels": ["network intrusion or port scan", _NORMAL_LABEL],
|
|
"scores": [0.45, 0.55],
|
|
}]))
|
|
|
|
result = score_security_entries(db, model_id="fake-model", threshold=0.60)
|
|
assert result.scored == 1
|
|
assert result.detections == 0
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Pattern-matched entry (not anomaly-flagged) still gets scored
|
|
# ---------------------------------------------------------------------------
|
|
|
|
def test_pattern_matched_entry_scored(db, monkeypatch):
|
|
_insert_entry(db, "e1", "SSH port forwarding conflict detected",
|
|
anomaly_label=None,
|
|
matched_patterns='["ssh_forward_conflict"]')
|
|
|
|
monkeypatch.setattr(cybersec_mod, "_pipeline", MagicMock(return_value=[{
|
|
"labels": ["network intrusion or port scan", _NORMAL_LABEL],
|
|
"scores": [0.70, 0.30],
|
|
}]))
|
|
|
|
result = score_security_entries(db, model_id="fake-model", threshold=0.60)
|
|
assert result.scored == 1
|
|
assert result.detections == 1
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Idempotency — re-run finds nothing unscored
|
|
# ---------------------------------------------------------------------------
|
|
|
|
def test_idempotent_rerun(db, monkeypatch):
|
|
_insert_entry(db, "e1", "Failed login", anomaly_label="ERROR")
|
|
|
|
monkeypatch.setattr(cybersec_mod, "_pipeline", MagicMock(return_value=[{
|
|
"labels": ["authentication failure or brute force attack"],
|
|
"scores": [0.80],
|
|
}]))
|
|
|
|
score_security_entries(db, model_id="fake-model", threshold=0.60)
|
|
result2 = score_security_entries(db, model_id="fake-model", threshold=0.60)
|
|
assert result2.skipped is True
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# list_cybersec_detections filters to scorer='cybersec'
|
|
# ---------------------------------------------------------------------------
|
|
|
|
def test_list_cybersec_detections(db, monkeypatch):
|
|
_insert_entry(db, "e1", "Failed login", anomaly_label="ERROR")
|
|
|
|
monkeypatch.setattr(cybersec_mod, "_pipeline", MagicMock(return_value=[{
|
|
"labels": ["authentication failure or brute force attack"],
|
|
"scores": [0.90],
|
|
}]))
|
|
score_security_entries(db, model_id="fake-model", threshold=0.60)
|
|
|
|
rows = list_cybersec_detections(db)
|
|
assert len(rows) == 1
|
|
assert rows[0]["scorer"] == "cybersec"
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# list_detections scorer filter (anomaly service)
|
|
# ---------------------------------------------------------------------------
|
|
|
|
def test_list_detections_scorer_filter(db, monkeypatch):
|
|
from app.services.anomaly import list_detections
|
|
_insert_entry(db, "e1", "Failed login", anomaly_label="ERROR")
|
|
|
|
monkeypatch.setattr(cybersec_mod, "_pipeline", MagicMock(return_value=[{
|
|
"labels": ["authentication failure or brute force attack"],
|
|
"scores": [0.90],
|
|
}]))
|
|
score_security_entries(db, model_id="fake-model", threshold=0.60)
|
|
|
|
all_dets = list_detections(db)
|
|
cybersec_dets = list_detections(db, scorer="cybersec")
|
|
anomaly_dets = list_detections(db, scorer="anomaly")
|
|
|
|
assert len(cybersec_dets) == 1
|
|
assert len(anomaly_dets) == 0
|
|
assert len(all_dets) >= 1
|