- Add app/services/anomaly.py: batch scorer using HF text-classification
pipeline; rewrites anomaly_score/anomaly_label/anomaly_scored_at on
log_entries; inserts high-confidence hits into detections table
- Add app/tasks/anomaly_scorer.py: background task (same shape as
glean_scheduler); triggered after each glean cycle when
TURNSTONE_ANOMALY_MODEL is set
- DB schema: add anomaly_score/anomaly_label/anomaly_scored_at columns to
log_entries (idempotent ALTER TABLE migration); add detections table
- Wire scorer into scheduler_loop and glean_scheduler.run_once; no-op when
model env var is empty (safe to leave unconfigured)
- REST endpoints: GET/POST /api/anomaly/status, /api/anomaly/run,
GET /api/anomaly/detections, POST /api/anomaly/detections/{id}/acknowledge
- Reuses Hybrid-BERT label map from diagnose/classifier.py; works with any
HF text-classification model
- 12 new tests; 406/406 passing
Closes: #10
220 lines
6.8 KiB
Python
220 lines
6.8 KiB
Python
"""Tests for app/services/anomaly.py — anomaly scoring pipeline."""
|
|
from __future__ import annotations
|
|
|
|
import sqlite3
|
|
import uuid
|
|
from pathlib import Path
|
|
from unittest.mock import MagicMock
|
|
|
|
import pytest
|
|
|
|
import app.services.anomaly as anomaly_mod
|
|
from app.db.schema import ensure_schema
|
|
from app.services.anomaly import (
|
|
ScoringResult,
|
|
acknowledge_detection,
|
|
list_detections,
|
|
reset_pipeline,
|
|
score_unscored,
|
|
)
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Fixtures
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
@pytest.fixture(autouse=True)
|
|
def _reset_pipeline():
|
|
"""Ensure the ML singleton is cleared between tests."""
|
|
reset_pipeline()
|
|
yield
|
|
reset_pipeline()
|
|
|
|
|
|
@pytest.fixture
|
|
def db(tmp_path: Path) -> Path:
|
|
db_path = tmp_path / "t.db"
|
|
ensure_schema(db_path)
|
|
return db_path
|
|
|
|
|
|
def _insert_entry(db_path: Path, text: str, entry_id: str | None = None) -> str:
|
|
eid = entry_id or str(uuid.uuid4())
|
|
conn = sqlite3.connect(str(db_path))
|
|
conn.execute(
|
|
"INSERT INTO log_entries(id, tenant_id, source_id, sequence, ingest_time, text) "
|
|
"VALUES (?,?,?,?,?,?)",
|
|
(eid, "", "src", 1, "2026-01-01T00:00:00", text),
|
|
)
|
|
conn.commit()
|
|
conn.close()
|
|
return eid
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# score_unscored
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
def test_score_unscored_no_model_returns_skipped(db: Path):
|
|
result = score_unscored(db, model_id="")
|
|
assert result.skipped is True
|
|
assert result.scored == 0
|
|
|
|
|
|
def test_score_unscored_scores_entries(db: Path, monkeypatch):
|
|
_insert_entry(db, "kernel panic — OOM killer invoked")
|
|
_insert_entry(db, "user login successful")
|
|
|
|
mock_pipe = MagicMock(return_value=[
|
|
{"label": "SYSTEM_FAILURE", "score": 0.92},
|
|
{"label": "NORMAL", "score": 0.88},
|
|
])
|
|
monkeypatch.setattr(anomaly_mod, "_pipeline", mock_pipe)
|
|
|
|
result = score_unscored(db, model_id="fake-model", batch_size=10)
|
|
assert result.skipped is False
|
|
assert result.scored == 2
|
|
|
|
|
|
def test_score_unscored_creates_detection_above_threshold(db: Path, monkeypatch):
|
|
_insert_entry(db, "segfault in service")
|
|
|
|
mock_pipe = MagicMock(return_value=[
|
|
{"label": "SYSTEM_FAILURE", "score": 0.95},
|
|
])
|
|
monkeypatch.setattr(anomaly_mod, "_pipeline", mock_pipe)
|
|
|
|
result = score_unscored(db, model_id="fake-model", threshold=0.80)
|
|
assert result.detections == 1
|
|
|
|
detections = list_detections(db)
|
|
assert len(detections) == 1
|
|
assert detections[0]["anomaly_label"] == "SYSTEM_FAILURE"
|
|
assert detections[0]["anomaly_score"] == pytest.approx(0.95)
|
|
|
|
|
|
def test_score_unscored_no_detection_below_threshold(db: Path, monkeypatch):
|
|
_insert_entry(db, "warning: disk at 80%")
|
|
|
|
mock_pipe = MagicMock(return_value=[
|
|
{"label": "PERFORMANCE_ISSUE", "score": 0.60},
|
|
])
|
|
monkeypatch.setattr(anomaly_mod, "_pipeline", mock_pipe)
|
|
|
|
result = score_unscored(db, model_id="fake-model", threshold=0.80)
|
|
assert result.detections == 0
|
|
assert result.scored == 1
|
|
|
|
|
|
def test_score_unscored_normal_label_never_detection(db: Path, monkeypatch):
|
|
_insert_entry(db, "service started successfully")
|
|
|
|
mock_pipe = MagicMock(return_value=[
|
|
{"label": "NORMAL", "score": 0.99},
|
|
])
|
|
monkeypatch.setattr(anomaly_mod, "_pipeline", mock_pipe)
|
|
|
|
result = score_unscored(db, model_id="fake-model", threshold=0.50)
|
|
assert result.detections == 0
|
|
|
|
|
|
def test_score_unscored_idempotent(db: Path, monkeypatch):
|
|
"""Entries already scored are not re-scored on subsequent runs."""
|
|
_insert_entry(db, "first entry")
|
|
|
|
call_count = 0
|
|
|
|
def _side_effect(texts, **_kwargs):
|
|
nonlocal call_count
|
|
call_count += 1
|
|
return [{"label": "NORMAL", "score": 0.90} for _ in texts]
|
|
|
|
mock_pipe = MagicMock(side_effect=_side_effect)
|
|
monkeypatch.setattr(anomaly_mod, "_pipeline", mock_pipe)
|
|
|
|
score_unscored(db, model_id="fake-model")
|
|
score_unscored(db, model_id="fake-model")
|
|
|
|
assert call_count == 1 # second run finds no unscored rows
|
|
|
|
|
|
def test_score_unscored_pipeline_error_returns_error(db: Path, monkeypatch):
|
|
_insert_entry(db, "some log line")
|
|
|
|
mock_pipe = MagicMock(side_effect=RuntimeError("CUDA OOM"))
|
|
monkeypatch.setattr(anomaly_mod, "_pipeline", mock_pipe)
|
|
|
|
result = score_unscored(db, model_id="fake-model")
|
|
assert result.error is not None
|
|
assert "CUDA OOM" in result.error
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# list_detections / acknowledge_detection
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
def test_list_detections_empty(db: Path):
|
|
assert list_detections(db) == []
|
|
|
|
|
|
def test_list_detections_filters_unacked(db: Path, monkeypatch):
|
|
_insert_entry(db, "crash")
|
|
|
|
mock_pipe = MagicMock(return_value=[{"label": "SYSTEM_FAILURE", "score": 0.91}])
|
|
monkeypatch.setattr(anomaly_mod, "_pipeline", mock_pipe)
|
|
score_unscored(db, model_id="fake-model", threshold=0.80)
|
|
|
|
all_dets = list_detections(db)
|
|
assert len(all_dets) == 1
|
|
unacked = list_detections(db, unacked_only=True)
|
|
assert len(unacked) == 1
|
|
|
|
|
|
def test_acknowledge_detection(db: Path, monkeypatch):
|
|
_insert_entry(db, "network anomaly")
|
|
|
|
mock_pipe = MagicMock(return_value=[{"label": "NETWORK_ANOMALY", "score": 0.88}])
|
|
monkeypatch.setattr(anomaly_mod, "_pipeline", mock_pipe)
|
|
score_unscored(db, model_id="fake-model", threshold=0.80)
|
|
|
|
dets = list_detections(db)
|
|
assert len(dets) == 1
|
|
det_id = dets[0]["id"]
|
|
|
|
updated = acknowledge_detection(db, det_id, notes="benign test traffic")
|
|
assert updated is True
|
|
|
|
unacked = list_detections(db, unacked_only=True)
|
|
assert len(unacked) == 0
|
|
|
|
all_dets = list_detections(db)
|
|
assert all_dets[0]["acknowledged"] == 1
|
|
assert all_dets[0]["notes"] == "benign test traffic"
|
|
|
|
|
|
def test_acknowledge_detection_unknown_id(db: Path):
|
|
updated = acknowledge_detection(db, "nonexistent-id")
|
|
assert updated is False
|
|
|
|
|
|
def test_list_detections_label_filter(db: Path, monkeypatch):
|
|
_insert_entry(db, "OOM kill")
|
|
_insert_entry(db, "network timeout")
|
|
|
|
mock_pipe = MagicMock(side_effect=[
|
|
[{"label": "SYSTEM_FAILURE", "score": 0.93}],
|
|
[{"label": "NETWORK_ANOMALY", "score": 0.85}],
|
|
])
|
|
monkeypatch.setattr(anomaly_mod, "_pipeline", mock_pipe)
|
|
|
|
score_unscored(db, model_id="fake-model", batch_size=1, threshold=0.80)
|
|
score_unscored(db, model_id="fake-model", batch_size=1, threshold=0.80)
|
|
|
|
sys_dets = list_detections(db, label="SYSTEM_FAILURE")
|
|
assert all(d["anomaly_label"] == "SYSTEM_FAILURE" for d in sys_dets)
|
|
|
|
net_dets = list_detections(db, label="NETWORK_ANOMALY")
|
|
assert all(d["anomaly_label"] == "NETWORK_ANOMALY" for d in net_dets)
|