diff --git a/.env.example b/.env.example index fff4a27..a1d2d91 100644 --- a/.env.example +++ b/.env.example @@ -23,6 +23,15 @@ # Remote endpoint to push diagnostic bundles for escalation. # TURNSTONE_BUNDLE_ENDPOINT=https://example.com/api/bundles +# --- Log corpus export to Avocet (optional) --- +# Push ERROR/CRITICAL entries and labeled incidents to the Avocet corpus endpoint +# for logreading fine-tune training. Requires a consent token issued by CF. +# Contact alan@circuitforge.tech to register your node and receive a token. +# Watermarks are stored at data/corpus_watermark.txt and data/incident_watermark.txt. +# AVOCET_CORPUS_ENDPOINT=https://avocet.circuitforge.tech/api/corpus/log-batch +# AVOCET_CONSENT_TOKEN=your-uuid-token-here +# TURNSTONE_SOURCE_HOST=my-server-name # defaults to system hostname if unset + # --- Periodic batch glean --- # Seconds between automatic glean runs from sources.yaml. Set to 0 to disable. # TURNSTONE_GLEAN_INTERVAL=900 @@ -42,6 +51,32 @@ # TURNSTONE_EMBED_MODEL=BAAI/bge-small-en-v1.5 # TURNSTONE_EMBED_DEVICE=cpu +# --- Cybersec scoring pipeline (zero-shot, second-pass on flagged entries) --- +# Runs a zero-shot classifier on entries already flagged by the anomaly scorer +# or that have pattern matches — a focused second opinion using cybersec vocabulary. +# The DeBERTa-v3-base-mnli model (required by the diagnose pipeline) is the recommended +# zero-shot classifier — it produces human-readable cybersec labels with no fine-tuning. +# TURNSTONE_CYBERSEC_MODEL=MoritzLaurer/DeBERTa-v3-base-mnli-fever-anli +# TURNSTONE_CYBERSEC_DEVICE=cpu +# TURNSTONE_CYBERSEC_THRESHOLD=0.60 # lower than anomaly threshold (zero-shot is calibrated differently) + +# --- Anomaly scoring pipeline (IDS / watchdog) --- +# Batch-scores every ingested log entry after each glean cycle. +# Any HuggingFace text-classification model works; the byviz classifier (already +# required by the diagnose pipeline) is the recommended starting point. +# Detections above the threshold are inserted into the detections table and +# surfaced in the Security Alerts tab. +# +# Set TURNSTONE_ANOMALY_MODEL to enable; leave unset to disable (safe default). +# TURNSTONE_ANOMALY_MODEL=byviz/bylastic_classification_logs +# TURNSTONE_ANOMALY_DEVICE=cpu # or "cuda" / "mps" for GPU inference +# TURNSTONE_ANOMALY_THRESHOLD=0.80 # confidence floor for detection insertion +# TURNSTONE_ANOMALY_INTERVAL=0 # standalone loop (0 = glean-triggered only) +# +# HuggingFace model cache — share with the host to avoid re-downloading models. +# HF_HOME=/hf_cache # inside container (set in docker-compose) +# HF_CACHE_PATH=/Library/Assets/LLM # host bind-mount source (docker-compose only) + # --- Air-gapped / offline deployment --- # Set to 1 to block all HuggingFace hub network access at runtime. # Pre-download models to ~/.cache/huggingface/ before deploying — see docs/air-gapped-deployment.md. diff --git a/app/context/store.py b/app/context/store.py index 1ffa08a..a030570 100644 --- a/app/context/store.py +++ b/app/context/store.py @@ -1,12 +1,13 @@ """Context fact and document CRUD — MIT licensed.""" from __future__ import annotations -import sqlite3 import uuid from dataclasses import dataclass from datetime import datetime, timezone from pathlib import Path +from app.db import get_conn, resolve_tenant_id + @dataclass(frozen=True) class ContextFact: @@ -28,19 +29,8 @@ class ContextDocument: uploaded_at: str -def _connect(db_path: Path) -> sqlite3.Connection: - # timeout=30: retry for up to 30 s when another writer (e.g. the glean - # collector) holds a WAL write lock. PRAGMA busy_timeout is a SQLite-level - # hint that operates after the connection is open; the Python sqlite3 module's - # own retry loop is controlled solely by this timeout= argument. - conn = sqlite3.connect(str(db_path), timeout=30.0) - conn.execute("PRAGMA journal_mode=WAL") - conn.execute("PRAGMA foreign_keys=ON") - conn.row_factory = sqlite3.Row - return conn - - def add_fact(db_path: Path, category: str, key: str, value: str, source: str | None = None) -> ContextFact: + tid = resolve_tenant_id() fact = ContextFact( id=str(uuid.uuid4()), category=category, @@ -49,27 +39,28 @@ def add_fact(db_path: Path, category: str, key: str, value: str, source: str | N source=source, created_at=datetime.now(timezone.utc).isoformat(), ) - conn = _connect(db_path) - conn.execute( - "INSERT INTO context_facts(id, category, key, value, source, created_at) VALUES (?,?,?,?,?,?)", - (fact.id, fact.category, fact.key, fact.value, fact.source, fact.created_at), - ) - conn.commit() - conn.close() + with get_conn(db_path) as conn: + conn.execute( + "INSERT INTO context_facts(id, tenant_id, category, key, value, source, created_at) VALUES (?,?,?,?,?,?,?)", + (fact.id, tid, fact.category, fact.key, fact.value, fact.source, fact.created_at), + ) + conn.commit() return fact def list_facts(db_path: Path, category: str | None = None) -> list[ContextFact]: - conn = _connect(db_path) - if category: - rows = conn.execute( - "SELECT * FROM context_facts WHERE category=? ORDER BY created_at", (category,) - ).fetchall() - else: - rows = conn.execute( - "SELECT * FROM context_facts ORDER BY category, created_at" - ).fetchall() - conn.close() + tid = resolve_tenant_id() + with get_conn(db_path) as conn: + if category: + rows = conn.execute( + "SELECT * FROM context_facts WHERE category=? AND (tenant_id=? OR tenant_id='') ORDER BY created_at", + (category, tid), + ).fetchall() + else: + rows = conn.execute( + "SELECT * FROM context_facts WHERE (tenant_id=? OR tenant_id='') ORDER BY category, created_at", + (tid,), + ).fetchall() return [ ContextFact( id=r["id"], category=r["category"], key=r["key"], @@ -80,10 +71,13 @@ def list_facts(db_path: Path, category: str | None = None) -> list[ContextFact]: def delete_fact(db_path: Path, fact_id: str) -> bool: - conn = _connect(db_path) - cursor = conn.execute("DELETE FROM context_facts WHERE id=?", (fact_id,)) - conn.commit() - conn.close() + tid = resolve_tenant_id() + with get_conn(db_path) as conn: + cursor = conn.execute( + "DELETE FROM context_facts WHERE id=? AND (tenant_id=? OR tenant_id='')", + (fact_id, tid), + ) + conn.commit() return cursor.rowcount > 0 @@ -94,6 +88,7 @@ def add_document( full_text: str, file_size: int | None = None, ) -> ContextDocument: + tid = resolve_tenant_id() doc = ContextDocument( id=str(uuid.uuid4()), filename=filename, @@ -102,24 +97,24 @@ def add_document( file_size=file_size, uploaded_at=datetime.now(timezone.utc).isoformat(), ) - conn = _connect(db_path) - conn.execute( - "INSERT INTO context_documents(id, filename, doc_type, full_text, file_size, uploaded_at)" - " VALUES (?,?,?,?,?,?)", - (doc.id, doc.filename, doc.doc_type, doc.full_text, doc.file_size, doc.uploaded_at), - ) - conn.commit() - conn.close() + with get_conn(db_path) as conn: + conn.execute( + "INSERT INTO context_documents(id, tenant_id, filename, doc_type, full_text, file_size, uploaded_at)" + " VALUES (?,?,?,?,?,?,?)", + (doc.id, tid, doc.filename, doc.doc_type, doc.full_text, doc.file_size, doc.uploaded_at), + ) + conn.commit() return doc def list_documents(db_path: Path) -> list[ContextDocument]: - conn = _connect(db_path) - rows = conn.execute( - "SELECT id, filename, doc_type, full_text, file_size, uploaded_at" - " FROM context_documents ORDER BY uploaded_at DESC" - ).fetchall() - conn.close() + tid = resolve_tenant_id() + with get_conn(db_path) as conn: + rows = conn.execute( + "SELECT id, filename, doc_type, full_text, file_size, uploaded_at" + " FROM context_documents WHERE (tenant_id=? OR tenant_id='') ORDER BY uploaded_at DESC", + (tid,), + ).fetchall() return [ ContextDocument( id=r["id"], filename=r["filename"], doc_type=r["doc_type"], @@ -130,8 +125,11 @@ def list_documents(db_path: Path) -> list[ContextDocument]: def delete_document(db_path: Path, doc_id: str) -> bool: - conn = _connect(db_path) - cursor = conn.execute("DELETE FROM context_documents WHERE id=?", (doc_id,)) - conn.commit() - conn.close() + tid = resolve_tenant_id() + with get_conn(db_path) as conn: + cursor = conn.execute( + "DELETE FROM context_documents WHERE id=? AND (tenant_id=? OR tenant_id='')", + (doc_id, tid), + ) + conn.commit() return cursor.rowcount > 0 diff --git a/app/db/__init__.py b/app/db/__init__.py new file mode 100644 index 0000000..5823b7b --- /dev/null +++ b/app/db/__init__.py @@ -0,0 +1,36 @@ +"""Turnstone database abstraction — unified SQLite / Postgres interface. + +Public API: + BACKEND — Backend.SQLITE or Backend.POSTGRES + get_conn(path) — context manager yielding a DbConn + resolve_tenant_id() — this node's tenant ID (env or hostname) + q(sql) — rewrite ? placeholders to %s for Postgres + frag — SQL fragment helpers (insert_or_ignore, source_group_expr, …) + ensure_schema — idempotent schema init + close_pool — call during shutdown when using Postgres +""" +from app.db.backend import BACKEND, Backend +from app.db.conn import DbConn, close_pool, get_conn +from app.db.dialect import frag, q +from app.db.schema import ( + ensure_context_schema, + ensure_incidents_schema, + ensure_schema, + migrate_incidents_to_dedicated_db, +) +from app.db.tenant import resolve_tenant_id + +__all__ = [ + "BACKEND", + "Backend", + "DbConn", + "close_pool", + "get_conn", + "frag", + "q", + "ensure_schema", + "ensure_context_schema", + "ensure_incidents_schema", + "migrate_incidents_to_dedicated_db", + "resolve_tenant_id", +] diff --git a/app/db/backend.py b/app/db/backend.py new file mode 100644 index 0000000..2e86839 --- /dev/null +++ b/app/db/backend.py @@ -0,0 +1,20 @@ +"""Backend detection — SQLITE (default) or POSTGRES based on DATABASE_URL.""" +from __future__ import annotations + +import os +from enum import Enum + + +class Backend(Enum): + SQLITE = "sqlite" + POSTGRES = "postgres" + + +def _detect() -> Backend: + url = os.environ.get("DATABASE_URL", "") + if url.startswith(("postgresql://", "postgres://", "postgresql+psycopg://")): + return Backend.POSTGRES + return Backend.SQLITE + + +BACKEND: Backend = _detect() diff --git a/app/db/conn.py b/app/db/conn.py new file mode 100644 index 0000000..30e0e8b --- /dev/null +++ b/app/db/conn.py @@ -0,0 +1,137 @@ +"""Uniform connection wrapper over sqlite3 and psycopg3. + +Usage: + with get_conn(db_path) as conn: + conn.execute("SELECT ...", (param,)) + conn.commit() + +For Postgres, db_path is ignored — all connections go through the shared pool. +The pool is initialized lazily on first use from DATABASE_URL. +""" +from __future__ import annotations + +import logging +import os +import sqlite3 +from contextlib import contextmanager +from pathlib import Path +from typing import Any, Generator + +from app.db.backend import BACKEND, Backend + +logger = logging.getLogger(__name__) + +_pool: Any = None # psycopg_pool.ConnectionPool, typed as Any to avoid import-time errors + + +class _NopCursor: + """Returned when a PRAGMA or other SQLite-only statement is skipped on Postgres.""" + rowcount = 0 + + def fetchall(self) -> list: + return [] + + def fetchone(self) -> None: + return None + + def __iter__(self): + return iter([]) + + +class DbConn: + """Wraps a raw sqlite3 or psycopg connection with a uniform execute API. + + Row access is always dict-like: + - SQLite: conn.row_factory = sqlite3.Row (supports row["col"] and row[0]) + - Postgres: row_factory = dict_row (returns plain dicts) + """ + + __slots__ = ("_c", "_backend") + + def __init__(self, raw: Any, backend: Backend) -> None: + self._c = raw + self._backend = backend + + def _prep(self, sql: str) -> str | None: + """Return None to skip (PRAGMA on Postgres), else return ready-to-execute SQL.""" + stripped = sql.strip() + if self._backend == Backend.POSTGRES and stripped.lower().startswith("pragma"): + return None + if self._backend == Backend.POSTGRES: + return stripped.replace("?", "%s") + return stripped + + def execute(self, sql: str, params: Any = ()) -> Any: + prepared = self._prep(sql) + if prepared is None: + return _NopCursor() + return self._c.execute(prepared, params) + + def executemany(self, sql: str, params_seq: Any) -> Any: + prepared = self._prep(sql) + if prepared is None: + return _NopCursor() + return self._c.executemany(prepared, params_seq) + + def commit(self) -> None: + self._c.commit() + + def close(self) -> None: + self._c.close() + + def __enter__(self) -> "DbConn": + return self + + def __exit__(self, *_: Any) -> None: + self.close() + + +def _get_pool() -> Any: + global _pool + if _pool is not None: + return _pool + try: + from psycopg_pool import ConnectionPool # type: ignore[import] + url = os.environ["DATABASE_URL"] + _pool = ConnectionPool(url, min_size=2, max_size=10, open=True) + logger.info("Postgres connection pool opened (DATABASE_URL set)") + return _pool + except ImportError as exc: + raise RuntimeError( + "psycopg[binary,pool] is required for Postgres backend. " + "Run: pip install 'psycopg[binary,pool]'" + ) from exc + except KeyError: + raise RuntimeError("DATABASE_URL must be set when using Postgres backend") from None + + +@contextmanager +def get_conn(db_path: Path | None = None) -> Generator[DbConn, None, None]: + """Yield a DbConn backed by sqlite3 (db_path required) or the Postgres pool.""" + if BACKEND == Backend.POSTGRES: + pool = _get_pool() + from psycopg.rows import dict_row # type: ignore[import] + with pool.connection() as raw: + raw.row_factory = dict_row + yield DbConn(raw, BACKEND) + else: + if db_path is None: + raise ValueError("db_path is required for SQLite backend") + raw = sqlite3.connect(str(db_path), timeout=90.0) + raw.row_factory = sqlite3.Row + try: + raw.execute("PRAGMA journal_mode=WAL") + raw.execute("PRAGMA busy_timeout=90000") + raw.execute("PRAGMA foreign_keys=ON") + yield DbConn(raw, BACKEND) + finally: + raw.close() + + +def close_pool() -> None: + """Close the Postgres connection pool — call during application shutdown.""" + global _pool + if _pool is not None: + _pool.close() + _pool = None + logger.info("Postgres connection pool closed") diff --git a/app/db/dialect.py b/app/db/dialect.py new file mode 100644 index 0000000..70f018a --- /dev/null +++ b/app/db/dialect.py @@ -0,0 +1,93 @@ +"""Per-backend SQL fragments and placeholder rewriting. + +All production SQL should be written with SQLite-style `?` placeholders. +Call q(sql) before passing to execute/executemany — it rewrites to %s for +Postgres and leaves SQLite queries untouched. +""" +from __future__ import annotations + +from app.db.backend import BACKEND, Backend + + +def q(sql: str) -> str: + """Rewrite ? placeholders to %s for Postgres; no-op for SQLite.""" + if BACKEND == Backend.POSTGRES: + return sql.replace("?", "%s") + return sql + + +class _Fragments: + """SQL fragments that differ between backends.""" + + @property + def insert_or_ignore(self) -> str: + return "INSERT" if BACKEND == Backend.POSTGRES else "INSERT OR IGNORE" + + @property + def on_conflict_ignore(self) -> str: + # Caller must substitute the column name(s) at use time when using Postgres. + # For log_entries: ON CONFLICT (tenant_id, id) DO NOTHING + # For generic use this property is a no-op sentinel; prefer insert_ignore_into(). + return "" + + def insert_ignore_entries(self) -> str: + """Full INSERT ... ON CONFLICT clause for log_entries.""" + if BACKEND == Backend.POSTGRES: + return "INSERT INTO log_entries" + return "INSERT OR IGNORE INTO log_entries" + + def entries_conflict_clause(self) -> str: + if BACKEND == Backend.POSTGRES: + return "ON CONFLICT (tenant_id, id) DO NOTHING" + return "" + + def fingerprint_upsert(self) -> str: + if BACKEND == Backend.POSTGRES: + return ( + "INSERT INTO glean_fingerprints (tenant_id, path, mtime, size, gleaned_at)" + " VALUES (%s, %s, %s, %s, %s)" + " ON CONFLICT (tenant_id, path)" + " DO UPDATE SET mtime=EXCLUDED.mtime, size=EXCLUDED.size, gleaned_at=EXCLUDED.gleaned_at" + ) + return ( + "INSERT OR REPLACE INTO glean_fingerprints (tenant_id, path, mtime, size, gleaned_at)" + " VALUES (?,?,?,?,?)" + ) + + def source_group_expr(self, col: str = "source_id") -> str: + """SQL expression that collapses prefix:host:unit → prefix:host stem.""" + if BACKEND == Backend.POSTGRES: + return f""" + CASE + WHEN array_length(string_to_array({col}, ':'), 1) >= 3 + THEN split_part({col}, ':', 1) || ':' || split_part({col}, ':', 2) + ELSE {col} + END + """ + return f""" + CASE + WHEN INSTR(SUBSTR({col}, INSTR({col}, ':')+1), ':') > 0 + THEN SUBSTR({col}, 1, + INSTR({col}, ':') + + INSTR(SUBSTR({col}, INSTR({col}, ':')+1), ':') + - 1) + ELSE {col} + END + """ + + def fts_match_clause(self) -> str: + """WHERE clause fragment for FTS query. Caller supplies the query param.""" + if BACKEND == Backend.POSTGRES: + return "text_tsv @@ websearch_to_tsquery('english', %s)" + return "log_fts MATCH ?" + + def fts_rank_expr(self) -> str: + """ORDER BY expression for FTS rank (best match first). Postgres needs the query twice.""" + if BACKEND == Backend.POSTGRES: + # ts_rank returns 0..1 where higher is better; pass the query again as param + return "ts_rank(text_tsv, websearch_to_tsquery('english', %s)) DESC" + # FTS5 rank is negative BM25; ASC = most-negative = best match + return "rank ASC" + + +frag = _Fragments() diff --git a/app/db/schema.py b/app/db/schema.py new file mode 100644 index 0000000..311a321 --- /dev/null +++ b/app/db/schema.py @@ -0,0 +1,522 @@ +"""Schema creation and idempotent migrations for all Turnstone databases. + +Three logical databases (main, context, incidents) map to: + - SQLite: three separate .db files (avoids write-lock contention) + - Postgres: three table-groups in one physical DB (row-level locking makes separation unnecessary) + +All ensure_* functions are idempotent: safe to call on every startup. +""" +from __future__ import annotations + +import logging +import sqlite3 +from pathlib import Path + +from app.db.backend import BACKEND, Backend +from app.db.conn import get_conn + +logger = logging.getLogger(__name__) + +# --------------------------------------------------------------------------- +# SQLite DDL — kept as executescript strings (SQLite only) +# --------------------------------------------------------------------------- + +_MAIN_SCHEMA_SQLITE = """ +CREATE TABLE IF NOT EXISTS log_entries ( + id TEXT NOT NULL, + tenant_id TEXT NOT NULL DEFAULT '', + source_id TEXT NOT NULL, + sequence INTEGER NOT NULL, + timestamp_raw TEXT, + timestamp_iso TEXT, + ingest_time TEXT NOT NULL, + severity TEXT, + repeat_count INTEGER DEFAULT 1, + out_of_order INTEGER DEFAULT 0, + matched_patterns TEXT DEFAULT '[]', + text TEXT NOT NULL, + anomaly_score REAL, + anomaly_label TEXT, + anomaly_scored_at TEXT, + ml_score REAL, + ml_label TEXT, + ml_scored_at TEXT, + PRIMARY KEY (tenant_id, id) +); +CREATE INDEX IF NOT EXISTS idx_source ON log_entries(source_id); +CREATE INDEX IF NOT EXISTS idx_tenant_src ON log_entries(tenant_id, source_id); +CREATE INDEX IF NOT EXISTS idx_timestamp ON log_entries(timestamp_iso); +CREATE INDEX IF NOT EXISTS idx_ts_repeat ON log_entries(timestamp_iso, repeat_count); +CREATE INDEX IF NOT EXISTS idx_severity ON log_entries(tenant_id, severity); +CREATE INDEX IF NOT EXISTS idx_patterns ON log_entries(matched_patterns); +CREATE INDEX IF NOT EXISTS idx_anomaly ON log_entries(tenant_id, anomaly_score); +CREATE INDEX IF NOT EXISTS idx_ml_scored ON log_entries(tenant_id, ml_scored_at); + +CREATE TABLE IF NOT EXISTS detections ( + id TEXT PRIMARY KEY, + tenant_id TEXT NOT NULL DEFAULT '', + entry_id TEXT NOT NULL, + source_id TEXT NOT NULL, + anomaly_label TEXT NOT NULL, + anomaly_score REAL NOT NULL, + severity TEXT NOT NULL, + text TEXT NOT NULL, + timestamp_iso TEXT, + detected_at TEXT NOT NULL, + acknowledged INTEGER NOT NULL DEFAULT 0, + acknowledged_at TEXT, + notes TEXT NOT NULL DEFAULT '', + scorer TEXT NOT NULL DEFAULT 'anomaly' +); +CREATE INDEX IF NOT EXISTS idx_detections_tenant ON detections(tenant_id, detected_at); +CREATE INDEX IF NOT EXISTS idx_detections_ack ON detections(acknowledged); +CREATE INDEX IF NOT EXISTS idx_detections_label ON detections(anomaly_label); +CREATE INDEX IF NOT EXISTS idx_detections_entry ON detections(entry_id); +CREATE INDEX IF NOT EXISTS idx_detections_scorer ON detections(scorer); + +CREATE TABLE IF NOT EXISTS glean_fingerprints ( + tenant_id TEXT NOT NULL DEFAULT '', + path TEXT NOT NULL, + mtime REAL NOT NULL, + size INTEGER NOT NULL, + gleaned_at TEXT NOT NULL, + PRIMARY KEY (tenant_id, path) +); + +CREATE TABLE IF NOT EXISTS incidents ( + id TEXT PRIMARY KEY, + tenant_id TEXT NOT NULL DEFAULT '', + label TEXT NOT NULL, + issue_type TEXT NOT NULL DEFAULT '', + started_at TEXT, + ended_at TEXT, + notes TEXT NOT NULL DEFAULT '', + created_at TEXT NOT NULL, + severity TEXT NOT NULL DEFAULT 'medium' +); +CREATE INDEX IF NOT EXISTS idx_incidents_time ON incidents(started_at, ended_at); +CREATE INDEX IF NOT EXISTS idx_incidents_tenant ON incidents(tenant_id); + +CREATE TABLE IF NOT EXISTS received_bundles ( + id TEXT PRIMARY KEY, + tenant_id TEXT NOT NULL DEFAULT '', + source_host TEXT NOT NULL, + issue_type TEXT NOT NULL DEFAULT '', + label TEXT NOT NULL, + severity TEXT NOT NULL DEFAULT 'medium', + started_at TEXT, + bundled_at TEXT NOT NULL, + entry_count INTEGER NOT NULL DEFAULT 0, + bundle_json TEXT NOT NULL +); +CREATE INDEX IF NOT EXISTS idx_bundles_bundled ON received_bundles(bundled_at); +CREATE INDEX IF NOT EXISTS idx_bundles_type ON received_bundles(issue_type); + +CREATE TABLE IF NOT EXISTS sent_bundles ( + id TEXT PRIMARY KEY, + tenant_id TEXT NOT NULL DEFAULT '', + incident_id TEXT NOT NULL, + exported_at TEXT NOT NULL, + sanitized INTEGER NOT NULL DEFAULT 0, + entry_count INTEGER NOT NULL DEFAULT 0, + bundle_json TEXT NOT NULL +); +CREATE INDEX IF NOT EXISTS idx_sent_bundles_incident ON sent_bundles(incident_id); +CREATE INDEX IF NOT EXISTS idx_sent_bundles_time ON sent_bundles(exported_at); + +CREATE TABLE IF NOT EXISTS blocklist_candidates ( + id TEXT PRIMARY KEY, + tenant_id TEXT NOT NULL DEFAULT '', + domain_or_ip TEXT NOT NULL, + source_device_ip TEXT, + source_device_name TEXT, + first_seen TEXT NOT NULL, + last_seen TEXT NOT NULL, + hit_count INTEGER DEFAULT 1, + status TEXT DEFAULT 'pending', + pushed_at TEXT, + log_evidence TEXT DEFAULT '[]', + matched_rule TEXT, + llm_score REAL, + llm_reason TEXT +); +CREATE INDEX IF NOT EXISTS idx_blocklist_device ON blocklist_candidates(source_device_ip); +CREATE INDEX IF NOT EXISTS idx_blocklist_status ON blocklist_candidates(status); +CREATE INDEX IF NOT EXISTS idx_blocklist_domain ON blocklist_candidates(domain_or_ip); +CREATE INDEX IF NOT EXISTS idx_blocklist_tenant ON blocklist_candidates(tenant_id); +""" + +_CONTEXT_SCHEMA_SQLITE = """ +CREATE TABLE IF NOT EXISTS context_facts ( + id TEXT PRIMARY KEY, + tenant_id TEXT NOT NULL DEFAULT '', + category TEXT NOT NULL, + key TEXT NOT NULL, + value TEXT NOT NULL, + source TEXT, + created_at TEXT NOT NULL +); +CREATE INDEX IF NOT EXISTS idx_facts_category ON context_facts(category); +CREATE INDEX IF NOT EXISTS idx_facts_key ON context_facts(key); +CREATE INDEX IF NOT EXISTS idx_facts_tenant ON context_facts(tenant_id); + +CREATE TABLE IF NOT EXISTS context_documents ( + id TEXT PRIMARY KEY, + tenant_id TEXT NOT NULL DEFAULT '', + filename TEXT NOT NULL, + doc_type TEXT NOT NULL, + full_text TEXT NOT NULL, + file_size INTEGER, + uploaded_at TEXT NOT NULL +); +CREATE INDEX IF NOT EXISTS idx_docs_tenant ON context_documents(tenant_id); + +CREATE TABLE IF NOT EXISTS context_chunks ( + id TEXT PRIMARY KEY, + tenant_id TEXT NOT NULL DEFAULT '', + document_id TEXT NOT NULL REFERENCES context_documents(id) ON DELETE CASCADE, + chunk_index INTEGER NOT NULL, + text TEXT NOT NULL, + embedding BLOB +); +CREATE INDEX IF NOT EXISTS idx_chunks_doc ON context_chunks(document_id); +CREATE INDEX IF NOT EXISTS idx_chunks_tenant ON context_chunks(tenant_id); +""" + + +# --------------------------------------------------------------------------- +# Postgres DDL — executed statement-by-statement +# --------------------------------------------------------------------------- + +_MAIN_SCHEMA_PG_STMTS = [ + """ + CREATE TABLE IF NOT EXISTS log_entries ( + id TEXT NOT NULL, + tenant_id TEXT NOT NULL DEFAULT '', + source_id TEXT NOT NULL, + sequence INTEGER NOT NULL, + timestamp_raw TEXT, + timestamp_iso TEXT, + ingest_time TEXT NOT NULL, + severity TEXT, + repeat_count INTEGER DEFAULT 1, + out_of_order INTEGER DEFAULT 0, + matched_patterns TEXT DEFAULT '[]', + text TEXT NOT NULL, + text_tsv tsvector, + anomaly_score DOUBLE PRECISION, + anomaly_label TEXT, + anomaly_scored_at TEXT, + ml_score DOUBLE PRECISION, + ml_label TEXT, + ml_scored_at TEXT, + PRIMARY KEY (tenant_id, id) + ) + """, + "CREATE INDEX IF NOT EXISTS idx_tenant_src ON log_entries(tenant_id, source_id)", + "CREATE INDEX IF NOT EXISTS idx_timestamp ON log_entries(timestamp_iso)", + "CREATE INDEX IF NOT EXISTS idx_severity ON log_entries(tenant_id, severity)", + "CREATE INDEX IF NOT EXISTS idx_patterns ON log_entries(matched_patterns)", + "CREATE INDEX IF NOT EXISTS idx_fts_gin ON log_entries USING GIN(text_tsv)", + "CREATE INDEX IF NOT EXISTS idx_anomaly ON log_entries(tenant_id, anomaly_score)", + "CREATE INDEX IF NOT EXISTS idx_ml_scored ON log_entries(tenant_id, ml_scored_at)", + """ + CREATE TABLE IF NOT EXISTS detections ( + id TEXT PRIMARY KEY, + tenant_id TEXT NOT NULL DEFAULT '', + entry_id TEXT NOT NULL, + source_id TEXT NOT NULL, + anomaly_label TEXT NOT NULL, + anomaly_score DOUBLE PRECISION NOT NULL, + severity TEXT NOT NULL, + text TEXT NOT NULL, + timestamp_iso TEXT, + detected_at TEXT NOT NULL, + acknowledged INTEGER NOT NULL DEFAULT 0, + acknowledged_at TEXT, + notes TEXT NOT NULL DEFAULT '', + scorer TEXT NOT NULL DEFAULT 'anomaly' + ) + """, + "CREATE INDEX IF NOT EXISTS idx_detections_tenant ON detections(tenant_id, detected_at)", + "CREATE INDEX IF NOT EXISTS idx_detections_ack ON detections(acknowledged)", + "CREATE INDEX IF NOT EXISTS idx_detections_label ON detections(anomaly_label)", + "CREATE INDEX IF NOT EXISTS idx_detections_entry ON detections(entry_id)", + "CREATE INDEX IF NOT EXISTS idx_detections_scorer ON detections(scorer)", + """ + CREATE OR REPLACE FUNCTION _ts_update_text_tsv() RETURNS trigger AS $$ + BEGIN + NEW.text_tsv := to_tsvector('english', COALESCE(NEW.text, '')); + RETURN NEW; + END; + $$ LANGUAGE plpgsql + """, + """ + DO $$ BEGIN + IF NOT EXISTS ( + SELECT 1 FROM pg_trigger WHERE tgname = 'trig_log_entries_tsv' + ) THEN + CREATE TRIGGER trig_log_entries_tsv + BEFORE INSERT OR UPDATE OF text ON log_entries + FOR EACH ROW EXECUTE FUNCTION _ts_update_text_tsv(); + END IF; + END $$ + """, + """ + CREATE TABLE IF NOT EXISTS glean_fingerprints ( + tenant_id TEXT NOT NULL DEFAULT '', + path TEXT NOT NULL, + mtime DOUBLE PRECISION NOT NULL, + size BIGINT NOT NULL, + gleaned_at TEXT NOT NULL, + PRIMARY KEY (tenant_id, path) + ) + """, + """ + CREATE TABLE IF NOT EXISTS incidents ( + id TEXT PRIMARY KEY, + tenant_id TEXT NOT NULL DEFAULT '', + label TEXT NOT NULL, + issue_type TEXT NOT NULL DEFAULT '', + started_at TEXT, + ended_at TEXT, + notes TEXT NOT NULL DEFAULT '', + created_at TEXT NOT NULL, + severity TEXT NOT NULL DEFAULT 'medium' + ) + """, + "CREATE INDEX IF NOT EXISTS idx_incidents_time ON incidents(started_at, ended_at)", + "CREATE INDEX IF NOT EXISTS idx_incidents_tenant ON incidents(tenant_id)", + """ + CREATE TABLE IF NOT EXISTS received_bundles ( + id TEXT PRIMARY KEY, + tenant_id TEXT NOT NULL DEFAULT '', + source_host TEXT NOT NULL, + issue_type TEXT NOT NULL DEFAULT '', + label TEXT NOT NULL, + severity TEXT NOT NULL DEFAULT 'medium', + started_at TEXT, + bundled_at TEXT NOT NULL, + entry_count INTEGER NOT NULL DEFAULT 0, + bundle_json TEXT NOT NULL + ) + """, + "CREATE INDEX IF NOT EXISTS idx_bundles_bundled ON received_bundles(bundled_at)", + "CREATE INDEX IF NOT EXISTS idx_bundles_type ON received_bundles(issue_type)", + """ + CREATE TABLE IF NOT EXISTS sent_bundles ( + id TEXT PRIMARY KEY, + tenant_id TEXT NOT NULL DEFAULT '', + incident_id TEXT NOT NULL, + exported_at TEXT NOT NULL, + sanitized INTEGER NOT NULL DEFAULT 0, + entry_count INTEGER NOT NULL DEFAULT 0, + bundle_json TEXT NOT NULL + ) + """, + "CREATE INDEX IF NOT EXISTS idx_sent_bundles_incident ON sent_bundles(incident_id)", + "CREATE INDEX IF NOT EXISTS idx_sent_bundles_time ON sent_bundles(exported_at)", + """ + CREATE TABLE IF NOT EXISTS blocklist_candidates ( + id TEXT PRIMARY KEY, + tenant_id TEXT NOT NULL DEFAULT '', + domain_or_ip TEXT NOT NULL, + source_device_ip TEXT, + source_device_name TEXT, + first_seen TEXT NOT NULL, + last_seen TEXT NOT NULL, + hit_count INTEGER DEFAULT 1, + status TEXT DEFAULT 'pending', + pushed_at TEXT, + log_evidence TEXT DEFAULT '[]', + matched_rule TEXT, + llm_score DOUBLE PRECISION, + llm_reason TEXT + ) + """, + "CREATE INDEX IF NOT EXISTS idx_blocklist_device ON blocklist_candidates(source_device_ip)", + "CREATE INDEX IF NOT EXISTS idx_blocklist_status ON blocklist_candidates(status)", + "CREATE INDEX IF NOT EXISTS idx_blocklist_domain ON blocklist_candidates(domain_or_ip)", + "CREATE INDEX IF NOT EXISTS idx_blocklist_tenant ON blocklist_candidates(tenant_id)", +] + +_CONTEXT_SCHEMA_PG_STMTS = [ + """ + CREATE TABLE IF NOT EXISTS context_facts ( + id TEXT PRIMARY KEY, + tenant_id TEXT NOT NULL DEFAULT '', + category TEXT NOT NULL, + key TEXT NOT NULL, + value TEXT NOT NULL, + source TEXT, + created_at TEXT NOT NULL + ) + """, + "CREATE INDEX IF NOT EXISTS idx_facts_category ON context_facts(category)", + "CREATE INDEX IF NOT EXISTS idx_facts_key ON context_facts(key)", + "CREATE INDEX IF NOT EXISTS idx_facts_tenant ON context_facts(tenant_id)", + """ + CREATE TABLE IF NOT EXISTS context_documents ( + id TEXT PRIMARY KEY, + tenant_id TEXT NOT NULL DEFAULT '', + filename TEXT NOT NULL, + doc_type TEXT NOT NULL, + full_text TEXT NOT NULL, + file_size BIGINT, + uploaded_at TEXT NOT NULL + ) + """, + "CREATE INDEX IF NOT EXISTS idx_docs_tenant ON context_documents(tenant_id)", + """ + CREATE TABLE IF NOT EXISTS context_chunks ( + id TEXT PRIMARY KEY, + tenant_id TEXT NOT NULL DEFAULT '', + document_id TEXT NOT NULL REFERENCES context_documents(id) ON DELETE CASCADE, + chunk_index INTEGER NOT NULL, + text TEXT NOT NULL, + embedding BYTEA + ) + """, + "CREATE INDEX IF NOT EXISTS idx_chunks_doc ON context_chunks(document_id)", + "CREATE INDEX IF NOT EXISTS idx_chunks_tenant ON context_chunks(tenant_id)", +] + + +# --------------------------------------------------------------------------- +# SQLite additive column migrations — applied after CREATE TABLE on every boot +# --------------------------------------------------------------------------- + +_MAIN_MIGRATIONS_SQLITE = [ + "ALTER TABLE log_entries ADD COLUMN tenant_id TEXT NOT NULL DEFAULT ''", + "ALTER TABLE incidents ADD COLUMN issue_type TEXT NOT NULL DEFAULT ''", + "ALTER TABLE incidents ADD COLUMN tenant_id TEXT NOT NULL DEFAULT ''", + "ALTER TABLE received_bundles ADD COLUMN tenant_id TEXT NOT NULL DEFAULT ''", + "ALTER TABLE sent_bundles ADD COLUMN tenant_id TEXT NOT NULL DEFAULT ''", + "ALTER TABLE blocklist_candidates ADD COLUMN tenant_id TEXT NOT NULL DEFAULT ''", + "ALTER TABLE glean_fingerprints ADD COLUMN tenant_id TEXT NOT NULL DEFAULT ''", + "ALTER TABLE glean_fingerprints ADD COLUMN mtime REAL", + "ALTER TABLE glean_fingerprints ADD COLUMN size INTEGER", + "ALTER TABLE glean_fingerprints ADD COLUMN gleaned_at TEXT", + "ALTER TABLE log_entries ADD COLUMN anomaly_score REAL", + "ALTER TABLE log_entries ADD COLUMN anomaly_label TEXT", + "ALTER TABLE log_entries ADD COLUMN anomaly_scored_at TEXT", + "ALTER TABLE log_entries ADD COLUMN ml_score REAL", + "ALTER TABLE log_entries ADD COLUMN ml_label TEXT", + "ALTER TABLE log_entries ADD COLUMN ml_scored_at TEXT", + "ALTER TABLE detections ADD COLUMN scorer TEXT NOT NULL DEFAULT 'anomaly'", +] + +_CONTEXT_MIGRATIONS_SQLITE = [ + "ALTER TABLE context_facts ADD COLUMN tenant_id TEXT NOT NULL DEFAULT ''", + "ALTER TABLE context_documents ADD COLUMN tenant_id TEXT NOT NULL DEFAULT ''", + "ALTER TABLE context_chunks ADD COLUMN tenant_id TEXT NOT NULL DEFAULT ''", +] + + +def _run_sqlite_migrations(conn: sqlite3.Connection, stmts: list[str]) -> None: + for stmt in stmts: + try: + conn.execute(stmt) + except sqlite3.OperationalError: + pass # column already exists or table not present yet — both are fine + + +def _run_pg_stmts(stmts: list[str]) -> None: + """Execute Postgres DDL statements — each in its own transaction for IF NOT EXISTS safety.""" + from psycopg import connect as pg_connect # type: ignore[import] + import os + url = os.environ["DATABASE_URL"] + with pg_connect(url, autocommit=True) as conn: + for stmt in stmts: + stripped = stmt.strip() + if stripped: + conn.execute(stripped) + + +# --------------------------------------------------------------------------- +# Public API +# --------------------------------------------------------------------------- + +def ensure_schema(db_path: Path) -> None: + """Ensure main log/incidents/blocklist tables exist. Idempotent.""" + if BACKEND == Backend.POSTGRES: + _run_pg_stmts(_MAIN_SCHEMA_PG_STMTS) + logger.debug("Postgres main schema verified") + return + + conn = sqlite3.connect(str(db_path), timeout=30.0) + conn.execute("PRAGMA journal_mode=WAL") + # Migrations first: add tenant_id to existing tables BEFORE index creation touches it + _run_sqlite_migrations(conn, _MAIN_MIGRATIONS_SQLITE) + conn.commit() + conn.executescript(_MAIN_SCHEMA_SQLITE) + conn.close() + logger.debug("SQLite main schema verified at %s", db_path) + + +def ensure_context_schema(db_path: Path) -> None: + """Ensure context KB tables exist. Idempotent.""" + if BACKEND == Backend.POSTGRES: + _run_pg_stmts(_CONTEXT_SCHEMA_PG_STMTS) + logger.debug("Postgres context schema verified") + return + + conn = sqlite3.connect(str(db_path), timeout=30.0) + conn.execute("PRAGMA journal_mode=WAL") + conn.execute("PRAGMA foreign_keys=ON") + _run_sqlite_migrations(conn, _CONTEXT_MIGRATIONS_SQLITE) + conn.commit() + conn.executescript(_CONTEXT_SCHEMA_SQLITE) + conn.close() + logger.debug("SQLite context schema verified at %s", db_path) + + +def migrate_incidents_to_dedicated_db(main_db: Path, incidents_db: Path) -> int: + """One-shot migration: copy incidents/bundles rows from main DB to incidents DB. + + Safe to call on every startup — rows already in incidents_db are skipped. + No-op for Postgres (single DB, no migration needed). + """ + if BACKEND == Backend.POSTGRES: + return 0 + + src = sqlite3.connect(str(main_db), timeout=30.0) + src.row_factory = sqlite3.Row + dst = sqlite3.connect(str(incidents_db), timeout=30.0) + migrated = 0 + for table in ("incidents", "received_bundles", "sent_bundles"): + try: + rows = src.execute(f"SELECT * FROM {table}").fetchall() # noqa: S608 + except sqlite3.OperationalError: + continue + if not rows: + continue + cols = ", ".join(rows[0].keys()) + placeholders = ", ".join("?" * len(rows[0].keys())) + dst.executemany( + f"INSERT OR IGNORE INTO {table} ({cols}) VALUES ({placeholders})", # noqa: S608 + [tuple(r) for r in rows], + ) + migrated += len(rows) + dst.commit() + src.close() + dst.close() + return migrated + + +def ensure_incidents_schema(db_path: Path) -> None: + """Ensure incidents/bundles tables exist. Idempotent. + + For Postgres, incidents live in the same DB as log_entries (already created by + ensure_schema), so this is a no-op — the tables were created above. + """ + if BACKEND == Backend.POSTGRES: + return + + conn = sqlite3.connect(str(db_path), timeout=30.0) + conn.execute("PRAGMA journal_mode=WAL") + _run_sqlite_migrations(conn, _MAIN_MIGRATIONS_SQLITE) + conn.commit() + conn.executescript(_MAIN_SCHEMA_SQLITE) + conn.close() + logger.debug("SQLite incidents schema verified at %s", db_path) diff --git a/app/db/tenant.py b/app/db/tenant.py new file mode 100644 index 0000000..5d2542e --- /dev/null +++ b/app/db/tenant.py @@ -0,0 +1,12 @@ +"""Tenant ID resolution — TURNSTONE_TENANT_ID env var, hostname fallback.""" +from __future__ import annotations + +import os +import socket +from functools import lru_cache + + +@lru_cache(maxsize=1) +def resolve_tenant_id() -> str: + """Return this node's tenant ID. Result is cached after first call.""" + return os.environ.get("TURNSTONE_TENANT_ID") or socket.gethostname() diff --git a/app/glean/doc_upload.py b/app/glean/doc_upload.py index c2d4d9a..0cfd604 100644 --- a/app/glean/doc_upload.py +++ b/app/glean/doc_upload.py @@ -1,18 +1,19 @@ """Upload adapter: processes file bytes and writes to context store — MIT licensed.""" from __future__ import annotations -import sqlite3 import uuid from pathlib import Path from typing import Any from app.context.chunker import process_upload from app.context.store import add_document, add_fact +from app.db import get_conn, resolve_tenant_id def glean_upload(db_path: Path, filename: str, content: bytes) -> dict[str, Any]: """Process an uploaded file and write to context store. Returns result summary.""" doc_type, facts, chunks = process_upload(filename, content) + tid = resolve_tenant_id() doc = add_document( db_path, @@ -25,15 +26,13 @@ def glean_upload(db_path: Path, filename: str, content: bytes) -> dict[str, Any] for fact in facts: add_fact(db_path, fact.category, fact.key, fact.value, source="upload") - conn = sqlite3.connect(str(db_path), timeout=30.0) - conn.execute("PRAGMA journal_mode=WAL") - for i, chunk_text in enumerate(chunks): - conn.execute( - "INSERT INTO context_chunks(id, document_id, chunk_index, text) VALUES (?,?,?,?)", - (str(uuid.uuid4()), doc.id, i, chunk_text), - ) - conn.commit() - conn.close() + with get_conn(db_path) as conn: + for i, chunk_text in enumerate(chunks): + conn.execute( + "INSERT INTO context_chunks(id, tenant_id, document_id, chunk_index, text) VALUES (?,?,?,?,?)", + (str(uuid.uuid4()), tid, doc.id, i, chunk_text), + ) + conn.commit() return { "document_id": doc.id, diff --git a/app/glean/pipeline.py b/app/glean/pipeline.py index 38bd0f1..d6a99a6 100644 --- a/app/glean/pipeline.py +++ b/app/glean/pipeline.py @@ -1,12 +1,24 @@ -"""Glean pipeline: auto-detect format, parse, write to SQLite.""" +"""Glean pipeline: auto-detect format, parse, write to SQLite or Postgres.""" from __future__ import annotations import json import logging import re -import sqlite3 +import sqlite3 # still used in migrate_incidents_to_dedicated_db (SQLite-only migration) from pathlib import Path -from typing import Iterator +from typing import Any, Iterator + +from app.db import ( + frag, + get_conn, + resolve_tenant_id, +) +from app.db.schema import ( + ensure_context_schema, + ensure_incidents_schema, + ensure_schema, + migrate_incidents_to_dedicated_db, +) import yaml @@ -169,127 +181,13 @@ CREATE INDEX IF NOT EXISTS idx_chunks_doc ON context_chunks(document_id); """ -def ensure_schema(db_path: Path) -> None: - """Create all tables and apply additive migrations. Safe to call on every startup.""" - conn = sqlite3.connect(str(db_path), timeout=30.0) - conn.execute("PRAGMA journal_mode=WAL") - conn.executescript(_SCHEMA) - # Additive column migrations — ALTER TABLE silently skips if column exists - for stmt in [ - "ALTER TABLE incidents ADD COLUMN issue_type TEXT NOT NULL DEFAULT ''", - ]: - try: - conn.execute(stmt) - except sqlite3.OperationalError: - pass - conn.commit() - conn.close() +# ensure_schema / ensure_context_schema / ensure_incidents_schema / migrate_incidents_to_dedicated_db +# are now implemented in app/db/schema.py and re-exported via app/db/__init__.py. +# The imports at the top of this file bring them in; these names are kept as module-level +# symbols so existing callers (rest.py, tests) still find them here without changes. -def ensure_context_schema(db_path: Path) -> None: - """Create context KB tables in a dedicated database file. - - Using a separate file from the main log DB means context fact writes never - contend with the high-throughput glean scheduler, which can hold the main - DB write lock for seconds at a time when flushing large journal batches. - """ - conn = sqlite3.connect(str(db_path), timeout=30.0) - conn.execute("PRAGMA journal_mode=WAL") - conn.execute("PRAGMA foreign_keys=ON") - conn.executescript(_CONTEXT_SCHEMA) - conn.commit() - conn.close() - - -_INCIDENTS_SCHEMA = """ -CREATE TABLE IF NOT EXISTS incidents ( - id TEXT PRIMARY KEY, - label TEXT NOT NULL, - issue_type TEXT NOT NULL DEFAULT '', - started_at TEXT, - ended_at TEXT, - notes TEXT NOT NULL DEFAULT '', - created_at TEXT NOT NULL, - severity TEXT NOT NULL DEFAULT 'medium' -); -CREATE INDEX IF NOT EXISTS idx_incidents_time ON incidents(started_at, ended_at); - -CREATE TABLE IF NOT EXISTS received_bundles ( - id TEXT PRIMARY KEY, - source_host TEXT NOT NULL, - issue_type TEXT NOT NULL DEFAULT '', - label TEXT NOT NULL, - severity TEXT NOT NULL DEFAULT 'medium', - started_at TEXT, - bundled_at TEXT NOT NULL, - entry_count INTEGER NOT NULL DEFAULT 0, - bundle_json TEXT NOT NULL -); -CREATE INDEX IF NOT EXISTS idx_bundles_bundled ON received_bundles(bundled_at); -CREATE INDEX IF NOT EXISTS idx_bundles_type ON received_bundles(issue_type); - -CREATE TABLE IF NOT EXISTS sent_bundles ( - id TEXT PRIMARY KEY, - incident_id TEXT NOT NULL, - exported_at TEXT NOT NULL, - sanitized INTEGER NOT NULL DEFAULT 0, - entry_count INTEGER NOT NULL DEFAULT 0, - bundle_json TEXT NOT NULL -); -CREATE INDEX IF NOT EXISTS idx_sent_bundles_incident ON sent_bundles(incident_id); -CREATE INDEX IF NOT EXISTS idx_sent_bundles_time ON sent_bundles(exported_at); -""" - - -def ensure_incidents_schema(db_path: Path) -> None: - """Create incidents tables in a dedicated database file. - - Using a separate file from the main log DB means incident writes never - contend with the FTS5 bulk-insert write lock held by the glean scheduler. - Mirrors the context_facts split (CONTEXT_DB_PATH / turnstone-context.db). - """ - conn = sqlite3.connect(str(db_path), timeout=30.0) - conn.execute("PRAGMA journal_mode=WAL") - conn.executescript(_INCIDENTS_SCHEMA) - for stmt in [ - "ALTER TABLE incidents ADD COLUMN issue_type TEXT NOT NULL DEFAULT ''", - ]: - try: - conn.execute(stmt) - except sqlite3.OperationalError: - pass - conn.commit() - conn.close() - - -def migrate_incidents_to_dedicated_db(main_db: Path, incidents_db: Path) -> int: - """One-shot migration: copy incidents/bundles rows from main DB to incidents DB. - - Safe to call on every startup — rows already present in incidents_db are - skipped via INSERT OR IGNORE. Returns the count of rows migrated. - """ - src = sqlite3.connect(str(main_db), timeout=30.0) - src.row_factory = sqlite3.Row - dst = sqlite3.connect(str(incidents_db), timeout=30.0) - migrated = 0 - for table in ("incidents", "received_bundles", "sent_bundles"): - try: - rows = src.execute(f"SELECT * FROM {table}").fetchall() # noqa: S608 - except sqlite3.OperationalError: - continue - if not rows: - continue - cols = ", ".join(rows[0].keys()) - placeholders = ", ".join("?" * len(rows[0].keys())) - dst.executemany( - f"INSERT OR IGNORE INTO {table} ({cols}) VALUES ({placeholders})", # noqa: S608 - [tuple(r) for r in rows], - ) - migrated += len(rows) - dst.commit() - src.close() - dst.close() - return migrated +# _INCIDENTS_SCHEMA and its ensure_/migrate_ functions moved to app/db/schema.py def _fingerprint(path: Path) -> tuple[float, int]: @@ -298,36 +196,28 @@ def _fingerprint(path: Path) -> tuple[float, int]: return st.st_mtime, st.st_size -def _fp_unchanged(conn: sqlite3.Connection, path: Path, mtime: float, size: int) -> bool: - """Return True only when the stored fingerprint exactly matches (mtime, size). - - A smaller size (log rotation) or a larger size (new lines appended) both - return False so the caller re-gleams the file. - """ +def _fp_unchanged(conn: Any, path: Path, mtime: float, size: int) -> bool: + """Return True only when the stored fingerprint exactly matches (mtime, size).""" + tid = resolve_tenant_id() row = conn.execute( - "SELECT mtime, size FROM glean_fingerprints WHERE path = ?", - (str(path),), + "SELECT mtime, size FROM glean_fingerprints WHERE path = ? AND (tenant_id = ? OR tenant_id = '')", + (str(path), tid), ).fetchone() if row is None: return False - return row[0] == mtime and row[1] == size + return row["mtime"] == mtime and row["size"] == size def _save_fingerprint( - conn: sqlite3.Connection, + conn: Any, path: Path, mtime: float, size: int, gleaned_at: str, ) -> None: """Upsert the fingerprint for *path* after a successful glean.""" - conn.execute( - """ - INSERT OR REPLACE INTO glean_fingerprints (path, mtime, size, gleaned_at) - VALUES (?, ?, ?, ?) - """, - (str(path), mtime, size, gleaned_at), - ) + tid = resolve_tenant_id() + conn.execute(frag.fingerprint_upsert(), (tid, str(path), mtime, size, gleaned_at)) def _detect_format(first_line: str) -> str: @@ -400,18 +290,22 @@ def _parse_file( yield from plaintext.parse(all_lines(), source_id, compiled, ingest_time) -def _write_batch(conn: sqlite3.Connection, batch: list[RetrievedEntry]) -> None: - conn.executemany( - """ - INSERT OR IGNORE INTO log_entries - (id, source_id, sequence, timestamp_raw, timestamp_iso, +def _write_batch(conn: Any, batch: list[RetrievedEntry]) -> None: + tid = resolve_tenant_id() + conflict = frag.entries_conflict_clause() + sql = f""" + {frag.insert_ignore_entries()} + (tenant_id, id, source_id, sequence, timestamp_raw, timestamp_iso, ingest_time, severity, repeat_count, out_of_order, matched_patterns, text) - VALUES (?,?,?,?,?,?,?,?,?,?,?) - """, + VALUES (?,?,?,?,?,?,?,?,?,?,?,?) + {conflict} + """ + conn.executemany( + sql, [ ( - e.entry_id, e.source_id, e.sequence, + tid, e.entry_id, e.source_id, e.sequence, e.timestamp_raw, e.timestamp_iso, e.ingest_time, e.severity, e.repeat_count, int(e.out_of_order), json.dumps(list(e.matched_patterns)), e.text, @@ -435,46 +329,41 @@ def _glean_files( ingest_time = now_iso() source_id_map = source_id_map or {} - conn = sqlite3.connect(str(db_path), timeout=30.0) - conn.execute("PRAGMA journal_mode=WAL") - conn.executescript(_SCHEMA) - conn.commit() + ensure_schema(db_path) - stats: dict[str, int] = {} - skipped: list[str] = [] + with get_conn(db_path) as conn: + stats: dict[str, int] = {} + skipped: list[str] = [] - for log_file in files: - source_id = source_id_map.get(log_file, log_file.stem) + for log_file in files: + source_id = source_id_map.get(log_file, log_file.stem) - # Fingerprint check — skip files whose mtime+size haven't changed. - mtime, size = _fingerprint(log_file) - if not force and _fp_unchanged(conn, log_file, mtime, size): - logger.debug("Skipping unchanged file: %s", log_file.name) - skipped.append(log_file.name) - stats[source_id] = stats.get(source_id, 0) - continue + mtime, size = _fingerprint(log_file) + if not force and _fp_unchanged(conn, log_file, mtime, size): + logger.debug("Skipping unchanged file: %s", log_file.name) + skipped.append(log_file.name) + stats[source_id] = stats.get(source_id, 0) + continue - count = 0 - batch: list[RetrievedEntry] = [] - for entry in _parse_file(log_file, compiled, ingest_time, source_id=source_id): - batch.append(entry) - if len(batch) >= batch_size: + count = 0 + batch: list[RetrievedEntry] = [] + for entry in _parse_file(log_file, compiled, ingest_time, source_id=source_id): + batch.append(entry) + if len(batch) >= batch_size: + _write_batch(conn, batch) + conn.commit() + count += len(batch) + batch.clear() + if batch: _write_batch(conn, batch) conn.commit() count += len(batch) - batch.clear() - if batch: - _write_batch(conn, batch) + + _save_fingerprint(conn, log_file, mtime, size, ingest_time) conn.commit() - count += len(batch) - _save_fingerprint(conn, log_file, mtime, size, ingest_time) - conn.commit() - - stats[source_id] = stats.get(source_id, 0) + count - logger.info("Gleaned %d entries from %s (source: %s)", count, log_file.name, source_id) - - conn.close() + stats[source_id] = stats.get(source_id, 0) + count + logger.info("Gleaned %d entries from %s (source: %s)", count, log_file.name, source_id) if skipped: logger.info("Skipped %d unchanged file(s): %s", len(skipped), ", ".join(skipped)) @@ -493,7 +382,7 @@ def _stream_and_write( source_id: str, compiled: list[tuple[LogPattern, object]], ingest_time: str, - conn: sqlite3.Connection, + conn: Any, batch_size: int, ) -> int: """Stream *cmd* output through *parser* and write entries to *conn*. @@ -525,7 +414,7 @@ def _glean_ssh_source( src: dict, # type: ignore[type-arg] compiled: list[tuple[LogPattern, object]], ingest_time: str, - conn: sqlite3.Connection, + conn: Any, batch_size: int, ) -> dict[str, int]: """Open one SSHTransport connection for *src* and glean all its glean items. @@ -618,15 +507,9 @@ def glean_ssh_source( compiled = _compile(load_patterns(effective_pattern_file)) ingest_time = now_iso() - conn = sqlite3.connect(str(db_path), timeout=30.0) - conn.execute("PRAGMA journal_mode=WAL") - conn.executescript(_SCHEMA) - conn.commit() - - try: + ensure_schema(db_path) + with get_conn(db_path) as conn: stats = _glean_ssh_source(src, compiled, ingest_time, conn, batch_size) - finally: - conn.close() logger.info("Rebuilding FTS index after SSH source glean...") build_fts_index(db_path) @@ -645,7 +528,7 @@ def glean_dir( Pass ``force=True`` to bypass fingerprint checks and re-glean all files regardless of whether they have changed since the last run. """ - files = sorted(corpus_dir.glob("*.jsonl")) + sorted(corpus_dir.glob("*.log")) + files = sorted(corpus_dir.rglob("*.jsonl")) + sorted(corpus_dir.rglob("*.log")) return _glean_files(files, db_path, pattern_file, batch_size, force=force) @@ -740,18 +623,13 @@ def glean_sources( compiled = _compile(load_patterns(effective_pattern_file)) ingest_time = now_iso() - conn = sqlite3.connect(str(db_path), timeout=30.0) - conn.execute("PRAGMA journal_mode=WAL") - conn.executescript(_SCHEMA) - conn.commit() - - try: + ensure_schema(db_path) + with get_conn(db_path) as conn: for src in ssh_sources: ssh_stats = _glean_ssh_source(src, compiled, ingest_time, conn, batch_size) for k, v in ssh_stats.items(): stats[k] = stats.get(k, 0) + v - finally: - conn.close() + conn.commit() # Rebuild FTS only when SSH sources added entries (_glean_files already # rebuilds when local sources are present; safe to call again if both ran). diff --git a/app/glean/plaintext.py b/app/glean/plaintext.py index a205fc0..65e36cd 100644 --- a/app/glean/plaintext.py +++ b/app/glean/plaintext.py @@ -32,10 +32,11 @@ def _extract_ts(line: str) -> tuple[str, str]: if m: ts_raw = m.group("ts") try: - # Strip fractional seconds / TZ for strptime compat + # Strip fractional seconds / TZ for strptime compat. + # Normalise ISO 8601 T-separator to space so strptime format matches. clean = re.sub(r"(\.\d+)?([Zz]|[+-]\d{2}:?\d{2})?$", "", ts_raw).strip() clean = clean.replace("T", " ") - dt = datetime.strptime(clean, fmt) + dt = datetime.strptime(clean, fmt.replace("T", " ")) if dt.year == 1900: dt = dt.replace(year=datetime.now().year) dt = dt.astimezone(timezone.utc) diff --git a/app/mcp_server.py b/app/mcp_server.py index 607a3ca..38b55ec 100644 --- a/app/mcp_server.py +++ b/app/mcp_server.py @@ -11,7 +11,7 @@ from __future__ import annotations import logging import os -import sqlite3 +import sqlite3 # still used for the pre-index-check on SQLite backend import sys from pathlib import Path @@ -53,15 +53,15 @@ _index_ready = False def _ensure_index() -> None: - """Build FTS index on first use; skip if already present.""" + """Build FTS index on first use; skip if already present (SQLite only).""" global _index_ready if _index_ready: return try: - conn = sqlite3.connect(str(DB_PATH), timeout=30.0) - count = conn.execute("SELECT COUNT(*) FROM log_fts").fetchone()[0] - conn.close() + raw = sqlite3.connect(str(DB_PATH), timeout=30.0) + count = raw.execute("SELECT COUNT(*) FROM log_fts").fetchone()[0] + raw.close() if count > 0: _index_ready = True logger.info("FTS index present (%d entries)", count) @@ -93,7 +93,7 @@ def search_logs( Example: '"connection refused" OR "connection lost"' severity: Filter by level — EMERGENCY, ALERT, CRITICAL, ERROR, WARN, NOTICE, INFO, DEBUG. source: Partial match on source_id. Format is 'corpus:host:service'. - Example: 'example-node:caddy' matches all Caddy entries from example-node. + Example: 'myserver:caddy' matches all Caddy entries from myserver. pattern: Filter by named pattern tag applied at glean time. Known tags: auth_failure, connection_lost, oom, segfault, disk_full, timeout, caddy_tls_error, caddy_config_error, caddy_auth_error, diff --git a/app/rest.py b/app/rest.py index 9efe9df..246b5cc 100644 --- a/app/rest.py +++ b/app/rest.py @@ -12,6 +12,7 @@ import hmac import json import logging import os +import re import time # Offline mode: must be set before any HuggingFace library is imported. @@ -35,7 +36,8 @@ from fastapi.responses import FileResponse, RedirectResponse, StreamingResponse from fastapi.staticfiles import StaticFiles from pydantic import BaseModel -from app.glean.pipeline import ensure_schema, ensure_context_schema, ensure_incidents_schema, migrate_incidents_to_dedicated_db, glean_file as _glean_file, glean_ssh_source as _glean_ssh_source +from app.db import close_pool, ensure_schema, ensure_context_schema, ensure_incidents_schema, migrate_incidents_to_dedicated_db +from app.glean.pipeline import glean_file as _glean_file, glean_ssh_source as _glean_ssh_source from app.glean.base import load_compiled_patterns, now_iso from app.glean.tautulli import parse_webhook as _parse_tautulli from app.glean.wazuh import is_wazuh_alert as _is_wazuh_alert, parse as _parse_wazuh @@ -87,6 +89,10 @@ from app.glean.doc_upload import glean_upload as _glean_upload from app.context.wizard import get_schema as _wizard_schema, advance_step, is_complete, apply_session from app.context.chunker import UnsupportedDocType, FileTooLarge from app.tasks.glean_scheduler import get_state as _glean_state, run_once as _run_glean, scheduler_loop as _scheduler_loop, submit_matched as _submit_matched +from app.tasks.anomaly_scorer import get_state as _scorer_state, run_once as _run_scorer +from app.tasks.cybersec_scorer import get_state as _cybersec_state, run_once as _run_cybersec +from app.services.anomaly import list_detections as _list_detections, acknowledge_detection as _ack_detection +from app.services.cybersec import list_cybersec_detections as _list_cybersec, CYBERSEC_LABELS from app.glean.mqtt_subscriber import run_mqtt_subscribers as _run_mqtt_subscribers DB_PATH = Path(os.environ.get("TURNSTONE_DB", Path(__file__).parent.parent / "data" / "turnstone.db")) @@ -108,6 +114,13 @@ PATTERN_DIR = Path(os.environ.get("TURNSTONE_PATTERNS", Path(__file__).parent.pa PATTERN_FILE = PATTERN_DIR / "default.yaml" GLEAN_INTERVAL = int(os.environ.get("TURNSTONE_GLEAN_INTERVAL", "900")) SUBMIT_ENDPOINT = os.environ.get("TURNSTONE_SUBMIT_ENDPOINT", "").rstrip("/") +ANOMALY_MODEL = os.environ.get("TURNSTONE_ANOMALY_MODEL", "") +ANOMALY_DEVICE = os.environ.get("TURNSTONE_ANOMALY_DEVICE", "cpu") +ANOMALY_THRESHOLD = float(os.environ.get("TURNSTONE_ANOMALY_THRESHOLD", "0.75")) +CYBERSEC_MODEL = os.environ.get("TURNSTONE_CYBERSEC_MODEL", "") +CYBERSEC_DEVICE = os.environ.get("TURNSTONE_CYBERSEC_DEVICE", "cpu") +CYBERSEC_THRESHOLD = float(os.environ.get("TURNSTONE_CYBERSEC_THRESHOLD", "0.60")) +AUTO_INCIDENT = os.environ.get("TURNSTONE_AUTO_INCIDENT", "true").lower() not in ("0", "false", "no") # When set, all /api/ routes require Authorization: Bearer . # Unset (default) means no authentication — suitable for local-only deployments. _API_KEY: str | None = os.environ.get("TURNSTONE_API_KEY") or None @@ -164,6 +177,14 @@ async def _lifespan(app: FastAPI): sources_file, DB_PATH, PATTERN_FILE, GLEAN_INTERVAL, submit_endpoint=SUBMIT_ENDPOINT or None, source_host=SOURCE_HOST, + anomaly_model=ANOMALY_MODEL, + anomaly_device=ANOMALY_DEVICE, + anomaly_threshold=ANOMALY_THRESHOLD, + cybersec_model=CYBERSEC_MODEL, + cybersec_device=CYBERSEC_DEVICE, + cybersec_threshold=CYBERSEC_THRESHOLD, + incidents_db_path=INCIDENTS_DB_PATH, + auto_incident=AUTO_INCIDENT, ), name="glean-scheduler", ) @@ -185,6 +206,7 @@ async def _lifespan(app: FastAPI): await task except asyncio.CancelledError: pass + close_pool() # no-op if SQLite backend app = FastAPI(title="Turnstone API", version="0.6.2", docs_url="/turnstone/docs", redoc_url=None, lifespan=_lifespan) @@ -256,6 +278,10 @@ class DiagnoseRequest(BaseModel): source: str | None = None +class SourceSuggestRequest(BaseModel): + query: str + + class SeverityOverride(BaseModel): name: str pattern: str @@ -502,6 +528,55 @@ async def diagnose_post_stream(body: DiagnoseRequest) -> StreamingResponse: ) +_SUGGEST_STOPWORDS = frozenset({ + "the", "and", "that", "this", "with", "have", "from", "they", + "been", "their", "what", "when", "there", "some", "would", "make", + "like", "into", "time", "look", "just", "know", "take", "year", + "your", "good", "some", "could", "them", "then", "very", "also", + "back", "after", "work", "need", "even", "much", "most", "tell", + "does", "more", "once", "help", "seem", "here", "about", "issue", + "thing", "logs", "error", "again", "still", "these", "those", + "getting", "having", "trying", "going", "where", "which", "cant", + "now", "set", "kind", "weird", "stable", "huge", "real", "nice", +}) + + +@router.post("/api/sources/suggest") +def suggest_sources(body: SourceSuggestRequest) -> dict: + """Return source IDs ranked by relevance to a natural-language problem description.""" + all_sources = _list_sources(DB_PATH) + query_tokens = { + t.lower() + for t in re.findall(r"[a-zA-Z]+", body.query) + if len(t) > 2 and t.lower() not in _SUGGEST_STOPWORDS + } + + suggestions = [] + for src in all_sources: + src_id: str = src["source_id"] + # Tokenise source ID: split on colon, dash, underscore, digits + parts = { + p.lower() + for seg in re.split(r"[:\-_\d]+", src_id) + for p in [seg.strip()] + if len(p) > 2 + } + matched = query_tokens & parts + if matched: + score = round(len(matched) / max(len(parts), 1), 3) + suggestions.append({ + "source_id": src_id, + "score": score, + "matched_tokens": sorted(matched), + }) + + suggestions.sort(key=lambda x: x["score"], reverse=True) + return { + "suggested": suggestions, + "all_source_ids": [s["source_id"] for s in all_sources], + } + + @router.get("/api/settings") def get_settings() -> dict: return _load_prefs() @@ -993,7 +1068,7 @@ def get_incident_endpoint(incident_id: str) -> dict: incident = get_incident(INCIDENTS_DB_PATH, incident_id) if not incident: raise HTTPException(status_code=404, detail="Incident not found") - entries = get_incident_entries(INCIDENTS_DB_PATH, incident) + entries = get_incident_entries(DB_PATH, incident) return { **dataclasses.asdict(incident), "entries": [dataclasses.asdict(e) for e in entries], @@ -1012,7 +1087,7 @@ def get_incident_bundle(incident_id: str, sanitize: bool = False) -> dict: incident = get_incident(INCIDENTS_DB_PATH, incident_id) if not incident: raise HTTPException(status_code=404, detail="Incident not found") - bundle = build_bundle(INCIDENTS_DB_PATH, incident, source_host=SOURCE_HOST, sanitize=sanitize) + bundle = build_bundle(DB_PATH, incident, source_host=SOURCE_HOST, sanitize=sanitize) record_sent_bundle(INCIDENTS_DB_PATH, incident_id, bundle, sanitized=sanitize) return bundle @@ -1030,7 +1105,7 @@ def send_incident_bundle(incident_id: str, sanitize: bool = False) -> dict: incident = get_incident(INCIDENTS_DB_PATH, incident_id) if not incident: raise HTTPException(status_code=404, detail="Incident not found") - bundle = build_bundle(INCIDENTS_DB_PATH, incident, source_host=SOURCE_HOST, sanitize=sanitize) + bundle = build_bundle(DB_PATH, incident, source_host=SOURCE_HOST, sanitize=sanitize) record_sent_bundle(INCIDENTS_DB_PATH, incident_id, bundle, sanitized=sanitize) payload = json.dumps(bundle).encode() req = urllib.request.Request( @@ -1316,6 +1391,115 @@ async def debug_search(q: str): app.include_router(_ctx) +# --------------------------------------------------------------------------- +# Anomaly scoring endpoints +# --------------------------------------------------------------------------- + +_anomaly = APIRouter(prefix="/turnstone/api/anomaly", dependencies=[Depends(_check_api_key)]) + + +@_anomaly.get("/status") +async def anomaly_status(): + """Return scorer state and configuration.""" + state = _scorer_state() + return { + "model": ANOMALY_MODEL or None, + "threshold": ANOMALY_THRESHOLD, + "device": ANOMALY_DEVICE, + "enabled": bool(ANOMALY_MODEL), + **vars(state), + } + + +@_anomaly.post("/run") +async def anomaly_run(background_tasks: BackgroundTasks): + """Trigger a manual anomaly scoring pass (runs in background).""" + if not ANOMALY_MODEL: + raise HTTPException(status_code=400, detail="TURNSTONE_ANOMALY_MODEL not configured") + background_tasks.add_task( + _run_scorer, DB_PATH, ANOMALY_MODEL, ANOMALY_DEVICE, 256, ANOMALY_THRESHOLD + ) + return {"ok": True, "message": "scorer triggered"} + + +@_anomaly.get("/detections") +async def anomaly_detections( + limit: int = Query(100, ge=1, le=1000), + unacked_only: bool = Query(False), + label: str | None = Query(None), + scorer: str | None = Query(None), +): + """List detections ordered by detected_at DESC. Optionally filter by scorer ('anomaly'|'cybersec').""" + loop = asyncio.get_running_loop() + rows = await loop.run_in_executor( + None, lambda: _list_detections(DB_PATH, limit=limit, unacked_only=unacked_only, label=label, scorer=scorer) + ) + return {"detections": rows, "total": len(rows)} + + +@_anomaly.post("/detections/{detection_id}/acknowledge") +async def acknowledge_detection(detection_id: str, notes: str = ""): + """Acknowledge a detection (mark as reviewed).""" + loop = asyncio.get_running_loop() + updated = await loop.run_in_executor( + None, lambda: _ack_detection(DB_PATH, detection_id, notes) + ) + if not updated: + raise HTTPException(status_code=404, detail="Detection not found") + return {"ok": True} + + +app.include_router(_anomaly) + + +# --------------------------------------------------------------------------- +# Cybersec scoring endpoints +# --------------------------------------------------------------------------- + +_cybersec_router = APIRouter(prefix="/turnstone/api/cybersec", dependencies=[Depends(_check_api_key)]) + + +@_cybersec_router.get("/status") +async def cybersec_status(): + """Return cybersec scorer state and configuration.""" + return { + "model": CYBERSEC_MODEL or None, + "threshold": CYBERSEC_THRESHOLD, + "device": CYBERSEC_DEVICE, + "enabled": bool(CYBERSEC_MODEL), + "candidate_labels": CYBERSEC_LABELS, + **_cybersec_state(), + } + + +@_cybersec_router.post("/run") +async def cybersec_run(background_tasks: BackgroundTasks): + """Trigger a manual cybersec scoring pass (runs in background).""" + if not CYBERSEC_MODEL: + raise HTTPException(status_code=400, detail="TURNSTONE_CYBERSEC_MODEL not configured") + background_tasks.add_task( + _run_cybersec, DB_PATH, CYBERSEC_MODEL, CYBERSEC_DEVICE, 32, CYBERSEC_THRESHOLD + ) + return {"ok": True, "message": "cybersec scorer triggered"} + + +@_cybersec_router.get("/detections") +async def cybersec_detections( + limit: int = Query(100, ge=1, le=1000), + unacked_only: bool = Query(False), + label: str | None = Query(None), +): + """List cybersec detections ordered by detected_at DESC.""" + loop = asyncio.get_running_loop() + rows = await loop.run_in_executor( + None, lambda: _list_cybersec(DB_PATH, limit=limit, unacked_only=unacked_only, label=label) + ) + return {"detections": rows, "total": len(rows)} + + +app.include_router(_cybersec_router) + + # Root redirect → /turnstone/ @app.get("/") def root_redirect() -> RedirectResponse: diff --git a/app/services/anomaly.py b/app/services/anomaly.py new file mode 100644 index 0000000..4dbc21b --- /dev/null +++ b/app/services/anomaly.py @@ -0,0 +1,305 @@ +"""Anomaly scoring pipeline — batch-score log_entries with a HF classifier. + +Designed to run after each glean cycle (or standalone). When no model is +configured the scorer is a no-op and returns immediately, so it is always +safe to wire into the glean pipeline. + +Model: any HuggingFace text-classification model. The existing Hybrid-BERT +label map (from diagnose/classifier.py) is reused when the model produces +NORMAL/SECURITY_ANOMALY/… outputs; other models get a generic severity map. + +Scoring strategy +---------------- +- Query unscored rows in batches (WHERE anomaly_scored_at IS NULL) +- Run each entry text through the HF pipeline +- Write anomaly_score + anomaly_label + anomaly_scored_at back +- INSERT high-confidence hits (score >= threshold) into detections table, + skipping duplicates so the scorer is safe to re-run +""" +from __future__ import annotations + +import logging +import os +import time +import uuid +from dataclasses import dataclass +from datetime import datetime, timezone +from pathlib import Path +from typing import Any + +from app.db import get_conn, resolve_tenant_id +from app.db.dialect import q + +logger = logging.getLogger(__name__) + +# --------------------------------------------------------------------------- +# Label maps — reuse Hybrid-BERT vocabulary from diagnose/classifier.py +# --------------------------------------------------------------------------- + +_HYBRID_BERT_SEVERITY: dict[str, str] = { + "NORMAL": "INFO", + "SECURITY_ANOMALY": "ERROR", + "SYSTEM_FAILURE": "CRITICAL", + "PERFORMANCE_ISSUE": "WARN", + "NETWORK_ANOMALY": "WARN", + "CONFIG_ERROR": "ERROR", + "HARDWARE_ISSUE": "CRITICAL", +} + +_GENERIC_SEVERITY: dict[str, str] = { + "CRITICAL": "CRITICAL", + "ERROR": "ERROR", + "WARNING": "WARN", + "WARN": "WARN", + "INFO": "INFO", + "DEBUG": "DEBUG", +} + +_ANOMALOUS_LABELS: frozenset[str] = frozenset( + { + "SECURITY_ANOMALY", + "SYSTEM_FAILURE", + "PERFORMANCE_ISSUE", + "NETWORK_ANOMALY", + "CONFIG_ERROR", + "HARDWARE_ISSUE", + "CRITICAL", + "ERROR", + } +) + +_DEFAULT_THRESHOLD = float(os.environ.get("TURNSTONE_ANOMALY_THRESHOLD", "0.75")) +_DEFAULT_MODEL = os.environ.get("TURNSTONE_ANOMALY_MODEL", "") +_DEFAULT_DEVICE = os.environ.get("TURNSTONE_ANOMALY_DEVICE", "cpu") +_DEFAULT_BATCH = int(os.environ.get("TURNSTONE_ANOMALY_BATCH", "256")) + +# --------------------------------------------------------------------------- +# ML singleton +# --------------------------------------------------------------------------- + +_pipeline: Any | None = None + + +def _get_pipeline(model_id: str, device: str) -> Any: + global _pipeline # noqa: PLW0603 + if _pipeline is None: + from transformers import pipeline as hf_pipeline # type: ignore[import-untyped] + _pipeline = hf_pipeline("text-classification", model=model_id, device=device) + return _pipeline + + +def reset_pipeline() -> None: + """Reset the cached pipeline singleton (test helper).""" + global _pipeline # noqa: PLW0603 + _pipeline = None + + +# --------------------------------------------------------------------------- +# Result types +# --------------------------------------------------------------------------- + + +@dataclass +class ScoringResult: + scored: int = 0 + detections: int = 0 + skipped: bool = False + error: str | None = None + + +# --------------------------------------------------------------------------- +# Internal helpers +# --------------------------------------------------------------------------- + + +def _map_label(raw_label: str, score: float) -> tuple[str, str]: + """Return (normalised_label, severity) for a raw model output label.""" + upper = raw_label.upper() + if upper in _HYBRID_BERT_SEVERITY: + return upper, _HYBRID_BERT_SEVERITY[upper] + sev = _GENERIC_SEVERITY.get(upper, "WARN") + return upper, sev + + +def _fetch_unscored(conn: Any, tenant_id: str, limit: int) -> list[dict]: + rows = conn.execute( + q(""" + SELECT id, source_id, text, timestamp_iso, severity + FROM log_entries + WHERE anomaly_scored_at IS NULL + AND (tenant_id = ? OR tenant_id = '') + ORDER BY ingest_time DESC + LIMIT ? + """), + (tenant_id, limit), + ).fetchall() + return [dict(r) for r in rows] + + +def _write_scores( + conn: Any, + rows: list[dict], + scored_at: str, +) -> None: + conn.executemany( + q("UPDATE log_entries SET anomaly_score = ?, anomaly_label = ?, anomaly_scored_at = ? WHERE id = ?"), + [(r["anomaly_score"], r["anomaly_label"], scored_at, r["id"]) for r in rows], + ) + + +def _insert_detections(conn: Any, rows: list[dict], tenant_id: str, detected_at: str) -> int: + inserted = 0 + for r in rows: + try: + conn.execute( + q(""" + INSERT INTO detections + (id, tenant_id, entry_id, source_id, anomaly_label, anomaly_score, + severity, text, timestamp_iso, detected_at) + VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?) + """), + ( + str(uuid.uuid4()), + tenant_id, + r["id"], + r["source_id"], + r["anomaly_label"], + r["anomaly_score"], + r["severity"], + r["text"][:2000], + r.get("timestamp_iso"), + detected_at, + ), + ) + inserted += 1 + except Exception: # noqa: BLE001 + pass # duplicate entry_id or constraint violation — skip + return inserted + + +# --------------------------------------------------------------------------- +# Public API +# --------------------------------------------------------------------------- + + +def score_unscored( + db_path: Path, + model_id: str = _DEFAULT_MODEL, + device: str = _DEFAULT_DEVICE, + batch_size: int = _DEFAULT_BATCH, + threshold: float = _DEFAULT_THRESHOLD, +) -> ScoringResult: + """Score all unscored log_entries in batches. + + Returns immediately (skipped=True) when model_id is empty — allows + unconditional wiring without requiring the model to be configured. + """ + if not model_id: + return ScoringResult(skipped=True) + + try: + pipe = _get_pipeline(model_id, device) + except Exception as exc: + logger.error("Failed to load anomaly model %r: %s", model_id, exc) + return ScoringResult(error=str(exc)) + + tenant_id = resolve_tenant_id() + total_scored = 0 + total_detections = 0 + + while True: + with get_conn(db_path) as conn: + batch = _fetch_unscored(conn, tenant_id, batch_size) + if not batch: + break + + texts = [r["text"][:512] for r in batch] + try: + predictions = pipe(texts, truncation=True, max_length=512) + except Exception as exc: + logger.error("Inference error on batch of %d: %s", len(batch), exc) + return ScoringResult(scored=total_scored, detections=total_detections, error=str(exc)) + + scored_at = datetime.now(tz=timezone.utc).isoformat() + scored_rows: list[dict] = [] + detection_rows: list[dict] = [] + + for row, pred in zip(batch, predictions): + label, severity = _map_label(pred["label"], pred["score"]) + enriched = {**row, "anomaly_score": pred["score"], "anomaly_label": label, "severity": severity} + scored_rows.append(enriched) + if label in _ANOMALOUS_LABELS and pred["score"] >= threshold: + detection_rows.append(enriched) + + for _attempt in range(4): + try: + with get_conn(db_path) as conn: + _write_scores(conn, scored_rows, scored_at) + det_count = _insert_detections(conn, detection_rows, tenant_id, scored_at) + conn.commit() + break + except Exception as exc: + if "database is locked" in str(exc).lower() and _attempt < 3: + logger.warning("DB locked, retrying write in 10s (attempt %d/4)", _attempt + 1) + time.sleep(10) + else: + raise + + total_scored += len(scored_rows) + total_detections += det_count + logger.info( + "Scored %d entries, %d detections (threshold=%.2f)", + len(scored_rows), det_count, threshold, + ) + + if len(batch) < batch_size: + break + + return ScoringResult(scored=total_scored, detections=total_detections) + + +def list_detections( + db_path: Path, + limit: int = 100, + unacked_only: bool = False, + label: str | None = None, + scorer: str | None = None, +) -> list[dict]: + """Return detections ordered by detected_at DESC.""" + tenant_id = resolve_tenant_id() + conditions = ["(tenant_id = ? OR tenant_id = '')"] + params: list[Any] = [tenant_id] + + if unacked_only: + conditions.append("acknowledged = 0") + if label: + conditions.append(q("anomaly_label = ?")) + params.append(label.upper()) + if scorer: + conditions.append(q("scorer = ?")) + params.append(scorer.lower()) + + where = " AND ".join(conditions) + with get_conn(db_path) as conn: + rows = conn.execute( + q(f"SELECT * FROM detections WHERE {where} ORDER BY detected_at DESC LIMIT ?"), # noqa: S608 + (*params, limit), + ).fetchall() + return [dict(r) for r in rows] + + +def acknowledge_detection(db_path: Path, detection_id: str, notes: str = "") -> bool: + """Mark a detection as acknowledged. Returns True if a row was updated.""" + tenant_id = resolve_tenant_id() + acked_at = datetime.now(tz=timezone.utc).isoformat() + with get_conn(db_path) as conn: + cur = conn.execute( + q(""" + UPDATE detections + SET acknowledged = 1, acknowledged_at = ?, notes = ? + WHERE id = ? AND (tenant_id = ? OR tenant_id = '') + """), + (acked_at, notes, detection_id, tenant_id), + ) + conn.commit() + return cur.rowcount > 0 diff --git a/app/services/blocklist.py b/app/services/blocklist.py index 998014a..ea984a3 100644 --- a/app/services/blocklist.py +++ b/app/services/blocklist.py @@ -4,10 +4,12 @@ from __future__ import annotations import dataclasses import json import re -import sqlite3 import uuid from datetime import datetime, timezone from pathlib import Path +from typing import Any + +from app.db import get_conn, resolve_tenant_id import yaml @@ -91,26 +93,26 @@ def _now_iso() -> str: return datetime.now(timezone.utc).isoformat() -def _row_to_candidate(row: tuple) -> BlocklistCandidate: +def _row_to_candidate(row: Any) -> BlocklistCandidate: return BlocklistCandidate( - id=row[0], - domain_or_ip=row[1], - source_device_ip=row[2], - source_device_name=row[3], - first_seen=row[4], - last_seen=row[5], - hit_count=row[6], - status=row[7], - pushed_at=row[8], - log_evidence=json.loads(row[9] or "[]"), - matched_rule=row[10], - llm_score=row[11], - llm_reason=row[12], + id=row["id"], + domain_or_ip=row["domain_or_ip"], + source_device_ip=row["source_device_ip"], + source_device_name=row["source_device_name"], + first_seen=row["first_seen"], + last_seen=row["last_seen"], + hit_count=row["hit_count"], + status=row["status"], + pushed_at=row["pushed_at"], + log_evidence=json.loads(row["log_evidence"] or "[]"), + matched_rule=row["matched_rule"], + llm_score=row["llm_score"], + llm_reason=row["llm_reason"], ) def _upsert_candidate( - conn: sqlite3.Connection, + conn: Any, domain_or_ip: str, source_device_ip: str | None, source_device_name: str | None, @@ -119,26 +121,29 @@ def _upsert_candidate( now: str, ) -> bool: """Insert or update a candidate. Returns True if a new row was created.""" + tid = resolve_tenant_id() row = conn.execute( "SELECT id, hit_count, log_evidence FROM blocklist_candidates " - "WHERE domain_or_ip = ? AND source_device_ip IS ?", - (domain_or_ip, source_device_ip), + "WHERE domain_or_ip = ? AND source_device_ip IS ? AND (tenant_id = ? OR tenant_id = '')", + (domain_or_ip, source_device_ip, tid), ).fetchone() if row is None: conn.execute( """INSERT INTO blocklist_candidates - (id, domain_or_ip, source_device_ip, source_device_name, + (id, tenant_id, domain_or_ip, source_device_ip, source_device_name, first_seen, last_seen, hit_count, status, pushed_at, log_evidence, matched_rule) - VALUES (?, ?, ?, ?, ?, ?, 1, 'pending', NULL, ?, ?)""", + VALUES (?, ?, ?, ?, ?, ?, ?, 1, 'pending', NULL, ?, ?)""", ( - str(uuid.uuid4()), domain_or_ip, source_device_ip, source_device_name, + str(uuid.uuid4()), tid, domain_or_ip, source_device_ip, source_device_name, now, now, json.dumps([entry_id]), matched_rule, ), ) return True - existing_id, hit_count, existing_evidence = row + existing_id = row["id"] + hit_count = row["hit_count"] + existing_evidence = row["log_evidence"] evidence = json.loads(existing_evidence or "[]") if entry_id not in evidence: evidence.append(entry_id) @@ -172,14 +177,16 @@ def run_scan( now = _now_iso() count = 0 - conn = sqlite3.connect(str(db_path), timeout=30.0) - try: + tid = resolve_tenant_id() + with get_conn(db_path) as conn: rows = conn.execute( - f"SELECT id, text FROM log_entries WHERE source_id IN ({placeholders})", - router_source_ids, + f"SELECT id, text FROM log_entries WHERE source_id IN ({placeholders}) AND (tenant_id = ? OR tenant_id = '')", # noqa: S608 + (*router_source_ids, tid), ).fetchall() - for entry_id, text in rows: + for row in rows: + entry_id, text = row["id"], row["text"] + # rest of loop body follows unchanged src_ip: str | None = None dst: str | None = None @@ -204,8 +211,6 @@ def run_scan( count += 1 conn.commit() - finally: - conn.close() return count @@ -226,26 +231,27 @@ def list_candidates( status: str | None = None, device_ip: str | None = None, ) -> list[BlocklistCandidate]: - conn = sqlite3.connect(str(db_path), timeout=30.0) - try: - query = f"{_CANDIDATE_SELECT} WHERE 1=1" - params: list = [] - if status and status != "all": - query += " AND status = ?" - params.append(status) - if device_ip: - query += " AND source_device_ip = ?" - params.append(device_ip) - query += " ORDER BY last_seen DESC" - rows = conn.execute(query, params).fetchall() - finally: - conn.close() + tid = resolve_tenant_id() + conditions = ["(tenant_id = ? OR tenant_id = '')"] + params: list = [tid] + if status and status != "all": + conditions.append("status = ?") + params.append(status) + if device_ip: + conditions.append("source_device_ip = ?") + params.append(device_ip) + where = " AND ".join(conditions) + with get_conn(db_path) as conn: + rows = conn.execute( + f"{_CANDIDATE_SELECT} WHERE {where} ORDER BY last_seen DESC", # noqa: S608 + params, + ).fetchall() return [_row_to_candidate(r) for r in rows] -def _get_candidate(conn: sqlite3.Connection, candidate_id: str) -> BlocklistCandidate: +def _get_candidate(conn: Any, candidate_id: str) -> BlocklistCandidate: row = conn.execute( - f"{_CANDIDATE_SELECT} WHERE id=?", + f"{_CANDIDATE_SELECT} WHERE id=?", # noqa: S608 (candidate_id,), ).fetchone() if row is None: @@ -255,43 +261,31 @@ def _get_candidate(conn: sqlite3.Connection, candidate_id: str) -> BlocklistCand def get_candidate(db_path: Path, candidate_id: str) -> BlocklistCandidate: """Fetch a single candidate by ID. Raises KeyError if not found.""" - conn = sqlite3.connect(str(db_path), timeout=30.0) - try: + with get_conn(db_path) as conn: return _get_candidate(conn, candidate_id) - finally: - conn.close() def update_candidate_status(db_path: Path, candidate_id: str, new_status: str) -> BlocklistCandidate: if new_status not in _VALID_STATUSES: raise ValueError(f"Invalid status {new_status!r}. Must be one of {_VALID_STATUSES}") - conn = sqlite3.connect(str(db_path), timeout=30.0) - try: + with get_conn(db_path) as conn: conn.execute("UPDATE blocklist_candidates SET status=? WHERE id=?", (new_status, candidate_id)) conn.commit() return _get_candidate(conn, candidate_id) - finally: - conn.close() def mark_pushed(db_path: Path, candidate_id: str) -> BlocklistCandidate: - conn = sqlite3.connect(str(db_path), timeout=30.0) - try: + with get_conn(db_path) as conn: conn.execute( "UPDATE blocklist_candidates SET status='pushed', pushed_at=? WHERE id=?", (_now_iso(), candidate_id), ) conn.commit() return _get_candidate(conn, candidate_id) - finally: - conn.close() def mark_unblocked(db_path: Path, candidate_id: str) -> BlocklistCandidate: - conn = sqlite3.connect(str(db_path), timeout=30.0) - try: + with get_conn(db_path) as conn: conn.execute("UPDATE blocklist_candidates SET status='unblocked' WHERE id=?", (candidate_id,)) conn.commit() return _get_candidate(conn, candidate_id) - finally: - conn.close() diff --git a/app/services/cybersec.py b/app/services/cybersec.py new file mode 100644 index 0000000..a769b0d --- /dev/null +++ b/app/services/cybersec.py @@ -0,0 +1,241 @@ +"""Cybersecurity-focused scoring pipeline using zero-shot classification. + +Runs a second-pass analysis on entries that were already flagged by the +anomaly scorer or that have pattern matches. Uses a zero-shot classification +model (DeBERTa-v3-base-mnli is cached locally) so no fine-tuning is needed. + +The scorer writes ml_score / ml_label / ml_scored_at to log_entries and +inserts high-confidence non-normal hits into the detections table tagged +with scorer='cybersec'. + +Env vars +-------- +TURNSTONE_CYBERSEC_MODEL — HF model id for zero-shot classification. + Recommended: MoritzLaurer/DeBERTa-v3-base-mnli-fever-anli + (already cached from the diagnose pipeline). + Set to empty string to disable (safe default). +TURNSTONE_CYBERSEC_DEVICE — 'cpu' (default) or 'cuda' +TURNSTONE_CYBERSEC_THRESHOLD — float confidence floor for detection insertion (default 0.60) +""" +from __future__ import annotations + +import logging +import uuid +from dataclasses import dataclass, field +from datetime import datetime, timezone +from pathlib import Path +from typing import Any + +from app.db import get_conn, resolve_tenant_id +from app.db.dialect import q + +logger = logging.getLogger(__name__) + +# --------------------------------------------------------------------------- +# Candidate labels — cybersec vocabulary for zero-shot inference +# --------------------------------------------------------------------------- + +CYBERSEC_LABELS: list[str] = [ + "authentication failure or brute force attack", + "privilege escalation or unauthorized access", + "network intrusion or port scan", + "malware or suspicious process activity", + "data exfiltration or unusual outbound traffic", + "normal system operation", +] + +_NORMAL_LABEL = "normal system operation" + +_LABEL_SEVERITY: dict[str, str] = { + "authentication failure or brute force attack": "ERROR", + "privilege escalation or unauthorized access": "CRITICAL", + "network intrusion or port scan": "ERROR", + "malware or suspicious process activity": "CRITICAL", + "data exfiltration or unusual outbound traffic":"CRITICAL", + "normal system operation": "INFO", +} + +# --------------------------------------------------------------------------- +# Pipeline singleton +# --------------------------------------------------------------------------- + +_pipeline: Any = None + + +def _get_pipeline(model_id: str, device: str) -> Any: + global _pipeline # noqa: PLW0603 + if _pipeline is None: + from transformers import pipeline # type: ignore[import-untyped] + logger.info("loading cybersec zero-shot pipeline: %s on %s", model_id, device) + _pipeline = pipeline( + "zero-shot-classification", + model=model_id, + device=0 if device == "cuda" else -1, + ) + logger.info("cybersec pipeline ready") + return _pipeline + + +def reset_pipeline() -> None: + """Clear the cached pipeline — for testing only.""" + global _pipeline # noqa: PLW0603 + _pipeline = None + + +# --------------------------------------------------------------------------- +# Result type +# --------------------------------------------------------------------------- + +@dataclass +class CybersecResult: + scored: int = 0 + detections: int = 0 + skipped: bool = False + error: str | None = None + + +# --------------------------------------------------------------------------- +# Core scoring function +# --------------------------------------------------------------------------- + +def score_security_entries( + db_path: Path, + model_id: str, + device: str = "cpu", + batch_size: int = 32, + threshold: float = 0.60, +) -> CybersecResult: + """Score entries that were anomaly-flagged or pattern-matched. + + Only entries with ml_scored_at IS NULL are processed (idempotent). + Writes ml_score / ml_label / ml_scored_at and inserts high-confidence + hits into detections with scorer='cybersec'. + """ + if not model_id: + return CybersecResult(skipped=True) + + tenant_id = resolve_tenant_id() + try: + pipe = _get_pipeline(model_id, device) + except Exception as exc: + logger.error("failed to load cybersec pipeline: %s", exc) + return CybersecResult(error=str(exc)) + + total_scored = 0 + total_detections = 0 + + try: + with get_conn(db_path) as conn: + # Only score entries that are worth a second look: + # anomaly-flagged (non-normal) OR have at least one pattern match. + rows = conn.execute( + q(""" + SELECT id, source_id, text, timestamp_iso + FROM log_entries + WHERE (tenant_id = ? OR tenant_id = '') + AND ml_scored_at IS NULL + AND ( + (anomaly_label IS NOT NULL AND anomaly_label != 'NORMAL') + OR (matched_patterns IS NOT NULL AND matched_patterns != '[]' AND matched_patterns != '') + ) + LIMIT ? + """), + (tenant_id, batch_size * 10), + ).fetchall() + + if not rows: + return CybersecResult(skipped=True) + + # Process in chunks to avoid OOM on large backlogs + for i in range(0, len(rows), batch_size): + chunk = rows[i : i + batch_size] + texts = [r["text"] for r in chunk] + + try: + results = pipe(texts, candidate_labels=CYBERSEC_LABELS, multi_label=False) + except Exception as exc: + logger.warning("zero-shot inference error on chunk %d: %s", i, exc) + continue + + now = datetime.now(tz=timezone.utc).isoformat() + + with get_conn(db_path) as conn: + for row, result in zip(chunk, results): + top_label: str = result["labels"][0] + top_score: float = result["scores"][0] + + conn.execute( + q(""" + UPDATE log_entries + SET ml_score = ?, ml_label = ?, ml_scored_at = ? + WHERE id = ? AND (tenant_id = ? OR tenant_id = '') + """), + (top_score, top_label, now, row["id"], tenant_id), + ) + total_scored += 1 + + if top_score >= threshold and top_label != _NORMAL_LABEL: + severity = _LABEL_SEVERITY.get(top_label, "WARN") + try: + conn.execute( + q(""" + INSERT INTO detections + (id, tenant_id, entry_id, source_id, anomaly_label, + anomaly_score, severity, text, timestamp_iso, + detected_at, scorer) + VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, 'cybersec') + """), + ( + str(uuid.uuid4()), + tenant_id, + row["id"], + row["source_id"], + top_label, + top_score, + severity, + row["text"], + row["timestamp_iso"], + now, + ), + ) + total_detections += 1 + except Exception: + pass # entry may already have a detection — skip + + conn.commit() + + except Exception as exc: + logger.error("cybersec scoring failed: %s", exc, exc_info=True) + return CybersecResult(scored=total_scored, detections=total_detections, error=str(exc)) + + return CybersecResult(scored=total_scored, detections=total_detections) + + +# --------------------------------------------------------------------------- +# Query helpers (used by REST layer) +# --------------------------------------------------------------------------- + +def list_cybersec_detections( + db_path: Path, + limit: int = 100, + unacked_only: bool = False, + label: str | None = None, +) -> list[dict]: + """Return cybersec detections ordered by detected_at DESC.""" + tenant_id = resolve_tenant_id() + conditions = ["(tenant_id = ? OR tenant_id = '')", "scorer = 'cybersec'"] + params: list[Any] = [tenant_id] + + if unacked_only: + conditions.append("acknowledged = 0") + if label: + conditions.append(q("anomaly_label = ?")) + params.append(label) + + where = " AND ".join(conditions) + with get_conn(db_path) as conn: + rows = conn.execute( + q(f"SELECT * FROM detections WHERE {where} ORDER BY detected_at DESC LIMIT ?"), # noqa: S608 + (*params, limit), + ).fetchall() + return [dict(r) for r in rows] diff --git a/app/services/incidents.py b/app/services/incidents.py index 1d71422..9094de5 100644 --- a/app/services/incidents.py +++ b/app/services/incidents.py @@ -3,10 +3,10 @@ from __future__ import annotations import json import re -import sqlite3 import uuid from pathlib import Path +from app.db import get_conn, resolve_tenant_id from app.glean.base import now_iso from app.services.models import Incident, ReceivedBundle, SentBundle from app.services.search import SearchResult, entries_in_window, search @@ -26,7 +26,7 @@ def _redact_text(text: str) -> str: return text -def _row_to_incident(row: sqlite3.Row) -> Incident: +def _row_to_incident(row) -> Incident: return Incident( id=row["id"], label=row["label"], @@ -39,7 +39,7 @@ def _row_to_incident(row: sqlite3.Row) -> Incident: ) -def _row_to_bundle(row: sqlite3.Row) -> ReceivedBundle: +def _row_to_bundle(row) -> ReceivedBundle: return ReceivedBundle( id=row["id"], source_host=row["source_host"], @@ -62,6 +62,7 @@ def create_incident( notes: str = "", severity: str = "medium", ) -> Incident: + tid = resolve_tenant_id() incident = Incident( id=str(uuid.uuid4()), label=label, @@ -72,47 +73,45 @@ def create_incident( created_at=now_iso(), severity=severity, ) - conn = sqlite3.connect(str(db_path), timeout=30.0) - conn.execute("PRAGMA journal_mode=WAL") - conn.execute( - "INSERT INTO incidents (id, label, issue_type, started_at, ended_at, notes, created_at, severity) " - "VALUES (?, ?, ?, ?, ?, ?, ?, ?)", - (incident.id, incident.label, incident.issue_type, incident.started_at, - incident.ended_at, incident.notes, incident.created_at, incident.severity), - ) - conn.commit() - conn.close() + with get_conn(db_path) as conn: + conn.execute( + "INSERT INTO incidents (id, tenant_id, label, issue_type, started_at, ended_at, notes, created_at, severity) " + "VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)", + (incident.id, tid, incident.label, incident.issue_type, incident.started_at, + incident.ended_at, incident.notes, incident.created_at, incident.severity), + ) + conn.commit() return incident def list_incidents(db_path: Path) -> list[Incident]: - conn = sqlite3.connect(str(db_path), timeout=30.0) - conn.execute("PRAGMA journal_mode=WAL") - conn.row_factory = sqlite3.Row - rows = conn.execute( - "SELECT * FROM incidents ORDER BY created_at DESC" - ).fetchall() - conn.close() + tid = resolve_tenant_id() + with get_conn(db_path) as conn: + rows = conn.execute( + "SELECT * FROM incidents WHERE (tenant_id = ? OR tenant_id = '') ORDER BY created_at DESC", + (tid,), + ).fetchall() return [_row_to_incident(r) for r in rows] def get_incident(db_path: Path, incident_id: str) -> Incident | None: - conn = sqlite3.connect(str(db_path), timeout=30.0) - conn.execute("PRAGMA journal_mode=WAL") - conn.row_factory = sqlite3.Row - row = conn.execute( - "SELECT * FROM incidents WHERE id = ?", (incident_id,) - ).fetchone() - conn.close() + tid = resolve_tenant_id() + with get_conn(db_path) as conn: + row = conn.execute( + "SELECT * FROM incidents WHERE id = ? AND (tenant_id = ? OR tenant_id = '')", + (incident_id, tid), + ).fetchone() return _row_to_incident(row) if row else None def delete_incident(db_path: Path, incident_id: str) -> bool: - conn = sqlite3.connect(str(db_path), timeout=30.0) - conn.execute("PRAGMA journal_mode=WAL") - cur = conn.execute("DELETE FROM incidents WHERE id = ?", (incident_id,)) - conn.commit() - conn.close() + tid = resolve_tenant_id() + with get_conn(db_path) as conn: + cur = conn.execute( + "DELETE FROM incidents WHERE id = ? AND (tenant_id = ? OR tenant_id = '')", + (incident_id, tid), + ) + conn.commit() return cur.rowcount > 0 @@ -191,6 +190,7 @@ def build_bundle( def record_sent_bundle(db_path: Path, incident_id: str, bundle: dict, sanitized: bool) -> SentBundle: """Log an outgoing bundle export to the sent_bundles table.""" + tid = resolve_tenant_id() record = SentBundle( id=str(uuid.uuid4()), incident_id=incident_id, @@ -199,28 +199,25 @@ def record_sent_bundle(db_path: Path, incident_id: str, bundle: dict, sanitized: entry_count=len(bundle.get("log_entries", [])), bundle_json=json.dumps(bundle), ) - conn = sqlite3.connect(str(db_path), timeout=30.0) - conn.execute("PRAGMA journal_mode=WAL") - conn.execute( - "INSERT INTO sent_bundles (id, incident_id, exported_at, sanitized, entry_count, bundle_json) " - "VALUES (?, ?, ?, ?, ?, ?)", - (record.id, record.incident_id, record.exported_at, int(record.sanitized), - record.entry_count, record.bundle_json), - ) - conn.commit() - conn.close() + with get_conn(db_path) as conn: + conn.execute( + "INSERT INTO sent_bundles (id, tenant_id, incident_id, exported_at, sanitized, entry_count, bundle_json) " + "VALUES (?, ?, ?, ?, ?, ?, ?)", + (record.id, tid, record.incident_id, record.exported_at, + int(record.sanitized), record.entry_count, record.bundle_json), + ) + conn.commit() return record def list_sent_bundles(db_path: Path) -> list[SentBundle]: - conn = sqlite3.connect(str(db_path), timeout=30.0) - conn.execute("PRAGMA journal_mode=WAL") - conn.row_factory = sqlite3.Row - rows = conn.execute( - "SELECT id, incident_id, exported_at, sanitized, entry_count, bundle_json " - "FROM sent_bundles ORDER BY exported_at DESC" - ).fetchall() - conn.close() + tid = resolve_tenant_id() + with get_conn(db_path) as conn: + rows = conn.execute( + "SELECT id, incident_id, exported_at, sanitized, entry_count, bundle_json " + "FROM sent_bundles WHERE (tenant_id = ? OR tenant_id = '') ORDER BY exported_at DESC", + (tid,), + ).fetchall() return [ SentBundle( id=r["id"], @@ -236,6 +233,7 @@ def list_sent_bundles(db_path: Path) -> list[SentBundle]: def store_bundle(db_path: Path, bundle: dict) -> ReceivedBundle: """Store an incoming bundle from a remote Turnstone instance.""" + tid = resolve_tenant_id() inc = bundle.get("incident", {}) record = ReceivedBundle( id=str(uuid.uuid4()), @@ -248,38 +246,34 @@ def store_bundle(db_path: Path, bundle: dict) -> ReceivedBundle: entry_count=len(bundle.get("log_entries", [])), bundle_json=json.dumps(bundle), ) - conn = sqlite3.connect(str(db_path), timeout=30.0) - conn.execute("PRAGMA journal_mode=WAL") - conn.execute( - "INSERT INTO received_bundles " - "(id, source_host, issue_type, label, severity, started_at, bundled_at, entry_count, bundle_json) " - "VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)", - (record.id, record.source_host, record.issue_type, record.label, - record.severity, record.started_at, record.bundled_at, record.entry_count, record.bundle_json), - ) - conn.commit() - conn.close() + with get_conn(db_path) as conn: + conn.execute( + "INSERT INTO received_bundles " + "(id, tenant_id, source_host, issue_type, label, severity, started_at, bundled_at, entry_count, bundle_json) " + "VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)", + (record.id, tid, record.source_host, record.issue_type, record.label, + record.severity, record.started_at, record.bundled_at, record.entry_count, record.bundle_json), + ) + conn.commit() return record def list_bundles(db_path: Path) -> list[ReceivedBundle]: - conn = sqlite3.connect(str(db_path), timeout=30.0) - conn.execute("PRAGMA journal_mode=WAL") - conn.row_factory = sqlite3.Row - rows = conn.execute( - "SELECT id, source_host, issue_type, label, severity, started_at, bundled_at, entry_count, bundle_json " - "FROM received_bundles ORDER BY bundled_at DESC" - ).fetchall() - conn.close() + tid = resolve_tenant_id() + with get_conn(db_path) as conn: + rows = conn.execute( + "SELECT id, source_host, issue_type, label, severity, started_at, bundled_at, entry_count, bundle_json " + "FROM received_bundles WHERE (tenant_id = ? OR tenant_id = '') ORDER BY bundled_at DESC", + (tid,), + ).fetchall() return [_row_to_bundle(r) for r in rows] def get_bundle(db_path: Path, bundle_id: str) -> ReceivedBundle | None: - conn = sqlite3.connect(str(db_path), timeout=30.0) - conn.execute("PRAGMA journal_mode=WAL") - conn.row_factory = sqlite3.Row - row = conn.execute( - "SELECT * FROM received_bundles WHERE id = ?", (bundle_id,) - ).fetchone() - conn.close() + tid = resolve_tenant_id() + with get_conn(db_path) as conn: + row = conn.execute( + "SELECT * FROM received_bundles WHERE id = ? AND (tenant_id = ? OR tenant_id = '')", + (bundle_id, tid), + ).fetchone() return _row_to_bundle(row) if row else None diff --git a/app/services/llm.py b/app/services/llm.py index 0b04098..44c42ff 100644 --- a/app/services/llm.py +++ b/app/services/llm.py @@ -88,7 +88,7 @@ def summarize( logger.debug("Task endpoint unavailable (%s) — falling back to direct model", exc) # Fallback: OpenAI-compat endpoint with explicit model name (local instances, - # example-node, or any cf-orch that doesn't have task assignments loaded). + # or any cf-orch node that doesn't have task assignments loaded). try: resp = httpx.post( f"{llm_url.rstrip('/')}/v1/chat/completions", diff --git a/app/services/search.py b/app/services/search.py index 56b2c0a..90ad4d7 100644 --- a/app/services/search.py +++ b/app/services/search.py @@ -1,4 +1,8 @@ -"""FTS5-based log search with optional hybrid BM25 + vector re-ranking.""" +"""FTS-based log search with optional hybrid BM25 + vector re-ranking. + +SQLite backend: FTS5 virtual table with Porter stemmer. +Postgres backend: tsvector column with GIN index + websearch_to_tsquery. +""" from __future__ import annotations import json @@ -6,8 +10,11 @@ import logging import re import sqlite3 from dataclasses import dataclass +from datetime import datetime, timedelta, timezone from pathlib import Path +from app.db import BACKEND, Backend, frag, get_conn, resolve_tenant_id + logger = logging.getLogger(__name__) @@ -28,48 +35,47 @@ class SearchResult: def build_fts_index(db_path: Path) -> None: """Build (or rebuild) the FTS5 index from log_entries. Safe to re-run. - Drops and recreates the table if the schema is stale (missing sequence column). + For Postgres, the tsvector column is maintained by a trigger — this is a no-op. """ - conn = sqlite3.connect(str(db_path), timeout=30.0) - conn.execute("PRAGMA journal_mode=WAL") + if BACKEND == Backend.POSTGRES: + return - # Check whether existing table has the sequence column; rebuild if not. - needs_rebuild = False - try: - conn.execute("SELECT sequence FROM log_fts LIMIT 0") - except sqlite3.OperationalError: - needs_rebuild = True + with get_conn(db_path) as conn: + needs_rebuild = False + try: + conn.execute("SELECT sequence FROM log_fts LIMIT 0") + except Exception: + needs_rebuild = True - if needs_rebuild: - conn.execute("DROP TABLE IF EXISTS log_fts") + if needs_rebuild: + conn.execute("DROP TABLE IF EXISTS log_fts") + conn.commit() - conn.executescript(""" - CREATE VIRTUAL TABLE IF NOT EXISTS log_fts USING fts5( - text, - entry_id UNINDEXED, - source_id UNINDEXED, - sequence UNINDEXED, - severity UNINDEXED, - timestamp_iso UNINDEXED, - matched_patterns UNINDEXED, - repeat_count UNINDEXED, - out_of_order UNINDEXED, - tokenize = 'porter ascii' - ); - """) - # Only insert rows not already indexed - conn.execute(""" - INSERT INTO log_fts(text, entry_id, source_id, sequence, severity, - timestamp_iso, matched_patterns, - repeat_count, out_of_order) - SELECT e.text, e.id, e.source_id, e.sequence, e.severity, - e.timestamp_iso, e.matched_patterns, - e.repeat_count, e.out_of_order - FROM log_entries e - WHERE e.id NOT IN (SELECT entry_id FROM log_fts WHERE entry_id IS NOT NULL) - """) - conn.commit() - conn.close() + conn.execute(""" + CREATE VIRTUAL TABLE IF NOT EXISTS log_fts USING fts5( + text, + entry_id UNINDEXED, + source_id UNINDEXED, + sequence UNINDEXED, + severity UNINDEXED, + timestamp_iso UNINDEXED, + matched_patterns UNINDEXED, + repeat_count UNINDEXED, + out_of_order UNINDEXED, + tokenize = 'porter ascii' + ) + """) + conn.execute(""" + INSERT INTO log_fts(text, entry_id, source_id, sequence, severity, + timestamp_iso, matched_patterns, + repeat_count, out_of_order) + SELECT e.text, e.id, e.source_id, e.sequence, e.severity, + e.timestamp_iso, e.matched_patterns, + e.repeat_count, e.out_of_order + FROM log_entries e + WHERE e.id NOT IN (SELECT entry_id FROM log_fts WHERE entry_id IS NOT NULL) + """) + conn.commit() def _sanitize_fts_query(raw: str, or_mode: bool = False) -> str: @@ -198,14 +204,44 @@ def _bm25_search( include_repeats: bool = False, or_mode: bool = False, ) -> list[SearchResult]: - """Pure BM25 FTS5 search — internal helper used by both search() and _hybrid_search().""" - conn = sqlite3.connect(str(db_path), timeout=30.0) - conn.execute("PRAGMA journal_mode=WAL") - conn.row_factory = sqlite3.Row + """FTS search — BM25 via FTS5 (SQLite) or tsvector (Postgres).""" + tid = resolve_tenant_id() + if BACKEND == Backend.POSTGRES: + return _pg_fts_search( + db_path, query, tid, + severity=severity, source_filter=source_filter, + pattern_filter=pattern_filter, since=since, until=until, + limit=limit, include_repeats=include_repeats, + ) + + return _sqlite_fts_search( + db_path, query, tid, + severity=severity, source_filter=source_filter, + pattern_filter=pattern_filter, since=since, until=until, + limit=limit, include_repeats=include_repeats, or_mode=or_mode, + ) + + +def _sqlite_fts_search( + db_path: Path, + query: str, + tid: str, + severity: str | None, + source_filter: str | None, + pattern_filter: str | None, + since: str | None, + until: str | None, + limit: int, + include_repeats: bool, + or_mode: bool, +) -> list[SearchResult]: fts_query = _sanitize_fts_query(query, or_mode=or_mode) - conditions = ["log_fts MATCH ?"] - params: list = [fts_query] + conditions = [ + "log_fts MATCH ?", + "(e.tenant_id = ? OR e.tenant_id = '')", + ] + params: list = [fts_query, tid] if severity: conditions.append("severity = ?") @@ -223,29 +259,33 @@ def _bm25_search( conditions.append("timestamp_iso <= ?") params.append(until) if not include_repeats: - conditions.append("repeat_count = 1") + conditions.append("f.repeat_count = 1") where = " AND ".join(conditions) params.append(limit) + raw = sqlite3.connect(str(db_path), timeout=30.0) + raw.row_factory = sqlite3.Row try: - rows = conn.execute( + rows = raw.execute( f""" - SELECT entry_id, source_id, sequence, timestamp_iso, severity, - repeat_count, out_of_order, matched_patterns, text, rank - FROM log_fts + SELECT f.entry_id, f.source_id, f.sequence, f.timestamp_iso, f.severity, + f.repeat_count, f.out_of_order, f.matched_patterns, f.text, f.rank + FROM log_fts f + JOIN log_entries e ON e.id = f.entry_id WHERE {where} - ORDER BY rank + ORDER BY f.rank LIMIT ? """, params, ).fetchall() - except sqlite3.OperationalError as e: - logger.warning("FTS query failed (%s) — index may not be built yet", e) - conn.close() + except sqlite3.OperationalError as exc: + logger.warning("FTS query failed (%s) — index may not be built yet", exc) return [] + finally: + raw.close() - results = [ + return [ SearchResult( entry_id=r["entry_id"], source_id=r["source_id"], @@ -256,12 +296,83 @@ def _bm25_search( out_of_order=bool(r["out_of_order"]), matched_patterns=json.loads(r["matched_patterns"] or "[]"), text=r["text"], - rank=r["rank"], + rank=float(r["rank"]), + ) + for r in rows + ] + + +def _pg_fts_search( + db_path: Path, + query: str, + tid: str, + severity: str | None, + source_filter: str | None, + pattern_filter: str | None, + since: str | None, + until: str | None, + limit: int, + include_repeats: bool, +) -> list[SearchResult]: + """Postgres FTS via tsvector column and websearch_to_tsquery.""" + tsq = "websearch_to_tsquery('english', %s)" + conditions = [ + f"text_tsv @@ {tsq}", + "(tenant_id = %s OR tenant_id = '')", + ] + params: list = [query, tid] + + if severity: + conditions.append("severity = %s") + params.append(severity.upper()) + if source_filter: + conditions.append("source_id LIKE %s") + params.append(f"%{source_filter}%") + if pattern_filter: + conditions.append("matched_patterns LIKE %s") + params.append(f'%"{pattern_filter}"%') + if since: + conditions.append("timestamp_iso >= %s") + params.append(since) + if until: + conditions.append("timestamp_iso <= %s") + params.append(until) + if not include_repeats: + conditions.append("repeat_count = 1") + + where = " AND ".join(conditions) + # ts_rank needs the tsquery again — append it then the limit + params.extend([query, limit]) + + with get_conn(db_path) as conn: + rows = conn.execute( + f""" + SELECT id AS entry_id, source_id, sequence, timestamp_iso, severity, + repeat_count, out_of_order, matched_patterns, text, + ts_rank(text_tsv, {tsq}) AS rank + FROM log_entries + WHERE {where} + ORDER BY rank DESC + LIMIT %s + """, + params, + ).fetchall() + + return [ + SearchResult( + entry_id=r["entry_id"], + source_id=r["source_id"], + sequence=r["sequence"], + timestamp_iso=r["timestamp_iso"], + severity=r["severity"], + repeat_count=r["repeat_count"], + out_of_order=bool(r["out_of_order"]), + matched_patterns=json.loads(r["matched_patterns"] or "[]"), + text=r["text"], + rank=float(r["rank"]), ) for r in rows ] - conn.close() - return results def entries_in_window( @@ -282,12 +393,12 @@ def entries_in_window( (e.g. network-syslog) don't crowd out lower-volume but more interesting ones. Errors/warnings are ranked first within each source partition. """ - conn = sqlite3.connect(str(db_path), timeout=30.0) - conn.execute("PRAGMA journal_mode=WAL") - conn.row_factory = sqlite3.Row - - conditions: list[str] = ["repeat_count = 1"] - params: list = [] + tid = resolve_tenant_id() + conditions: list[str] = [ + "repeat_count = 1", + "(tenant_id = ? OR tenant_id = '')", + ] + params: list = [tid] if since: conditions.append("timestamp_iso >= ?") @@ -305,8 +416,7 @@ def entries_in_window( where = " AND ".join(conditions) if per_source_cap is not None: - # Use a window function to cap rows per source, errors/warnings first. - query = f""" + sql = f""" WITH ranked AS ( SELECT id as entry_id, source_id, sequence, timestamp_iso, severity, repeat_count, out_of_order, matched_patterns, text, 0.0 as rank, @@ -333,7 +443,7 @@ def entries_in_window( """ params.extend([per_source_cap, limit]) else: - query = f""" + sql = f""" SELECT id as entry_id, source_id, sequence, timestamp_iso, severity, repeat_count, out_of_order, matched_patterns, text, 0.0 as rank FROM log_entries @@ -343,8 +453,8 @@ def entries_in_window( """ params.append(limit) - rows = conn.execute(query, params).fetchall() - conn.close() + with get_conn(db_path) as conn: + rows = conn.execute(sql, params).fetchall() return [ SearchResult( @@ -357,7 +467,7 @@ def entries_in_window( out_of_order=bool(r["out_of_order"]), matched_patterns=json.loads(r["matched_patterns"] or "[]"), text=r["text"], - rank=r["rank"], + rank=float(r["rank"]), ) for r in rows ] @@ -376,16 +486,14 @@ def recent_source_errors( Bypasses FTS ranking so text content doesn't affect which errors surface. Used by diagnose when FTS keyword search returns nothing for a known source. """ - conn = sqlite3.connect(str(db_path), timeout=30.0) - conn.execute("PRAGMA journal_mode=WAL") - conn.row_factory = sqlite3.Row - + tid = resolve_tenant_id() conditions = [ "source_id LIKE ?", "severity = ?", "repeat_count = 1", + "(tenant_id = ? OR tenant_id = '')", ] - params: list = [f"%{source_filter}%", severity.upper()] + params: list = [f"%{source_filter}%", severity.upper(), tid] if since: conditions.append("timestamp_iso >= ?") @@ -397,18 +505,18 @@ def recent_source_errors( params.append(limit) where = " AND ".join(conditions) - rows = conn.execute( - f""" - SELECT id as entry_id, source_id, sequence, timestamp_iso, severity, - repeat_count, out_of_order, matched_patterns, text, 0.0 as rank - FROM log_entries - WHERE {where} - ORDER BY timestamp_iso DESC - LIMIT ? - """, - params, - ).fetchall() - conn.close() + with get_conn(db_path) as conn: + rows = conn.execute( + f""" + SELECT id as entry_id, source_id, sequence, timestamp_iso, severity, + repeat_count, out_of_order, matched_patterns, text, 0.0 as rank + FROM log_entries + WHERE {where} + ORDER BY timestamp_iso DESC + LIMIT ? + """, + params, + ).fetchall() return [ SearchResult( @@ -421,7 +529,7 @@ def recent_source_errors( out_of_order=bool(r["out_of_order"]), matched_patterns=json.loads(r["matched_patterns"] or "[]"), text=r["text"], - rank=r["rank"], + rank=float(r["rank"]), ) for r in rows ] @@ -436,37 +544,34 @@ def list_sources(db_path: Path) -> list[dict]: returned as-is. ``unit_count`` reports how many distinct sub-units were merged into each row. """ - conn = sqlite3.connect(str(db_path), timeout=30.0) - conn.execute("PRAGMA journal_mode=WAL") - rows = conn.execute(""" - SELECT - CASE - WHEN INSTR(SUBSTR(source_id, INSTR(source_id, ':')+1), ':') > 0 - THEN SUBSTR(source_id, 1, - INSTR(source_id, ':') - + INSTR(SUBSTR(source_id, INSTR(source_id, ':')+1), ':') - - 1) - ELSE source_id - END AS group_id, - COUNT(DISTINCT source_id) AS unit_count, - COUNT(*) AS entry_count, - MIN(timestamp_iso) AS earliest, - MAX(timestamp_iso) AS latest, - SUM(CASE WHEN severity IN ('ERROR','CRITICAL','EMERGENCY','ALERT') - THEN 1 ELSE 0 END) AS error_count - FROM log_entries - GROUP BY group_id - ORDER BY entry_count DESC - """).fetchall() - conn.close() + tid = resolve_tenant_id() + group_expr = frag.source_group_expr("source_id") + with get_conn(db_path) as conn: + rows = conn.execute( + f""" + SELECT + {group_expr} AS group_id, + COUNT(DISTINCT source_id) AS unit_count, + COUNT(*) AS entry_count, + MIN(timestamp_iso) AS earliest, + MAX(timestamp_iso) AS latest, + SUM(CASE WHEN severity IN ('ERROR','CRITICAL','EMERGENCY','ALERT') + THEN 1 ELSE 0 END) AS error_count + FROM log_entries + WHERE (tenant_id = ? OR tenant_id = '') + GROUP BY group_id + ORDER BY entry_count DESC + """, + (tid,), + ).fetchall() return [ { - "source_id": r[0], - "unit_count": r[1], - "entry_count": r[2], - "earliest": r[3], - "latest": r[4], - "error_count": r[5], + "source_id": r["group_id"], + "unit_count": r["unit_count"], + "entry_count": r["entry_count"], + "earliest": r["earliest"], + "latest": r["latest"], + "error_count": r["error_count"], } for r in rows ] @@ -498,47 +603,80 @@ def stats_summary(db_path: Path, window_hours: int = 24, severity_overrides: lis Queries plain log_entries (not FTS) so it works even before the index is built. """ rules = _compile_overrides(severity_overrides or []) + tid = resolve_tenant_id() + group_expr = frag.source_group_expr("source_id") + since_iso = ( + datetime.now(timezone.utc) - timedelta(hours=window_hours) + ).strftime("%Y-%m-%dT%H:%M:%S") - conn = sqlite3.connect(str(db_path), timeout=30.0) - conn.execute("PRAGMA journal_mode=WAL") - conn.row_factory = sqlite3.Row + with get_conn(db_path) as conn: + row = conn.execute( + """ + SELECT + COUNT(*) AS total, + SUM(CASE WHEN severity = 'CRITICAL' THEN 1 ELSE 0 END) AS criticals, + SUM(CASE WHEN severity IN ('ERROR','CRITICAL','EMERGENCY','ALERT') THEN 1 ELSE 0 END) AS errors + FROM log_entries + WHERE timestamp_iso >= ? + AND repeat_count = 1 + AND (tenant_id = ? OR tenant_id = '') + """, + (since_iso, tid), + ).fetchone() + total_24h = int(row["total"] or 0) + criticals_24h = int(row["criticals"] or 0) + errors_24h = int(row["errors"] or 0) - since_expr = f"strftime('%Y-%m-%dT%H:%M:%S', 'now', '-{window_hours} hours')" + source_rows = conn.execute( + f""" + SELECT + {group_expr} AS group_id, + COUNT(*) AS entry_count, + SUM(CASE WHEN severity IN ('ERROR','CRITICAL','EMERGENCY','ALERT') THEN 1 ELSE 0 END) AS error_count, + MAX(timestamp_iso) AS latest + FROM log_entries + WHERE timestamp_iso >= ? + AND repeat_count = 1 + AND (tenant_id = ? OR tenant_id = '') + GROUP BY group_id + ORDER BY error_count DESC, entry_count DESC + """, + (since_iso, tid), + ).fetchall() - # Overall counts in window - row = conn.execute(f""" - SELECT - COUNT(*) AS total, - SUM(CASE WHEN severity = 'CRITICAL' THEN 1 ELSE 0 END) AS criticals, - SUM(CASE WHEN severity IN ('ERROR','CRITICAL','EMERGENCY','ALERT') THEN 1 ELSE 0 END) AS errors - FROM log_entries - WHERE timestamp_iso >= {since_expr} - AND repeat_count = 1 - """).fetchone() - total_24h = int(row["total"] or 0) - criticals_24h = int(row["criticals"] or 0) - errors_24h = int(row["errors"] or 0) + crit_rows = conn.execute( + """ + SELECT id as entry_id, source_id, timestamp_iso, severity, text + FROM log_entries + WHERE severity = 'CRITICAL' + AND repeat_count = 1 + AND (tenant_id = ? OR tenant_id = '') + ORDER BY timestamp_iso DESC + LIMIT 25 + """, + (tid,), + ).fetchall() + + timeline_rows = conn.execute( + """ + SELECT id as entry_id, source_id, timestamp_iso, severity, text + FROM log_entries + WHERE severity IN ('CRITICAL','ERROR','WARN','WARNING','EMERGENCY','ALERT') + AND timestamp_iso >= ? + AND timestamp_iso IS NOT NULL + AND repeat_count = 1 + AND (tenant_id = ? OR tenant_id = '') + ORDER BY timestamp_iso DESC + LIMIT 300 + """, + (since_iso, tid), + ).fetchall() + + last_row = conn.execute( + "SELECT MAX(ingest_time) AS t FROM log_entries WHERE (tenant_id = ? OR tenant_id = '')", + (tid,), + ).fetchone() - # Per-source breakdown — grouped by prefix:host stem (same logic as list_sources). - source_rows = conn.execute(f""" - SELECT - CASE - WHEN INSTR(SUBSTR(source_id, INSTR(source_id, ':')+1), ':') > 0 - THEN SUBSTR(source_id, 1, - INSTR(source_id, ':') - + INSTR(SUBSTR(source_id, INSTR(source_id, ':')+1), ':') - - 1) - ELSE source_id - END AS group_id, - COUNT(*) AS entry_count, - SUM(CASE WHEN severity IN ('ERROR','CRITICAL','EMERGENCY','ALERT') THEN 1 ELSE 0 END) AS error_count, - MAX(timestamp_iso) AS latest - FROM log_entries - WHERE timestamp_iso >= {since_expr} - AND repeat_count = 1 - GROUP BY group_id - ORDER BY error_count DESC, entry_count DESC - """).fetchall() source_health = [ { "source_id": r["group_id"], @@ -549,16 +687,6 @@ def stats_summary(db_path: Path, window_hours: int = 24, severity_overrides: lis for r in source_rows ] - # Fetch candidate criticals (fetch more so filtering doesn't leave us with too few) - crit_rows = conn.execute(""" - SELECT id as entry_id, source_id, timestamp_iso, severity, text - FROM log_entries - WHERE severity = 'CRITICAL' AND repeat_count = 1 - ORDER BY timestamp_iso DESC - LIMIT 25 - """).fetchall() - - # Apply overrides: skip entries whose effective severity is no longer CRITICAL suppressed = 0 recent_criticals = [] for r in crit_rows: @@ -576,10 +704,18 @@ def stats_summary(db_path: Path, window_hours: int = 24, severity_overrides: lis else: suppressed += 1 - last_row = conn.execute("SELECT MAX(ingest_time) AS t FROM log_entries").fetchone() - last_gleaned: str | None = last_row["t"] if last_row else None + timeline_events = [ + { + "entry_id": r["entry_id"], + "source_id": r["source_id"], + "timestamp_iso": r["timestamp_iso"], + "severity": r["severity"], + "text": r["text"], + } + for r in timeline_rows + ] - conn.close() + last_gleaned: str | None = last_row["t"] if last_row else None return { "window_hours": window_hours, @@ -590,6 +726,7 @@ def stats_summary(db_path: Path, window_hours: int = 24, severity_overrides: lis "recent_criticals": recent_criticals, "suppressed_criticals": suppressed, "last_gleaned": last_gleaned, + "timeline_events": timeline_events, } diff --git a/app/tasks/anomaly_scorer.py b/app/tasks/anomaly_scorer.py new file mode 100644 index 0000000..e952b62 --- /dev/null +++ b/app/tasks/anomaly_scorer.py @@ -0,0 +1,114 @@ +"""Background anomaly scoring task. + +Runs score_unscored() after each glean cycle (triggered by glean_scheduler) +or on its own interval when TURNSTONE_ANOMALY_INTERVAL is set. + +Set TURNSTONE_ANOMALY_MODEL to a HuggingFace model ID to activate. +When the env var is empty (default) the scorer is a no-op. +""" +from __future__ import annotations + +import asyncio +import logging +import os +from dataclasses import dataclass, field +from datetime import datetime, timedelta, timezone +from pathlib import Path + +from app.services.anomaly import ScoringResult, score_unscored + +logger = logging.getLogger(__name__) + +_DEFAULT_INTERVAL = int(os.environ.get("TURNSTONE_ANOMALY_INTERVAL", "0")) + +_lock = asyncio.Lock() + + +@dataclass +class ScorerState: + last_run_at: str | None = None + last_duration_s: float | None = None + last_scored: int = 0 + last_detections: int = 0 + last_error: str | None = None + run_count: int = 0 + next_run_at: str | None = None + running: bool = False + total_scored: int = 0 + total_detections: int = 0 + + +_state = ScorerState() + + +def get_state() -> ScorerState: + return _state + + +async def run_once( + db_path: Path, + model_id: str = "", + device: str = "cpu", + batch_size: int = 256, + threshold: float = 0.75, +) -> ScoringResult: + """Score unscored entries once. Skips if already running or model not configured.""" + if _lock.locked(): + return ScoringResult(skipped=True, error="scorer already running") + + async with _lock: + _state.running = True + started = datetime.now(tz=timezone.utc) + try: + loop = asyncio.get_running_loop() + result: ScoringResult = await loop.run_in_executor( + None, + lambda: score_unscored(db_path, model_id, device, batch_size, threshold), + ) + duration = (datetime.now(tz=timezone.utc) - started).total_seconds() + _state.last_run_at = started.isoformat() + _state.last_duration_s = round(duration, 2) + _state.last_scored = result.scored + _state.last_detections = result.detections + _state.last_error = result.error + _state.run_count += 1 + _state.total_scored += result.scored + _state.total_detections += result.detections + if not result.skipped: + logger.info( + "Anomaly scorer: %d scored, %d detections in %.1fs", + result.scored, result.detections, duration, + ) + return result + except Exception as exc: + duration = (datetime.now(tz=timezone.utc) - started).total_seconds() + _state.last_run_at = started.isoformat() + _state.last_duration_s = round(duration, 2) + _state.last_error = str(exc) + _state.run_count += 1 + logger.error("Anomaly scorer failed: %s", exc) + return ScoringResult(error=str(exc)) + finally: + _state.running = False + + +async def scorer_loop( + db_path: Path, + model_id: str, + device: str, + interval_s: int, + batch_size: int = 256, + threshold: float = 0.75, +) -> None: + """Score unscored entries every interval_s seconds until cancelled.""" + logger.info("Anomaly scorer loop started — interval %ds, model: %s", interval_s, model_id) + while True: + await run_once(db_path, model_id, device, batch_size, threshold) + next_run = datetime.now(tz=timezone.utc) + timedelta(seconds=interval_s) + _state.next_run_at = next_run.isoformat() + try: + await asyncio.sleep(interval_s) + except asyncio.CancelledError: + logger.info("Anomaly scorer loop cancelled") + _state.next_run_at = None + raise diff --git a/app/tasks/cybersec_scorer.py b/app/tasks/cybersec_scorer.py new file mode 100644 index 0000000..6b3ca4c --- /dev/null +++ b/app/tasks/cybersec_scorer.py @@ -0,0 +1,84 @@ +"""Background task wrapper for the cybersec zero-shot scoring pipeline.""" +from __future__ import annotations + +import asyncio +import logging +from dataclasses import dataclass, field +from datetime import datetime, timezone +from pathlib import Path + +from app.services.cybersec import score_security_entries + +logger = logging.getLogger(__name__) + +_lock = asyncio.Lock() + + +@dataclass +class CybersecState: + last_run_at: str | None = None + last_duration_s: float | None = None + last_scored: int = 0 + last_detections: int = 0 + last_error: str | None = None + run_count: int = 0 + running: bool = False + total_scored: int = 0 + total_detections: int = 0 + + +_state = CybersecState() + + +def get_state() -> dict: + return { + "last_run_at": _state.last_run_at, + "last_duration_s":_state.last_duration_s, + "last_scored": _state.last_scored, + "last_detections":_state.last_detections, + "last_error": _state.last_error, + "run_count": _state.run_count, + "running": _state.running, + "total_scored": _state.total_scored, + "total_detections": _state.total_detections, + } + + +async def run_once( + db_path: Path, + model_id: str, + device: str = "cpu", + batch_size: int = 32, + threshold: float = 0.60, +) -> None: + """Single cybersec scoring pass — no-op if already running or no model set.""" + if not model_id or _lock.locked(): + return + + async with _lock: + _state.running = True + started = datetime.now(tz=timezone.utc) + try: + loop = asyncio.get_running_loop() + result = await loop.run_in_executor( + None, + lambda: score_security_entries(db_path, model_id, device, batch_size, threshold), + ) + elapsed = (datetime.now(tz=timezone.utc) - started).total_seconds() + _state.last_run_at = started.isoformat() + _state.last_duration_s = elapsed + _state.last_scored = result.scored + _state.last_detections = result.detections + _state.last_error = result.error + _state.run_count += 1 + _state.total_scored += result.scored + _state.total_detections += result.detections + if result.error: + logger.error("cybersec scorer error: %s", result.error) + elif not result.skipped: + logger.info( + "cybersec scorer: scored=%d detections=%d in %.1fs", + result.scored, result.detections, elapsed, + ) + finally: + _state.running = False diff --git a/app/tasks/glean_scheduler.py b/app/tasks/glean_scheduler.py index 02c6567..edf9255 100644 --- a/app/tasks/glean_scheduler.py +++ b/app/tasks/glean_scheduler.py @@ -11,7 +11,7 @@ from __future__ import annotations import asyncio import json import logging -import sqlite3 +from app.db import get_conn, resolve_tenant_id from dataclasses import dataclass, field from datetime import datetime, timedelta, timezone from pathlib import Path @@ -20,6 +20,9 @@ from typing import Any import httpx from app.glean.pipeline import glean_sources +from app.tasks.anomaly_scorer import run_once as _run_scorer +from app.tasks.cybersec_scorer import run_once as _run_cybersec +from app.tasks.incident_detector import run_once as _run_incident_detector logger = logging.getLogger(__name__) @@ -49,9 +52,8 @@ def get_state() -> IngestState: def _query_matched_since(db_path: Path, since: str | None) -> list[dict]: """Return entries with non-empty matched_patterns, optionally filtered by ingest_time.""" - conn = sqlite3.connect(str(db_path), timeout=30.0) - conn.row_factory = sqlite3.Row - try: + tid = resolve_tenant_id() + with get_conn(db_path) as conn: if since: rows = conn.execute( """ @@ -59,11 +61,13 @@ def _query_matched_since(db_path: Path, since: str | None) -> list[dict]: ingest_time, severity, repeat_count, out_of_order, matched_patterns, text FROM log_entries - WHERE matched_patterns != '[]' AND ingest_time > ? + WHERE matched_patterns != '[]' + AND ingest_time > ? + AND (tenant_id = ? OR tenant_id = '') ORDER BY ingest_time LIMIT 5000 """, - (since,), + (since, tid), ).fetchall() else: rows = conn.execute( @@ -73,13 +77,13 @@ def _query_matched_since(db_path: Path, since: str | None) -> list[dict]: matched_patterns, text FROM log_entries WHERE matched_patterns != '[]' + AND (tenant_id = ? OR tenant_id = '') ORDER BY ingest_time DESC LIMIT 5000 """, + (tid,), ).fetchall() - return [dict(r) for r in rows] - finally: - conn.close() + return [dict(r) for r in rows] async def submit_matched( @@ -122,6 +126,14 @@ async def run_once( submit_endpoint: str | None = None, source_host: str = "unknown", force: bool = False, + anomaly_model: str = "", + anomaly_device: str = "cpu", + anomaly_threshold: float = 0.75, + cybersec_model: str = "", + cybersec_device: str = "cpu", + cybersec_threshold: float = 0.60, + incidents_db_path: Path | None = None, + auto_incident: bool = True, ) -> dict[str, Any]: """Ingest all sources once, then submit matched entries if configured. @@ -162,6 +174,18 @@ async def run_once( if submit_endpoint: await submit_matched(db_path, submit_endpoint, source_host, since=_state.last_submitted_at) + if anomaly_model: + await _run_scorer(db_path, anomaly_model, anomaly_device, threshold=anomaly_threshold) + + if cybersec_model: + await _run_cybersec(db_path, cybersec_model, cybersec_device, threshold=cybersec_threshold) + + if auto_incident and incidents_db_path: + glean_started_iso = _state.last_run_at + result = await _run_incident_detector(db_path, incidents_db_path, since=glean_started_iso) + if result["created"]: + logger.info("Incident detector: %d incident(s) auto-created", result["created"]) + return {"ok": True, "stats": _state.last_stats, "duration_s": _state.last_duration_s} @@ -172,13 +196,37 @@ async def scheduler_loop( interval_s: int, submit_endpoint: str | None = None, source_host: str = "unknown", + anomaly_model: str = "", + anomaly_device: str = "cpu", + anomaly_threshold: float = 0.75, + cybersec_model: str = "", + cybersec_device: str = "cpu", + cybersec_threshold: float = 0.60, + incidents_db_path: Path | None = None, + auto_incident: bool = True, ) -> None: - """Run glean + optional submission every interval_s seconds until cancelled.""" + """Run glean + optional submission + optional anomaly/cybersec scoring every interval_s seconds.""" logger.info("Ingest scheduler started — interval %ds, sources: %s", interval_s, sources_file) if submit_endpoint: logger.info("Submission enabled — endpoint: %s", submit_endpoint) + if anomaly_model: + logger.info("Anomaly scoring enabled — model: %s", anomaly_model) + if cybersec_model: + logger.info("Cybersec scoring enabled — model: %s", cybersec_model) + if auto_incident and incidents_db_path: + logger.info("Auto-incident detection enabled") while True: - await run_once(sources_file, db_path, pattern_file, submit_endpoint, source_host) + await run_once( + sources_file, db_path, pattern_file, submit_endpoint, source_host, + anomaly_model=anomaly_model, + anomaly_device=anomaly_device, + anomaly_threshold=anomaly_threshold, + cybersec_model=cybersec_model, + cybersec_device=cybersec_device, + cybersec_threshold=cybersec_threshold, + incidents_db_path=incidents_db_path, + auto_incident=auto_incident, + ) next_run = datetime.now(tz=timezone.utc) + timedelta(seconds=interval_s) _state.next_run_at = next_run.isoformat() try: diff --git a/app/tasks/incident_detector.py b/app/tasks/incident_detector.py new file mode 100644 index 0000000..6a62b2f --- /dev/null +++ b/app/tasks/incident_detector.py @@ -0,0 +1,188 @@ +"""Post-glean automatic incident detection. + +After each batch glean, scan entries ingested since the last run for +ERROR/CRITICAL clusters. If a source produces >= threshold errors within +window_s seconds, auto-create an incident unless one already exists for +that source in that time window. + +Environment variables (all optional): + TURNSTONE_AUTO_INCIDENT_THRESHOLD integer, default 5 + TURNSTONE_AUTO_INCIDENT_WINDOW seconds, default 600 (10 min) +""" +from __future__ import annotations + +import asyncio +import logging +import os +from collections import defaultdict +from datetime import datetime, timezone +from pathlib import Path + +from app.db import get_conn, resolve_tenant_id +from app.services.incidents import create_incident + +logger = logging.getLogger(__name__) + +_THRESHOLD = int(os.environ.get("TURNSTONE_AUTO_INCIDENT_THRESHOLD", "5")) +_WINDOW_S = int(os.environ.get("TURNSTONE_AUTO_INCIDENT_WINDOW", "600")) + +# Severity rank — used to pick the cluster's worst severity +_SEV_RANK = {"CRITICAL": 3, "ERROR": 2, "WARN": 1, "INFO": 0, "DEBUG": 0} + + +def _query_recent_errors(db_path: Path, since: str | None) -> list[dict]: + tid = resolve_tenant_id() + with get_conn(db_path) as conn: + if since: + rows = conn.execute( + """ + SELECT source_id, timestamp_iso, severity + FROM log_entries + WHERE severity IN ('ERROR', 'CRITICAL') + AND ingest_time > ? + AND (tenant_id = ? OR tenant_id = '') + ORDER BY source_id, timestamp_iso ASC + """, + (since, tid), + ).fetchall() + else: + rows = conn.execute( + """ + SELECT source_id, timestamp_iso, severity + FROM log_entries + WHERE severity IN ('ERROR', 'CRITICAL') + AND (tenant_id = ? OR tenant_id = '') + ORDER BY source_id, timestamp_iso ASC + LIMIT 10000 + """, + (tid,), + ).fetchall() + return [dict(r) for r in rows] + + +def _parse_ts(iso: str | None) -> float | None: + """Parse ISO timestamp to epoch seconds; return None on failure.""" + if not iso: + return None + try: + dt = datetime.fromisoformat(iso.replace("Z", "+00:00")) + return dt.timestamp() + except (ValueError, TypeError): + return None + + +def _find_clusters( + events: list[dict], window_s: int, threshold: int +) -> list[tuple[str, str, str]]: + """Return (started_at_iso, ended_at_iso, worst_severity) for each cluster.""" + # Filter to events with parseable timestamps, sorted ascending + timed = [] + for e in events: + t = _parse_ts(e["timestamp_iso"]) + if t is not None: + timed.append((t, e["timestamp_iso"], e["severity"])) + timed.sort() + + clusters: list[tuple[str, str, str]] = [] + i = 0 + while i < len(timed): + j = i + while j < len(timed) and timed[j][0] - timed[i][0] <= window_s: + j += 1 + count = j - i + if count >= threshold: + worst = max((timed[k][2] for k in range(i, j)), key=lambda s: _SEV_RANK.get(s, 0)) + clusters.append((timed[i][1], timed[j - 1][1], worst)) + i = j # skip past the cluster to avoid overlap + else: + i += 1 + return clusters + + +def _incident_exists_for_cluster( + incidents_db_path: Path, source_id: str, started_at: str, ended_at: str +) -> bool: + """Return True if an auto-incident for this source already covers the window.""" + issue_type = f"auto:{source_id}" + start_ts = _parse_ts(started_at) + end_ts = _parse_ts(ended_at) + if start_ts is None or end_ts is None: + return False + tid = resolve_tenant_id() + with get_conn(incidents_db_path) as conn: + rows = conn.execute( + """ + SELECT started_at, ended_at FROM incidents + WHERE issue_type = ? + AND (tenant_id = ? OR tenant_id = '') + """, + (issue_type, tid), + ).fetchall() + for row in rows: + ex_start = _parse_ts(row["started_at"]) + ex_end = _parse_ts(row["ended_at"]) + if ex_start is None or ex_end is None: + continue + # Overlap check: two intervals [a,b] and [c,d] overlap when a<=d and b>=c + if ex_start <= end_ts and ex_end >= start_ts: + return True + return False + + +def detect_and_create( + db_path: Path, + incidents_db_path: Path, + since: str | None, + threshold: int = _THRESHOLD, + window_s: int = _WINDOW_S, +) -> dict[str, int]: + """Detect error clusters and create incidents. Returns {"created": N}.""" + entries = _query_recent_errors(db_path, since) + if not entries: + return {"created": 0} + + by_source: dict[str, list[dict]] = defaultdict(list) + for e in entries: + by_source[e["source_id"]].append(e) + + created = 0 + for source_id, events in by_source.items(): + clusters = _find_clusters(events, window_s, threshold) + for started_at, ended_at, worst_sev in clusters: + if _incident_exists_for_cluster(incidents_db_path, source_id, started_at, ended_at): + continue + n = len(events) # event count for this source in the glean window + sev_label = "critical" if worst_sev == "CRITICAL" else "high" + create_incident( + incidents_db_path, + label=f"Auto: {source_id} — {n} errors", + issue_type=f"auto:{source_id}", + started_at=started_at, + ended_at=ended_at, + notes="Auto-detected error cluster. Review and label as needed.", + severity=sev_label, + ) + logger.info( + "Auto-incident created: source=%s window=[%s, %s] severity=%s", + source_id, started_at, ended_at, sev_label, + ) + created += 1 + + if created: + logger.info("Incident detector: %d new incident(s) created", created) + return {"created": created} + + +async def run_once( + db_path: Path, + incidents_db_path: Path, + since: str | None, + threshold: int = _THRESHOLD, + window_s: int = _WINDOW_S, +) -> dict[str, int]: + """Async wrapper — runs detection in a thread to avoid blocking the event loop.""" + loop = asyncio.get_running_loop() + return await loop.run_in_executor( + None, + lambda: detect_and_create(db_path, incidents_db_path, since, threshold, window_s), + ) diff --git a/app/watch/watcher.py b/app/watch/watcher.py index 1108087..c397ae9 100644 --- a/app/watch/watcher.py +++ b/app/watch/watcher.py @@ -8,7 +8,6 @@ from __future__ import annotations import json import logging -import sqlite3 import subprocess import threading from dataclasses import dataclass, field @@ -21,17 +20,16 @@ import yaml from app.glean import journald as journald_parser, syslog as syslog_parser from app.glean import plaintext as plaintext_parser, servarr as servarr_parser, plex as plex_parser from app.glean import qbittorrent as qbit_parser, caddy as caddy_parser -from app.glean.pipeline import _detect_format +from app.db import get_conn +from app.db.schema import ensure_schema +from app.glean.pipeline import _detect_format, _write_batch from app.glean.base import _compile, load_patterns, now_iso -from app.glean.pipeline import _write_batch, _SCHEMA -from app.services.search import build_fts_index from app.services.models import RetrievedEntry logger = logging.getLogger(__name__) FLUSH_INTERVAL_SEC = 10 FLUSH_BATCH_SIZE = 100 -FTS_SYNC_EVERY_N_FLUSHES = 3 # sync FTS every ~30s under normal load # ── Config ──────────────────────────────────────────────────────────────────── @@ -111,10 +109,7 @@ class WatchSource: patterns = load_patterns(self.pattern_file) compiled = _compile(patterns) - conn = sqlite3.connect(str(self.db_path), timeout=30.0) - conn.execute("PRAGMA journal_mode=WAL") - conn.executescript(_SCHEMA) - conn.commit() + ensure_schema(self.db_path) try: cmd = self._build_command() @@ -127,12 +122,10 @@ class WatchSource: text=True, bufsize=1, ) - self._drain(conn, compiled) + self._drain(compiled) except Exception as exc: self._error = str(exc) logger.error("Watch source %r crashed: %s", self.config.source_id, exc) - finally: - conn.close() def _build_command(self) -> list[str] | None: t = self.config.source_type @@ -193,7 +186,7 @@ class WatchSource: return [] - def _drain(self, conn: sqlite3.Connection, compiled) -> None: + def _drain(self, compiled) -> None: """Read lines from the subprocess and flush to DB periodically.""" assert self._proc is not None buffer: list[str] = [] @@ -221,29 +214,28 @@ class WatchSource: should_flush = len(buffer) >= FLUSH_BATCH_SIZE or elapsed >= FLUSH_INTERVAL_SEC if buffer and should_flush: - flush_count = self._flush(conn, buffer, compiled, flush_count) + flush_count = self._flush(buffer, compiled, flush_count) buffer.clear() last_flush = datetime.now(tz=timezone.utc) # Flush remainder if buffer: - self._flush(conn, buffer, compiled, flush_count) + self._flush(buffer, compiled, flush_count) - def _flush(self, conn: sqlite3.Connection, lines: list[str], compiled, flush_count: int) -> int: + def _flush(self, lines: list[str], compiled, flush_count: int) -> int: ingest_time = now_iso() try: entries = self._parse_lines(lines, ingest_time, compiled) if entries: - _write_batch(conn, entries) - conn.commit() + with get_conn(self.db_path) as conn: + _write_batch(conn, entries) + conn.commit() self._entry_count += len(entries) self._last_event = now_iso() if entries: self._last_event = entries[-1].timestamp_iso or self._last_event flush_count += 1 - if flush_count % FTS_SYNC_EVERY_N_FLUSHES == 0: - build_fts_index(self.db_path) except Exception as exc: logger.warning("Flush error for %r: %s", self.config.source_id, exc) return flush_count diff --git a/docker-compose.yml b/docker-compose.yml new file mode 100644 index 0000000..2e064a4 --- /dev/null +++ b/docker-compose.yml @@ -0,0 +1,68 @@ +version: "3.9" + +# Turnstone with external Postgres DB. +# Data lives in the named volume `turnstone_pgdata` — survives image rebuilds. +# To adopt an EXISTING Postgres install, set DATABASE_URL to point at it and +# remove the `db` service and `depends_on` blocks. +# +# Quick start: +# docker compose up -d +# # Then open http://localhost:8520 + +services: + db: + image: postgres:16-alpine + restart: unless-stopped + environment: + POSTGRES_DB: turnstone + POSTGRES_USER: turnstone + POSTGRES_PASSWORD: ${POSTGRES_PASSWORD:-turnstone_dev} + volumes: + - turnstone_pgdata:/var/lib/postgresql/data + healthcheck: + test: ["CMD-SHELL", "pg_isready -U turnstone -d turnstone"] + interval: 5s + timeout: 5s + retries: 5 + + turnstone: + build: . + restart: unless-stopped + ports: + - "${TURNSTONE_PORT:-8520}:8520" + depends_on: + db: + condition: service_healthy + environment: + # Backend selection — comment out DATABASE_URL to fall back to SQLite + DATABASE_URL: postgresql://turnstone:${POSTGRES_PASSWORD:-turnstone_dev}@db:5432/turnstone + TURNSTONE_TENANT_ID: ${TURNSTONE_TENANT_ID:-} + TURNSTONE_API_KEY: ${TURNSTONE_API_KEY:-} + TURNSTONE_GLEAN_INTERVAL: ${TURNSTONE_GLEAN_INTERVAL:-900} + TURNSTONE_SOURCE_HOST: ${TURNSTONE_SOURCE_HOST:-} + TURNSTONE_SUBMIT_ENDPOINT: ${TURNSTONE_SUBMIT_ENDPOINT:-} + # --- Multi-agent diagnose pipeline --- + TURNSTONE_MULTI_AGENT_DIAGNOSE: ${TURNSTONE_MULTI_AGENT_DIAGNOSE:-false} + TURNSTONE_CLASSIFIER_MODEL: ${TURNSTONE_CLASSIFIER_MODEL:-} + TURNSTONE_EMBED_BACKEND: ${TURNSTONE_EMBED_BACKEND:-} + TURNSTONE_EMBED_MODEL: ${TURNSTONE_EMBED_MODEL:-} + TURNSTONE_EMBED_DEVICE: ${TURNSTONE_EMBED_DEVICE:-cpu} + # --- Cybersec scoring pipeline --- + TURNSTONE_CYBERSEC_MODEL: ${TURNSTONE_CYBERSEC_MODEL:-} + TURNSTONE_CYBERSEC_DEVICE: ${TURNSTONE_CYBERSEC_DEVICE:-cpu} + TURNSTONE_CYBERSEC_THRESHOLD: ${TURNSTONE_CYBERSEC_THRESHOLD:-0.60} + # --- Anomaly scoring pipeline --- + TURNSTONE_ANOMALY_MODEL: ${TURNSTONE_ANOMALY_MODEL:-} + TURNSTONE_ANOMALY_DEVICE: ${TURNSTONE_ANOMALY_DEVICE:-cpu} + TURNSTONE_ANOMALY_THRESHOLD: ${TURNSTONE_ANOMALY_THRESHOLD:-0.75} + TURNSTONE_ANOMALY_INTERVAL: ${TURNSTONE_ANOMALY_INTERVAL:-0} + # --- HuggingFace model cache --- + HF_HOME: /hf_cache + volumes: + - ./patterns:/app/patterns:ro + - ./data:/app/data # optional: persists SQLite files if DATABASE_URL unset + - ${HF_CACHE_PATH:-/Library/Assets/LLM}:/hf_cache:ro # shared model cache + +volumes: + turnstone_pgdata: + name: turnstone_pgdata diff --git a/docker-standalone.sh b/docker-standalone.sh index 7098fa8..3d77a9f 100755 --- a/docker-standalone.sh +++ b/docker-standalone.sh @@ -62,7 +62,10 @@ set -euo pipefail REPO_DIR="${HOME}/turnstone" DATA_DIR="${REPO_DIR}/data" PATTERNS_DIR="${REPO_DIR}/patterns" -HF_CACHE_DIR="${REPO_DIR}/hf-cache" # persists downloaded ML models across restarts +# HF_CACHE_DIR: override to a shared cache directory to avoid re-downloading models. +# Example (Heimdall, where byviz/bylastic_classification_logs is already cached): +# export HF_CACHE_DIR=/Library/Assets/LLM +HF_CACHE_DIR="${HF_CACHE_DIR:-${REPO_DIR}/hf-cache}" TZ="${TZ:-America/Los_Angeles}" @@ -83,11 +86,21 @@ TZ="${TZ:-America/Los_Angeles}" # bash ~/turnstone/docker-standalone.sh # +# ── Anomaly scoring pipeline (IDS / watchdog) ──────────────────────────────── +# Set TURNSTONE_ANOMALY_MODEL to enable automatic anomaly scoring after each +# glean run. The byviz classifier (already used by the diagnose pipeline) is +# a good default — it's cached alongside the other models. +# +# export TURNSTONE_ANOMALY_MODEL=byviz/bylastic_classification_logs +# export TURNSTONE_ANOMALY_THRESHOLD=0.80 # confidence floor (default 0.75) +# bash ~/turnstone/docker-standalone.sh +# + # ── Multi-agent diagnose pipeline ──────────────────────────────────────────── # Enable the 5-stage ML pipeline to get smarter diagnose results. # -# If your host has WireGuard to Heimdall's LAN (e.g. Huginn): -# export GPU_SERVER_URL=http://:7700 +# If your host has WireGuard to Heimdall's LAN: +# export GPU_SERVER_URL=http://:7700 # export TURNSTONE_MULTI_AGENT_DIAGNOSE=true # bash ~/turnstone/docker-standalone.sh # @@ -134,6 +147,13 @@ docker run -d \ -e TURNSTONE_EMBED_BACKEND="${TURNSTONE_EMBED_BACKEND:-sentence_transformers}" \ -e TURNSTONE_EMBED_MODEL="${TURNSTONE_EMBED_MODEL:-sentence-transformers/all-MiniLM-L6-v2}" \ -e TURNSTONE_EMBED_DEVICE="${TURNSTONE_EMBED_DEVICE:-cpu}" \ + -e TURNSTONE_CYBERSEC_MODEL="${TURNSTONE_CYBERSEC_MODEL:-}" \ + -e TURNSTONE_CYBERSEC_DEVICE="${TURNSTONE_CYBERSEC_DEVICE:-cpu}" \ + -e TURNSTONE_CYBERSEC_THRESHOLD="${TURNSTONE_CYBERSEC_THRESHOLD:-0.60}" \ + -e TURNSTONE_ANOMALY_MODEL="${TURNSTONE_ANOMALY_MODEL:-}" \ + -e TURNSTONE_ANOMALY_DEVICE="${TURNSTONE_ANOMALY_DEVICE:-cpu}" \ + -e TURNSTONE_ANOMALY_THRESHOLD="${TURNSTONE_ANOMALY_THRESHOLD:-0.75}" \ + -e TURNSTONE_ANOMALY_INTERVAL="${TURNSTONE_ANOMALY_INTERVAL:-0}" \ localhost/turnstone:latest echo "" diff --git a/patterns/default.yaml b/patterns/default.yaml index 6142a7c..c5ea6a0 100644 --- a/patterns/default.yaml +++ b/patterns/default.yaml @@ -4,7 +4,7 @@ # # domain: groups patterns into service health domains for triage-level summaries. # Valid domains: service_health | networking | auth | storage | memory | -# kernel | power | web_proxy | media | gpu +# kernel | power | web_proxy | media | gpu | audio # # Patterns are applied in order; multiple can match a single entry. @@ -275,3 +275,41 @@ patterns: severity: ERROR domain: power description: Undervoltage event — instability risk, check PSU and cable connections + + # ── Audio / PipeWire / ALSA ────────────────────────────────────────────────── + + - name: pipewire_overflow + pattern: "(OVERFLOW channel|stream.*OVERFLOW|protocol.pulse.*OVERFLOW)" + severity: WARN + domain: audio + description: PipeWire-Pulse stream buffer overflow — client not draining audio fast enough; usually indicates a quantum/period-size mismatch or CPU scheduling issue + + - name: pipewire_underrun + pattern: "(pw\\.node.*underrun|spa\\.alsa.*underrun|alsa.*underrun|UNDERRUN)" + severity: WARN + domain: audio + description: PipeWire/ALSA buffer underrun (xrun) — audio thread missed its deadline; increase quantum or period-size for the affected device + + - name: alsa_xrun + pattern: "(ALSA.*[Xx][Rr][Uu][Nn]|alsa.*xrun|snd_pcm.*xrun|pcm.*underrun|pcm.*overrun)" + severity: WARN + domain: audio + description: ALSA xrun (hardware buffer overrun/underrun) — increase api.alsa.period-size via WirePlumber rule or raise clock.min-quantum + + - name: pipewire_quantum_mismatch + pattern: "(quantum.*mismatch|rate.*mismatch|sample.rate.*mismatch|resampl.*fail|can.*t adapt quantum)" + severity: WARN + domain: audio + description: PipeWire quantum or sample-rate mismatch between nodes — check for mixed 44100/48000 streams; may need per-device WirePlumber rules + + - name: pipewire_node_error + pattern: "(pw\\.node.*error|node.*ERROR|pipewire.*failed to set|spa\\.alsa.*error|alsa_sink.*error|alsa_source.*error)" + severity: ERROR + domain: audio + description: PipeWire node error — device may be unavailable or misconfigured + + - name: pipewire_jackdbus_missing + pattern: "(jackdbus.*reply|jackaudio.*service.*not.*provided|org\\.jackaudio\\.service)" + severity: INFO + domain: audio + description: PipeWire JACK D-Bus probe — JACK not running; benign on non-JACK systems, fires once per PipeWire restart diff --git a/patterns/sources-example.yaml b/patterns/sources-example.yaml new file mode 100644 index 0000000..804601b --- /dev/null +++ b/patterns/sources-example.yaml @@ -0,0 +1,50 @@ +# Turnstone log sources — example node (Docker/Podman, self-hosted media stack) +# +# Copy this file to your patterns directory and edit for your setup. +# Container paths: /opt and /var/log are bind-mounted read-only. +# journal-export.jsonl is written to /data/ by export_journal.sh (run via cron before glean). +# +# Add or remove sources freely. Missing paths are skipped with a warning. + +sources: + # ── System ──────────────────────────────────────────────────────────────── + # Requires: cron job to run export_journal.sh before each glean. + # Example cron (every 15 min — edit paths for your install): + # */15 * * * * /opt/turnstone/scripts/export_journal.sh \ + # /opt/turnstone-data/ + - id: system-journal + path: /data/journal-export.jsonl + + - id: dmesg + path: /data/dmesg-export.txt + + # ── Servarr stack ───────────────────────────────────────────────────────── + - id: sonarr + path: /opt/sonarr/config/logs/sonarr.0.txt + + - id: radarr + path: /opt/radarr/config/logs/radarr.0.txt + + - id: bazarr + path: /opt/bazarr/config/log/bazarr.log + + - id: prowlarr + path: /opt/prowlarr/config/logs/prowlarr.0.txt + + # ── Media server / tracking ──────────────────────────────────────────────── + - id: tautulli + path: /opt/tautulli/config/logs/plex_websocket.log + + # ── Download automation ──────────────────────────────────────────────────── + - id: autoscan + path: /opt/autoscan/config/autoscan.log + + # ── Web / proxy ──────────────────────────────────────────────────────────── + - id: organizr-nginx + path: /opt/organizr/log/nginx/error.log + + - id: organizr-app + path: /opt/organizr/www/organizr/server.log + + - id: nextcloud-nginx + path: /opt/nextcloud/config/log/nginx/error.log diff --git a/podman-standalone.sh b/podman-standalone.sh index 469c490..5bfc7f8 100755 --- a/podman-standalone.sh +++ b/podman-standalone.sh @@ -46,7 +46,7 @@ # ── Adding Caddy reverse proxy ──────────────────────────────────────────────── # Add to /etc/caddy/Caddyfile: # -# turnstone.example-node.tv { +# turnstone.your-domain.example { # import protected # reverse_proxy 10.0.0.10:8534 # import cloudflare @@ -59,11 +59,14 @@ # set -euo pipefail -REPO_DIR=/opt/turnstone -DATA_DIR=/opt/turnstone/data -PATTERNS_DIR=/opt/turnstone/patterns -HF_CACHE_DIR=/opt/turnstone/hf-cache # persists downloaded ML models across restarts -TZ=America/Los_Angeles +# Auto-detect repo from script location — works whether cloned to /opt/turnstone +# or to /Library/Development/CircuitForge/turnstone or any other path. +REPO_DIR="${TURNSTONE_REPO_DIR:-$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)}" +# Data and patterns live OUTSIDE the repo so they survive git pulls. +DATA_DIR="${TURNSTONE_DATA_DIR:-/opt/turnstone-data}" +PATTERNS_DIR="${TURNSTONE_PATTERNS_DIR:-${DATA_DIR}/patterns}" +HF_CACHE_DIR="${TURNSTONE_HF_CACHE:-${DATA_DIR}/hf-cache}" +TZ="${TZ:-America/Los_Angeles}" # ── Bundle push configuration ──────────────────────────────────────────────── # Set TURNSTONE_BUNDLE_ENDPOINT before running this script to enable the @@ -91,8 +94,7 @@ TZ=America/Los_Angeles # ML models are downloaded on first diagnose run and cached in HF_CACHE_DIR. # On a CPU-only host (no GPU) set TURNSTONE_EMBED_DEVICE=cpu (default). # -# For Contributor2's instance (example-node.tv) — no WireGuard to Heimdall LAN, -# use the public cf-orch endpoint instead: +# If your host has no WireGuard to Heimdall — use the public cf-orch endpoint: # export GPU_SERVER_URL=https://orch.circuitforge.tech # export TURNSTONE_MULTI_AGENT_DIAGNOSE=true # sudo bash /opt/turnstone/podman-standalone.sh @@ -114,13 +116,26 @@ TZ=America/Los_Angeles # Must be run as root (sudo bash podman-standalone.sh) — rootful Podman only. # +# Bootstrap data and patterns dirs if this is a first run +mkdir -p "${DATA_DIR}" "${PATTERNS_DIR}" "${HF_CACHE_DIR}" +# Copy default patterns if the dir is empty (first run only) +if [ -z "$(ls -A "${PATTERNS_DIR}")" ]; then + cp "${REPO_DIR}/patterns/default.yaml" "${PATTERNS_DIR}/" + # Copy host-specific sources if present, otherwise copy the generic template + HOST_SOURCES="${REPO_DIR}/patterns/sources-$(hostname).yaml" + if [ -f "${HOST_SOURCES}" ]; then + cp "${HOST_SOURCES}" "${PATTERNS_DIR}/sources.yaml" + echo "==> Installed host-specific sources: ${HOST_SOURCES}" + else + cp "${REPO_DIR}/patterns/sources.yaml" "${PATTERNS_DIR}/" + echo "==> Installed default sources.yaml — edit ${PATTERNS_DIR}/sources.yaml for this host" + fi +fi + # Build image from current source (bakes app/ code into the image) echo "Building Turnstone image..." podman build -t localhost/turnstone:latest "${REPO_DIR}" -# Create HF model cache dir if not present (persists across container rebuilds) -mkdir -p "${HF_CACHE_DIR}" - # Remove existing container if present (safe re-run) podman rm -f turnstone 2>/dev/null || true @@ -142,6 +157,9 @@ podman run -d \ -e TURNSTONE_MULTI_AGENT_DIAGNOSE="${TURNSTONE_MULTI_AGENT_DIAGNOSE:-false}" \ -e GPU_SERVER_URL="${GPU_SERVER_URL:-}" \ -e HF_HOME=/hf-cache \ + -e TURNSTONE_AUTO_INCIDENT="${TURNSTONE_AUTO_INCIDENT:-true}" \ + -e TURNSTONE_AUTO_INCIDENT_THRESHOLD="${TURNSTONE_AUTO_INCIDENT_THRESHOLD:-5}" \ + -e TURNSTONE_AUTO_INCIDENT_WINDOW="${TURNSTONE_AUTO_INCIDENT_WINDOW:-600}" \ -e TURNSTONE_CLASSIFIER_MODEL="${TURNSTONE_CLASSIFIER_MODEL:-byviz/bylastic_classification_logs}" \ -e TURNSTONE_EMBED_BACKEND="${TURNSTONE_EMBED_BACKEND:-sentence_transformers}" \ -e TURNSTONE_EMBED_MODEL="${TURNSTONE_EMBED_MODEL:-sentence-transformers/all-MiniLM-L6-v2}" \ diff --git a/requirements.txt b/requirements.txt index f91b900..21b3c6c 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,5 +1,7 @@ fastapi>=0.110.0 uvicorn[standard]>=0.27.0 +# Postgres backend — optional; SQLite is used when DATABASE_URL is unset +psycopg[binary,pool]>=3.1.0 pydantic>=2.0.0 pyyaml>=6.0 aiofiles>=23.0.0 diff --git a/scripts/gen_corpus.py b/scripts/gen_corpus.py new file mode 100644 index 0000000..01b65f2 --- /dev/null +++ b/scripts/gen_corpus.py @@ -0,0 +1,383 @@ +"""Synthetic log corpus generator. + +Produces realistic-but-entirely-artificial log files for demos, load tests, +and parser regression suites — no production data required. + +Usage: + python scripts/gen_corpus.py --days 7 --out /tmp/demo-corpus/ + python scripts/gen_corpus.py --days 1 --out /tmp/test-run/ --seed 42 --error-rate 0.15 + python scripts/gen_corpus.py --help + +Output tree: + /journald/system.jsonl — systemd/kernel journald JSON + /docker/services.jsonl — containerised app stdout + /qbittorrent/qbt.log — hotio-format qBittorrent log + /ext_device/device.log — EXT_DEVICE device plaintext log +""" +from __future__ import annotations + +import argparse +import json +import random +import sys +from datetime import datetime, timedelta, timezone +from pathlib import Path +from typing import Callable + +# ── Severity distribution ────────────────────────────────────────────────────── + +_SYSLOG_PRIORITY = { + "CRITICAL": "2", + "ERROR": "3", + "WARN": "4", + "INFO": "6", + "DEBUG": "7", +} + +_SEVERITY_WEIGHTS = { + "INFO": 0.70, + "DEBUG": 0.10, + "WARN": 0.12, + "ERROR": 0.06, + "CRITICAL": 0.02, +} + + +def _pick_severity(rng: random.Random, error_rate: float) -> str: + """Return a severity string, boosting ERROR/CRITICAL by error_rate.""" + weights = dict(_SEVERITY_WEIGHTS) + boost = error_rate * 0.08 # distribute extra weight to error tiers + weights["ERROR"] += boost + weights["CRITICAL"] += boost / 2 + weights["INFO"] -= boost * 1.2 + weights["DEBUG"] -= boost * 0.3 + choices = list(weights.keys()) + probs = [max(0.0, weights[k]) for k in choices] + return rng.choices(choices, weights=probs, k=1)[0] + + +# ── Timestamp helpers ────────────────────────────────────────────────────────── + +def _ts_seq(start: datetime, end: datetime, rng: random.Random) -> list[datetime]: + """Return a sorted list of random timestamps between start and end.""" + total_seconds = (end - start).total_seconds() + # Roughly 1 event every ~4 seconds on average across all sources + count = int(total_seconds / 4) + offsets = sorted(rng.uniform(0, total_seconds) for _ in range(count)) + return [start + timedelta(seconds=o) for o in offsets] + + +def _micros(dt: datetime) -> str: + """Journald __REALTIME_TIMESTAMP: microseconds since epoch, as string.""" + return str(int(dt.timestamp() * 1_000_000)) + + +# ── Message libraries ────────────────────────────────────────────────────────── + +_JOURNALD_UNITS = [ + "sshd.service", "nginx.service", "docker.service", "systemd-resolved.service", + "cron.service", "systemd-journald.service", "NetworkManager.service", + "turnstone.service", "podman.service", "fail2ban.service", +] + +_JOURNALD_MESSAGES: dict[str, list[str]] = { + "INFO": [ + "Started {unit}.", + "Listening on {port}/tcp.", + "Reloaded configuration for {unit}.", + "New connection from {ip}:{port}", + "Session opened for user {user} by (uid=0)", + "Accepted publickey for {user} from {ip} port {port}", + "System time synchronized from NTP server {ip}", + "Unit {unit} entered active state.", + "Loaded kernel module {module}.", + "DNS query resolved: {host} -> {ip}", + ], + "DEBUG": [ + "Polling interval set to {n}ms", + "Cache hit for key '{key}'", + "Heartbeat OK from {host}", + "Timer {n} fired", + "Worker {n} idle", + ], + "WARN": [ + "High memory usage on {unit}: {pct}% used", + "Slow DNS response ({ms}ms) for {host}", + "Deprecated option '{key}' in config — will be removed in next release", + "Retrying connection to {host} (attempt {n}/5)", + "Journal size limit reached, rotating", + "Disk usage at {pct}% on /dev/sda1", + ], + "ERROR": [ + "Failed to start {unit}: exit code {n}", + "Connection refused to {host}:{port}", + "Segmentation fault in {unit} (core dumped)", + "Authentication failure for user {user} from {ip}", + "Timeout waiting for {unit} to become ready", + "Failed to bind {port}/tcp: address already in use", + ], + "CRITICAL": [ + "Kernel panic — not syncing: {msg}", + "Out of memory: killed process {n} ({unit})", + "Hardware error on /dev/sda1: I/O error", + "Disk quota exceeded on /home for user {user}", + "Critical service {unit} failed; system may be unstable", + ], +} + +_DOCKER_SERVICES = [ + "caddy", "postgres", "redis", "turnstone", "avocet", + "prometheus", "grafana", "loki", "minio", "vllm", +] + +_DOCKER_MESSAGES: dict[str, list[str]] = { + "INFO": [ + "level=info msg=\"Server listening on 0.0.0.0:{port}\"", + "level=info msg=\"Connected to database at {host}:5432\"", + 'level=info msg="GET /api/health 200 {ms}ms" user={user}', + 'level=info msg="POST /api/v1/jobs 201 {ms}ms"', + "INFO: Worker pool size: {n}", + "INFO: Cache warmed — {n} entries loaded", + "INFO: Startup complete in {ms}ms", + "INFO: Scheduled job '{key}' executed successfully", + ], + "DEBUG": [ + "DEBUG: SQL query took {ms}ms: SELECT * FROM {key}", + "DEBUG: Redis HIT for key {key}", + "level=debug msg=\"span {key} completed\" duration={ms}ms", + "DEBUG: Trace ID {key}: handler returned 200", + ], + "WARN": [ + "level=warn msg=\"Slow query ({ms}ms) on table {key}\"", + "WARN: Connection pool at {pct}% capacity", + "WARN: Rate limit approaching for client {ip}", + "WARN: Deprecated endpoint /v1/{key} called by {ip}", + "level=warn msg=\"GC pause {ms}ms — possible memory pressure\"", + ], + "ERROR": [ + "level=error msg=\"Unhandled exception in handler '{key}'\" err={msg}", + "ERROR: Database connection lost: {msg}", + "level=error msg=\"Failed to acquire lock on {key} after {ms}ms\"", + "ERROR: HTTP 500 POST /api/v1/{key}: internal server error", + "ERROR: Redis NOAUTH: authentication required", + ], + "CRITICAL": [ + "level=critical msg=\"Panic: nil pointer dereference in {key}\"", + "CRITICAL: Fatal: cannot open database: {msg}", + "CRITICAL: OOM killer invoked — process {n} terminated", + ], +} + +_QBT_MESSAGES: dict[str, list[str]] = { + "INFO": [ + "Successfully listening on IP: 0.0.0.0; port: {port}", + "Torrent '{key}' added to download queue", + "Download of '{key}' complete ({n} MB)", + "Seeding '{key}' at {n} KB/s", + "Tracker '{host}' working, {n} seeds", + "Peer {ip} connected to torrent '{key}'", + "Free disk space: {n} GB", + ], + "WARN": [ + "Tracker '{host}' is not working (retrying)", + "Slow download speed ({n} KB/s) for '{key}'", + "Too many open files — reducing connection limit", + "DHT bootstrap failed, retrying in {n}s", + ], + "CRITICAL": [ + "Not enough space on disk to download '{key}'", + "File I/O error for torrent '{key}': {msg}", + "Unable to bind listen port {port}", + ], +} + +_EXT_DEVICE_CODES: dict[str, list[str]] = { + "INFO": [ + "SYS-0100 Device boot complete, firmware v{n}.{n}.{n}", + "SYS-0101 Sensor array calibration OK", + "NET-0200 Link established on interface eth{n}", + "CFG-0300 Configuration loaded from flash", + "HW-0400 Fan speed nominal: {n} RPM", + ], + "WARN": [ + "NET-0210 Link quality degraded: RSSI -{n} dBm", + "HW-0410 Fan speed elevated: {n} RPM (threshold: {n} RPM)", + "CFG-0310 Unknown config key '{key}' ignored", + "SYS-0110 Watchdog near timeout — {n}ms remaining", + ], + "ERROR": [ + "ERR-1001 Sensor read failure on channel {n}: timeout", + "ERR-1002 I2C bus {n} NACK from address 0x{key}", + "ERR-2001 Network tx queue overflow — dropped {n} packets", + "ERR-3001 Flash write error at sector {n}", + ], + "CRITICAL": [ + "ERR-9001 Thermal runaway detected — initiating shutdown", + "ERR-9002 Supply voltage out of range: {n}mV", + "ERR-9003 Memory parity error at address 0x{key}", + ], +} + + +# ── Template substitution ────────────────────────────────────────────────────── + +_HOSTS = ["node1", "node2", "node3", "node4", "gateway", "remotehost"] +_USERS = ["alan", "root", "deployer", "backup", "nobody"] +_MODULES = ["btrfs", "xfs", "nf_conntrack", "ip6table_filter", "overlay"] + +def _fill(template: str, rng: random.Random) -> str: + """Replace {placeholder} tokens with plausible random values.""" + def _sub(m: re.Match) -> str: + import re + key = m.group(1) + if key == "ip": return f"10.{rng.randint(0,255)}.{rng.randint(0,255)}.{rng.randint(1,254)}" + if key == "port": return str(rng.randint(1024, 65535)) + if key == "n": return str(rng.randint(1, 9999)) + if key == "pct": return str(rng.randint(50, 99)) + if key == "ms": return str(rng.randint(1, 5000)) + if key == "unit": return rng.choice(_JOURNALD_UNITS) + if key == "user": return rng.choice(_USERS) + if key == "host": return rng.choice(_HOSTS) + if key == "module": return rng.choice(_MODULES) + if key == "msg": return rng.choice(["unexpected EOF", "connection reset", "no such file"]) + if key == "key": return rng.choice(["auth", "jobs", "cache", "index", "sessions", "queue"]) + return m.group(0) + import re + return re.sub(r"\{(\w+)\}", _sub, template) + + +def _pick_msg(library: dict[str, list[str]], severity: str, rng: random.Random) -> str: + candidates = library.get(severity) or library.get("INFO", ["log entry"]) + return _fill(rng.choice(candidates), rng) + + +# ── Per-format generators ────────────────────────────────────────────────────── + +def gen_journald(path: Path, start: datetime, end: datetime, rng: random.Random, error_rate: float) -> int: + """Emit journald JSON lines (-o json format).""" + lines = 0 + hostname = rng.choice(_HOSTS) + with path.open("w") as fh: + for dt in _ts_seq(start, end, rng): + severity = _pick_severity(rng, error_rate) + unit = rng.choice(_JOURNALD_UNITS) + msg = _pick_msg(_JOURNALD_MESSAGES, severity, rng) + entry = { + "__REALTIME_TIMESTAMP": _micros(dt), + "MESSAGE": msg, + "PRIORITY": _SYSLOG_PRIORITY.get(severity, "6"), + "_HOSTNAME": hostname, + "_SYSTEMD_UNIT": unit, + "SYSLOG_IDENTIFIER": unit.replace(".service", ""), + } + fh.write(json.dumps(entry) + "\n") + lines += 1 + return lines + + +def gen_docker(path: Path, start: datetime, end: datetime, rng: random.Random, error_rate: float) -> int: + """Emit Docker-format JSON lines (SOURCE + MESSAGE envelope).""" + lines = 0 + with path.open("w") as fh: + for dt in _ts_seq(start, end, rng): + severity = _pick_severity(rng, error_rate) + service = rng.choice(_DOCKER_SERVICES) + msg = _pick_msg(_DOCKER_MESSAGES, severity, rng) + entry = { + "SOURCE": f"docker:{service}", + "MESSAGE": msg, + } + fh.write(json.dumps(entry) + "\n") + lines += 1 + return lines + + +def gen_qbittorrent(path: Path, start: datetime, end: datetime, rng: random.Random, error_rate: float) -> int: + """Emit hotio-format qBittorrent plaintext log.""" + _CODE = {"INFO": "N", "WARN": "W", "CRITICAL": "C", "ERROR": "C", "DEBUG": "N"} + lines = 0 + with path.open("w") as fh: + for dt in _ts_seq(start, end, rng): + severity = _pick_severity(rng, error_rate) + msg = _pick_msg(_QBT_MESSAGES, severity, rng) + code = _CODE.get(severity, "N") + ts_str = dt.strftime("%Y-%m-%dT%H:%M:%S") + fh.write(f"({code}) {ts_str} - {msg}\n") + lines += 1 + return lines + + +def gen_ext_device(path: Path, start: datetime, end: datetime, rng: random.Random, error_rate: float) -> int: + """Emit EXT_DEVICE device plaintext log (ISO timestamp + level + ERR/SYS/NET code + message).""" + lines = 0 + with path.open("w") as fh: + for dt in _ts_seq(start, end, rng): + severity = _pick_severity(rng, error_rate) + msg = _pick_msg(_EXT_DEVICE_CODES, severity, rng) + ts_str = dt.strftime("%Y-%m-%dT%H:%M:%S") + fh.write(f"{ts_str} [{severity}] {msg}\n") + lines += 1 + return lines + + +# ── Orchestration ────────────────────────────────────────────────────────────── + +_GENERATORS: list[tuple[str, str, Callable]] = [ + ("journald", "system.jsonl", gen_journald), + ("docker", "services.jsonl", gen_docker), + ("qbittorrent", "qbt.log", gen_qbittorrent), + ("ext_device", "device.log", gen_ext_device), +] + + +def generate( + out: Path, + days: int, + seed: int | None, + error_rate: float, + reference_time: datetime | None = None, +) -> dict[str, int]: + rng = random.Random(seed) + end = reference_time or datetime.now(tz=timezone.utc) + start = end - timedelta(days=days) + + totals: dict[str, int] = {} + for subdir, filename, gen_fn in _GENERATORS: + dest = out / subdir / filename + dest.parent.mkdir(parents=True, exist_ok=True) + # Each source gets its own seeded sub-RNG so streams are independent + sub_rng = random.Random(rng.randint(0, 2**31)) + count = gen_fn(dest, start, end, sub_rng, error_rate) + totals[str(dest.relative_to(out))] = count + print(f" {dest.relative_to(out)}: {count:,} lines") + + return totals + + +# ── CLI ──────────────────────────────────────────────────────────────────────── + +def main(argv: list[str] | None = None) -> int: + parser = argparse.ArgumentParser( + description="Generate a synthetic Turnstone log corpus for demos and testing." + ) + parser.add_argument("--days", type=int, default=7, help="Days of history to generate (default: 7)") + parser.add_argument("--out", type=Path, required=True, help="Output directory") + parser.add_argument("--seed", type=int, default=None, help="RNG seed for reproducibility") + parser.add_argument("--error-rate", type=float, default=0.05, help="Error injection rate 0.0-1.0 (default: 0.05)") + args = parser.parse_args(argv) + + if not 0.0 <= args.error_rate <= 1.0: + print("ERROR: --error-rate must be between 0.0 and 1.0", file=sys.stderr) + return 1 + + args.out.mkdir(parents=True, exist_ok=True) + print(f"Generating {args.days}-day corpus → {args.out} (seed={args.seed}, error_rate={args.error_rate})") + + totals = generate(args.out, args.days, args.seed, args.error_rate) + total_lines = sum(totals.values()) + print(f"Done — {total_lines:,} total log lines across {len(totals)} files") + return 0 + + +if __name__ == "__main__": + sys.exit(main()) diff --git a/scripts/migrate_sqlite_to_postgres.py b/scripts/migrate_sqlite_to_postgres.py new file mode 100644 index 0000000..4402353 --- /dev/null +++ b/scripts/migrate_sqlite_to_postgres.py @@ -0,0 +1,204 @@ +#!/usr/bin/env python3 +"""One-shot migration: copy data from existing SQLite DBs into Postgres. + +Usage: + DATABASE_URL=postgresql://... python scripts/migrate_sqlite_to_postgres.py \ + --main-db data/turnstone.db \ + --context-db data/turnstone-context.db \ + --incidents-db data/turnstone-incidents.db \ + [--tenant-id heimdall] + +The script is idempotent: rows already present in Postgres (same id) are skipped. +It must be run ONCE per node after deploying the shared Postgres backend. + +Prerequisites: + pip install 'psycopg[binary,pool]' + Set DATABASE_URL to the target Postgres connection string. +""" +from __future__ import annotations + +import argparse +import os +import sqlite3 +import sys +from pathlib import Path + +# Allow running from the project root without installing the package +sys.path.insert(0, str(Path(__file__).parent.parent)) + + +def _pg_connect(): + import psycopg # type: ignore[import] + url = os.environ.get("DATABASE_URL") + if not url: + print("ERROR: DATABASE_URL not set", file=sys.stderr) + sys.exit(1) + return psycopg.connect(url, autocommit=False) + + +def _ensure_schema_pg() -> None: + from app.db.schema import ensure_schema, ensure_context_schema, ensure_incidents_schema + from pathlib import Path + ensure_schema(Path("/dev/null")) # db_path ignored for Postgres + ensure_context_schema(Path("/dev/null")) + ensure_incidents_schema(Path("/dev/null")) + print("Postgres schema verified") + + +def _migrate_table( + src_conn: sqlite3.Connection, + dst_conn, + table: str, + tenant_id: str, + columns: list[str], + conflict_cols: list[str], +) -> int: + """Copy rows from SQLite table to Postgres. Returns rows inserted.""" + # Check if source table exists + try: + rows = src_conn.execute(f"SELECT * FROM {table} LIMIT 0").fetchall() # noqa: S608 + except sqlite3.OperationalError: + print(f" {table}: not found in SQLite — skipping") + return 0 + + # Fetch all rows + src_conn.row_factory = sqlite3.Row + rows = src_conn.execute(f"SELECT * FROM {table}").fetchall() # noqa: S608 + if not rows: + print(f" {table}: empty — skipping") + return 0 + + # Build INSERT ... ON CONFLICT DO NOTHING + col_list = ", ".join(columns) + placeholders = ", ".join("%s" for _ in columns) + conflict = ", ".join(conflict_cols) + sql = ( + f"INSERT INTO {table} ({col_list}) VALUES ({placeholders}) " # noqa: S608 + f"ON CONFLICT ({conflict}) DO NOTHING" + ) + + inserted = 0 + with dst_conn.cursor() as cur: + for row in rows: + # Build values: inject tenant_id if not present in source row + vals = [] + for col in columns: + if col == "tenant_id": + try: + val = row["tenant_id"] or tenant_id + except (IndexError, KeyError): + val = tenant_id + else: + try: + vals.append(row[col]) + except (IndexError, KeyError): + vals.append(None) + continue + vals.append(val) + cur.execute(sql, vals) + inserted += cur.rowcount + + dst_conn.commit() + print(f" {table}: {inserted}/{len(rows)} rows inserted ({len(rows) - inserted} skipped)") + return inserted + + +def main() -> None: + parser = argparse.ArgumentParser(description="Migrate Turnstone SQLite → Postgres") + parser.add_argument("--main-db", default="data/turnstone.db") + parser.add_argument("--context-db", default="data/turnstone-context.db") + parser.add_argument("--incidents-db", default="data/turnstone-incidents.db") + parser.add_argument("--tenant-id", default=None, help="Override tenant ID (default: socket.gethostname())") + args = parser.parse_args() + + if args.tenant_id: + os.environ["TURNSTONE_TENANT_ID"] = args.tenant_id + + import socket + tenant_id = os.environ.get("TURNSTONE_TENANT_ID") or socket.gethostname() + print(f"Migrating as tenant_id={tenant_id!r}") + + # Ensure Postgres schema exists first + os.environ.setdefault("DATABASE_URL", "") # schema functions check this + _ensure_schema_pg() + + pg = _pg_connect() + total = 0 + + # ── Main DB ─────────────────────────────────────────────────────────────── + main_path = Path(args.main_db) + if main_path.exists(): + print(f"\nMigrating main DB: {main_path}") + src = sqlite3.connect(str(main_path)) + src.row_factory = sqlite3.Row + + total += _migrate_table(src, pg, "log_entries", tenant_id, + columns=["tenant_id", "id", "source_id", "sequence", "timestamp_raw", + "timestamp_iso", "ingest_time", "severity", "repeat_count", + "out_of_order", "matched_patterns", "text"], + conflict_cols=["tenant_id", "id"]) + + total += _migrate_table(src, pg, "glean_fingerprints", tenant_id, + columns=["tenant_id", "path", "mtime", "size", "gleaned_at"], + conflict_cols=["tenant_id", "path"]) + + total += _migrate_table(src, pg, "blocklist_candidates", tenant_id, + columns=["id", "tenant_id", "domain_or_ip", "source_device_ip", "source_device_name", + "first_seen", "last_seen", "hit_count", "status", "pushed_at", + "log_evidence", "matched_rule", "llm_score", "llm_reason"], + conflict_cols=["id"]) + src.close() + else: + print(f"Main DB not found at {main_path} — skipping") + + # ── Context DB ──────────────────────────────────────────────────────────── + ctx_path = Path(args.context_db) + if ctx_path.exists(): + print(f"\nMigrating context DB: {ctx_path}") + src = sqlite3.connect(str(ctx_path)) + + total += _migrate_table(src, pg, "context_facts", tenant_id, + columns=["id", "tenant_id", "category", "key", "value", "source", "created_at"], + conflict_cols=["id"]) + + total += _migrate_table(src, pg, "context_documents", tenant_id, + columns=["id", "tenant_id", "filename", "doc_type", "full_text", "file_size", "uploaded_at"], + conflict_cols=["id"]) + + total += _migrate_table(src, pg, "context_chunks", tenant_id, + columns=["id", "tenant_id", "document_id", "chunk_index", "text"], + conflict_cols=["id"]) + src.close() + else: + print(f"Context DB not found at {ctx_path} — skipping") + + # ── Incidents DB ────────────────────────────────────────────────────────── + inc_path = Path(args.incidents_db) + if inc_path.exists(): + print(f"\nMigrating incidents DB: {inc_path}") + src = sqlite3.connect(str(inc_path)) + + total += _migrate_table(src, pg, "incidents", tenant_id, + columns=["id", "tenant_id", "label", "issue_type", "started_at", "ended_at", + "notes", "created_at", "severity"], + conflict_cols=["id"]) + + total += _migrate_table(src, pg, "received_bundles", tenant_id, + columns=["id", "tenant_id", "source_host", "issue_type", "label", "severity", + "started_at", "bundled_at", "entry_count", "bundle_json"], + conflict_cols=["id"]) + + total += _migrate_table(src, pg, "sent_bundles", tenant_id, + columns=["id", "tenant_id", "incident_id", "exported_at", "sanitized", + "entry_count", "bundle_json"], + conflict_cols=["id"]) + src.close() + else: + print(f"Incidents DB not found at {inc_path} — skipping") + + pg.close() + print(f"\nDone. Total rows inserted: {total}") + + +if __name__ == "__main__": + main() diff --git a/scripts/update.sh b/scripts/update.sh index 50c7e7d..5db724c 100644 --- a/scripts/update.sh +++ b/scripts/update.sh @@ -6,8 +6,10 @@ # sudo bash /opt/turnstone/scripts/update.sh feat/live-watch # test a branch # # Local files preserved across updates: -# patterns/watch.yaml — site-specific watch source config -# data/ — database and live journal files (bind-mounted, untouched) +# patterns/watch.yaml — site-specific watch source config +# data/corpus_watermark.txt — corpus export watermark (last exported rowid) +# data/incident_watermark.txt — incident export watermark (last exported timestamp) +# data/ — database and live journal files (bind-mounted, untouched) set -euo pipefail @@ -21,7 +23,9 @@ echo "==> Turnstone update: branch=$BRANCH" # ── Preserve site-local config ──────────────────────────────────────────────── # watch.yaml is tracked in git as a template but overridden per-host. -# Back it up before the pull and restore it after. +# Corpus watermarks track the last exported entry/incident — must survive updates +# or the next export run will re-push everything from the beginning. +# Back them up before the pull and restore after. WATCH_YAML="$REPO_DIR/patterns/watch.yaml" WATCH_BACKUP="" if [ -f "$WATCH_YAML" ]; then @@ -29,6 +33,19 @@ if [ -f "$WATCH_YAML" ]; then cp "$WATCH_YAML" "$WATCH_BACKUP" fi +CORPUS_WM="$REPO_DIR/data/corpus_watermark.txt" +INCIDENT_WM="$REPO_DIR/data/incident_watermark.txt" +CORPUS_WM_BACKUP="" +INCIDENT_WM_BACKUP="" +if [ -f "$CORPUS_WM" ]; then + CORPUS_WM_BACKUP=$(mktemp /tmp/corpus-wm.XXXXXX) + cp "$CORPUS_WM" "$CORPUS_WM_BACKUP" +fi +if [ -f "$INCIDENT_WM" ]; then + INCIDENT_WM_BACKUP=$(mktemp /tmp/incident-wm.XXXXXX) + cp "$INCIDENT_WM" "$INCIDENT_WM_BACKUP" +fi + # ── Pull ────────────────────────────────────────────────────────────────────── git fetch --all --tags --quiet @@ -50,6 +67,16 @@ if [ -n "$WATCH_BACKUP" ]; then rm -f "$WATCH_BACKUP" echo "==> Restored patterns/watch.yaml" fi +if [ -n "$CORPUS_WM_BACKUP" ]; then + cp "$CORPUS_WM_BACKUP" "$CORPUS_WM" + rm -f "$CORPUS_WM_BACKUP" + echo "==> Restored data/corpus_watermark.txt" +fi +if [ -n "$INCIDENT_WM_BACKUP" ]; then + cp "$INCIDENT_WM_BACKUP" "$INCIDENT_WM" + rm -f "$INCIDENT_WM_BACKUP" + echo "==> Restored data/incident_watermark.txt" +fi # ── Build ───────────────────────────────────────────────────────────────────── echo "==> Building $IMAGE ..." diff --git a/tests/context/test_diagnose_context.py b/tests/context/test_diagnose_context.py index f34da5f..1a8a6e2 100644 --- a/tests/context/test_diagnose_context.py +++ b/tests/context/test_diagnose_context.py @@ -4,6 +4,7 @@ import sqlite3 from pathlib import Path from unittest.mock import patch import pytest +from app.db.schema import ensure_schema, ensure_context_schema from app.services.llm import summarize from app.services.search import SearchResult @@ -64,36 +65,14 @@ def test_summarize_without_context_block_unchanged(): @pytest.fixture def db_with_facts(tmp_path): db_path = tmp_path / "t.db" + ensure_schema(db_path) + ensure_context_schema(db_path) conn = sqlite3.connect(str(db_path)) - conn.executescript(""" - CREATE TABLE log_entries ( - id TEXT PRIMARY KEY, source_id TEXT NOT NULL, sequence INTEGER NOT NULL, - timestamp_raw TEXT, timestamp_iso TEXT, ingest_time TEXT NOT NULL, - severity TEXT, repeat_count INTEGER DEFAULT 1, out_of_order INTEGER DEFAULT 0, - matched_patterns TEXT DEFAULT '[]', text TEXT NOT NULL - ); - CREATE VIRTUAL TABLE IF NOT EXISTS log_fts USING fts5( - text, entry_id UNINDEXED, source_id UNINDEXED, sequence UNINDEXED, - severity UNINDEXED, timestamp_iso UNINDEXED, matched_patterns UNINDEXED, - repeat_count UNINDEXED, out_of_order UNINDEXED, tokenize='porter ascii' - ); - CREATE TABLE context_facts ( - id TEXT PRIMARY KEY, category TEXT NOT NULL, key TEXT NOT NULL, - value TEXT NOT NULL, source TEXT, created_at TEXT NOT NULL - ); - CREATE TABLE context_documents ( - id TEXT PRIMARY KEY, filename TEXT NOT NULL, doc_type TEXT NOT NULL, - full_text TEXT NOT NULL, file_size INTEGER, uploaded_at TEXT NOT NULL - ); - CREATE TABLE context_chunks ( - id TEXT PRIMARY KEY, document_id TEXT NOT NULL - REFERENCES context_documents(id) ON DELETE CASCADE, - chunk_index INTEGER NOT NULL, text TEXT NOT NULL, embedding BLOB - ); - INSERT INTO context_facts VALUES ( - 'f1','service','plex','port:32400','wizard','2026-05-13T00:00:00+00:00' - ); - """) + conn.execute( + "INSERT INTO context_facts(id, tenant_id, category, key, value, source, created_at) " + "VALUES (?,?,?,?,?,?,?)", + ("f1", "", "service", "plex", "port:32400", "wizard", "2026-05-13T00:00:00+00:00"), + ) conn.commit() conn.close() return db_path diff --git a/tests/context/test_doc_upload.py b/tests/context/test_doc_upload.py index 162f6f5..12e1fa0 100644 --- a/tests/context/test_doc_upload.py +++ b/tests/context/test_doc_upload.py @@ -1,8 +1,8 @@ """End-to-end upload pipeline: file bytes → DB rows.""" -import sqlite3 import pytest from pathlib import Path +from app.db.schema import ensure_context_schema from app.glean.doc_upload import glean_upload from app.context.store import list_facts, list_documents from app.context.chunker import UnsupportedDocType @@ -11,24 +11,7 @@ from app.context.chunker import UnsupportedDocType @pytest.fixture def db(tmp_path): db_path = tmp_path / "t.db" - conn = sqlite3.connect(str(db_path)) - conn.executescript(""" - CREATE TABLE context_facts ( - id TEXT PRIMARY KEY, category TEXT NOT NULL, key TEXT NOT NULL, - value TEXT NOT NULL, source TEXT, created_at TEXT NOT NULL - ); - CREATE TABLE context_documents ( - id TEXT PRIMARY KEY, filename TEXT NOT NULL, doc_type TEXT NOT NULL, - full_text TEXT NOT NULL, file_size INTEGER, uploaded_at TEXT NOT NULL - ); - CREATE TABLE context_chunks ( - id TEXT PRIMARY KEY, document_id TEXT NOT NULL - REFERENCES context_documents(id) ON DELETE CASCADE, - chunk_index INTEGER NOT NULL, text TEXT NOT NULL, embedding BLOB - ); - """) - conn.commit() - conn.close() + ensure_context_schema(db_path) return db_path diff --git a/tests/context/test_schema.py b/tests/context/test_schema.py index ea71812..4943b79 100644 --- a/tests/context/test_schema.py +++ b/tests/context/test_schema.py @@ -1,13 +1,13 @@ -"""Verify the three new context tables are created by ensure_schema.""" +"""Verify the three context tables are created by ensure_context_schema.""" import sqlite3 from pathlib import Path import pytest -from app.glean.pipeline import ensure_schema +from app.db.schema import ensure_context_schema def test_context_tables_created(tmp_path): db = tmp_path / "t.db" - ensure_schema(db) + ensure_context_schema(db) conn = sqlite3.connect(str(db)) tables = {r[0] for r in conn.execute( "SELECT name FROM sqlite_master WHERE type='table'" @@ -20,5 +20,5 @@ def test_context_tables_created(tmp_path): def test_context_schema_idempotent(tmp_path): db = tmp_path / "t.db" - ensure_schema(db) - ensure_schema(db) # second call must not raise + ensure_context_schema(db) + ensure_context_schema(db) # second call must not raise diff --git a/tests/context/test_store.py b/tests/context/test_store.py index 8c6edea..7197579 100644 --- a/tests/context/test_store.py +++ b/tests/context/test_store.py @@ -2,6 +2,7 @@ import sqlite3 import pytest from pathlib import Path +from app.db.schema import ensure_context_schema from app.context.store import ( add_fact, list_facts, delete_fact, add_document, list_documents, delete_document, @@ -12,24 +13,7 @@ from app.context.store import ( @pytest.fixture def db(tmp_path): db_path = tmp_path / "t.db" - conn = sqlite3.connect(str(db_path)) - conn.executescript(""" - CREATE TABLE context_facts ( - id TEXT PRIMARY KEY, category TEXT NOT NULL, key TEXT NOT NULL, - value TEXT NOT NULL, source TEXT, created_at TEXT NOT NULL - ); - CREATE TABLE context_documents ( - id TEXT PRIMARY KEY, filename TEXT NOT NULL, doc_type TEXT NOT NULL, - full_text TEXT NOT NULL, file_size INTEGER, uploaded_at TEXT NOT NULL - ); - CREATE TABLE context_chunks ( - id TEXT PRIMARY KEY, document_id TEXT NOT NULL - REFERENCES context_documents(id) ON DELETE CASCADE, - chunk_index INTEGER NOT NULL, text TEXT NOT NULL, embedding BLOB - ); - """) - conn.commit() - conn.close() + ensure_context_schema(db_path) return db_path diff --git a/tests/context/test_wizard.py b/tests/context/test_wizard.py index e10682e..8d76f81 100644 --- a/tests/context/test_wizard.py +++ b/tests/context/test_wizard.py @@ -2,21 +2,14 @@ import sqlite3 import pytest from pathlib import Path +from app.db.schema import ensure_context_schema from app.context.wizard import get_schema, advance_step, is_complete, apply_session, TOTAL_STEPS @pytest.fixture def db(tmp_path): db_path = tmp_path / "t.db" - conn = sqlite3.connect(str(db_path)) - conn.executescript(""" - CREATE TABLE context_facts ( - id TEXT PRIMARY KEY, category TEXT NOT NULL, key TEXT NOT NULL, - value TEXT NOT NULL, source TEXT, created_at TEXT NOT NULL - ); - """) - conn.commit() - conn.close() + ensure_context_schema(db_path) return db_path diff --git a/tests/test_anomaly.py b/tests/test_anomaly.py new file mode 100644 index 0000000..31bbe98 --- /dev/null +++ b/tests/test_anomaly.py @@ -0,0 +1,220 @@ +"""Tests for app/services/anomaly.py — anomaly scoring pipeline.""" +from __future__ import annotations + +import sqlite3 +import uuid +from pathlib import Path +from unittest.mock import MagicMock + +import pytest + +import app.services.anomaly as anomaly_mod +from app.db.schema import ensure_schema +from app.services.anomaly import ( + ScoringResult, + acknowledge_detection, + list_detections, + reset_pipeline, + score_unscored, +) + + +# --------------------------------------------------------------------------- +# Fixtures +# --------------------------------------------------------------------------- + + +@pytest.fixture(autouse=True) +def _reset_pipeline(): + """Ensure the ML singleton is cleared between tests.""" + reset_pipeline() + yield + reset_pipeline() + + +@pytest.fixture +def db(tmp_path: Path) -> Path: + db_path = tmp_path / "t.db" + ensure_schema(db_path) + return db_path + + +def _insert_entry(db_path: Path, text: str, entry_id: str | None = None) -> str: + eid = entry_id or str(uuid.uuid4()) + conn = sqlite3.connect(str(db_path)) + conn.execute( + "INSERT INTO log_entries(id, tenant_id, source_id, sequence, ingest_time, text) " + "VALUES (?,?,?,?,?,?)", + (eid, "", "src", 1, "2026-01-01T00:00:00", text), + ) + conn.commit() + conn.close() + return eid + + +# --------------------------------------------------------------------------- +# score_unscored +# --------------------------------------------------------------------------- + + +def test_score_unscored_no_model_returns_skipped(db: Path): + result = score_unscored(db, model_id="") + assert result.skipped is True + assert result.scored == 0 + + +def test_score_unscored_scores_entries(db: Path, monkeypatch): + _insert_entry(db, "kernel panic — OOM killer invoked") + _insert_entry(db, "user login successful") + + mock_pipe = MagicMock(return_value=[ + {"label": "SYSTEM_FAILURE", "score": 0.92}, + {"label": "NORMAL", "score": 0.88}, + ]) + monkeypatch.setattr(anomaly_mod, "_pipeline", mock_pipe) + + result = score_unscored(db, model_id="fake-model", batch_size=10) + assert result.skipped is False + assert result.scored == 2 + + +def test_score_unscored_creates_detection_above_threshold(db: Path, monkeypatch): + _insert_entry(db, "segfault in service") + + mock_pipe = MagicMock(return_value=[ + {"label": "SYSTEM_FAILURE", "score": 0.95}, + ]) + monkeypatch.setattr(anomaly_mod, "_pipeline", mock_pipe) + + result = score_unscored(db, model_id="fake-model", threshold=0.80) + assert result.detections == 1 + + detections = list_detections(db) + assert len(detections) == 1 + assert detections[0]["anomaly_label"] == "SYSTEM_FAILURE" + assert detections[0]["anomaly_score"] == pytest.approx(0.95) + + +def test_score_unscored_no_detection_below_threshold(db: Path, monkeypatch): + _insert_entry(db, "warning: disk at 80%") + + mock_pipe = MagicMock(return_value=[ + {"label": "PERFORMANCE_ISSUE", "score": 0.60}, + ]) + monkeypatch.setattr(anomaly_mod, "_pipeline", mock_pipe) + + result = score_unscored(db, model_id="fake-model", threshold=0.80) + assert result.detections == 0 + assert result.scored == 1 + + +def test_score_unscored_normal_label_never_detection(db: Path, monkeypatch): + _insert_entry(db, "service started successfully") + + mock_pipe = MagicMock(return_value=[ + {"label": "NORMAL", "score": 0.99}, + ]) + monkeypatch.setattr(anomaly_mod, "_pipeline", mock_pipe) + + result = score_unscored(db, model_id="fake-model", threshold=0.50) + assert result.detections == 0 + + +def test_score_unscored_idempotent(db: Path, monkeypatch): + """Entries already scored are not re-scored on subsequent runs.""" + _insert_entry(db, "first entry") + + call_count = 0 + + def _side_effect(texts, **_kwargs): + nonlocal call_count + call_count += 1 + return [{"label": "NORMAL", "score": 0.90} for _ in texts] + + mock_pipe = MagicMock(side_effect=_side_effect) + monkeypatch.setattr(anomaly_mod, "_pipeline", mock_pipe) + + score_unscored(db, model_id="fake-model") + score_unscored(db, model_id="fake-model") + + assert call_count == 1 # second run finds no unscored rows + + +def test_score_unscored_pipeline_error_returns_error(db: Path, monkeypatch): + _insert_entry(db, "some log line") + + mock_pipe = MagicMock(side_effect=RuntimeError("CUDA OOM")) + monkeypatch.setattr(anomaly_mod, "_pipeline", mock_pipe) + + result = score_unscored(db, model_id="fake-model") + assert result.error is not None + assert "CUDA OOM" in result.error + + +# --------------------------------------------------------------------------- +# list_detections / acknowledge_detection +# --------------------------------------------------------------------------- + + +def test_list_detections_empty(db: Path): + assert list_detections(db) == [] + + +def test_list_detections_filters_unacked(db: Path, monkeypatch): + _insert_entry(db, "crash") + + mock_pipe = MagicMock(return_value=[{"label": "SYSTEM_FAILURE", "score": 0.91}]) + monkeypatch.setattr(anomaly_mod, "_pipeline", mock_pipe) + score_unscored(db, model_id="fake-model", threshold=0.80) + + all_dets = list_detections(db) + assert len(all_dets) == 1 + unacked = list_detections(db, unacked_only=True) + assert len(unacked) == 1 + + +def test_acknowledge_detection(db: Path, monkeypatch): + _insert_entry(db, "network anomaly") + + mock_pipe = MagicMock(return_value=[{"label": "NETWORK_ANOMALY", "score": 0.88}]) + monkeypatch.setattr(anomaly_mod, "_pipeline", mock_pipe) + score_unscored(db, model_id="fake-model", threshold=0.80) + + dets = list_detections(db) + assert len(dets) == 1 + det_id = dets[0]["id"] + + updated = acknowledge_detection(db, det_id, notes="benign test traffic") + assert updated is True + + unacked = list_detections(db, unacked_only=True) + assert len(unacked) == 0 + + all_dets = list_detections(db) + assert all_dets[0]["acknowledged"] == 1 + assert all_dets[0]["notes"] == "benign test traffic" + + +def test_acknowledge_detection_unknown_id(db: Path): + updated = acknowledge_detection(db, "nonexistent-id") + assert updated is False + + +def test_list_detections_label_filter(db: Path, monkeypatch): + _insert_entry(db, "OOM kill") + _insert_entry(db, "network timeout") + + mock_pipe = MagicMock(side_effect=[ + [{"label": "SYSTEM_FAILURE", "score": 0.93}], + [{"label": "NETWORK_ANOMALY", "score": 0.85}], + ]) + monkeypatch.setattr(anomaly_mod, "_pipeline", mock_pipe) + + score_unscored(db, model_id="fake-model", batch_size=1, threshold=0.80) + score_unscored(db, model_id="fake-model", batch_size=1, threshold=0.80) + + sys_dets = list_detections(db, label="SYSTEM_FAILURE") + assert all(d["anomaly_label"] == "SYSTEM_FAILURE" for d in sys_dets) + + net_dets = list_detections(db, label="NETWORK_ANOMALY") + assert all(d["anomaly_label"] == "NETWORK_ANOMALY" for d in net_dets) diff --git a/tests/test_cybersec.py b/tests/test_cybersec.py new file mode 100644 index 0000000..8f4f99a --- /dev/null +++ b/tests/test_cybersec.py @@ -0,0 +1,233 @@ +"""Tests for the cybersec zero-shot scoring pipeline.""" +from __future__ import annotations + +import sqlite3 +import tempfile +from pathlib import Path +from unittest.mock import MagicMock + +import pytest + +from app.db.schema import ensure_schema +from app.services.cybersec import ( + CybersecResult, + CYBERSEC_LABELS, + _NORMAL_LABEL, + reset_pipeline, + score_security_entries, + list_cybersec_detections, +) +import app.services.cybersec as cybersec_mod + + +@pytest.fixture(autouse=True) +def _reset(tmp_path): + reset_pipeline() + yield + reset_pipeline() + + +@pytest.fixture +def db(tmp_path) -> Path: + path = tmp_path / "test.db" + ensure_schema(path) + return path + + +def _insert_entry(db: Path, entry_id: str, text: str, + anomaly_label: str | None = None, + matched_patterns: str = "[]") -> None: + with sqlite3.connect(db) as conn: + conn.execute( + """INSERT OR IGNORE INTO log_entries + (id, tenant_id, source_id, sequence, ingest_time, text, + anomaly_label, matched_patterns) + VALUES (?, '', 'test-src', 1, '2026-01-01T00:00:00Z', ?, ?, ?)""", + (entry_id, text, anomaly_label, matched_patterns), + ) + conn.commit() + + +# --------------------------------------------------------------------------- +# No model configured → skipped +# --------------------------------------------------------------------------- + +def test_no_model_returns_skipped(db): + result = score_security_entries(db, model_id="") + assert result.skipped is True + assert result.scored == 0 + + +# --------------------------------------------------------------------------- +# No eligible entries → skipped +# --------------------------------------------------------------------------- + +def test_no_eligible_entries_skipped(db): + _insert_entry(db, "e1", "Started nginx.service", anomaly_label=None, matched_patterns="[]") + mock_pipe = MagicMock(return_value=[{"labels": [_NORMAL_LABEL], "scores": [0.99]}]) + monkeypatch = pytest.MonkeyPatch() + monkeypatch.setattr(cybersec_mod, "_pipeline", mock_pipe) + result = score_security_entries(db, model_id="fake-model") + assert result.skipped is True + monkeypatch.undo() + + +# --------------------------------------------------------------------------- +# Security entry gets scored +# --------------------------------------------------------------------------- + +def test_security_entry_scored(db, monkeypatch): + _insert_entry(db, "e1", + "Failed password for root from 192.168.1.1 port 22 ssh2", + anomaly_label="SECURITY_ANOMALY") + + mock_pipe = MagicMock(return_value=[{ + "labels": ["authentication failure or brute force attack", _NORMAL_LABEL], + "scores": [0.85, 0.15], + }]) + monkeypatch.setattr(cybersec_mod, "_pipeline", mock_pipe) + + result = score_security_entries(db, model_id="fake-model", threshold=0.70) + assert result.scored == 1 + assert result.detections == 1 + assert result.error is None + + with sqlite3.connect(db) as conn: + conn.row_factory = sqlite3.Row + row = conn.execute("SELECT ml_score, ml_label, ml_scored_at FROM log_entries WHERE id='e1'").fetchone() + assert row["ml_score"] == pytest.approx(0.85) + assert row["ml_label"] == "authentication failure or brute force attack" + assert row["ml_scored_at"] is not None + + +# --------------------------------------------------------------------------- +# Detection created above threshold +# --------------------------------------------------------------------------- + +def test_detection_inserted_above_threshold(db, monkeypatch): + _insert_entry(db, "e1", "sudo: authentication failure", anomaly_label="ERROR") + + monkeypatch.setattr(cybersec_mod, "_pipeline", MagicMock(return_value=[{ + "labels": ["privilege escalation or unauthorized access", _NORMAL_LABEL], + "scores": [0.75, 0.25], + }])) + + score_security_entries(db, model_id="fake-model", threshold=0.60) + + with sqlite3.connect(db) as conn: + conn.row_factory = sqlite3.Row + dets = conn.execute("SELECT * FROM detections WHERE scorer='cybersec'").fetchall() + assert len(dets) == 1 + assert dets[0]["anomaly_label"] == "privilege escalation or unauthorized access" + assert dets[0]["severity"] == "CRITICAL" + + +# --------------------------------------------------------------------------- +# Normal label → no detection even above score threshold +# --------------------------------------------------------------------------- + +def test_normal_label_no_detection(db, monkeypatch): + _insert_entry(db, "e1", "Started nginx.service", anomaly_label="INFO", + matched_patterns='["service_start"]') + + monkeypatch.setattr(cybersec_mod, "_pipeline", MagicMock(return_value=[{ + "labels": [_NORMAL_LABEL, "network intrusion or port scan"], + "scores": [0.95, 0.05], + }])) + + result = score_security_entries(db, model_id="fake-model", threshold=0.60) + assert result.detections == 0 + + +# --------------------------------------------------------------------------- +# Below threshold → scored but no detection +# --------------------------------------------------------------------------- + +def test_below_threshold_no_detection(db, monkeypatch): + _insert_entry(db, "e1", "Some suspicious text", anomaly_label="WARN") + + monkeypatch.setattr(cybersec_mod, "_pipeline", MagicMock(return_value=[{ + "labels": ["network intrusion or port scan", _NORMAL_LABEL], + "scores": [0.45, 0.55], + }])) + + result = score_security_entries(db, model_id="fake-model", threshold=0.60) + assert result.scored == 1 + assert result.detections == 0 + + +# --------------------------------------------------------------------------- +# Pattern-matched entry (not anomaly-flagged) still gets scored +# --------------------------------------------------------------------------- + +def test_pattern_matched_entry_scored(db, monkeypatch): + _insert_entry(db, "e1", "SSH port forwarding conflict detected", + anomaly_label=None, + matched_patterns='["ssh_forward_conflict"]') + + monkeypatch.setattr(cybersec_mod, "_pipeline", MagicMock(return_value=[{ + "labels": ["network intrusion or port scan", _NORMAL_LABEL], + "scores": [0.70, 0.30], + }])) + + result = score_security_entries(db, model_id="fake-model", threshold=0.60) + assert result.scored == 1 + assert result.detections == 1 + + +# --------------------------------------------------------------------------- +# Idempotency — re-run finds nothing unscored +# --------------------------------------------------------------------------- + +def test_idempotent_rerun(db, monkeypatch): + _insert_entry(db, "e1", "Failed login", anomaly_label="ERROR") + + monkeypatch.setattr(cybersec_mod, "_pipeline", MagicMock(return_value=[{ + "labels": ["authentication failure or brute force attack"], + "scores": [0.80], + }])) + + score_security_entries(db, model_id="fake-model", threshold=0.60) + result2 = score_security_entries(db, model_id="fake-model", threshold=0.60) + assert result2.skipped is True + + +# --------------------------------------------------------------------------- +# list_cybersec_detections filters to scorer='cybersec' +# --------------------------------------------------------------------------- + +def test_list_cybersec_detections(db, monkeypatch): + _insert_entry(db, "e1", "Failed login", anomaly_label="ERROR") + + monkeypatch.setattr(cybersec_mod, "_pipeline", MagicMock(return_value=[{ + "labels": ["authentication failure or brute force attack"], + "scores": [0.90], + }])) + score_security_entries(db, model_id="fake-model", threshold=0.60) + + rows = list_cybersec_detections(db) + assert len(rows) == 1 + assert rows[0]["scorer"] == "cybersec" + + +# --------------------------------------------------------------------------- +# list_detections scorer filter (anomaly service) +# --------------------------------------------------------------------------- + +def test_list_detections_scorer_filter(db, monkeypatch): + from app.services.anomaly import list_detections + _insert_entry(db, "e1", "Failed login", anomaly_label="ERROR") + + monkeypatch.setattr(cybersec_mod, "_pipeline", MagicMock(return_value=[{ + "labels": ["authentication failure or brute force attack"], + "scores": [0.90], + }])) + score_security_entries(db, model_id="fake-model", threshold=0.60) + + all_dets = list_detections(db) + cybersec_dets = list_detections(db, scorer="cybersec") + anomaly_dets = list_detections(db, scorer="anomaly") + + assert len(cybersec_dets) == 1 + assert len(anomaly_dets) == 0 + assert len(all_dets) >= 1 diff --git a/tests/test_gen_corpus.py b/tests/test_gen_corpus.py new file mode 100644 index 0000000..59468f1 --- /dev/null +++ b/tests/test_gen_corpus.py @@ -0,0 +1,197 @@ +"""Tests for scripts/gen_corpus.py synthetic log generator.""" +from __future__ import annotations + +import json +import sys +from pathlib import Path + +import pytest + +sys.path.insert(0, str(Path(__file__).parent.parent)) + +from datetime import datetime, timezone + +from scripts.gen_corpus import generate, main + +# Fixed reference time keeps timestamps deterministic across test runs +_REF_TIME = datetime(2026, 6, 10, 12, 0, 0, tzinfo=timezone.utc) + + +# ── Helpers ──────────────────────────────────────────────────────────────────── + +def _run(tmp_path: Path, days: int = 1, seed: int = 42, error_rate: float = 0.05) -> dict[str, int]: + return generate(tmp_path, days=days, seed=seed, error_rate=error_rate, reference_time=_REF_TIME) + + +# ── Output structure ─────────────────────────────────────────────────────────── + +class TestOutputStructure: + def test_creates_all_four_files(self, tmp_path: Path) -> None: + _run(tmp_path) + assert (tmp_path / "journald" / "system.jsonl").exists() + assert (tmp_path / "docker" / "services.jsonl").exists() + assert (tmp_path / "qbittorrent" / "qbt.log").exists() + assert (tmp_path / "ext_device" / "device.log").exists() + + def test_returns_line_counts(self, tmp_path: Path) -> None: + totals = _run(tmp_path) + assert len(totals) == 4 + assert all(v > 0 for v in totals.values()) + + +# ── Reproducibility ──────────────────────────────────────────────────────────── + +class TestReproducibility: + def test_same_seed_same_output(self, tmp_path: Path) -> None: + out_a = tmp_path / "a" + out_b = tmp_path / "b" + _run(out_a, seed=99) + _run(out_b, seed=99) + for sub in ["journald/system.jsonl", "docker/services.jsonl"]: + assert (out_a / sub).read_text() == (out_b / sub).read_text() + + def test_different_seeds_differ(self, tmp_path: Path) -> None: + out_a = tmp_path / "a" + out_b = tmp_path / "b" + _run(out_a, seed=1) + _run(out_b, seed=2) + assert (out_a / "journald/system.jsonl").read_text() != (out_b / "journald/system.jsonl").read_text() + + +# ── Journald format ──────────────────────────────────────────────────────────── + +class TestJournaldFormat: + def test_valid_json_lines(self, tmp_path: Path) -> None: + _run(tmp_path) + lines = (tmp_path / "journald/system.jsonl").read_text().splitlines() + for line in lines[:100]: + obj = json.loads(line) + assert "__REALTIME_TIMESTAMP" in obj + assert "MESSAGE" in obj + assert "PRIORITY" in obj + + def test_timestamp_is_microseconds(self, tmp_path: Path) -> None: + _run(tmp_path) + lines = (tmp_path / "journald/system.jsonl").read_text().splitlines() + ts = int(json.loads(lines[0])["__REALTIME_TIMESTAMP"]) + # microseconds since epoch — should be > year 2020 + assert ts > 1_577_836_800_000_000 + + def test_parseable_by_journald_glean(self, tmp_path: Path) -> None: + from app.glean.journald import parse + _run(tmp_path) + with (tmp_path / "journald/system.jsonl").open() as fh: + entries = list(parse(fh, "test", [])) + assert len(entries) > 0 + severities = {e.severity for e in entries if e.severity} + assert severities <= {"INFO", "DEBUG", "WARN", "ERROR", "CRITICAL"} + + +# ── Docker format ────────────────────────────────────────────────────────────── + +class TestDockerFormat: + def test_valid_json_lines(self, tmp_path: Path) -> None: + _run(tmp_path) + lines = (tmp_path / "docker/services.jsonl").read_text().splitlines() + for line in lines[:100]: + obj = json.loads(line) + assert "SOURCE" in obj + assert "MESSAGE" in obj + + def test_parseable_by_docker_glean(self, tmp_path: Path) -> None: + from app.glean.docker_log import parse + _run(tmp_path) + with (tmp_path / "docker/services.jsonl").open() as fh: + entries = list(parse(fh, "test", [])) + assert len(entries) > 0 + # Severity should be detected in most entries (messages embed level= / LEVEL:) + detected = [e for e in entries if e.severity is not None] + assert len(detected) / len(entries) > 0.8 + + +# ── qBittorrent format ───────────────────────────────────────────────────────── + +class TestQbittorrentFormat: + def test_hotio_format_lines(self, tmp_path: Path) -> None: + _run(tmp_path) + lines = (tmp_path / "qbittorrent/qbt.log").read_text().splitlines() + import re + pattern = re.compile(r"^\([NIWC]\) \d{4}-\d{2}-\d{2}T\d{2}:\d{2}:\d{2} - .+$") + assert all(pattern.match(line) for line in lines[:50]) + + def test_parseable_by_qbt_glean(self, tmp_path: Path) -> None: + from app.glean.qbittorrent import parse + _run(tmp_path) + with (tmp_path / "qbittorrent/qbt.log").open() as fh: + entries = list(parse(fh, "test", [])) + assert len(entries) > 0 + severities = {e.severity for e in entries if e.severity} + assert severities <= {"INFO", "WARN", "CRITICAL"} + + +# ── EXT_DEVICE format ──────────────────────────────────────────────────────────────── + +class TestAvcxFormat: + def test_iso_timestamp_prefix(self, tmp_path: Path) -> None: + _run(tmp_path) + lines = (tmp_path / "ext_device/device.log").read_text().splitlines() + import re + pattern = re.compile(r"^\d{4}-\d{2}-\d{2}T\d{2}:\d{2}:\d{2} \[.+\] .+$") + assert all(pattern.match(line) for line in lines[:50]) + + def test_parseable_by_plaintext_glean(self, tmp_path: Path) -> None: + from app.glean.plaintext import parse + _run(tmp_path) + with (tmp_path / "ext_device/device.log").open() as fh: + entries = list(parse(fh, "test", [])) + assert len(entries) > 0 + # ISO timestamps should parse cleanly + timestamped = [e for e in entries if e.timestamp_iso] + assert len(timestamped) / len(entries) > 0.95 + + +# ── Error rate ───────────────────────────────────────────────────────────────── + +class TestErrorRate: + def test_high_error_rate_increases_errors(self, tmp_path: Path) -> None: + from app.glean.journald import parse + + low = tmp_path / "low" + high = tmp_path / "high" + _run(low, seed=7, error_rate=0.01) + _run(high, seed=7, error_rate=0.50) + + def error_ratio(path: Path) -> float: + with path.open() as fh: + entries = list(parse(fh, "test", [])) + errs = sum(1 for e in entries if e.severity in ("ERROR", "CRITICAL")) + return errs / len(entries) if entries else 0.0 + + assert error_ratio(high / "journald/system.jsonl") > error_ratio(low / "journald/system.jsonl") + + def test_invalid_error_rate_returns_nonzero(self, tmp_path: Path) -> None: + rc = main(["--days", "1", "--out", str(tmp_path), "--error-rate", "1.5"]) + assert rc != 0 + + +# ── CLI ──────────────────────────────────────────────────────────────────────── + +class TestCLI: + def test_acceptance_criteria(self, tmp_path: Path) -> None: + """Acceptance: --days 7 --out produces a gleanable corpus with varied severities.""" + from app.glean.journald import parse + + rc = main(["--days", "7", "--out", str(tmp_path)]) + assert rc == 0 + + with (tmp_path / "journald/system.jsonl").open() as fh: + entries = list(parse(fh, "test", [])) + + severities = {e.severity for e in entries if e.severity} + assert {"INFO", "WARN", "ERROR", "CRITICAL"}.issubset(severities) + assert len(entries) > 100_000 # 7 days of ~86k/day + + def test_missing_out_fails(self, tmp_path: Path, capsys: pytest.CaptureFixture) -> None: + with pytest.raises(SystemExit) as exc_info: + main(["--days", "1"]) + assert exc_info.value.code != 0 diff --git a/tests/test_glean_fingerprint.py b/tests/test_glean_fingerprint.py index 96aca23..827838b 100644 --- a/tests/test_glean_fingerprint.py +++ b/tests/test_glean_fingerprint.py @@ -51,12 +51,14 @@ class TestFingerprintHelpers: def test_fp_unchanged_returns_false_when_no_record(self, db_path: Path, log_file: Path) -> None: conn = sqlite3.connect(str(db_path)) + conn.row_factory = sqlite3.Row mtime, size = _fingerprint(log_file) assert _fp_unchanged(conn, log_file, mtime, size) is False conn.close() def test_fp_unchanged_returns_true_after_save(self, db_path: Path, log_file: Path) -> None: conn = sqlite3.connect(str(db_path)) + conn.row_factory = sqlite3.Row mtime, size = _fingerprint(log_file) _save_fingerprint(conn, log_file, mtime, size, now_iso()) conn.commit() @@ -65,6 +67,7 @@ class TestFingerprintHelpers: def test_fp_unchanged_returns_false_on_size_change(self, db_path: Path, log_file: Path) -> None: conn = sqlite3.connect(str(db_path)) + conn.row_factory = sqlite3.Row mtime, size = _fingerprint(log_file) _save_fingerprint(conn, log_file, mtime, size, now_iso()) conn.commit() @@ -74,6 +77,7 @@ class TestFingerprintHelpers: def test_fp_unchanged_returns_false_on_mtime_change(self, db_path: Path, log_file: Path) -> None: conn = sqlite3.connect(str(db_path)) + conn.row_factory = sqlite3.Row mtime, size = _fingerprint(log_file) _save_fingerprint(conn, log_file, mtime, size, now_iso()) conn.commit() diff --git a/tests/test_glean_syslog.py b/tests/test_glean_syslog.py index cb3573d..b6115f1 100644 --- a/tests/test_glean_syslog.py +++ b/tests/test_glean_syslog.py @@ -4,24 +4,24 @@ from __future__ import annotations from app.glean.syslog import is_syslog, parse SYSLOG_SAMPLE = """\ -May 11 14:23:01 example-node sshd[1234]: Accepted publickey for x from 192.168.1.1 port 54321 ssh2 -May 11 14:23:05 example-node sshd[1234]: Failed password for invalid user admin from 10.0.0.99 port 22 ssh2 -May 11 14:23:10 example-node sudo[5678]: x : TTY=pts/0 ; PWD=/home/x ; USER=root ; COMMAND=/usr/bin/apt update -May 11 14:23:15 example-node kernel: [12345.678] usb 1-1: USB disconnect, device number 2 -May 1 04:00:00 example-node CRON[9999]: (root) CMD (/usr/local/sbin/backup.sh) -May 11 14:24:00 example-node systemd[1]: Started NetworkManager. +May 11 14:23:01 testhost sshd[1234]: Accepted publickey for x from 192.168.1.1 port 54321 ssh2 +May 11 14:23:05 testhost sshd[1234]: Failed password for invalid user admin from 10.0.0.99 port 22 ssh2 +May 11 14:23:10 testhost sudo[5678]: x : TTY=pts/0 ; PWD=/home/x ; USER=root ; COMMAND=/usr/bin/apt update +May 11 14:23:15 testhost kernel: [12345.678] usb 1-1: USB disconnect, device number 2 +May 1 04:00:00 testhost CRON[9999]: (root) CMD (/usr/local/sbin/backup.sh) +May 11 14:24:00 testhost systemd[1]: Started NetworkManager. """ class TestDetector: def test_detects_standard_line(self): - assert is_syslog("May 11 14:23:01 example-node sshd[1234]: message") + assert is_syslog("May 11 14:23:01 testhost sshd[1234]: message") def test_detects_no_pid(self): - assert is_syslog("May 11 14:23:01 example-node kernel: message") + assert is_syslog("May 11 14:23:01 testhost kernel: message") def test_detects_space_padded_day(self): - assert is_syslog("May 1 04:00:00 example-node CRON[9999]: message") + assert is_syslog("May 1 04:00:00 testhost CRON[9999]: message") def test_rejects_servarr(self): assert not is_syslog("2026-05-11 02:31:51.5|Info|ComponentName|Message") diff --git a/tests/test_hybrid_search.py b/tests/test_hybrid_search.py index 1e3101e..631c5fb 100644 --- a/tests/test_hybrid_search.py +++ b/tests/test_hybrid_search.py @@ -33,12 +33,11 @@ def db(tmp_path: Path) -> Path: ("database connection refused backend gone away", "ERROR"), ("mDNS avahi heartbeat ok", "INFO"), ]): - # Columns: id, source_id, sequence, timestamp_raw, timestamp_iso, - # ingest_time, severity, repeat_count, out_of_order, - # matched_patterns, text conn.execute( - "INSERT INTO log_entries VALUES (?,?,?,?,?,?,?,?,?,?,?)", - (str(uuid.uuid4()), "src", i, None, None, "2026-01-01T00:00:00", sev, 1, 0, "[]", text), + "INSERT INTO log_entries(id, tenant_id, source_id, sequence, timestamp_raw, " + "timestamp_iso, ingest_time, severity, repeat_count, out_of_order, " + "matched_patterns, text) VALUES (?,?,?,?,?,?,?,?,?,?,?,?)", + (str(uuid.uuid4()), "", "src", i, None, None, "2026-01-01T00:00:00", sev, 1, 0, "[]", text), ) conn.commit() conn.close() diff --git a/tests/test_incident_detector.py b/tests/test_incident_detector.py new file mode 100644 index 0000000..c3a5e32 --- /dev/null +++ b/tests/test_incident_detector.py @@ -0,0 +1,238 @@ +"""Tests for app/tasks/incident_detector.py auto-incident detection.""" +from __future__ import annotations + +import sqlite3 +import tempfile +from datetime import datetime, timedelta, timezone +from pathlib import Path + +import pytest + +import sys +sys.path.insert(0, str(Path(__file__).parent.parent)) + +from app.db import ensure_schema, ensure_incidents_schema +from app.services.incidents import create_incident, list_incidents +from app.tasks.incident_detector import ( + _find_clusters, + _incident_exists_for_cluster, + _parse_ts, + detect_and_create, +) + + +# ── Helpers ──────────────────────────────────────────────────────────────────── + +def _make_db(path: Path) -> None: + ensure_schema(path) + + +def _make_incidents_db(path: Path) -> None: + ensure_incidents_schema(path) + + +def _iso(base: datetime, offset_s: float) -> str: + return (base + timedelta(seconds=offset_s)).isoformat() + + +def _insert_entry(db: Path, source_id: str, ts_iso: str, severity: str, ingest_time: str) -> None: + with sqlite3.connect(db) as conn: + conn.execute( + "INSERT INTO log_entries (id, source_id, sequence, timestamp_iso, ingest_time, " + "severity, text, repeat_count, out_of_order, matched_patterns, tenant_id) " + "VALUES (?,?,?,?,?,?,?,?,?,?,?)", + ( + f"{source_id}-{ts_iso}", source_id, 0, ts_iso, ingest_time, + severity, "error text", 0, 0, "[]", "", + ), + ) + + +# ── _parse_ts ────────────────────────────────────────────────────────────────── + +class TestParseTs: + def test_parses_utc_iso(self) -> None: + ts = _parse_ts("2026-06-11T12:00:00+00:00") + assert ts is not None + assert ts > 0 + + def test_parses_z_suffix(self) -> None: + ts = _parse_ts("2026-06-11T12:00:00Z") + assert ts is not None + + def test_none_input(self) -> None: + assert _parse_ts(None) is None + + def test_invalid_input(self) -> None: + assert _parse_ts("not-a-date") is None + + +# ── _find_clusters ───────────────────────────────────────────────────────────── + +class TestFindClusters: + BASE = datetime(2026, 6, 11, 12, 0, 0, tzinfo=timezone.utc) + + def _events(self, offsets: list[float], severity: str = "ERROR") -> list[dict]: + return [{"timestamp_iso": _iso(self.BASE, o), "severity": severity} for o in offsets] + + def test_dense_cluster_detected(self) -> None: + events = self._events([0, 60, 120, 180, 240]) # 5 errors in 4 min + clusters = _find_clusters(events, window_s=600, threshold=5) + assert len(clusters) == 1 + + def test_sparse_events_no_cluster(self) -> None: + events = self._events([0, 300, 600, 900, 1200]) # 5 errors, each 5 min apart + clusters = _find_clusters(events, window_s=60, threshold=5) + assert clusters == [] + + def test_threshold_not_met(self) -> None: + events = self._events([0, 10, 20, 30]) # only 4 events + clusters = _find_clusters(events, window_s=600, threshold=5) + assert clusters == [] + + def test_critical_wins_over_error(self) -> None: + events = self._events([0, 10, 20, 30, 40], "ERROR") + events[2]["severity"] = "CRITICAL" + clusters = _find_clusters(events, window_s=600, threshold=5) + assert clusters[0][2] == "CRITICAL" + + def test_two_non_overlapping_clusters(self) -> None: + # Dense cluster at 0-4 min, then another at 60-64 min + e1 = self._events([0, 60, 120, 180, 240]) + e2 = self._events([3600, 3660, 3720, 3780, 3840]) + clusters = _find_clusters(e1 + e2, window_s=600, threshold=5) + assert len(clusters) == 2 + + def test_no_timestamp_events_skipped(self) -> None: + events = [{"timestamp_iso": None, "severity": "ERROR"}] * 10 + clusters = _find_clusters(events, window_s=600, threshold=5) + assert clusters == [] + + +# ── _incident_exists_for_cluster ─────────────────────────────────────────────── + +class TestIncidentExists: + BASE = datetime(2026, 6, 11, 12, 0, 0, tzinfo=timezone.utc) + + def test_no_existing_incidents(self, tmp_path: Path) -> None: + db = tmp_path / "inc.db" + _make_incidents_db(db) + assert not _incident_exists_for_cluster( + db, "nginx", _iso(self.BASE, 0), _iso(self.BASE, 600) + ) + + def test_exact_overlap_detected(self, tmp_path: Path) -> None: + db = tmp_path / "inc.db" + _make_incidents_db(db) + create_incident( + db, label="Auto: nginx — 5 errors", + issue_type="auto:nginx", + started_at=_iso(self.BASE, 0), + ended_at=_iso(self.BASE, 600), + severity="high", + ) + assert _incident_exists_for_cluster( + db, "nginx", _iso(self.BASE, 100), _iso(self.BASE, 400) + ) + + def test_different_source_not_matched(self, tmp_path: Path) -> None: + db = tmp_path / "inc.db" + _make_incidents_db(db) + create_incident( + db, label="Auto: caddy — 5 errors", + issue_type="auto:caddy", + started_at=_iso(self.BASE, 0), + ended_at=_iso(self.BASE, 600), + severity="high", + ) + assert not _incident_exists_for_cluster( + db, "nginx", _iso(self.BASE, 0), _iso(self.BASE, 600) + ) + + def test_non_overlapping_not_matched(self, tmp_path: Path) -> None: + db = tmp_path / "inc.db" + _make_incidents_db(db) + create_incident( + db, label="Auto: nginx — 5 errors", + issue_type="auto:nginx", + started_at=_iso(self.BASE, 0), + ended_at=_iso(self.BASE, 300), + severity="high", + ) + # Cluster starts after existing incident ends + assert not _incident_exists_for_cluster( + db, "nginx", _iso(self.BASE, 900), _iso(self.BASE, 1200) + ) + + +# ── detect_and_create ────────────────────────────────────────────────────────── + +class TestDetectAndCreate: + BASE = datetime(2026, 6, 11, 12, 0, 0, tzinfo=timezone.utc) + + def _setup(self, tmp_path: Path) -> tuple[Path, Path]: + db = tmp_path / "ts.db" + idb = tmp_path / "incidents.db" + _make_db(db) + _make_incidents_db(idb) + return db, idb + + def test_creates_incident_on_cluster(self, tmp_path: Path) -> None: + db, idb = self._setup(tmp_path) + ingest = _iso(self.BASE, -60) + for i in range(6): + _insert_entry(db, "nginx", _iso(self.BASE, i * 30), "ERROR", ingest) + + result = detect_and_create(db, idb, since=_iso(self.BASE, -120)) + assert result["created"] == 1 + incidents = list_incidents(idb) + assert len(incidents) == 1 + assert "nginx" in incidents[0].label + assert incidents[0].issue_type == "auto:nginx" + + def test_no_incident_below_threshold(self, tmp_path: Path) -> None: + db, idb = self._setup(tmp_path) + ingest = _iso(self.BASE, -60) + for i in range(4): # only 4 errors — below default threshold of 5 + _insert_entry(db, "nginx", _iso(self.BASE, i * 30), "ERROR", ingest) + + result = detect_and_create(db, idb, since=_iso(self.BASE, -120), threshold=5) + assert result["created"] == 0 + + def test_no_duplicate_incidents(self, tmp_path: Path) -> None: + db, idb = self._setup(tmp_path) + ingest = _iso(self.BASE, -60) + for i in range(6): + _insert_entry(db, "nginx", _iso(self.BASE, i * 30), "ERROR", ingest) + + detect_and_create(db, idb, since=_iso(self.BASE, -120)) + detect_and_create(db, idb, since=_iso(self.BASE, -120)) # second run + + incidents = list_incidents(idb) + assert len(incidents) == 1 + + def test_critical_severity_mapped_to_critical_label(self, tmp_path: Path) -> None: + db, idb = self._setup(tmp_path) + ingest = _iso(self.BASE, -60) + for i in range(6): + sev = "CRITICAL" if i == 0 else "ERROR" + _insert_entry(db, "sshd", _iso(self.BASE, i * 30), sev, ingest) + + detect_and_create(db, idb, since=_iso(self.BASE, -120)) + incidents = list_incidents(idb) + assert incidents[0].severity == "critical" + + def test_empty_db_returns_zero(self, tmp_path: Path) -> None: + db, idb = self._setup(tmp_path) + result = detect_and_create(db, idb, since=None) + assert result["created"] == 0 + + def test_independent_sources_each_get_incident(self, tmp_path: Path) -> None: + db, idb = self._setup(tmp_path) + ingest = _iso(self.BASE, -60) + for src in ["caddy", "nginx"]: + for i in range(6): + _insert_entry(db, src, _iso(self.BASE, i * 30), "ERROR", ingest) + + result = detect_and_create(db, idb, since=_iso(self.BASE, -120)) + assert result["created"] == 2 diff --git a/web/src/App.vue b/web/src/App.vue index 914984d..f6a1b48 100644 --- a/web/src/App.vue +++ b/web/src/App.vue @@ -76,6 +76,7 @@ const navLinks = [ { to: '/search', label: 'Search' }, { to: '/diagnose', label: 'Diagnose' }, { to: '/incidents', label: 'Incidents' }, + { to: '/alerts', label: 'Alerts' }, { to: '/bundles', label: 'Bundles' }, { to: '/sources', label: 'Sources' }, { to: '/context', label: 'Context' }, diff --git a/web/src/components/ChatDiagnose.vue b/web/src/components/ChatDiagnose.vue new file mode 100644 index 0000000..eb87110 --- /dev/null +++ b/web/src/components/ChatDiagnose.vue @@ -0,0 +1,370 @@ + + + + +