feat: anomaly scoring pipeline (#10)
- Add app/services/anomaly.py: batch scorer using HF text-classification
pipeline; rewrites anomaly_score/anomaly_label/anomaly_scored_at on
log_entries; inserts high-confidence hits into detections table
- Add app/tasks/anomaly_scorer.py: background task (same shape as
glean_scheduler); triggered after each glean cycle when
TURNSTONE_ANOMALY_MODEL is set
- DB schema: add anomaly_score/anomaly_label/anomaly_scored_at columns to
log_entries (idempotent ALTER TABLE migration); add detections table
- Wire scorer into scheduler_loop and glean_scheduler.run_once; no-op when
model env var is empty (safe to leave unconfigured)
- REST endpoints: GET/POST /api/anomaly/status, /api/anomaly/run,
GET /api/anomaly/detections, POST /api/anomaly/detections/{id}/acknowledge
- Reuses Hybrid-BERT label map from diagnose/classifier.py; works with any
HF text-classification model
- 12 new tests; 406/406 passing
Closes: #10
This commit is contained in:
parent
8efd7f6745
commit
6e00bf03d3
6 changed files with 775 additions and 13 deletions
|
|
@ -23,18 +23,21 @@ logger = logging.getLogger(__name__)
|
|||
|
||||
_MAIN_SCHEMA_SQLITE = """
|
||||
CREATE TABLE IF NOT EXISTS log_entries (
|
||||
id TEXT NOT NULL,
|
||||
tenant_id TEXT NOT NULL DEFAULT '',
|
||||
source_id TEXT NOT NULL,
|
||||
sequence INTEGER NOT NULL,
|
||||
timestamp_raw TEXT,
|
||||
timestamp_iso TEXT,
|
||||
ingest_time TEXT NOT NULL,
|
||||
severity TEXT,
|
||||
repeat_count INTEGER DEFAULT 1,
|
||||
out_of_order INTEGER DEFAULT 0,
|
||||
id TEXT NOT NULL,
|
||||
tenant_id TEXT NOT NULL DEFAULT '',
|
||||
source_id TEXT NOT NULL,
|
||||
sequence INTEGER NOT NULL,
|
||||
timestamp_raw TEXT,
|
||||
timestamp_iso TEXT,
|
||||
ingest_time TEXT NOT NULL,
|
||||
severity TEXT,
|
||||
repeat_count INTEGER DEFAULT 1,
|
||||
out_of_order INTEGER DEFAULT 0,
|
||||
matched_patterns TEXT DEFAULT '[]',
|
||||
text TEXT NOT NULL,
|
||||
text TEXT NOT NULL,
|
||||
anomaly_score REAL,
|
||||
anomaly_label TEXT,
|
||||
anomaly_scored_at TEXT,
|
||||
PRIMARY KEY (tenant_id, id)
|
||||
);
|
||||
CREATE INDEX IF NOT EXISTS idx_source ON log_entries(source_id);
|
||||
|
|
@ -43,6 +46,27 @@ CREATE INDEX IF NOT EXISTS idx_timestamp ON log_entries(timestamp_iso);
|
|||
CREATE INDEX IF NOT EXISTS idx_ts_repeat ON log_entries(timestamp_iso, repeat_count);
|
||||
CREATE INDEX IF NOT EXISTS idx_severity ON log_entries(tenant_id, severity);
|
||||
CREATE INDEX IF NOT EXISTS idx_patterns ON log_entries(matched_patterns);
|
||||
CREATE INDEX IF NOT EXISTS idx_anomaly ON log_entries(tenant_id, anomaly_score);
|
||||
|
||||
CREATE TABLE IF NOT EXISTS detections (
|
||||
id TEXT PRIMARY KEY,
|
||||
tenant_id TEXT NOT NULL DEFAULT '',
|
||||
entry_id TEXT NOT NULL,
|
||||
source_id TEXT NOT NULL,
|
||||
anomaly_label TEXT NOT NULL,
|
||||
anomaly_score REAL NOT NULL,
|
||||
severity TEXT NOT NULL,
|
||||
text TEXT NOT NULL,
|
||||
timestamp_iso TEXT,
|
||||
detected_at TEXT NOT NULL,
|
||||
acknowledged INTEGER NOT NULL DEFAULT 0,
|
||||
acknowledged_at TEXT,
|
||||
notes TEXT NOT NULL DEFAULT ''
|
||||
);
|
||||
CREATE INDEX IF NOT EXISTS idx_detections_tenant ON detections(tenant_id, detected_at);
|
||||
CREATE INDEX IF NOT EXISTS idx_detections_ack ON detections(acknowledged);
|
||||
CREATE INDEX IF NOT EXISTS idx_detections_label ON detections(anomaly_label);
|
||||
CREATE INDEX IF NOT EXISTS idx_detections_entry ON detections(entry_id);
|
||||
|
||||
CREATE TABLE IF NOT EXISTS glean_fingerprints (
|
||||
tenant_id TEXT NOT NULL DEFAULT '',
|
||||
|
|
@ -174,6 +198,9 @@ _MAIN_SCHEMA_PG_STMTS = [
|
|||
matched_patterns TEXT DEFAULT '[]',
|
||||
text TEXT NOT NULL,
|
||||
text_tsv tsvector,
|
||||
anomaly_score DOUBLE PRECISION,
|
||||
anomaly_label TEXT,
|
||||
anomaly_scored_at TEXT,
|
||||
PRIMARY KEY (tenant_id, id)
|
||||
)
|
||||
""",
|
||||
|
|
@ -182,6 +209,28 @@ _MAIN_SCHEMA_PG_STMTS = [
|
|||
"CREATE INDEX IF NOT EXISTS idx_severity ON log_entries(tenant_id, severity)",
|
||||
"CREATE INDEX IF NOT EXISTS idx_patterns ON log_entries(matched_patterns)",
|
||||
"CREATE INDEX IF NOT EXISTS idx_fts_gin ON log_entries USING GIN(text_tsv)",
|
||||
"CREATE INDEX IF NOT EXISTS idx_anomaly ON log_entries(tenant_id, anomaly_score)",
|
||||
"""
|
||||
CREATE TABLE IF NOT EXISTS detections (
|
||||
id TEXT PRIMARY KEY,
|
||||
tenant_id TEXT NOT NULL DEFAULT '',
|
||||
entry_id TEXT NOT NULL,
|
||||
source_id TEXT NOT NULL,
|
||||
anomaly_label TEXT NOT NULL,
|
||||
anomaly_score DOUBLE PRECISION NOT NULL,
|
||||
severity TEXT NOT NULL,
|
||||
text TEXT NOT NULL,
|
||||
timestamp_iso TEXT,
|
||||
detected_at TEXT NOT NULL,
|
||||
acknowledged INTEGER NOT NULL DEFAULT 0,
|
||||
acknowledged_at TEXT,
|
||||
notes TEXT NOT NULL DEFAULT ''
|
||||
)
|
||||
""",
|
||||
"CREATE INDEX IF NOT EXISTS idx_detections_tenant ON detections(tenant_id, detected_at)",
|
||||
"CREATE INDEX IF NOT EXISTS idx_detections_ack ON detections(acknowledged)",
|
||||
"CREATE INDEX IF NOT EXISTS idx_detections_label ON detections(anomaly_label)",
|
||||
"CREATE INDEX IF NOT EXISTS idx_detections_entry ON detections(entry_id)",
|
||||
"""
|
||||
CREATE OR REPLACE FUNCTION _ts_update_text_tsv() RETURNS trigger AS $$
|
||||
BEGIN
|
||||
|
|
@ -336,6 +385,9 @@ _MAIN_MIGRATIONS_SQLITE = [
|
|||
"ALTER TABLE glean_fingerprints ADD COLUMN mtime REAL",
|
||||
"ALTER TABLE glean_fingerprints ADD COLUMN size INTEGER",
|
||||
"ALTER TABLE glean_fingerprints ADD COLUMN gleaned_at TEXT",
|
||||
"ALTER TABLE log_entries ADD COLUMN anomaly_score REAL",
|
||||
"ALTER TABLE log_entries ADD COLUMN anomaly_label TEXT",
|
||||
"ALTER TABLE log_entries ADD COLUMN anomaly_scored_at TEXT",
|
||||
]
|
||||
|
||||
_CONTEXT_MIGRATIONS_SQLITE = [
|
||||
|
|
|
|||
68
app/rest.py
68
app/rest.py
|
|
@ -88,6 +88,8 @@ from app.glean.doc_upload import glean_upload as _glean_upload
|
|||
from app.context.wizard import get_schema as _wizard_schema, advance_step, is_complete, apply_session
|
||||
from app.context.chunker import UnsupportedDocType, FileTooLarge
|
||||
from app.tasks.glean_scheduler import get_state as _glean_state, run_once as _run_glean, scheduler_loop as _scheduler_loop, submit_matched as _submit_matched
|
||||
from app.tasks.anomaly_scorer import get_state as _scorer_state, run_once as _run_scorer
|
||||
from app.services.anomaly import list_detections as _list_detections, acknowledge_detection as _ack_detection
|
||||
from app.glean.mqtt_subscriber import run_mqtt_subscribers as _run_mqtt_subscribers
|
||||
|
||||
DB_PATH = Path(os.environ.get("TURNSTONE_DB", Path(__file__).parent.parent / "data" / "turnstone.db"))
|
||||
|
|
@ -109,6 +111,9 @@ PATTERN_DIR = Path(os.environ.get("TURNSTONE_PATTERNS", Path(__file__).parent.pa
|
|||
PATTERN_FILE = PATTERN_DIR / "default.yaml"
|
||||
GLEAN_INTERVAL = int(os.environ.get("TURNSTONE_GLEAN_INTERVAL", "900"))
|
||||
SUBMIT_ENDPOINT = os.environ.get("TURNSTONE_SUBMIT_ENDPOINT", "").rstrip("/")
|
||||
ANOMALY_MODEL = os.environ.get("TURNSTONE_ANOMALY_MODEL", "")
|
||||
ANOMALY_DEVICE = os.environ.get("TURNSTONE_ANOMALY_DEVICE", "cpu")
|
||||
ANOMALY_THRESHOLD = float(os.environ.get("TURNSTONE_ANOMALY_THRESHOLD", "0.75"))
|
||||
# When set, all /api/ routes require Authorization: Bearer <key>.
|
||||
# Unset (default) means no authentication — suitable for local-only deployments.
|
||||
_API_KEY: str | None = os.environ.get("TURNSTONE_API_KEY") or None
|
||||
|
|
@ -165,6 +170,9 @@ async def _lifespan(app: FastAPI):
|
|||
sources_file, DB_PATH, PATTERN_FILE, GLEAN_INTERVAL,
|
||||
submit_endpoint=SUBMIT_ENDPOINT or None,
|
||||
source_host=SOURCE_HOST,
|
||||
anomaly_model=ANOMALY_MODEL,
|
||||
anomaly_device=ANOMALY_DEVICE,
|
||||
anomaly_threshold=ANOMALY_THRESHOLD,
|
||||
),
|
||||
name="glean-scheduler",
|
||||
)
|
||||
|
|
@ -1318,6 +1326,66 @@ async def debug_search(q: str):
|
|||
app.include_router(_ctx)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Anomaly scoring endpoints
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
_anomaly = APIRouter(prefix="/turnstone/api/anomaly", dependencies=[Depends(_check_api_key)])
|
||||
|
||||
|
||||
@_anomaly.get("/status")
|
||||
async def anomaly_status():
|
||||
"""Return scorer state and configuration."""
|
||||
state = _scorer_state()
|
||||
return {
|
||||
"model": ANOMALY_MODEL or None,
|
||||
"threshold": ANOMALY_THRESHOLD,
|
||||
"device": ANOMALY_DEVICE,
|
||||
"enabled": bool(ANOMALY_MODEL),
|
||||
**vars(state),
|
||||
}
|
||||
|
||||
|
||||
@_anomaly.post("/run")
|
||||
async def anomaly_run(background_tasks: BackgroundTasks):
|
||||
"""Trigger a manual anomaly scoring pass (runs in background)."""
|
||||
if not ANOMALY_MODEL:
|
||||
raise HTTPException(status_code=400, detail="TURNSTONE_ANOMALY_MODEL not configured")
|
||||
background_tasks.add_task(
|
||||
_run_scorer, DB_PATH, ANOMALY_MODEL, ANOMALY_DEVICE, 256, ANOMALY_THRESHOLD
|
||||
)
|
||||
return {"ok": True, "message": "scorer triggered"}
|
||||
|
||||
|
||||
@_anomaly.get("/detections")
|
||||
async def anomaly_detections(
|
||||
limit: int = Query(100, ge=1, le=1000),
|
||||
unacked_only: bool = Query(False),
|
||||
label: str | None = Query(None),
|
||||
):
|
||||
"""List anomaly detections ordered by detected_at DESC."""
|
||||
loop = asyncio.get_running_loop()
|
||||
rows = await loop.run_in_executor(
|
||||
None, lambda: _list_detections(DB_PATH, limit=limit, unacked_only=unacked_only, label=label)
|
||||
)
|
||||
return {"detections": rows, "total": len(rows)}
|
||||
|
||||
|
||||
@_anomaly.post("/detections/{detection_id}/acknowledge")
|
||||
async def acknowledge_detection(detection_id: str, notes: str = ""):
|
||||
"""Acknowledge a detection (mark as reviewed)."""
|
||||
loop = asyncio.get_running_loop()
|
||||
updated = await loop.run_in_executor(
|
||||
None, lambda: _ack_detection(DB_PATH, detection_id, notes)
|
||||
)
|
||||
if not updated:
|
||||
raise HTTPException(status_code=404, detail="Detection not found")
|
||||
return {"ok": True}
|
||||
|
||||
|
||||
app.include_router(_anomaly)
|
||||
|
||||
|
||||
# Root redirect → /turnstone/
|
||||
@app.get("/")
|
||||
def root_redirect() -> RedirectResponse:
|
||||
|
|
|
|||
291
app/services/anomaly.py
Normal file
291
app/services/anomaly.py
Normal file
|
|
@ -0,0 +1,291 @@
|
|||
"""Anomaly scoring pipeline — batch-score log_entries with a HF classifier.
|
||||
|
||||
Designed to run after each glean cycle (or standalone). When no model is
|
||||
configured the scorer is a no-op and returns immediately, so it is always
|
||||
safe to wire into the glean pipeline.
|
||||
|
||||
Model: any HuggingFace text-classification model. The existing Hybrid-BERT
|
||||
label map (from diagnose/classifier.py) is reused when the model produces
|
||||
NORMAL/SECURITY_ANOMALY/… outputs; other models get a generic severity map.
|
||||
|
||||
Scoring strategy
|
||||
----------------
|
||||
- Query unscored rows in batches (WHERE anomaly_scored_at IS NULL)
|
||||
- Run each entry text through the HF pipeline
|
||||
- Write anomaly_score + anomaly_label + anomaly_scored_at back
|
||||
- INSERT high-confidence hits (score >= threshold) into detections table,
|
||||
skipping duplicates so the scorer is safe to re-run
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import os
|
||||
import uuid
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime, timezone
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
from app.db import get_conn, resolve_tenant_id
|
||||
from app.db.dialect import q
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Label maps — reuse Hybrid-BERT vocabulary from diagnose/classifier.py
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
_HYBRID_BERT_SEVERITY: dict[str, str] = {
|
||||
"NORMAL": "INFO",
|
||||
"SECURITY_ANOMALY": "ERROR",
|
||||
"SYSTEM_FAILURE": "CRITICAL",
|
||||
"PERFORMANCE_ISSUE": "WARN",
|
||||
"NETWORK_ANOMALY": "WARN",
|
||||
"CONFIG_ERROR": "ERROR",
|
||||
"HARDWARE_ISSUE": "CRITICAL",
|
||||
}
|
||||
|
||||
_GENERIC_SEVERITY: dict[str, str] = {
|
||||
"CRITICAL": "CRITICAL",
|
||||
"ERROR": "ERROR",
|
||||
"WARNING": "WARN",
|
||||
"WARN": "WARN",
|
||||
"INFO": "INFO",
|
||||
"DEBUG": "DEBUG",
|
||||
}
|
||||
|
||||
_ANOMALOUS_LABELS: frozenset[str] = frozenset(
|
||||
{
|
||||
"SECURITY_ANOMALY",
|
||||
"SYSTEM_FAILURE",
|
||||
"PERFORMANCE_ISSUE",
|
||||
"NETWORK_ANOMALY",
|
||||
"CONFIG_ERROR",
|
||||
"HARDWARE_ISSUE",
|
||||
"CRITICAL",
|
||||
"ERROR",
|
||||
}
|
||||
)
|
||||
|
||||
_DEFAULT_THRESHOLD = float(os.environ.get("TURNSTONE_ANOMALY_THRESHOLD", "0.75"))
|
||||
_DEFAULT_MODEL = os.environ.get("TURNSTONE_ANOMALY_MODEL", "")
|
||||
_DEFAULT_DEVICE = os.environ.get("TURNSTONE_ANOMALY_DEVICE", "cpu")
|
||||
_DEFAULT_BATCH = int(os.environ.get("TURNSTONE_ANOMALY_BATCH", "256"))
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# ML singleton
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
_pipeline: Any | None = None
|
||||
|
||||
|
||||
def _get_pipeline(model_id: str, device: str) -> Any:
|
||||
global _pipeline # noqa: PLW0603
|
||||
if _pipeline is None:
|
||||
from transformers import pipeline as hf_pipeline # type: ignore[import-untyped]
|
||||
_pipeline = hf_pipeline("text-classification", model=model_id, device=device)
|
||||
return _pipeline
|
||||
|
||||
|
||||
def reset_pipeline() -> None:
|
||||
"""Reset the cached pipeline singleton (test helper)."""
|
||||
global _pipeline # noqa: PLW0603
|
||||
_pipeline = None
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Result types
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@dataclass
|
||||
class ScoringResult:
|
||||
scored: int = 0
|
||||
detections: int = 0
|
||||
skipped: bool = False
|
||||
error: str | None = None
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Internal helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _map_label(raw_label: str, score: float) -> tuple[str, str]:
|
||||
"""Return (normalised_label, severity) for a raw model output label."""
|
||||
upper = raw_label.upper()
|
||||
if upper in _HYBRID_BERT_SEVERITY:
|
||||
return upper, _HYBRID_BERT_SEVERITY[upper]
|
||||
sev = _GENERIC_SEVERITY.get(upper, "WARN")
|
||||
return upper, sev
|
||||
|
||||
|
||||
def _fetch_unscored(conn: Any, tenant_id: str, limit: int) -> list[dict]:
|
||||
rows = conn.execute(
|
||||
q("""
|
||||
SELECT id, source_id, text, timestamp_iso, severity
|
||||
FROM log_entries
|
||||
WHERE anomaly_scored_at IS NULL
|
||||
AND (tenant_id = ? OR tenant_id = '')
|
||||
ORDER BY ingest_time DESC
|
||||
LIMIT ?
|
||||
"""),
|
||||
(tenant_id, limit),
|
||||
).fetchall()
|
||||
return [dict(r) for r in rows]
|
||||
|
||||
|
||||
def _write_scores(
|
||||
conn: Any,
|
||||
rows: list[dict],
|
||||
scored_at: str,
|
||||
) -> None:
|
||||
conn.executemany(
|
||||
q("UPDATE log_entries SET anomaly_score = ?, anomaly_label = ?, anomaly_scored_at = ? WHERE id = ?"),
|
||||
[(r["anomaly_score"], r["anomaly_label"], scored_at, r["id"]) for r in rows],
|
||||
)
|
||||
|
||||
|
||||
def _insert_detections(conn: Any, rows: list[dict], tenant_id: str, detected_at: str) -> int:
|
||||
inserted = 0
|
||||
for r in rows:
|
||||
try:
|
||||
conn.execute(
|
||||
q("""
|
||||
INSERT INTO detections
|
||||
(id, tenant_id, entry_id, source_id, anomaly_label, anomaly_score,
|
||||
severity, text, timestamp_iso, detected_at)
|
||||
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
|
||||
"""),
|
||||
(
|
||||
str(uuid.uuid4()),
|
||||
tenant_id,
|
||||
r["id"],
|
||||
r["source_id"],
|
||||
r["anomaly_label"],
|
||||
r["anomaly_score"],
|
||||
r["severity"],
|
||||
r["text"][:2000],
|
||||
r.get("timestamp_iso"),
|
||||
detected_at,
|
||||
),
|
||||
)
|
||||
inserted += 1
|
||||
except Exception: # noqa: BLE001
|
||||
pass # duplicate entry_id or constraint violation — skip
|
||||
return inserted
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Public API
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def score_unscored(
|
||||
db_path: Path,
|
||||
model_id: str = _DEFAULT_MODEL,
|
||||
device: str = _DEFAULT_DEVICE,
|
||||
batch_size: int = _DEFAULT_BATCH,
|
||||
threshold: float = _DEFAULT_THRESHOLD,
|
||||
) -> ScoringResult:
|
||||
"""Score all unscored log_entries in batches.
|
||||
|
||||
Returns immediately (skipped=True) when model_id is empty — allows
|
||||
unconditional wiring without requiring the model to be configured.
|
||||
"""
|
||||
if not model_id:
|
||||
return ScoringResult(skipped=True)
|
||||
|
||||
try:
|
||||
pipe = _get_pipeline(model_id, device)
|
||||
except Exception as exc:
|
||||
logger.error("Failed to load anomaly model %r: %s", model_id, exc)
|
||||
return ScoringResult(error=str(exc))
|
||||
|
||||
tenant_id = resolve_tenant_id()
|
||||
total_scored = 0
|
||||
total_detections = 0
|
||||
|
||||
while True:
|
||||
with get_conn(db_path) as conn:
|
||||
batch = _fetch_unscored(conn, tenant_id, batch_size)
|
||||
if not batch:
|
||||
break
|
||||
|
||||
texts = [r["text"][:512] for r in batch]
|
||||
try:
|
||||
predictions = pipe(texts, truncation=True, max_length=512)
|
||||
except Exception as exc:
|
||||
logger.error("Inference error on batch of %d: %s", len(batch), exc)
|
||||
return ScoringResult(scored=total_scored, detections=total_detections, error=str(exc))
|
||||
|
||||
scored_at = datetime.now(tz=timezone.utc).isoformat()
|
||||
scored_rows: list[dict] = []
|
||||
detection_rows: list[dict] = []
|
||||
|
||||
for row, pred in zip(batch, predictions):
|
||||
label, severity = _map_label(pred["label"], pred["score"])
|
||||
enriched = {**row, "anomaly_score": pred["score"], "anomaly_label": label, "severity": severity}
|
||||
scored_rows.append(enriched)
|
||||
if label in _ANOMALOUS_LABELS and pred["score"] >= threshold:
|
||||
detection_rows.append(enriched)
|
||||
|
||||
with get_conn(db_path) as conn:
|
||||
_write_scores(conn, scored_rows, scored_at)
|
||||
det_count = _insert_detections(conn, detection_rows, tenant_id, scored_at)
|
||||
conn.commit()
|
||||
|
||||
total_scored += len(scored_rows)
|
||||
total_detections += det_count
|
||||
logger.info(
|
||||
"Scored %d entries, %d detections (threshold=%.2f)",
|
||||
len(scored_rows), det_count, threshold,
|
||||
)
|
||||
|
||||
if len(batch) < batch_size:
|
||||
break
|
||||
|
||||
return ScoringResult(scored=total_scored, detections=total_detections)
|
||||
|
||||
|
||||
def list_detections(
|
||||
db_path: Path,
|
||||
limit: int = 100,
|
||||
unacked_only: bool = False,
|
||||
label: str | None = None,
|
||||
) -> list[dict]:
|
||||
"""Return detections ordered by detected_at DESC."""
|
||||
tenant_id = resolve_tenant_id()
|
||||
conditions = ["(tenant_id = ? OR tenant_id = '')"]
|
||||
params: list[Any] = [tenant_id]
|
||||
|
||||
if unacked_only:
|
||||
conditions.append("acknowledged = 0")
|
||||
if label:
|
||||
conditions.append(q("anomaly_label = ?"))
|
||||
params.append(label.upper())
|
||||
|
||||
where = " AND ".join(conditions)
|
||||
with get_conn(db_path) as conn:
|
||||
rows = conn.execute(
|
||||
q(f"SELECT * FROM detections WHERE {where} ORDER BY detected_at DESC LIMIT ?"), # noqa: S608
|
||||
(*params, limit),
|
||||
).fetchall()
|
||||
return [dict(r) for r in rows]
|
||||
|
||||
|
||||
def acknowledge_detection(db_path: Path, detection_id: str, notes: str = "") -> bool:
|
||||
"""Mark a detection as acknowledged. Returns True if a row was updated."""
|
||||
tenant_id = resolve_tenant_id()
|
||||
acked_at = datetime.now(tz=timezone.utc).isoformat()
|
||||
with get_conn(db_path) as conn:
|
||||
cur = conn.execute(
|
||||
q("""
|
||||
UPDATE detections
|
||||
SET acknowledged = 1, acknowledged_at = ?, notes = ?
|
||||
WHERE id = ? AND (tenant_id = ? OR tenant_id = '')
|
||||
"""),
|
||||
(acked_at, notes, detection_id, tenant_id),
|
||||
)
|
||||
conn.commit()
|
||||
return cur.rowcount > 0
|
||||
114
app/tasks/anomaly_scorer.py
Normal file
114
app/tasks/anomaly_scorer.py
Normal file
|
|
@ -0,0 +1,114 @@
|
|||
"""Background anomaly scoring task.
|
||||
|
||||
Runs score_unscored() after each glean cycle (triggered by glean_scheduler)
|
||||
or on its own interval when TURNSTONE_ANOMALY_INTERVAL is set.
|
||||
|
||||
Set TURNSTONE_ANOMALY_MODEL to a HuggingFace model ID to activate.
|
||||
When the env var is empty (default) the scorer is a no-op.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import os
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from pathlib import Path
|
||||
|
||||
from app.services.anomaly import ScoringResult, score_unscored
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_DEFAULT_INTERVAL = int(os.environ.get("TURNSTONE_ANOMALY_INTERVAL", "0"))
|
||||
|
||||
_lock = asyncio.Lock()
|
||||
|
||||
|
||||
@dataclass
|
||||
class ScorerState:
|
||||
last_run_at: str | None = None
|
||||
last_duration_s: float | None = None
|
||||
last_scored: int = 0
|
||||
last_detections: int = 0
|
||||
last_error: str | None = None
|
||||
run_count: int = 0
|
||||
next_run_at: str | None = None
|
||||
running: bool = False
|
||||
total_scored: int = 0
|
||||
total_detections: int = 0
|
||||
|
||||
|
||||
_state = ScorerState()
|
||||
|
||||
|
||||
def get_state() -> ScorerState:
|
||||
return _state
|
||||
|
||||
|
||||
async def run_once(
|
||||
db_path: Path,
|
||||
model_id: str = "",
|
||||
device: str = "cpu",
|
||||
batch_size: int = 256,
|
||||
threshold: float = 0.75,
|
||||
) -> ScoringResult:
|
||||
"""Score unscored entries once. Skips if already running or model not configured."""
|
||||
if _lock.locked():
|
||||
return ScoringResult(skipped=True, error="scorer already running")
|
||||
|
||||
async with _lock:
|
||||
_state.running = True
|
||||
started = datetime.now(tz=timezone.utc)
|
||||
try:
|
||||
loop = asyncio.get_running_loop()
|
||||
result: ScoringResult = await loop.run_in_executor(
|
||||
None,
|
||||
lambda: score_unscored(db_path, model_id, device, batch_size, threshold),
|
||||
)
|
||||
duration = (datetime.now(tz=timezone.utc) - started).total_seconds()
|
||||
_state.last_run_at = started.isoformat()
|
||||
_state.last_duration_s = round(duration, 2)
|
||||
_state.last_scored = result.scored
|
||||
_state.last_detections = result.detections
|
||||
_state.last_error = result.error
|
||||
_state.run_count += 1
|
||||
_state.total_scored += result.scored
|
||||
_state.total_detections += result.detections
|
||||
if not result.skipped:
|
||||
logger.info(
|
||||
"Anomaly scorer: %d scored, %d detections in %.1fs",
|
||||
result.scored, result.detections, duration,
|
||||
)
|
||||
return result
|
||||
except Exception as exc:
|
||||
duration = (datetime.now(tz=timezone.utc) - started).total_seconds()
|
||||
_state.last_run_at = started.isoformat()
|
||||
_state.last_duration_s = round(duration, 2)
|
||||
_state.last_error = str(exc)
|
||||
_state.run_count += 1
|
||||
logger.error("Anomaly scorer failed: %s", exc)
|
||||
return ScoringResult(error=str(exc))
|
||||
finally:
|
||||
_state.running = False
|
||||
|
||||
|
||||
async def scorer_loop(
|
||||
db_path: Path,
|
||||
model_id: str,
|
||||
device: str,
|
||||
interval_s: int,
|
||||
batch_size: int = 256,
|
||||
threshold: float = 0.75,
|
||||
) -> None:
|
||||
"""Score unscored entries every interval_s seconds until cancelled."""
|
||||
logger.info("Anomaly scorer loop started — interval %ds, model: %s", interval_s, model_id)
|
||||
while True:
|
||||
await run_once(db_path, model_id, device, batch_size, threshold)
|
||||
next_run = datetime.now(tz=timezone.utc) + timedelta(seconds=interval_s)
|
||||
_state.next_run_at = next_run.isoformat()
|
||||
try:
|
||||
await asyncio.sleep(interval_s)
|
||||
except asyncio.CancelledError:
|
||||
logger.info("Anomaly scorer loop cancelled")
|
||||
_state.next_run_at = None
|
||||
raise
|
||||
|
|
@ -20,6 +20,7 @@ from typing import Any
|
|||
import httpx
|
||||
|
||||
from app.glean.pipeline import glean_sources
|
||||
from app.tasks.anomaly_scorer import run_once as _run_scorer
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
|
@ -123,6 +124,9 @@ async def run_once(
|
|||
submit_endpoint: str | None = None,
|
||||
source_host: str = "unknown",
|
||||
force: bool = False,
|
||||
anomaly_model: str = "",
|
||||
anomaly_device: str = "cpu",
|
||||
anomaly_threshold: float = 0.75,
|
||||
) -> dict[str, Any]:
|
||||
"""Ingest all sources once, then submit matched entries if configured.
|
||||
|
||||
|
|
@ -163,6 +167,9 @@ async def run_once(
|
|||
if submit_endpoint:
|
||||
await submit_matched(db_path, submit_endpoint, source_host, since=_state.last_submitted_at)
|
||||
|
||||
if anomaly_model:
|
||||
await _run_scorer(db_path, anomaly_model, anomaly_device, threshold=anomaly_threshold)
|
||||
|
||||
return {"ok": True, "stats": _state.last_stats, "duration_s": _state.last_duration_s}
|
||||
|
||||
|
||||
|
|
@ -173,13 +180,23 @@ async def scheduler_loop(
|
|||
interval_s: int,
|
||||
submit_endpoint: str | None = None,
|
||||
source_host: str = "unknown",
|
||||
anomaly_model: str = "",
|
||||
anomaly_device: str = "cpu",
|
||||
anomaly_threshold: float = 0.75,
|
||||
) -> None:
|
||||
"""Run glean + optional submission every interval_s seconds until cancelled."""
|
||||
"""Run glean + optional submission + optional anomaly scoring every interval_s seconds."""
|
||||
logger.info("Ingest scheduler started — interval %ds, sources: %s", interval_s, sources_file)
|
||||
if submit_endpoint:
|
||||
logger.info("Submission enabled — endpoint: %s", submit_endpoint)
|
||||
if anomaly_model:
|
||||
logger.info("Anomaly scoring enabled — model: %s", anomaly_model)
|
||||
while True:
|
||||
await run_once(sources_file, db_path, pattern_file, submit_endpoint, source_host)
|
||||
await run_once(
|
||||
sources_file, db_path, pattern_file, submit_endpoint, source_host,
|
||||
anomaly_model=anomaly_model,
|
||||
anomaly_device=anomaly_device,
|
||||
anomaly_threshold=anomaly_threshold,
|
||||
)
|
||||
next_run = datetime.now(tz=timezone.utc) + timedelta(seconds=interval_s)
|
||||
_state.next_run_at = next_run.isoformat()
|
||||
try:
|
||||
|
|
|
|||
220
tests/test_anomaly.py
Normal file
220
tests/test_anomaly.py
Normal file
|
|
@ -0,0 +1,220 @@
|
|||
"""Tests for app/services/anomaly.py — anomaly scoring pipeline."""
|
||||
from __future__ import annotations
|
||||
|
||||
import sqlite3
|
||||
import uuid
|
||||
from pathlib import Path
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
import app.services.anomaly as anomaly_mod
|
||||
from app.db.schema import ensure_schema
|
||||
from app.services.anomaly import (
|
||||
ScoringResult,
|
||||
acknowledge_detection,
|
||||
list_detections,
|
||||
reset_pipeline,
|
||||
score_unscored,
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Fixtures
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _reset_pipeline():
|
||||
"""Ensure the ML singleton is cleared between tests."""
|
||||
reset_pipeline()
|
||||
yield
|
||||
reset_pipeline()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def db(tmp_path: Path) -> Path:
|
||||
db_path = tmp_path / "t.db"
|
||||
ensure_schema(db_path)
|
||||
return db_path
|
||||
|
||||
|
||||
def _insert_entry(db_path: Path, text: str, entry_id: str | None = None) -> str:
|
||||
eid = entry_id or str(uuid.uuid4())
|
||||
conn = sqlite3.connect(str(db_path))
|
||||
conn.execute(
|
||||
"INSERT INTO log_entries(id, tenant_id, source_id, sequence, ingest_time, text) "
|
||||
"VALUES (?,?,?,?,?,?)",
|
||||
(eid, "", "src", 1, "2026-01-01T00:00:00", text),
|
||||
)
|
||||
conn.commit()
|
||||
conn.close()
|
||||
return eid
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# score_unscored
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_score_unscored_no_model_returns_skipped(db: Path):
|
||||
result = score_unscored(db, model_id="")
|
||||
assert result.skipped is True
|
||||
assert result.scored == 0
|
||||
|
||||
|
||||
def test_score_unscored_scores_entries(db: Path, monkeypatch):
|
||||
_insert_entry(db, "kernel panic — OOM killer invoked")
|
||||
_insert_entry(db, "user login successful")
|
||||
|
||||
mock_pipe = MagicMock(return_value=[
|
||||
{"label": "SYSTEM_FAILURE", "score": 0.92},
|
||||
{"label": "NORMAL", "score": 0.88},
|
||||
])
|
||||
monkeypatch.setattr(anomaly_mod, "_pipeline", mock_pipe)
|
||||
|
||||
result = score_unscored(db, model_id="fake-model", batch_size=10)
|
||||
assert result.skipped is False
|
||||
assert result.scored == 2
|
||||
|
||||
|
||||
def test_score_unscored_creates_detection_above_threshold(db: Path, monkeypatch):
|
||||
_insert_entry(db, "segfault in service")
|
||||
|
||||
mock_pipe = MagicMock(return_value=[
|
||||
{"label": "SYSTEM_FAILURE", "score": 0.95},
|
||||
])
|
||||
monkeypatch.setattr(anomaly_mod, "_pipeline", mock_pipe)
|
||||
|
||||
result = score_unscored(db, model_id="fake-model", threshold=0.80)
|
||||
assert result.detections == 1
|
||||
|
||||
detections = list_detections(db)
|
||||
assert len(detections) == 1
|
||||
assert detections[0]["anomaly_label"] == "SYSTEM_FAILURE"
|
||||
assert detections[0]["anomaly_score"] == pytest.approx(0.95)
|
||||
|
||||
|
||||
def test_score_unscored_no_detection_below_threshold(db: Path, monkeypatch):
|
||||
_insert_entry(db, "warning: disk at 80%")
|
||||
|
||||
mock_pipe = MagicMock(return_value=[
|
||||
{"label": "PERFORMANCE_ISSUE", "score": 0.60},
|
||||
])
|
||||
monkeypatch.setattr(anomaly_mod, "_pipeline", mock_pipe)
|
||||
|
||||
result = score_unscored(db, model_id="fake-model", threshold=0.80)
|
||||
assert result.detections == 0
|
||||
assert result.scored == 1
|
||||
|
||||
|
||||
def test_score_unscored_normal_label_never_detection(db: Path, monkeypatch):
|
||||
_insert_entry(db, "service started successfully")
|
||||
|
||||
mock_pipe = MagicMock(return_value=[
|
||||
{"label": "NORMAL", "score": 0.99},
|
||||
])
|
||||
monkeypatch.setattr(anomaly_mod, "_pipeline", mock_pipe)
|
||||
|
||||
result = score_unscored(db, model_id="fake-model", threshold=0.50)
|
||||
assert result.detections == 0
|
||||
|
||||
|
||||
def test_score_unscored_idempotent(db: Path, monkeypatch):
|
||||
"""Entries already scored are not re-scored on subsequent runs."""
|
||||
_insert_entry(db, "first entry")
|
||||
|
||||
call_count = 0
|
||||
|
||||
def _side_effect(texts, **_kwargs):
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
return [{"label": "NORMAL", "score": 0.90} for _ in texts]
|
||||
|
||||
mock_pipe = MagicMock(side_effect=_side_effect)
|
||||
monkeypatch.setattr(anomaly_mod, "_pipeline", mock_pipe)
|
||||
|
||||
score_unscored(db, model_id="fake-model")
|
||||
score_unscored(db, model_id="fake-model")
|
||||
|
||||
assert call_count == 1 # second run finds no unscored rows
|
||||
|
||||
|
||||
def test_score_unscored_pipeline_error_returns_error(db: Path, monkeypatch):
|
||||
_insert_entry(db, "some log line")
|
||||
|
||||
mock_pipe = MagicMock(side_effect=RuntimeError("CUDA OOM"))
|
||||
monkeypatch.setattr(anomaly_mod, "_pipeline", mock_pipe)
|
||||
|
||||
result = score_unscored(db, model_id="fake-model")
|
||||
assert result.error is not None
|
||||
assert "CUDA OOM" in result.error
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# list_detections / acknowledge_detection
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_list_detections_empty(db: Path):
|
||||
assert list_detections(db) == []
|
||||
|
||||
|
||||
def test_list_detections_filters_unacked(db: Path, monkeypatch):
|
||||
_insert_entry(db, "crash")
|
||||
|
||||
mock_pipe = MagicMock(return_value=[{"label": "SYSTEM_FAILURE", "score": 0.91}])
|
||||
monkeypatch.setattr(anomaly_mod, "_pipeline", mock_pipe)
|
||||
score_unscored(db, model_id="fake-model", threshold=0.80)
|
||||
|
||||
all_dets = list_detections(db)
|
||||
assert len(all_dets) == 1
|
||||
unacked = list_detections(db, unacked_only=True)
|
||||
assert len(unacked) == 1
|
||||
|
||||
|
||||
def test_acknowledge_detection(db: Path, monkeypatch):
|
||||
_insert_entry(db, "network anomaly")
|
||||
|
||||
mock_pipe = MagicMock(return_value=[{"label": "NETWORK_ANOMALY", "score": 0.88}])
|
||||
monkeypatch.setattr(anomaly_mod, "_pipeline", mock_pipe)
|
||||
score_unscored(db, model_id="fake-model", threshold=0.80)
|
||||
|
||||
dets = list_detections(db)
|
||||
assert len(dets) == 1
|
||||
det_id = dets[0]["id"]
|
||||
|
||||
updated = acknowledge_detection(db, det_id, notes="benign test traffic")
|
||||
assert updated is True
|
||||
|
||||
unacked = list_detections(db, unacked_only=True)
|
||||
assert len(unacked) == 0
|
||||
|
||||
all_dets = list_detections(db)
|
||||
assert all_dets[0]["acknowledged"] == 1
|
||||
assert all_dets[0]["notes"] == "benign test traffic"
|
||||
|
||||
|
||||
def test_acknowledge_detection_unknown_id(db: Path):
|
||||
updated = acknowledge_detection(db, "nonexistent-id")
|
||||
assert updated is False
|
||||
|
||||
|
||||
def test_list_detections_label_filter(db: Path, monkeypatch):
|
||||
_insert_entry(db, "OOM kill")
|
||||
_insert_entry(db, "network timeout")
|
||||
|
||||
mock_pipe = MagicMock(side_effect=[
|
||||
[{"label": "SYSTEM_FAILURE", "score": 0.93}],
|
||||
[{"label": "NETWORK_ANOMALY", "score": 0.85}],
|
||||
])
|
||||
monkeypatch.setattr(anomaly_mod, "_pipeline", mock_pipe)
|
||||
|
||||
score_unscored(db, model_id="fake-model", batch_size=1, threshold=0.80)
|
||||
score_unscored(db, model_id="fake-model", batch_size=1, threshold=0.80)
|
||||
|
||||
sys_dets = list_detections(db, label="SYSTEM_FAILURE")
|
||||
assert all(d["anomaly_label"] == "SYSTEM_FAILURE" for d in sys_dets)
|
||||
|
||||
net_dets = list_detections(db, label="NETWORK_ANOMALY")
|
||||
assert all(d["anomaly_label"] == "NETWORK_ANOMALY" for d in net_dets)
|
||||
Loading…
Reference in a new issue