"""Tests for the cybersec zero-shot scoring pipeline.""" from __future__ import annotations import sqlite3 import tempfile from pathlib import Path from unittest.mock import MagicMock import pytest from app.db.schema import ensure_schema from app.services.cybersec import ( CybersecResult, CYBERSEC_LABELS, _NORMAL_LABEL, reset_pipeline, score_security_entries, list_cybersec_detections, ) import app.services.cybersec as cybersec_mod @pytest.fixture(autouse=True) def _reset(tmp_path): reset_pipeline() yield reset_pipeline() @pytest.fixture def db(tmp_path) -> Path: path = tmp_path / "test.db" ensure_schema(path) return path def _insert_entry(db: Path, entry_id: str, text: str, anomaly_label: str | None = None, matched_patterns: str = "[]") -> None: with sqlite3.connect(db) as conn: conn.execute( """INSERT OR IGNORE INTO log_entries (id, tenant_id, source_id, sequence, ingest_time, text, anomaly_label, matched_patterns) VALUES (?, '', 'test-src', 1, '2026-01-01T00:00:00Z', ?, ?, ?)""", (entry_id, text, anomaly_label, matched_patterns), ) conn.commit() # --------------------------------------------------------------------------- # No model configured → skipped # --------------------------------------------------------------------------- def test_no_model_returns_skipped(db): result = score_security_entries(db, model_id="") assert result.skipped is True assert result.scored == 0 # --------------------------------------------------------------------------- # No eligible entries → skipped # --------------------------------------------------------------------------- def test_no_eligible_entries_skipped(db): _insert_entry(db, "e1", "Started nginx.service", anomaly_label=None, matched_patterns="[]") mock_pipe = MagicMock(return_value=[{"labels": [_NORMAL_LABEL], "scores": [0.99]}]) monkeypatch = pytest.MonkeyPatch() monkeypatch.setattr(cybersec_mod, "_pipeline", mock_pipe) result = score_security_entries(db, model_id="fake-model") assert result.skipped is True monkeypatch.undo() # --------------------------------------------------------------------------- # Security entry gets scored # --------------------------------------------------------------------------- def test_security_entry_scored(db, monkeypatch): _insert_entry(db, "e1", "Failed password for root from 192.168.1.1 port 22 ssh2", anomaly_label="SECURITY_ANOMALY") mock_pipe = MagicMock(return_value=[{ "labels": ["authentication failure or brute force attack", _NORMAL_LABEL], "scores": [0.85, 0.15], }]) monkeypatch.setattr(cybersec_mod, "_pipeline", mock_pipe) result = score_security_entries(db, model_id="fake-model", threshold=0.70) assert result.scored == 1 assert result.detections == 1 assert result.error is None with sqlite3.connect(db) as conn: conn.row_factory = sqlite3.Row row = conn.execute("SELECT ml_score, ml_label, ml_scored_at FROM log_entries WHERE id='e1'").fetchone() assert row["ml_score"] == pytest.approx(0.85) assert row["ml_label"] == "authentication failure or brute force attack" assert row["ml_scored_at"] is not None # --------------------------------------------------------------------------- # Detection created above threshold # --------------------------------------------------------------------------- def test_detection_inserted_above_threshold(db, monkeypatch): _insert_entry(db, "e1", "sudo: authentication failure", anomaly_label="ERROR") monkeypatch.setattr(cybersec_mod, "_pipeline", MagicMock(return_value=[{ "labels": ["privilege escalation or unauthorized access", _NORMAL_LABEL], "scores": [0.75, 0.25], }])) score_security_entries(db, model_id="fake-model", threshold=0.60) with sqlite3.connect(db) as conn: conn.row_factory = sqlite3.Row dets = conn.execute("SELECT * FROM detections WHERE scorer='cybersec'").fetchall() assert len(dets) == 1 assert dets[0]["anomaly_label"] == "privilege escalation or unauthorized access" assert dets[0]["severity"] == "CRITICAL" # --------------------------------------------------------------------------- # Normal label → no detection even above score threshold # --------------------------------------------------------------------------- def test_normal_label_no_detection(db, monkeypatch): _insert_entry(db, "e1", "Started nginx.service", anomaly_label="INFO", matched_patterns='["service_start"]') monkeypatch.setattr(cybersec_mod, "_pipeline", MagicMock(return_value=[{ "labels": [_NORMAL_LABEL, "network intrusion or port scan"], "scores": [0.95, 0.05], }])) result = score_security_entries(db, model_id="fake-model", threshold=0.60) assert result.detections == 0 # --------------------------------------------------------------------------- # Below threshold → scored but no detection # --------------------------------------------------------------------------- def test_below_threshold_no_detection(db, monkeypatch): _insert_entry(db, "e1", "Some suspicious text", anomaly_label="WARN") monkeypatch.setattr(cybersec_mod, "_pipeline", MagicMock(return_value=[{ "labels": ["network intrusion or port scan", _NORMAL_LABEL], "scores": [0.45, 0.55], }])) result = score_security_entries(db, model_id="fake-model", threshold=0.60) assert result.scored == 1 assert result.detections == 0 # --------------------------------------------------------------------------- # Pattern-matched entry (not anomaly-flagged) still gets scored # --------------------------------------------------------------------------- def test_pattern_matched_entry_scored(db, monkeypatch): _insert_entry(db, "e1", "SSH port forwarding conflict detected", anomaly_label=None, matched_patterns='["ssh_forward_conflict"]') monkeypatch.setattr(cybersec_mod, "_pipeline", MagicMock(return_value=[{ "labels": ["network intrusion or port scan", _NORMAL_LABEL], "scores": [0.70, 0.30], }])) result = score_security_entries(db, model_id="fake-model", threshold=0.60) assert result.scored == 1 assert result.detections == 1 # --------------------------------------------------------------------------- # Idempotency — re-run finds nothing unscored # --------------------------------------------------------------------------- def test_idempotent_rerun(db, monkeypatch): _insert_entry(db, "e1", "Failed login", anomaly_label="ERROR") monkeypatch.setattr(cybersec_mod, "_pipeline", MagicMock(return_value=[{ "labels": ["authentication failure or brute force attack"], "scores": [0.80], }])) score_security_entries(db, model_id="fake-model", threshold=0.60) result2 = score_security_entries(db, model_id="fake-model", threshold=0.60) assert result2.skipped is True # --------------------------------------------------------------------------- # list_cybersec_detections filters to scorer='cybersec' # --------------------------------------------------------------------------- def test_list_cybersec_detections(db, monkeypatch): _insert_entry(db, "e1", "Failed login", anomaly_label="ERROR") monkeypatch.setattr(cybersec_mod, "_pipeline", MagicMock(return_value=[{ "labels": ["authentication failure or brute force attack"], "scores": [0.90], }])) score_security_entries(db, model_id="fake-model", threshold=0.60) rows = list_cybersec_detections(db) assert len(rows) == 1 assert rows[0]["scorer"] == "cybersec" # --------------------------------------------------------------------------- # list_detections scorer filter (anomaly service) # --------------------------------------------------------------------------- def test_list_detections_scorer_filter(db, monkeypatch): from app.services.anomaly import list_detections _insert_entry(db, "e1", "Failed login", anomaly_label="ERROR") monkeypatch.setattr(cybersec_mod, "_pipeline", MagicMock(return_value=[{ "labels": ["authentication failure or brute force attack"], "scores": [0.90], }])) score_security_entries(db, model_id="fake-model", threshold=0.60) all_dets = list_detections(db) cybersec_dets = list_detections(db, scorer="cybersec") anomaly_dets = list_detections(db, scorer="anomaly") assert len(cybersec_dets) == 1 assert len(anomaly_dets) == 0 assert len(all_dets) >= 1