From 0311d72e53b3d19a4eb5be356c006131d215cd98 Mon Sep 17 00:00:00 2001
From: pyr0ball
Date: Mon, 8 Jun 2026 08:37:54 -0700
Subject: [PATCH 01/17] feat: dual-backend SQLite/Postgres + multi-tenant
source namespacing
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit
- Add app/db/ abstraction layer: Backend enum, DbConn wrapper,
dialect helper (q() for ? vs %s paramstyle), get_conn(), tenant_id()
- Auto-detect backend from DATABASE_URL; SQLite remains default when
unset — no config change for local deployments
- Add tenant_id column to all three logical DBs (main, context, incidents);
idempotent ALTER TABLE migration runs before schema scripts on existing DBs
- All INSERTs inject tenant_id; SELECTs use (tenant_id = ? OR tenant_id = '')
for backward compat with pre-namespacing rows
- Add docker-compose.yml with named volume turnstone_pgdata (survives rebuilds)
and optional external Postgres support via DATABASE_URL override
- Add scripts/migrate_sqlite_to_postgres.py — one-shot idempotent migration
for existing SQLite data; ON CONFLICT DO NOTHING for safe re-runs
- Fix SSH glean path in pipeline.py to use ensure_schema + get_conn
(was still using raw sqlite3.connect + old _SCHEMA without tenant_id)
- Fix FTS5 JOIN ambiguity: qualify repeat_count as f.repeat_count in search
- Update all tests to use ensure_*_schema fixtures; add row_factory where needed
- 394/394 tests passing
Closes: https://git.opensourcesolarpunk.com/Circuit-Forge/turnstone/issues/42
Closes: https://git.opensourcesolarpunk.com/Circuit-Forge/turnstone/issues/50
---
app/context/store.py | 102 +++---
app/db/__init__.py | 36 ++
app/db/backend.py | 20 ++
app/db/conn.py | 136 ++++++++
app/db/dialect.py | 93 +++++
app/db/schema.py | 454 +++++++++++++++++++++++++
app/db/tenant.py | 12 +
app/glean/doc_upload.py | 19 +-
app/glean/pipeline.py | 270 ++++-----------
app/mcp_server.py | 10 +-
app/rest.py | 4 +-
app/services/blocklist.py | 116 +++----
app/services/incidents.py | 146 ++++----
app/services/search.py | 392 +++++++++++++--------
app/tasks/glean_scheduler.py | 19 +-
app/watch/watcher.py | 48 ++-
docker-compose.yml | 50 +++
requirements.txt | 2 +
scripts/migrate_sqlite_to_postgres.py | 204 +++++++++++
tests/context/test_diagnose_context.py | 37 +-
tests/context/test_doc_upload.py | 21 +-
tests/context/test_schema.py | 10 +-
tests/context/test_store.py | 20 +-
tests/context/test_wizard.py | 11 +-
tests/test_glean_fingerprint.py | 4 +
tests/test_hybrid_search.py | 9 +-
26 files changed, 1584 insertions(+), 661 deletions(-)
create mode 100644 app/db/__init__.py
create mode 100644 app/db/backend.py
create mode 100644 app/db/conn.py
create mode 100644 app/db/dialect.py
create mode 100644 app/db/schema.py
create mode 100644 app/db/tenant.py
create mode 100644 docker-compose.yml
create mode 100644 scripts/migrate_sqlite_to_postgres.py
diff --git a/app/context/store.py b/app/context/store.py
index 1ffa08a..a030570 100644
--- a/app/context/store.py
+++ b/app/context/store.py
@@ -1,12 +1,13 @@
"""Context fact and document CRUD — MIT licensed."""
from __future__ import annotations
-import sqlite3
import uuid
from dataclasses import dataclass
from datetime import datetime, timezone
from pathlib import Path
+from app.db import get_conn, resolve_tenant_id
+
@dataclass(frozen=True)
class ContextFact:
@@ -28,19 +29,8 @@ class ContextDocument:
uploaded_at: str
-def _connect(db_path: Path) -> sqlite3.Connection:
- # timeout=30: retry for up to 30 s when another writer (e.g. the glean
- # collector) holds a WAL write lock. PRAGMA busy_timeout is a SQLite-level
- # hint that operates after the connection is open; the Python sqlite3 module's
- # own retry loop is controlled solely by this timeout= argument.
- conn = sqlite3.connect(str(db_path), timeout=30.0)
- conn.execute("PRAGMA journal_mode=WAL")
- conn.execute("PRAGMA foreign_keys=ON")
- conn.row_factory = sqlite3.Row
- return conn
-
-
def add_fact(db_path: Path, category: str, key: str, value: str, source: str | None = None) -> ContextFact:
+ tid = resolve_tenant_id()
fact = ContextFact(
id=str(uuid.uuid4()),
category=category,
@@ -49,27 +39,28 @@ def add_fact(db_path: Path, category: str, key: str, value: str, source: str | N
source=source,
created_at=datetime.now(timezone.utc).isoformat(),
)
- conn = _connect(db_path)
- conn.execute(
- "INSERT INTO context_facts(id, category, key, value, source, created_at) VALUES (?,?,?,?,?,?)",
- (fact.id, fact.category, fact.key, fact.value, fact.source, fact.created_at),
- )
- conn.commit()
- conn.close()
+ with get_conn(db_path) as conn:
+ conn.execute(
+ "INSERT INTO context_facts(id, tenant_id, category, key, value, source, created_at) VALUES (?,?,?,?,?,?,?)",
+ (fact.id, tid, fact.category, fact.key, fact.value, fact.source, fact.created_at),
+ )
+ conn.commit()
return fact
def list_facts(db_path: Path, category: str | None = None) -> list[ContextFact]:
- conn = _connect(db_path)
- if category:
- rows = conn.execute(
- "SELECT * FROM context_facts WHERE category=? ORDER BY created_at", (category,)
- ).fetchall()
- else:
- rows = conn.execute(
- "SELECT * FROM context_facts ORDER BY category, created_at"
- ).fetchall()
- conn.close()
+ tid = resolve_tenant_id()
+ with get_conn(db_path) as conn:
+ if category:
+ rows = conn.execute(
+ "SELECT * FROM context_facts WHERE category=? AND (tenant_id=? OR tenant_id='') ORDER BY created_at",
+ (category, tid),
+ ).fetchall()
+ else:
+ rows = conn.execute(
+ "SELECT * FROM context_facts WHERE (tenant_id=? OR tenant_id='') ORDER BY category, created_at",
+ (tid,),
+ ).fetchall()
return [
ContextFact(
id=r["id"], category=r["category"], key=r["key"],
@@ -80,10 +71,13 @@ def list_facts(db_path: Path, category: str | None = None) -> list[ContextFact]:
def delete_fact(db_path: Path, fact_id: str) -> bool:
- conn = _connect(db_path)
- cursor = conn.execute("DELETE FROM context_facts WHERE id=?", (fact_id,))
- conn.commit()
- conn.close()
+ tid = resolve_tenant_id()
+ with get_conn(db_path) as conn:
+ cursor = conn.execute(
+ "DELETE FROM context_facts WHERE id=? AND (tenant_id=? OR tenant_id='')",
+ (fact_id, tid),
+ )
+ conn.commit()
return cursor.rowcount > 0
@@ -94,6 +88,7 @@ def add_document(
full_text: str,
file_size: int | None = None,
) -> ContextDocument:
+ tid = resolve_tenant_id()
doc = ContextDocument(
id=str(uuid.uuid4()),
filename=filename,
@@ -102,24 +97,24 @@ def add_document(
file_size=file_size,
uploaded_at=datetime.now(timezone.utc).isoformat(),
)
- conn = _connect(db_path)
- conn.execute(
- "INSERT INTO context_documents(id, filename, doc_type, full_text, file_size, uploaded_at)"
- " VALUES (?,?,?,?,?,?)",
- (doc.id, doc.filename, doc.doc_type, doc.full_text, doc.file_size, doc.uploaded_at),
- )
- conn.commit()
- conn.close()
+ with get_conn(db_path) as conn:
+ conn.execute(
+ "INSERT INTO context_documents(id, tenant_id, filename, doc_type, full_text, file_size, uploaded_at)"
+ " VALUES (?,?,?,?,?,?,?)",
+ (doc.id, tid, doc.filename, doc.doc_type, doc.full_text, doc.file_size, doc.uploaded_at),
+ )
+ conn.commit()
return doc
def list_documents(db_path: Path) -> list[ContextDocument]:
- conn = _connect(db_path)
- rows = conn.execute(
- "SELECT id, filename, doc_type, full_text, file_size, uploaded_at"
- " FROM context_documents ORDER BY uploaded_at DESC"
- ).fetchall()
- conn.close()
+ tid = resolve_tenant_id()
+ with get_conn(db_path) as conn:
+ rows = conn.execute(
+ "SELECT id, filename, doc_type, full_text, file_size, uploaded_at"
+ " FROM context_documents WHERE (tenant_id=? OR tenant_id='') ORDER BY uploaded_at DESC",
+ (tid,),
+ ).fetchall()
return [
ContextDocument(
id=r["id"], filename=r["filename"], doc_type=r["doc_type"],
@@ -130,8 +125,11 @@ def list_documents(db_path: Path) -> list[ContextDocument]:
def delete_document(db_path: Path, doc_id: str) -> bool:
- conn = _connect(db_path)
- cursor = conn.execute("DELETE FROM context_documents WHERE id=?", (doc_id,))
- conn.commit()
- conn.close()
+ tid = resolve_tenant_id()
+ with get_conn(db_path) as conn:
+ cursor = conn.execute(
+ "DELETE FROM context_documents WHERE id=? AND (tenant_id=? OR tenant_id='')",
+ (doc_id, tid),
+ )
+ conn.commit()
return cursor.rowcount > 0
diff --git a/app/db/__init__.py b/app/db/__init__.py
new file mode 100644
index 0000000..5823b7b
--- /dev/null
+++ b/app/db/__init__.py
@@ -0,0 +1,36 @@
+"""Turnstone database abstraction — unified SQLite / Postgres interface.
+
+Public API:
+ BACKEND — Backend.SQLITE or Backend.POSTGRES
+ get_conn(path) — context manager yielding a DbConn
+ resolve_tenant_id() — this node's tenant ID (env or hostname)
+ q(sql) — rewrite ? placeholders to %s for Postgres
+ frag — SQL fragment helpers (insert_or_ignore, source_group_expr, …)
+ ensure_schema — idempotent schema init
+ close_pool — call during shutdown when using Postgres
+"""
+from app.db.backend import BACKEND, Backend
+from app.db.conn import DbConn, close_pool, get_conn
+from app.db.dialect import frag, q
+from app.db.schema import (
+ ensure_context_schema,
+ ensure_incidents_schema,
+ ensure_schema,
+ migrate_incidents_to_dedicated_db,
+)
+from app.db.tenant import resolve_tenant_id
+
+__all__ = [
+ "BACKEND",
+ "Backend",
+ "DbConn",
+ "close_pool",
+ "get_conn",
+ "frag",
+ "q",
+ "ensure_schema",
+ "ensure_context_schema",
+ "ensure_incidents_schema",
+ "migrate_incidents_to_dedicated_db",
+ "resolve_tenant_id",
+]
diff --git a/app/db/backend.py b/app/db/backend.py
new file mode 100644
index 0000000..2e86839
--- /dev/null
+++ b/app/db/backend.py
@@ -0,0 +1,20 @@
+"""Backend detection — SQLITE (default) or POSTGRES based on DATABASE_URL."""
+from __future__ import annotations
+
+import os
+from enum import Enum
+
+
+class Backend(Enum):
+ SQLITE = "sqlite"
+ POSTGRES = "postgres"
+
+
+def _detect() -> Backend:
+ url = os.environ.get("DATABASE_URL", "")
+ if url.startswith(("postgresql://", "postgres://", "postgresql+psycopg://")):
+ return Backend.POSTGRES
+ return Backend.SQLITE
+
+
+BACKEND: Backend = _detect()
diff --git a/app/db/conn.py b/app/db/conn.py
new file mode 100644
index 0000000..51f62ed
--- /dev/null
+++ b/app/db/conn.py
@@ -0,0 +1,136 @@
+"""Uniform connection wrapper over sqlite3 and psycopg3.
+
+Usage:
+ with get_conn(db_path) as conn:
+ conn.execute("SELECT ...", (param,))
+ conn.commit()
+
+For Postgres, db_path is ignored — all connections go through the shared pool.
+The pool is initialized lazily on first use from DATABASE_URL.
+"""
+from __future__ import annotations
+
+import logging
+import os
+import sqlite3
+from contextlib import contextmanager
+from pathlib import Path
+from typing import Any, Generator
+
+from app.db.backend import BACKEND, Backend
+
+logger = logging.getLogger(__name__)
+
+_pool: Any = None # psycopg_pool.ConnectionPool, typed as Any to avoid import-time errors
+
+
+class _NopCursor:
+ """Returned when a PRAGMA or other SQLite-only statement is skipped on Postgres."""
+ rowcount = 0
+
+ def fetchall(self) -> list:
+ return []
+
+ def fetchone(self) -> None:
+ return None
+
+ def __iter__(self):
+ return iter([])
+
+
+class DbConn:
+ """Wraps a raw sqlite3 or psycopg connection with a uniform execute API.
+
+ Row access is always dict-like:
+ - SQLite: conn.row_factory = sqlite3.Row (supports row["col"] and row[0])
+ - Postgres: row_factory = dict_row (returns plain dicts)
+ """
+
+ __slots__ = ("_c", "_backend")
+
+ def __init__(self, raw: Any, backend: Backend) -> None:
+ self._c = raw
+ self._backend = backend
+
+ def _prep(self, sql: str) -> str | None:
+ """Return None to skip (PRAGMA on Postgres), else return ready-to-execute SQL."""
+ stripped = sql.strip()
+ if self._backend == Backend.POSTGRES and stripped.lower().startswith("pragma"):
+ return None
+ if self._backend == Backend.POSTGRES:
+ return stripped.replace("?", "%s")
+ return stripped
+
+ def execute(self, sql: str, params: Any = ()) -> Any:
+ prepared = self._prep(sql)
+ if prepared is None:
+ return _NopCursor()
+ return self._c.execute(prepared, params)
+
+ def executemany(self, sql: str, params_seq: Any) -> Any:
+ prepared = self._prep(sql)
+ if prepared is None:
+ return _NopCursor()
+ return self._c.executemany(prepared, params_seq)
+
+ def commit(self) -> None:
+ self._c.commit()
+
+ def close(self) -> None:
+ self._c.close()
+
+ def __enter__(self) -> "DbConn":
+ return self
+
+ def __exit__(self, *_: Any) -> None:
+ self.close()
+
+
+def _get_pool() -> Any:
+ global _pool
+ if _pool is not None:
+ return _pool
+ try:
+ from psycopg_pool import ConnectionPool # type: ignore[import]
+ url = os.environ["DATABASE_URL"]
+ _pool = ConnectionPool(url, min_size=2, max_size=10, open=True)
+ logger.info("Postgres connection pool opened (DATABASE_URL set)")
+ return _pool
+ except ImportError as exc:
+ raise RuntimeError(
+ "psycopg[binary,pool] is required for Postgres backend. "
+ "Run: pip install 'psycopg[binary,pool]'"
+ ) from exc
+ except KeyError:
+ raise RuntimeError("DATABASE_URL must be set when using Postgres backend") from None
+
+
+@contextmanager
+def get_conn(db_path: Path | None = None) -> Generator[DbConn, None, None]:
+ """Yield a DbConn backed by sqlite3 (db_path required) or the Postgres pool."""
+ if BACKEND == Backend.POSTGRES:
+ pool = _get_pool()
+ from psycopg.rows import dict_row # type: ignore[import]
+ with pool.connection() as raw:
+ raw.row_factory = dict_row
+ yield DbConn(raw, BACKEND)
+ else:
+ if db_path is None:
+ raise ValueError("db_path is required for SQLite backend")
+ raw = sqlite3.connect(str(db_path), timeout=30.0)
+ raw.row_factory = sqlite3.Row
+ try:
+ raw.execute("PRAGMA journal_mode=WAL")
+ raw.execute("PRAGMA foreign_keys=ON")
+ yield DbConn(raw, BACKEND)
+ finally:
+ raw.close()
+
+
+def close_pool() -> None:
+ """Close the Postgres connection pool — call during application shutdown."""
+ global _pool
+ if _pool is not None:
+ _pool.close()
+ _pool = None
+ logger.info("Postgres connection pool closed")
diff --git a/app/db/dialect.py b/app/db/dialect.py
new file mode 100644
index 0000000..70f018a
--- /dev/null
+++ b/app/db/dialect.py
@@ -0,0 +1,93 @@
+"""Per-backend SQL fragments and placeholder rewriting.
+
+All production SQL should be written with SQLite-style `?` placeholders.
+Call q(sql) before passing to execute/executemany — it rewrites to %s for
+Postgres and leaves SQLite queries untouched.
+"""
+from __future__ import annotations
+
+from app.db.backend import BACKEND, Backend
+
+
+def q(sql: str) -> str:
+ """Rewrite ? placeholders to %s for Postgres; no-op for SQLite."""
+ if BACKEND == Backend.POSTGRES:
+ return sql.replace("?", "%s")
+ return sql
+
+
+class _Fragments:
+ """SQL fragments that differ between backends."""
+
+ @property
+ def insert_or_ignore(self) -> str:
+ return "INSERT" if BACKEND == Backend.POSTGRES else "INSERT OR IGNORE"
+
+ @property
+ def on_conflict_ignore(self) -> str:
+ # Caller must substitute the column name(s) at use time when using Postgres.
+ # For log_entries: ON CONFLICT (tenant_id, id) DO NOTHING
+ # For generic use this property is a no-op sentinel; prefer insert_ignore_into().
+ return ""
+
+ def insert_ignore_entries(self) -> str:
+ """Full INSERT ... ON CONFLICT clause for log_entries."""
+ if BACKEND == Backend.POSTGRES:
+ return "INSERT INTO log_entries"
+ return "INSERT OR IGNORE INTO log_entries"
+
+ def entries_conflict_clause(self) -> str:
+ if BACKEND == Backend.POSTGRES:
+ return "ON CONFLICT (tenant_id, id) DO NOTHING"
+ return ""
+
+ def fingerprint_upsert(self) -> str:
+ if BACKEND == Backend.POSTGRES:
+ return (
+ "INSERT INTO glean_fingerprints (tenant_id, path, mtime, size, gleaned_at)"
+ " VALUES (%s, %s, %s, %s, %s)"
+ " ON CONFLICT (tenant_id, path)"
+ " DO UPDATE SET mtime=EXCLUDED.mtime, size=EXCLUDED.size, gleaned_at=EXCLUDED.gleaned_at"
+ )
+ return (
+ "INSERT OR REPLACE INTO glean_fingerprints (tenant_id, path, mtime, size, gleaned_at)"
+ " VALUES (?,?,?,?,?)"
+ )
+
+ def source_group_expr(self, col: str = "source_id") -> str:
+ """SQL expression that collapses prefix:host:unit → prefix:host stem."""
+ if BACKEND == Backend.POSTGRES:
+ return f"""
+ CASE
+ WHEN array_length(string_to_array({col}, ':'), 1) >= 3
+ THEN split_part({col}, ':', 1) || ':' || split_part({col}, ':', 2)
+ ELSE {col}
+ END
+ """
+ return f"""
+ CASE
+ WHEN INSTR(SUBSTR({col}, INSTR({col}, ':')+1), ':') > 0
+ THEN SUBSTR({col}, 1,
+ INSTR({col}, ':')
+ + INSTR(SUBSTR({col}, INSTR({col}, ':')+1), ':')
+ - 1)
+ ELSE {col}
+ END
+ """
+
+ def fts_match_clause(self) -> str:
+ """WHERE clause fragment for FTS query. Caller supplies the query param."""
+ if BACKEND == Backend.POSTGRES:
+ return "text_tsv @@ websearch_to_tsquery('english', %s)"
+ return "log_fts MATCH ?"
+
+ def fts_rank_expr(self) -> str:
+ """ORDER BY expression for FTS rank (best match first). Postgres needs the query twice."""
+ if BACKEND == Backend.POSTGRES:
+ # ts_rank returns 0..1 where higher is better; pass the query again as param
+ return "ts_rank(text_tsv, websearch_to_tsquery('english', %s)) DESC"
+ # FTS5 rank is negative BM25; ASC = most-negative = best match
+ return "rank ASC"
+
+
+frag = _Fragments()
diff --git a/app/db/schema.py b/app/db/schema.py
new file mode 100644
index 0000000..7cc8d97
--- /dev/null
+++ b/app/db/schema.py
@@ -0,0 +1,454 @@
+"""Schema creation and idempotent migrations for all Turnstone databases.
+
+Three logical databases (main, context, incidents) map to:
+ - SQLite: three separate .db files (avoids write-lock contention)
+ - Postgres: three table-groups in one physical DB (row-level locking makes separation unnecessary)
+
+All ensure_* functions are idempotent: safe to call on every startup.
+"""
+from __future__ import annotations
+
+import logging
+import sqlite3
+from pathlib import Path
+
+from app.db.backend import BACKEND, Backend
+from app.db.conn import get_conn
+
+logger = logging.getLogger(__name__)
+
+# ---------------------------------------------------------------------------
+# SQLite DDL — kept as executescript strings (SQLite only)
+# ---------------------------------------------------------------------------
+
+_MAIN_SCHEMA_SQLITE = """
+CREATE TABLE IF NOT EXISTS log_entries (
+ id TEXT NOT NULL,
+ tenant_id TEXT NOT NULL DEFAULT '',
+ source_id TEXT NOT NULL,
+ sequence INTEGER NOT NULL,
+ timestamp_raw TEXT,
+ timestamp_iso TEXT,
+ ingest_time TEXT NOT NULL,
+ severity TEXT,
+ repeat_count INTEGER DEFAULT 1,
+ out_of_order INTEGER DEFAULT 0,
+ matched_patterns TEXT DEFAULT '[]',
+ text TEXT NOT NULL,
+ PRIMARY KEY (tenant_id, id)
+);
+CREATE INDEX IF NOT EXISTS idx_source ON log_entries(source_id);
+CREATE INDEX IF NOT EXISTS idx_tenant_src ON log_entries(tenant_id, source_id);
+CREATE INDEX IF NOT EXISTS idx_timestamp ON log_entries(timestamp_iso);
+CREATE INDEX IF NOT EXISTS idx_ts_repeat ON log_entries(timestamp_iso, repeat_count);
+CREATE INDEX IF NOT EXISTS idx_severity ON log_entries(tenant_id, severity);
+CREATE INDEX IF NOT EXISTS idx_patterns ON log_entries(matched_patterns);
+
+CREATE TABLE IF NOT EXISTS glean_fingerprints (
+ tenant_id TEXT NOT NULL DEFAULT '',
+ path TEXT NOT NULL,
+ mtime REAL NOT NULL,
+ size INTEGER NOT NULL,
+ gleaned_at TEXT NOT NULL,
+ PRIMARY KEY (tenant_id, path)
+);
+
+CREATE TABLE IF NOT EXISTS incidents (
+ id TEXT PRIMARY KEY,
+ tenant_id TEXT NOT NULL DEFAULT '',
+ label TEXT NOT NULL,
+ issue_type TEXT NOT NULL DEFAULT '',
+ started_at TEXT,
+ ended_at TEXT,
+ notes TEXT NOT NULL DEFAULT '',
+ created_at TEXT NOT NULL,
+ severity TEXT NOT NULL DEFAULT 'medium'
+);
+CREATE INDEX IF NOT EXISTS idx_incidents_time ON incidents(started_at, ended_at);
+CREATE INDEX IF NOT EXISTS idx_incidents_tenant ON incidents(tenant_id);
+
+CREATE TABLE IF NOT EXISTS received_bundles (
+ id TEXT PRIMARY KEY,
+ tenant_id TEXT NOT NULL DEFAULT '',
+ source_host TEXT NOT NULL,
+ issue_type TEXT NOT NULL DEFAULT '',
+ label TEXT NOT NULL,
+ severity TEXT NOT NULL DEFAULT 'medium',
+ started_at TEXT,
+ bundled_at TEXT NOT NULL,
+ entry_count INTEGER NOT NULL DEFAULT 0,
+ bundle_json TEXT NOT NULL
+);
+CREATE INDEX IF NOT EXISTS idx_bundles_bundled ON received_bundles(bundled_at);
+CREATE INDEX IF NOT EXISTS idx_bundles_type ON received_bundles(issue_type);
+
+CREATE TABLE IF NOT EXISTS sent_bundles (
+ id TEXT PRIMARY KEY,
+ tenant_id TEXT NOT NULL DEFAULT '',
+ incident_id TEXT NOT NULL,
+ exported_at TEXT NOT NULL,
+ sanitized INTEGER NOT NULL DEFAULT 0,
+ entry_count INTEGER NOT NULL DEFAULT 0,
+ bundle_json TEXT NOT NULL
+);
+CREATE INDEX IF NOT EXISTS idx_sent_bundles_incident ON sent_bundles(incident_id);
+CREATE INDEX IF NOT EXISTS idx_sent_bundles_time ON sent_bundles(exported_at);
+
+CREATE TABLE IF NOT EXISTS blocklist_candidates (
+ id TEXT PRIMARY KEY,
+ tenant_id TEXT NOT NULL DEFAULT '',
+ domain_or_ip TEXT NOT NULL,
+ source_device_ip TEXT,
+ source_device_name TEXT,
+ first_seen TEXT NOT NULL,
+ last_seen TEXT NOT NULL,
+ hit_count INTEGER DEFAULT 1,
+ status TEXT DEFAULT 'pending',
+ pushed_at TEXT,
+ log_evidence TEXT DEFAULT '[]',
+ matched_rule TEXT,
+ llm_score REAL,
+ llm_reason TEXT
+);
+CREATE INDEX IF NOT EXISTS idx_blocklist_device ON blocklist_candidates(source_device_ip);
+CREATE INDEX IF NOT EXISTS idx_blocklist_status ON blocklist_candidates(status);
+CREATE INDEX IF NOT EXISTS idx_blocklist_domain ON blocklist_candidates(domain_or_ip);
+CREATE INDEX IF NOT EXISTS idx_blocklist_tenant ON blocklist_candidates(tenant_id);
+"""
+
+_CONTEXT_SCHEMA_SQLITE = """
+CREATE TABLE IF NOT EXISTS context_facts (
+ id TEXT PRIMARY KEY,
+ tenant_id TEXT NOT NULL DEFAULT '',
+ category TEXT NOT NULL,
+ key TEXT NOT NULL,
+ value TEXT NOT NULL,
+ source TEXT,
+ created_at TEXT NOT NULL
+);
+CREATE INDEX IF NOT EXISTS idx_facts_category ON context_facts(category);
+CREATE INDEX IF NOT EXISTS idx_facts_key ON context_facts(key);
+CREATE INDEX IF NOT EXISTS idx_facts_tenant ON context_facts(tenant_id);
+
+CREATE TABLE IF NOT EXISTS context_documents (
+ id TEXT PRIMARY KEY,
+ tenant_id TEXT NOT NULL DEFAULT '',
+ filename TEXT NOT NULL,
+ doc_type TEXT NOT NULL,
+ full_text TEXT NOT NULL,
+ file_size INTEGER,
+ uploaded_at TEXT NOT NULL
+);
+CREATE INDEX IF NOT EXISTS idx_docs_tenant ON context_documents(tenant_id);
+
+CREATE TABLE IF NOT EXISTS context_chunks (
+ id TEXT PRIMARY KEY,
+ tenant_id TEXT NOT NULL DEFAULT '',
+ document_id TEXT NOT NULL REFERENCES context_documents(id) ON DELETE CASCADE,
+ chunk_index INTEGER NOT NULL,
+ text TEXT NOT NULL,
+ embedding BLOB
+);
+CREATE INDEX IF NOT EXISTS idx_chunks_doc ON context_chunks(document_id);
+CREATE INDEX IF NOT EXISTS idx_chunks_tenant ON context_chunks(tenant_id);
+"""
+
+
+# ---------------------------------------------------------------------------
+# Postgres DDL — executed statement-by-statement
+# ---------------------------------------------------------------------------
+
+_MAIN_SCHEMA_PG_STMTS = [
+ """
+ CREATE TABLE IF NOT EXISTS log_entries (
+ id TEXT NOT NULL,
+ tenant_id TEXT NOT NULL DEFAULT '',
+ source_id TEXT NOT NULL,
+ sequence INTEGER NOT NULL,
+ timestamp_raw TEXT,
+ timestamp_iso TEXT,
+ ingest_time TEXT NOT NULL,
+ severity TEXT,
+ repeat_count INTEGER DEFAULT 1,
+ out_of_order INTEGER DEFAULT 0,
+ matched_patterns TEXT DEFAULT '[]',
+ text TEXT NOT NULL,
+ text_tsv tsvector,
+ PRIMARY KEY (tenant_id, id)
+ )
+ """,
+ "CREATE INDEX IF NOT EXISTS idx_tenant_src ON log_entries(tenant_id, source_id)",
+ "CREATE INDEX IF NOT EXISTS idx_timestamp ON log_entries(timestamp_iso)",
+ "CREATE INDEX IF NOT EXISTS idx_severity ON log_entries(tenant_id, severity)",
+ "CREATE INDEX IF NOT EXISTS idx_patterns ON log_entries(matched_patterns)",
+ "CREATE INDEX IF NOT EXISTS idx_fts_gin ON log_entries USING GIN(text_tsv)",
+ """
+ CREATE OR REPLACE FUNCTION _ts_update_text_tsv() RETURNS trigger AS $$
+ BEGIN
+ NEW.text_tsv := to_tsvector('english', COALESCE(NEW.text, ''));
+ RETURN NEW;
+ END;
+ $$ LANGUAGE plpgsql
+ """,
+ """
+ DO $$ BEGIN
+ IF NOT EXISTS (
+ SELECT 1 FROM pg_trigger WHERE tgname = 'trig_log_entries_tsv'
+ ) THEN
+ CREATE TRIGGER trig_log_entries_tsv
+ BEFORE INSERT OR UPDATE OF text ON log_entries
+ FOR EACH ROW EXECUTE FUNCTION _ts_update_text_tsv();
+ END IF;
+ END $$
+ """,
+ """
+ CREATE TABLE IF NOT EXISTS glean_fingerprints (
+ tenant_id TEXT NOT NULL DEFAULT '',
+ path TEXT NOT NULL,
+ mtime DOUBLE PRECISION NOT NULL,
+ size BIGINT NOT NULL,
+ gleaned_at TEXT NOT NULL,
+ PRIMARY KEY (tenant_id, path)
+ )
+ """,
+ """
+ CREATE TABLE IF NOT EXISTS incidents (
+ id TEXT PRIMARY KEY,
+ tenant_id TEXT NOT NULL DEFAULT '',
+ label TEXT NOT NULL,
+ issue_type TEXT NOT NULL DEFAULT '',
+ started_at TEXT,
+ ended_at TEXT,
+ notes TEXT NOT NULL DEFAULT '',
+ created_at TEXT NOT NULL,
+ severity TEXT NOT NULL DEFAULT 'medium'
+ )
+ """,
+ "CREATE INDEX IF NOT EXISTS idx_incidents_time ON incidents(started_at, ended_at)",
+ "CREATE INDEX IF NOT EXISTS idx_incidents_tenant ON incidents(tenant_id)",
+ """
+ CREATE TABLE IF NOT EXISTS received_bundles (
+ id TEXT PRIMARY KEY,
+ tenant_id TEXT NOT NULL DEFAULT '',
+ source_host TEXT NOT NULL,
+ issue_type TEXT NOT NULL DEFAULT '',
+ label TEXT NOT NULL,
+ severity TEXT NOT NULL DEFAULT 'medium',
+ started_at TEXT,
+ bundled_at TEXT NOT NULL,
+ entry_count INTEGER NOT NULL DEFAULT 0,
+ bundle_json TEXT NOT NULL
+ )
+ """,
+ "CREATE INDEX IF NOT EXISTS idx_bundles_bundled ON received_bundles(bundled_at)",
+ "CREATE INDEX IF NOT EXISTS idx_bundles_type ON received_bundles(issue_type)",
+ """
+ CREATE TABLE IF NOT EXISTS sent_bundles (
+ id TEXT PRIMARY KEY,
+ tenant_id TEXT NOT NULL DEFAULT '',
+ incident_id TEXT NOT NULL,
+ exported_at TEXT NOT NULL,
+ sanitized INTEGER NOT NULL DEFAULT 0,
+ entry_count INTEGER NOT NULL DEFAULT 0,
+ bundle_json TEXT NOT NULL
+ )
+ """,
+ "CREATE INDEX IF NOT EXISTS idx_sent_bundles_incident ON sent_bundles(incident_id)",
+ "CREATE INDEX IF NOT EXISTS idx_sent_bundles_time ON sent_bundles(exported_at)",
+ """
+ CREATE TABLE IF NOT EXISTS blocklist_candidates (
+ id TEXT PRIMARY KEY,
+ tenant_id TEXT NOT NULL DEFAULT '',
+ domain_or_ip TEXT NOT NULL,
+ source_device_ip TEXT,
+ source_device_name TEXT,
+ first_seen TEXT NOT NULL,
+ last_seen TEXT NOT NULL,
+ hit_count INTEGER DEFAULT 1,
+ status TEXT DEFAULT 'pending',
+ pushed_at TEXT,
+ log_evidence TEXT DEFAULT '[]',
+ matched_rule TEXT,
+ llm_score DOUBLE PRECISION,
+ llm_reason TEXT
+ )
+ """,
+ "CREATE INDEX IF NOT EXISTS idx_blocklist_device ON blocklist_candidates(source_device_ip)",
+ "CREATE INDEX IF NOT EXISTS idx_blocklist_status ON blocklist_candidates(status)",
+ "CREATE INDEX IF NOT EXISTS idx_blocklist_domain ON blocklist_candidates(domain_or_ip)",
+ "CREATE INDEX IF NOT EXISTS idx_blocklist_tenant ON blocklist_candidates(tenant_id)",
+]
+
+_CONTEXT_SCHEMA_PG_STMTS = [
+ """
+ CREATE TABLE IF NOT EXISTS context_facts (
+ id TEXT PRIMARY KEY,
+ tenant_id TEXT NOT NULL DEFAULT '',
+ category TEXT NOT NULL,
+ key TEXT NOT NULL,
+ value TEXT NOT NULL,
+ source TEXT,
+ created_at TEXT NOT NULL
+ )
+ """,
+ "CREATE INDEX IF NOT EXISTS idx_facts_category ON context_facts(category)",
+ "CREATE INDEX IF NOT EXISTS idx_facts_key ON context_facts(key)",
+ "CREATE INDEX IF NOT EXISTS idx_facts_tenant ON context_facts(tenant_id)",
+ """
+ CREATE TABLE IF NOT EXISTS context_documents (
+ id TEXT PRIMARY KEY,
+ tenant_id TEXT NOT NULL DEFAULT '',
+ filename TEXT NOT NULL,
+ doc_type TEXT NOT NULL,
+ full_text TEXT NOT NULL,
+ file_size BIGINT,
+ uploaded_at TEXT NOT NULL
+ )
+ """,
+ "CREATE INDEX IF NOT EXISTS idx_docs_tenant ON context_documents(tenant_id)",
+ """
+ CREATE TABLE IF NOT EXISTS context_chunks (
+ id TEXT PRIMARY KEY,
+ tenant_id TEXT NOT NULL DEFAULT '',
+ document_id TEXT NOT NULL REFERENCES context_documents(id) ON DELETE CASCADE,
+ chunk_index INTEGER NOT NULL,
+ text TEXT NOT NULL,
+ embedding BYTEA
+ )
+ """,
+ "CREATE INDEX IF NOT EXISTS idx_chunks_doc ON context_chunks(document_id)",
+ "CREATE INDEX IF NOT EXISTS idx_chunks_tenant ON context_chunks(tenant_id)",
+]
+
+
+# ---------------------------------------------------------------------------
+# SQLite additive column migrations — applied after CREATE TABLE on every boot
+# ---------------------------------------------------------------------------
+
+_MAIN_MIGRATIONS_SQLITE = [
+ "ALTER TABLE log_entries ADD COLUMN tenant_id TEXT NOT NULL DEFAULT ''",
+ "ALTER TABLE incidents ADD COLUMN issue_type TEXT NOT NULL DEFAULT ''",
+ "ALTER TABLE incidents ADD COLUMN tenant_id TEXT NOT NULL DEFAULT ''",
+ "ALTER TABLE received_bundles ADD COLUMN tenant_id TEXT NOT NULL DEFAULT ''",
+ "ALTER TABLE sent_bundles ADD COLUMN tenant_id TEXT NOT NULL DEFAULT ''",
+ "ALTER TABLE blocklist_candidates ADD COLUMN tenant_id TEXT NOT NULL DEFAULT ''",
+ "ALTER TABLE glean_fingerprints ADD COLUMN tenant_id TEXT NOT NULL DEFAULT ''",
+ "ALTER TABLE glean_fingerprints ADD COLUMN mtime REAL",
+ "ALTER TABLE glean_fingerprints ADD COLUMN size INTEGER",
+ "ALTER TABLE glean_fingerprints ADD COLUMN gleaned_at TEXT",
+]
+
+_CONTEXT_MIGRATIONS_SQLITE = [
+ "ALTER TABLE context_facts ADD COLUMN tenant_id TEXT NOT NULL DEFAULT ''",
+ "ALTER TABLE context_documents ADD COLUMN tenant_id TEXT NOT NULL DEFAULT ''",
+ "ALTER TABLE context_chunks ADD COLUMN tenant_id TEXT NOT NULL DEFAULT ''",
+]
+
+
+def _run_sqlite_migrations(conn: sqlite3.Connection, stmts: list[str]) -> None:
+ for stmt in stmts:
+ try:
+ conn.execute(stmt)
+ except sqlite3.OperationalError:
+ pass # column already exists or table not present yet — both are fine
+
+
+def _run_pg_stmts(stmts: list[str]) -> None:
+ """Execute Postgres DDL statements — each in its own transaction for IF NOT EXISTS safety."""
+ from psycopg import connect as pg_connect # type: ignore[import]
+ import os
+ url = os.environ["DATABASE_URL"]
+ with pg_connect(url, autocommit=True) as conn:
+ for stmt in stmts:
+ stripped = stmt.strip()
+ if stripped:
+ conn.execute(stripped)
+
+
+# ---------------------------------------------------------------------------
+# Public API
+# ---------------------------------------------------------------------------
+
+def ensure_schema(db_path: Path) -> None:
+ """Ensure main log/incidents/blocklist tables exist. Idempotent."""
+ if BACKEND == Backend.POSTGRES:
+ _run_pg_stmts(_MAIN_SCHEMA_PG_STMTS)
+ logger.debug("Postgres main schema verified")
+ return
+
+ conn = sqlite3.connect(str(db_path), timeout=30.0)
+ conn.execute("PRAGMA journal_mode=WAL")
+ # Migrations first: add tenant_id to existing tables BEFORE index creation touches it
+ _run_sqlite_migrations(conn, _MAIN_MIGRATIONS_SQLITE)
+ conn.commit()
+ conn.executescript(_MAIN_SCHEMA_SQLITE)
+ conn.close()
+ logger.debug("SQLite main schema verified at %s", db_path)
+
+
+def ensure_context_schema(db_path: Path) -> None:
+ """Ensure context KB tables exist. Idempotent."""
+ if BACKEND == Backend.POSTGRES:
+ _run_pg_stmts(_CONTEXT_SCHEMA_PG_STMTS)
+ logger.debug("Postgres context schema verified")
+ return
+
+ conn = sqlite3.connect(str(db_path), timeout=30.0)
+ conn.execute("PRAGMA journal_mode=WAL")
+ conn.execute("PRAGMA foreign_keys=ON")
+ _run_sqlite_migrations(conn, _CONTEXT_MIGRATIONS_SQLITE)
+ conn.commit()
+ conn.executescript(_CONTEXT_SCHEMA_SQLITE)
+ conn.close()
+ logger.debug("SQLite context schema verified at %s", db_path)
+
+
+def migrate_incidents_to_dedicated_db(main_db: Path, incidents_db: Path) -> int:
+ """One-shot migration: copy incidents/bundles rows from main DB to incidents DB.
+
+ Safe to call on every startup — rows already in incidents_db are skipped.
+ No-op for Postgres (single DB, no migration needed).
+ """
+ if BACKEND == Backend.POSTGRES:
+ return 0
+
+ src = sqlite3.connect(str(main_db), timeout=30.0)
+ src.row_factory = sqlite3.Row
+ dst = sqlite3.connect(str(incidents_db), timeout=30.0)
+ migrated = 0
+ for table in ("incidents", "received_bundles", "sent_bundles"):
+ try:
+ rows = src.execute(f"SELECT * FROM {table}").fetchall() # noqa: S608
+ except sqlite3.OperationalError:
+ continue
+ if not rows:
+ continue
+ cols = ", ".join(rows[0].keys())
+ placeholders = ", ".join("?" * len(rows[0].keys()))
+ dst.executemany(
+ f"INSERT OR IGNORE INTO {table} ({cols}) VALUES ({placeholders})", # noqa: S608
+ [tuple(r) for r in rows],
+ )
+ migrated += len(rows)
+ dst.commit()
+ src.close()
+ dst.close()
+ return migrated
+
+
+def ensure_incidents_schema(db_path: Path) -> None:
+ """Ensure incidents/bundles tables exist. Idempotent.
+
+ For Postgres, incidents live in the same DB as log_entries (already created by
+ ensure_schema), so this is a no-op — the tables were created above.
+ """
+ if BACKEND == Backend.POSTGRES:
+ return
+
+ conn = sqlite3.connect(str(db_path), timeout=30.0)
+ conn.execute("PRAGMA journal_mode=WAL")
+ _run_sqlite_migrations(conn, _MAIN_MIGRATIONS_SQLITE)
+ conn.commit()
+ conn.executescript(_MAIN_SCHEMA_SQLITE)
+ conn.close()
+ logger.debug("SQLite incidents schema verified at %s", db_path)
diff --git a/app/db/tenant.py b/app/db/tenant.py
new file mode 100644
index 0000000..5d2542e
--- /dev/null
+++ b/app/db/tenant.py
@@ -0,0 +1,12 @@
+"""Tenant ID resolution — TURNSTONE_TENANT_ID env var, hostname fallback."""
+from __future__ import annotations
+
+import os
+import socket
+from functools import lru_cache
+
+
+@lru_cache(maxsize=1)
+def resolve_tenant_id() -> str:
+ """Return this node's tenant ID. Result is cached after first call."""
+ return os.environ.get("TURNSTONE_TENANT_ID") or socket.gethostname()
diff --git a/app/glean/doc_upload.py b/app/glean/doc_upload.py
index c2d4d9a..0cfd604 100644
--- a/app/glean/doc_upload.py
+++ b/app/glean/doc_upload.py
@@ -1,18 +1,19 @@
"""Upload adapter: processes file bytes and writes to context store — MIT licensed."""
from __future__ import annotations
-import sqlite3
import uuid
from pathlib import Path
from typing import Any
from app.context.chunker import process_upload
from app.context.store import add_document, add_fact
+from app.db import get_conn, resolve_tenant_id
def glean_upload(db_path: Path, filename: str, content: bytes) -> dict[str, Any]:
"""Process an uploaded file and write to context store. Returns result summary."""
doc_type, facts, chunks = process_upload(filename, content)
+ tid = resolve_tenant_id()
doc = add_document(
db_path,
@@ -25,15 +26,13 @@ def glean_upload(db_path: Path, filename: str, content: bytes) -> dict[str, Any]
for fact in facts:
add_fact(db_path, fact.category, fact.key, fact.value, source="upload")
- conn = sqlite3.connect(str(db_path), timeout=30.0)
- conn.execute("PRAGMA journal_mode=WAL")
- for i, chunk_text in enumerate(chunks):
- conn.execute(
- "INSERT INTO context_chunks(id, document_id, chunk_index, text) VALUES (?,?,?,?)",
- (str(uuid.uuid4()), doc.id, i, chunk_text),
- )
- conn.commit()
- conn.close()
+ with get_conn(db_path) as conn:
+ for i, chunk_text in enumerate(chunks):
+ conn.execute(
+ "INSERT INTO context_chunks(id, tenant_id, document_id, chunk_index, text) VALUES (?,?,?,?,?)",
+ (str(uuid.uuid4()), tid, doc.id, i, chunk_text),
+ )
+ conn.commit()
return {
"document_id": doc.id,
diff --git a/app/glean/pipeline.py b/app/glean/pipeline.py
index 38bd0f1..2cb3184 100644
--- a/app/glean/pipeline.py
+++ b/app/glean/pipeline.py
@@ -1,12 +1,24 @@
-"""Glean pipeline: auto-detect format, parse, write to SQLite."""
+"""Glean pipeline: auto-detect format, parse, write to SQLite or Postgres."""
from __future__ import annotations
import json
import logging
import re
-import sqlite3
+import sqlite3 # still used in migrate_incidents_to_dedicated_db (SQLite-only migration)
from pathlib import Path
-from typing import Iterator
+from typing import Any, Iterator
+
+from app.db import (
+ frag,
+ get_conn,
+ resolve_tenant_id,
+)
+from app.db.schema import (
+ ensure_context_schema,
+ ensure_incidents_schema,
+ ensure_schema,
+ migrate_incidents_to_dedicated_db,
+)
import yaml
@@ -169,127 +181,13 @@ CREATE INDEX IF NOT EXISTS idx_chunks_doc ON context_chunks(document_id);
"""
-def ensure_schema(db_path: Path) -> None:
- """Create all tables and apply additive migrations. Safe to call on every startup."""
- conn = sqlite3.connect(str(db_path), timeout=30.0)
- conn.execute("PRAGMA journal_mode=WAL")
- conn.executescript(_SCHEMA)
- # Additive column migrations — ALTER TABLE silently skips if column exists
- for stmt in [
- "ALTER TABLE incidents ADD COLUMN issue_type TEXT NOT NULL DEFAULT ''",
- ]:
- try:
- conn.execute(stmt)
- except sqlite3.OperationalError:
- pass
- conn.commit()
- conn.close()
+# ensure_schema / ensure_context_schema / ensure_incidents_schema / migrate_incidents_to_dedicated_db
+# are now implemented in app/db/schema.py and re-exported via app/db/__init__.py.
+# The imports at the top of this file bring them in; these names are kept as module-level
+# symbols so existing callers (rest.py, tests) still find them here without changes.
-def ensure_context_schema(db_path: Path) -> None:
- """Create context KB tables in a dedicated database file.
-
- Using a separate file from the main log DB means context fact writes never
- contend with the high-throughput glean scheduler, which can hold the main
- DB write lock for seconds at a time when flushing large journal batches.
- """
- conn = sqlite3.connect(str(db_path), timeout=30.0)
- conn.execute("PRAGMA journal_mode=WAL")
- conn.execute("PRAGMA foreign_keys=ON")
- conn.executescript(_CONTEXT_SCHEMA)
- conn.commit()
- conn.close()
-
-
-_INCIDENTS_SCHEMA = """
-CREATE TABLE IF NOT EXISTS incidents (
- id TEXT PRIMARY KEY,
- label TEXT NOT NULL,
- issue_type TEXT NOT NULL DEFAULT '',
- started_at TEXT,
- ended_at TEXT,
- notes TEXT NOT NULL DEFAULT '',
- created_at TEXT NOT NULL,
- severity TEXT NOT NULL DEFAULT 'medium'
-);
-CREATE INDEX IF NOT EXISTS idx_incidents_time ON incidents(started_at, ended_at);
-
-CREATE TABLE IF NOT EXISTS received_bundles (
- id TEXT PRIMARY KEY,
- source_host TEXT NOT NULL,
- issue_type TEXT NOT NULL DEFAULT '',
- label TEXT NOT NULL,
- severity TEXT NOT NULL DEFAULT 'medium',
- started_at TEXT,
- bundled_at TEXT NOT NULL,
- entry_count INTEGER NOT NULL DEFAULT 0,
- bundle_json TEXT NOT NULL
-);
-CREATE INDEX IF NOT EXISTS idx_bundles_bundled ON received_bundles(bundled_at);
-CREATE INDEX IF NOT EXISTS idx_bundles_type ON received_bundles(issue_type);
-
-CREATE TABLE IF NOT EXISTS sent_bundles (
- id TEXT PRIMARY KEY,
- incident_id TEXT NOT NULL,
- exported_at TEXT NOT NULL,
- sanitized INTEGER NOT NULL DEFAULT 0,
- entry_count INTEGER NOT NULL DEFAULT 0,
- bundle_json TEXT NOT NULL
-);
-CREATE INDEX IF NOT EXISTS idx_sent_bundles_incident ON sent_bundles(incident_id);
-CREATE INDEX IF NOT EXISTS idx_sent_bundles_time ON sent_bundles(exported_at);
-"""
-
-
-def ensure_incidents_schema(db_path: Path) -> None:
- """Create incidents tables in a dedicated database file.
-
- Using a separate file from the main log DB means incident writes never
- contend with the FTS5 bulk-insert write lock held by the glean scheduler.
- Mirrors the context_facts split (CONTEXT_DB_PATH / turnstone-context.db).
- """
- conn = sqlite3.connect(str(db_path), timeout=30.0)
- conn.execute("PRAGMA journal_mode=WAL")
- conn.executescript(_INCIDENTS_SCHEMA)
- for stmt in [
- "ALTER TABLE incidents ADD COLUMN issue_type TEXT NOT NULL DEFAULT ''",
- ]:
- try:
- conn.execute(stmt)
- except sqlite3.OperationalError:
- pass
- conn.commit()
- conn.close()
-
-
-def migrate_incidents_to_dedicated_db(main_db: Path, incidents_db: Path) -> int:
- """One-shot migration: copy incidents/bundles rows from main DB to incidents DB.
-
- Safe to call on every startup — rows already present in incidents_db are
- skipped via INSERT OR IGNORE. Returns the count of rows migrated.
- """
- src = sqlite3.connect(str(main_db), timeout=30.0)
- src.row_factory = sqlite3.Row
- dst = sqlite3.connect(str(incidents_db), timeout=30.0)
- migrated = 0
- for table in ("incidents", "received_bundles", "sent_bundles"):
- try:
- rows = src.execute(f"SELECT * FROM {table}").fetchall() # noqa: S608
- except sqlite3.OperationalError:
- continue
- if not rows:
- continue
- cols = ", ".join(rows[0].keys())
- placeholders = ", ".join("?" * len(rows[0].keys()))
- dst.executemany(
- f"INSERT OR IGNORE INTO {table} ({cols}) VALUES ({placeholders})", # noqa: S608
- [tuple(r) for r in rows],
- )
- migrated += len(rows)
- dst.commit()
- src.close()
- dst.close()
- return migrated
+# _INCIDENTS_SCHEMA and its ensure_/migrate_ functions moved to app/db/schema.py
def _fingerprint(path: Path) -> tuple[float, int]:
@@ -298,36 +196,28 @@ def _fingerprint(path: Path) -> tuple[float, int]:
return st.st_mtime, st.st_size
-def _fp_unchanged(conn: sqlite3.Connection, path: Path, mtime: float, size: int) -> bool:
- """Return True only when the stored fingerprint exactly matches (mtime, size).
-
- A smaller size (log rotation) or a larger size (new lines appended) both
- return False so the caller re-gleams the file.
- """
+def _fp_unchanged(conn: Any, path: Path, mtime: float, size: int) -> bool:
+ """Return True only when the stored fingerprint exactly matches (mtime, size)."""
+ tid = resolve_tenant_id()
row = conn.execute(
- "SELECT mtime, size FROM glean_fingerprints WHERE path = ?",
- (str(path),),
+ "SELECT mtime, size FROM glean_fingerprints WHERE path = ? AND (tenant_id = ? OR tenant_id = '')",
+ (str(path), tid),
).fetchone()
if row is None:
return False
- return row[0] == mtime and row[1] == size
+ return row["mtime"] == mtime and row["size"] == size
def _save_fingerprint(
- conn: sqlite3.Connection,
+ conn: Any,
path: Path,
mtime: float,
size: int,
gleaned_at: str,
) -> None:
"""Upsert the fingerprint for *path* after a successful glean."""
- conn.execute(
- """
- INSERT OR REPLACE INTO glean_fingerprints (path, mtime, size, gleaned_at)
- VALUES (?, ?, ?, ?)
- """,
- (str(path), mtime, size, gleaned_at),
- )
+ tid = resolve_tenant_id()
+ conn.execute(frag.fingerprint_upsert(), (tid, str(path), mtime, size, gleaned_at))
def _detect_format(first_line: str) -> str:
@@ -400,18 +290,22 @@ def _parse_file(
yield from plaintext.parse(all_lines(), source_id, compiled, ingest_time)
-def _write_batch(conn: sqlite3.Connection, batch: list[RetrievedEntry]) -> None:
- conn.executemany(
- """
- INSERT OR IGNORE INTO log_entries
- (id, source_id, sequence, timestamp_raw, timestamp_iso,
+def _write_batch(conn: Any, batch: list[RetrievedEntry]) -> None:
+ tid = resolve_tenant_id()
+ conflict = frag.entries_conflict_clause()
+ sql = f"""
+ {frag.insert_ignore_entries()}
+ (tenant_id, id, source_id, sequence, timestamp_raw, timestamp_iso,
ingest_time, severity, repeat_count, out_of_order,
matched_patterns, text)
- VALUES (?,?,?,?,?,?,?,?,?,?,?)
- """,
+ VALUES (?,?,?,?,?,?,?,?,?,?,?,?)
+ {conflict}
+ """
+ conn.executemany(
+ sql,
[
(
- e.entry_id, e.source_id, e.sequence,
+ tid, e.entry_id, e.source_id, e.sequence,
e.timestamp_raw, e.timestamp_iso, e.ingest_time,
e.severity, e.repeat_count, int(e.out_of_order),
json.dumps(list(e.matched_patterns)), e.text,
@@ -435,46 +329,41 @@ def _glean_files(
ingest_time = now_iso()
source_id_map = source_id_map or {}
- conn = sqlite3.connect(str(db_path), timeout=30.0)
- conn.execute("PRAGMA journal_mode=WAL")
- conn.executescript(_SCHEMA)
- conn.commit()
+ ensure_schema(db_path)
- stats: dict[str, int] = {}
- skipped: list[str] = []
+ with get_conn(db_path) as conn:
+ stats: dict[str, int] = {}
+ skipped: list[str] = []
- for log_file in files:
- source_id = source_id_map.get(log_file, log_file.stem)
+ for log_file in files:
+ source_id = source_id_map.get(log_file, log_file.stem)
- # Fingerprint check — skip files whose mtime+size haven't changed.
- mtime, size = _fingerprint(log_file)
- if not force and _fp_unchanged(conn, log_file, mtime, size):
- logger.debug("Skipping unchanged file: %s", log_file.name)
- skipped.append(log_file.name)
- stats[source_id] = stats.get(source_id, 0)
- continue
+ mtime, size = _fingerprint(log_file)
+ if not force and _fp_unchanged(conn, log_file, mtime, size):
+ logger.debug("Skipping unchanged file: %s", log_file.name)
+ skipped.append(log_file.name)
+ stats[source_id] = stats.get(source_id, 0)
+ continue
- count = 0
- batch: list[RetrievedEntry] = []
- for entry in _parse_file(log_file, compiled, ingest_time, source_id=source_id):
- batch.append(entry)
- if len(batch) >= batch_size:
+ count = 0
+ batch: list[RetrievedEntry] = []
+ for entry in _parse_file(log_file, compiled, ingest_time, source_id=source_id):
+ batch.append(entry)
+ if len(batch) >= batch_size:
+ _write_batch(conn, batch)
+ conn.commit()
+ count += len(batch)
+ batch.clear()
+ if batch:
_write_batch(conn, batch)
conn.commit()
count += len(batch)
- batch.clear()
- if batch:
- _write_batch(conn, batch)
+
+ _save_fingerprint(conn, log_file, mtime, size, ingest_time)
conn.commit()
- count += len(batch)
- _save_fingerprint(conn, log_file, mtime, size, ingest_time)
- conn.commit()
-
- stats[source_id] = stats.get(source_id, 0) + count
- logger.info("Gleaned %d entries from %s (source: %s)", count, log_file.name, source_id)
-
- conn.close()
+ stats[source_id] = stats.get(source_id, 0) + count
+ logger.info("Gleaned %d entries from %s (source: %s)", count, log_file.name, source_id)
if skipped:
logger.info("Skipped %d unchanged file(s): %s", len(skipped), ", ".join(skipped))
@@ -493,7 +382,7 @@ def _stream_and_write(
source_id: str,
compiled: list[tuple[LogPattern, object]],
ingest_time: str,
- conn: sqlite3.Connection,
+ conn: Any,
batch_size: int,
) -> int:
"""Stream *cmd* output through *parser* and write entries to *conn*.
@@ -525,7 +414,7 @@ def _glean_ssh_source(
src: dict, # type: ignore[type-arg]
compiled: list[tuple[LogPattern, object]],
ingest_time: str,
- conn: sqlite3.Connection,
+ conn: Any,
batch_size: int,
) -> dict[str, int]:
"""Open one SSHTransport connection for *src* and glean all its glean items.
@@ -618,15 +507,9 @@ def glean_ssh_source(
compiled = _compile(load_patterns(effective_pattern_file))
ingest_time = now_iso()
- conn = sqlite3.connect(str(db_path), timeout=30.0)
- conn.execute("PRAGMA journal_mode=WAL")
- conn.executescript(_SCHEMA)
- conn.commit()
-
- try:
+ ensure_schema(db_path)
+ with get_conn(db_path) as conn:
stats = _glean_ssh_source(src, compiled, ingest_time, conn, batch_size)
- finally:
- conn.close()
logger.info("Rebuilding FTS index after SSH source glean...")
build_fts_index(db_path)
@@ -740,18 +623,13 @@ def glean_sources(
compiled = _compile(load_patterns(effective_pattern_file))
ingest_time = now_iso()
- conn = sqlite3.connect(str(db_path), timeout=30.0)
- conn.execute("PRAGMA journal_mode=WAL")
- conn.executescript(_SCHEMA)
- conn.commit()
-
- try:
+ ensure_schema(db_path)
+ with get_conn(db_path) as conn:
for src in ssh_sources:
ssh_stats = _glean_ssh_source(src, compiled, ingest_time, conn, batch_size)
for k, v in ssh_stats.items():
stats[k] = stats.get(k, 0) + v
- finally:
- conn.close()
+ conn.commit()
# Rebuild FTS only when SSH sources added entries (_glean_files already
# rebuilds when local sources are present; safe to call again if both ran).
diff --git a/app/mcp_server.py b/app/mcp_server.py
index 607a3ca..5eec5fd 100644
--- a/app/mcp_server.py
+++ b/app/mcp_server.py
@@ -11,7 +11,7 @@ from __future__ import annotations
import logging
import os
-import sqlite3
+import sqlite3 # still used for the pre-index-check on SQLite backend
import sys
from pathlib import Path
@@ -53,15 +53,15 @@ _index_ready = False
def _ensure_index() -> None:
- """Build FTS index on first use; skip if already present."""
+ """Build FTS index on first use; skip if already present (SQLite only)."""
global _index_ready
if _index_ready:
return
try:
- conn = sqlite3.connect(str(DB_PATH), timeout=30.0)
- count = conn.execute("SELECT COUNT(*) FROM log_fts").fetchone()[0]
- conn.close()
+ raw = sqlite3.connect(str(DB_PATH), timeout=30.0)
+ count = raw.execute("SELECT COUNT(*) FROM log_fts").fetchone()[0]
+ raw.close()
if count > 0:
_index_ready = True
logger.info("FTS index present (%d entries)", count)
diff --git a/app/rest.py b/app/rest.py
index 9efe9df..cc87254 100644
--- a/app/rest.py
+++ b/app/rest.py
@@ -35,7 +35,8 @@ from fastapi.responses import FileResponse, RedirectResponse, StreamingResponse
from fastapi.staticfiles import StaticFiles
from pydantic import BaseModel
-from app.glean.pipeline import ensure_schema, ensure_context_schema, ensure_incidents_schema, migrate_incidents_to_dedicated_db, glean_file as _glean_file, glean_ssh_source as _glean_ssh_source
+from app.db import close_pool, ensure_schema, ensure_context_schema, ensure_incidents_schema, migrate_incidents_to_dedicated_db
+from app.glean.pipeline import glean_file as _glean_file, glean_ssh_source as _glean_ssh_source
from app.glean.base import load_compiled_patterns, now_iso
from app.glean.tautulli import parse_webhook as _parse_tautulli
from app.glean.wazuh import is_wazuh_alert as _is_wazuh_alert, parse as _parse_wazuh
@@ -185,6 +186,7 @@ async def _lifespan(app: FastAPI):
await task
except asyncio.CancelledError:
pass
+ close_pool() # no-op if SQLite backend
app = FastAPI(title="Turnstone API", version="0.6.2", docs_url="/turnstone/docs", redoc_url=None, lifespan=_lifespan)
diff --git a/app/services/blocklist.py b/app/services/blocklist.py
index 998014a..ea984a3 100644
--- a/app/services/blocklist.py
+++ b/app/services/blocklist.py
@@ -4,10 +4,12 @@ from __future__ import annotations
import dataclasses
import json
import re
-import sqlite3
import uuid
from datetime import datetime, timezone
from pathlib import Path
+from typing import Any
+
+from app.db import get_conn, resolve_tenant_id
import yaml
@@ -91,26 +93,26 @@ def _now_iso() -> str:
return datetime.now(timezone.utc).isoformat()
-def _row_to_candidate(row: tuple) -> BlocklistCandidate:
+def _row_to_candidate(row: Any) -> BlocklistCandidate:
return BlocklistCandidate(
- id=row[0],
- domain_or_ip=row[1],
- source_device_ip=row[2],
- source_device_name=row[3],
- first_seen=row[4],
- last_seen=row[5],
- hit_count=row[6],
- status=row[7],
- pushed_at=row[8],
- log_evidence=json.loads(row[9] or "[]"),
- matched_rule=row[10],
- llm_score=row[11],
- llm_reason=row[12],
+ id=row["id"],
+ domain_or_ip=row["domain_or_ip"],
+ source_device_ip=row["source_device_ip"],
+ source_device_name=row["source_device_name"],
+ first_seen=row["first_seen"],
+ last_seen=row["last_seen"],
+ hit_count=row["hit_count"],
+ status=row["status"],
+ pushed_at=row["pushed_at"],
+ log_evidence=json.loads(row["log_evidence"] or "[]"),
+ matched_rule=row["matched_rule"],
+ llm_score=row["llm_score"],
+ llm_reason=row["llm_reason"],
)
def _upsert_candidate(
- conn: sqlite3.Connection,
+ conn: Any,
domain_or_ip: str,
source_device_ip: str | None,
source_device_name: str | None,
@@ -119,26 +121,29 @@ def _upsert_candidate(
now: str,
) -> bool:
"""Insert or update a candidate. Returns True if a new row was created."""
+ tid = resolve_tenant_id()
row = conn.execute(
"SELECT id, hit_count, log_evidence FROM blocklist_candidates "
- "WHERE domain_or_ip = ? AND source_device_ip IS ?",
- (domain_or_ip, source_device_ip),
+ "WHERE domain_or_ip = ? AND source_device_ip IS ? AND (tenant_id = ? OR tenant_id = '')",
+ (domain_or_ip, source_device_ip, tid),
).fetchone()
if row is None:
conn.execute(
"""INSERT INTO blocklist_candidates
- (id, domain_or_ip, source_device_ip, source_device_name,
+ (id, tenant_id, domain_or_ip, source_device_ip, source_device_name,
first_seen, last_seen, hit_count, status, pushed_at, log_evidence, matched_rule)
- VALUES (?, ?, ?, ?, ?, ?, 1, 'pending', NULL, ?, ?)""",
+ VALUES (?, ?, ?, ?, ?, ?, ?, 1, 'pending', NULL, ?, ?)""",
(
- str(uuid.uuid4()), domain_or_ip, source_device_ip, source_device_name,
+ str(uuid.uuid4()), tid, domain_or_ip, source_device_ip, source_device_name,
now, now, json.dumps([entry_id]), matched_rule,
),
)
return True
- existing_id, hit_count, existing_evidence = row
+ existing_id = row["id"]
+ hit_count = row["hit_count"]
+ existing_evidence = row["log_evidence"]
evidence = json.loads(existing_evidence or "[]")
if entry_id not in evidence:
evidence.append(entry_id)
@@ -172,14 +177,16 @@ def run_scan(
now = _now_iso()
count = 0
- conn = sqlite3.connect(str(db_path), timeout=30.0)
- try:
+ tid = resolve_tenant_id()
+ with get_conn(db_path) as conn:
rows = conn.execute(
- f"SELECT id, text FROM log_entries WHERE source_id IN ({placeholders})",
- router_source_ids,
+ f"SELECT id, text FROM log_entries WHERE source_id IN ({placeholders}) AND (tenant_id = ? OR tenant_id = '')", # noqa: S608
+ (*router_source_ids, tid),
).fetchall()
- for entry_id, text in rows:
+ for row in rows:
+ entry_id, text = row["id"], row["text"]
+ # rest of loop body follows unchanged
src_ip: str | None = None
dst: str | None = None
@@ -204,8 +211,6 @@ def run_scan(
count += 1
conn.commit()
- finally:
- conn.close()
return count
@@ -226,26 +231,27 @@ def list_candidates(
status: str | None = None,
device_ip: str | None = None,
) -> list[BlocklistCandidate]:
- conn = sqlite3.connect(str(db_path), timeout=30.0)
- try:
- query = f"{_CANDIDATE_SELECT} WHERE 1=1"
- params: list = []
- if status and status != "all":
- query += " AND status = ?"
- params.append(status)
- if device_ip:
- query += " AND source_device_ip = ?"
- params.append(device_ip)
- query += " ORDER BY last_seen DESC"
- rows = conn.execute(query, params).fetchall()
- finally:
- conn.close()
+ tid = resolve_tenant_id()
+ conditions = ["(tenant_id = ? OR tenant_id = '')"]
+ params: list = [tid]
+ if status and status != "all":
+ conditions.append("status = ?")
+ params.append(status)
+ if device_ip:
+ conditions.append("source_device_ip = ?")
+ params.append(device_ip)
+ where = " AND ".join(conditions)
+ with get_conn(db_path) as conn:
+ rows = conn.execute(
+ f"{_CANDIDATE_SELECT} WHERE {where} ORDER BY last_seen DESC", # noqa: S608
+ params,
+ ).fetchall()
return [_row_to_candidate(r) for r in rows]
-def _get_candidate(conn: sqlite3.Connection, candidate_id: str) -> BlocklistCandidate:
+def _get_candidate(conn: Any, candidate_id: str) -> BlocklistCandidate:
row = conn.execute(
- f"{_CANDIDATE_SELECT} WHERE id=?",
+ f"{_CANDIDATE_SELECT} WHERE id=?", # noqa: S608
(candidate_id,),
).fetchone()
if row is None:
@@ -255,43 +261,31 @@ def _get_candidate(conn: sqlite3.Connection, candidate_id: str) -> BlocklistCand
def get_candidate(db_path: Path, candidate_id: str) -> BlocklistCandidate:
"""Fetch a single candidate by ID. Raises KeyError if not found."""
- conn = sqlite3.connect(str(db_path), timeout=30.0)
- try:
+ with get_conn(db_path) as conn:
return _get_candidate(conn, candidate_id)
- finally:
- conn.close()
def update_candidate_status(db_path: Path, candidate_id: str, new_status: str) -> BlocklistCandidate:
if new_status not in _VALID_STATUSES:
raise ValueError(f"Invalid status {new_status!r}. Must be one of {_VALID_STATUSES}")
- conn = sqlite3.connect(str(db_path), timeout=30.0)
- try:
+ with get_conn(db_path) as conn:
conn.execute("UPDATE blocklist_candidates SET status=? WHERE id=?", (new_status, candidate_id))
conn.commit()
return _get_candidate(conn, candidate_id)
- finally:
- conn.close()
def mark_pushed(db_path: Path, candidate_id: str) -> BlocklistCandidate:
- conn = sqlite3.connect(str(db_path), timeout=30.0)
- try:
+ with get_conn(db_path) as conn:
conn.execute(
"UPDATE blocklist_candidates SET status='pushed', pushed_at=? WHERE id=?",
(_now_iso(), candidate_id),
)
conn.commit()
return _get_candidate(conn, candidate_id)
- finally:
- conn.close()
def mark_unblocked(db_path: Path, candidate_id: str) -> BlocklistCandidate:
- conn = sqlite3.connect(str(db_path), timeout=30.0)
- try:
+ with get_conn(db_path) as conn:
conn.execute("UPDATE blocklist_candidates SET status='unblocked' WHERE id=?", (candidate_id,))
conn.commit()
return _get_candidate(conn, candidate_id)
- finally:
- conn.close()
diff --git a/app/services/incidents.py b/app/services/incidents.py
index 1d71422..9094de5 100644
--- a/app/services/incidents.py
+++ b/app/services/incidents.py
@@ -3,10 +3,10 @@ from __future__ import annotations
import json
import re
-import sqlite3
import uuid
from pathlib import Path
+from app.db import get_conn, resolve_tenant_id
from app.glean.base import now_iso
from app.services.models import Incident, ReceivedBundle, SentBundle
from app.services.search import SearchResult, entries_in_window, search
@@ -26,7 +26,7 @@ def _redact_text(text: str) -> str:
return text
-def _row_to_incident(row: sqlite3.Row) -> Incident:
+def _row_to_incident(row) -> Incident:
return Incident(
id=row["id"],
label=row["label"],
@@ -39,7 +39,7 @@ def _row_to_incident(row: sqlite3.Row) -> Incident:
)
-def _row_to_bundle(row: sqlite3.Row) -> ReceivedBundle:
+def _row_to_bundle(row) -> ReceivedBundle:
return ReceivedBundle(
id=row["id"],
source_host=row["source_host"],
@@ -62,6 +62,7 @@ def create_incident(
notes: str = "",
severity: str = "medium",
) -> Incident:
+ tid = resolve_tenant_id()
incident = Incident(
id=str(uuid.uuid4()),
label=label,
@@ -72,47 +73,45 @@ def create_incident(
created_at=now_iso(),
severity=severity,
)
- conn = sqlite3.connect(str(db_path), timeout=30.0)
- conn.execute("PRAGMA journal_mode=WAL")
- conn.execute(
- "INSERT INTO incidents (id, label, issue_type, started_at, ended_at, notes, created_at, severity) "
- "VALUES (?, ?, ?, ?, ?, ?, ?, ?)",
- (incident.id, incident.label, incident.issue_type, incident.started_at,
- incident.ended_at, incident.notes, incident.created_at, incident.severity),
- )
- conn.commit()
- conn.close()
+ with get_conn(db_path) as conn:
+ conn.execute(
+ "INSERT INTO incidents (id, tenant_id, label, issue_type, started_at, ended_at, notes, created_at, severity) "
+ "VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)",
+ (incident.id, tid, incident.label, incident.issue_type, incident.started_at,
+ incident.ended_at, incident.notes, incident.created_at, incident.severity),
+ )
+ conn.commit()
return incident
def list_incidents(db_path: Path) -> list[Incident]:
- conn = sqlite3.connect(str(db_path), timeout=30.0)
- conn.execute("PRAGMA journal_mode=WAL")
- conn.row_factory = sqlite3.Row
- rows = conn.execute(
- "SELECT * FROM incidents ORDER BY created_at DESC"
- ).fetchall()
- conn.close()
+ tid = resolve_tenant_id()
+ with get_conn(db_path) as conn:
+ rows = conn.execute(
+ "SELECT * FROM incidents WHERE (tenant_id = ? OR tenant_id = '') ORDER BY created_at DESC",
+ (tid,),
+ ).fetchall()
return [_row_to_incident(r) for r in rows]
def get_incident(db_path: Path, incident_id: str) -> Incident | None:
- conn = sqlite3.connect(str(db_path), timeout=30.0)
- conn.execute("PRAGMA journal_mode=WAL")
- conn.row_factory = sqlite3.Row
- row = conn.execute(
- "SELECT * FROM incidents WHERE id = ?", (incident_id,)
- ).fetchone()
- conn.close()
+ tid = resolve_tenant_id()
+ with get_conn(db_path) as conn:
+ row = conn.execute(
+ "SELECT * FROM incidents WHERE id = ? AND (tenant_id = ? OR tenant_id = '')",
+ (incident_id, tid),
+ ).fetchone()
return _row_to_incident(row) if row else None
def delete_incident(db_path: Path, incident_id: str) -> bool:
- conn = sqlite3.connect(str(db_path), timeout=30.0)
- conn.execute("PRAGMA journal_mode=WAL")
- cur = conn.execute("DELETE FROM incidents WHERE id = ?", (incident_id,))
- conn.commit()
- conn.close()
+ tid = resolve_tenant_id()
+ with get_conn(db_path) as conn:
+ cur = conn.execute(
+ "DELETE FROM incidents WHERE id = ? AND (tenant_id = ? OR tenant_id = '')",
+ (incident_id, tid),
+ )
+ conn.commit()
return cur.rowcount > 0
@@ -191,6 +190,7 @@ def build_bundle(
def record_sent_bundle(db_path: Path, incident_id: str, bundle: dict, sanitized: bool) -> SentBundle:
"""Log an outgoing bundle export to the sent_bundles table."""
+ tid = resolve_tenant_id()
record = SentBundle(
id=str(uuid.uuid4()),
incident_id=incident_id,
@@ -199,28 +199,25 @@ def record_sent_bundle(db_path: Path, incident_id: str, bundle: dict, sanitized:
entry_count=len(bundle.get("log_entries", [])),
bundle_json=json.dumps(bundle),
)
- conn = sqlite3.connect(str(db_path), timeout=30.0)
- conn.execute("PRAGMA journal_mode=WAL")
- conn.execute(
- "INSERT INTO sent_bundles (id, incident_id, exported_at, sanitized, entry_count, bundle_json) "
- "VALUES (?, ?, ?, ?, ?, ?)",
- (record.id, record.incident_id, record.exported_at, int(record.sanitized),
- record.entry_count, record.bundle_json),
- )
- conn.commit()
- conn.close()
+ with get_conn(db_path) as conn:
+ conn.execute(
+ "INSERT INTO sent_bundles (id, tenant_id, incident_id, exported_at, sanitized, entry_count, bundle_json) "
+ "VALUES (?, ?, ?, ?, ?, ?, ?)",
+ (record.id, tid, record.incident_id, record.exported_at,
+ int(record.sanitized), record.entry_count, record.bundle_json),
+ )
+ conn.commit()
return record
def list_sent_bundles(db_path: Path) -> list[SentBundle]:
- conn = sqlite3.connect(str(db_path), timeout=30.0)
- conn.execute("PRAGMA journal_mode=WAL")
- conn.row_factory = sqlite3.Row
- rows = conn.execute(
- "SELECT id, incident_id, exported_at, sanitized, entry_count, bundle_json "
- "FROM sent_bundles ORDER BY exported_at DESC"
- ).fetchall()
- conn.close()
+ tid = resolve_tenant_id()
+ with get_conn(db_path) as conn:
+ rows = conn.execute(
+ "SELECT id, incident_id, exported_at, sanitized, entry_count, bundle_json "
+ "FROM sent_bundles WHERE (tenant_id = ? OR tenant_id = '') ORDER BY exported_at DESC",
+ (tid,),
+ ).fetchall()
return [
SentBundle(
id=r["id"],
@@ -236,6 +233,7 @@ def list_sent_bundles(db_path: Path) -> list[SentBundle]:
def store_bundle(db_path: Path, bundle: dict) -> ReceivedBundle:
"""Store an incoming bundle from a remote Turnstone instance."""
+ tid = resolve_tenant_id()
inc = bundle.get("incident", {})
record = ReceivedBundle(
id=str(uuid.uuid4()),
@@ -248,38 +246,34 @@ def store_bundle(db_path: Path, bundle: dict) -> ReceivedBundle:
entry_count=len(bundle.get("log_entries", [])),
bundle_json=json.dumps(bundle),
)
- conn = sqlite3.connect(str(db_path), timeout=30.0)
- conn.execute("PRAGMA journal_mode=WAL")
- conn.execute(
- "INSERT INTO received_bundles "
- "(id, source_host, issue_type, label, severity, started_at, bundled_at, entry_count, bundle_json) "
- "VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)",
- (record.id, record.source_host, record.issue_type, record.label,
- record.severity, record.started_at, record.bundled_at, record.entry_count, record.bundle_json),
- )
- conn.commit()
- conn.close()
+ with get_conn(db_path) as conn:
+ conn.execute(
+ "INSERT INTO received_bundles "
+ "(id, tenant_id, source_host, issue_type, label, severity, started_at, bundled_at, entry_count, bundle_json) "
+ "VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)",
+ (record.id, tid, record.source_host, record.issue_type, record.label,
+ record.severity, record.started_at, record.bundled_at, record.entry_count, record.bundle_json),
+ )
+ conn.commit()
return record
def list_bundles(db_path: Path) -> list[ReceivedBundle]:
- conn = sqlite3.connect(str(db_path), timeout=30.0)
- conn.execute("PRAGMA journal_mode=WAL")
- conn.row_factory = sqlite3.Row
- rows = conn.execute(
- "SELECT id, source_host, issue_type, label, severity, started_at, bundled_at, entry_count, bundle_json "
- "FROM received_bundles ORDER BY bundled_at DESC"
- ).fetchall()
- conn.close()
+ tid = resolve_tenant_id()
+ with get_conn(db_path) as conn:
+ rows = conn.execute(
+ "SELECT id, source_host, issue_type, label, severity, started_at, bundled_at, entry_count, bundle_json "
+ "FROM received_bundles WHERE (tenant_id = ? OR tenant_id = '') ORDER BY bundled_at DESC",
+ (tid,),
+ ).fetchall()
return [_row_to_bundle(r) for r in rows]
def get_bundle(db_path: Path, bundle_id: str) -> ReceivedBundle | None:
- conn = sqlite3.connect(str(db_path), timeout=30.0)
- conn.execute("PRAGMA journal_mode=WAL")
- conn.row_factory = sqlite3.Row
- row = conn.execute(
- "SELECT * FROM received_bundles WHERE id = ?", (bundle_id,)
- ).fetchone()
- conn.close()
+ tid = resolve_tenant_id()
+ with get_conn(db_path) as conn:
+ row = conn.execute(
+ "SELECT * FROM received_bundles WHERE id = ? AND (tenant_id = ? OR tenant_id = '')",
+ (bundle_id, tid),
+ ).fetchone()
return _row_to_bundle(row) if row else None
diff --git a/app/services/search.py b/app/services/search.py
index 56b2c0a..47a74e9 100644
--- a/app/services/search.py
+++ b/app/services/search.py
@@ -1,4 +1,8 @@
-"""FTS5-based log search with optional hybrid BM25 + vector re-ranking."""
+"""FTS-based log search with optional hybrid BM25 + vector re-ranking.
+
+SQLite backend: FTS5 virtual table with Porter stemmer.
+Postgres backend: tsvector column with GIN index + websearch_to_tsquery.
+"""
from __future__ import annotations
import json
@@ -6,8 +10,11 @@ import logging
import re
import sqlite3
from dataclasses import dataclass
+from datetime import datetime, timedelta, timezone
from pathlib import Path
+from app.db import BACKEND, Backend, frag, get_conn, resolve_tenant_id
+
logger = logging.getLogger(__name__)
@@ -28,22 +35,24 @@ class SearchResult:
def build_fts_index(db_path: Path) -> None:
"""Build (or rebuild) the FTS5 index from log_entries. Safe to re-run.
- Drops and recreates the table if the schema is stale (missing sequence column).
+ For Postgres, the tsvector column is maintained by a trigger — this is a no-op.
"""
- conn = sqlite3.connect(str(db_path), timeout=30.0)
- conn.execute("PRAGMA journal_mode=WAL")
+ if BACKEND == Backend.POSTGRES:
+ return
+
+ raw = sqlite3.connect(str(db_path), timeout=30.0)
+ raw.execute("PRAGMA journal_mode=WAL")
- # Check whether existing table has the sequence column; rebuild if not.
needs_rebuild = False
try:
- conn.execute("SELECT sequence FROM log_fts LIMIT 0")
+ raw.execute("SELECT sequence FROM log_fts LIMIT 0")
except sqlite3.OperationalError:
needs_rebuild = True
if needs_rebuild:
- conn.execute("DROP TABLE IF EXISTS log_fts")
+ raw.execute("DROP TABLE IF EXISTS log_fts")
- conn.executescript("""
+ raw.executescript("""
CREATE VIRTUAL TABLE IF NOT EXISTS log_fts USING fts5(
text,
entry_id UNINDEXED,
@@ -57,8 +66,7 @@ def build_fts_index(db_path: Path) -> None:
tokenize = 'porter ascii'
);
""")
- # Only insert rows not already indexed
- conn.execute("""
+ raw.execute("""
INSERT INTO log_fts(text, entry_id, source_id, sequence, severity,
timestamp_iso, matched_patterns,
repeat_count, out_of_order)
@@ -68,8 +76,8 @@ def build_fts_index(db_path: Path) -> None:
FROM log_entries e
WHERE e.id NOT IN (SELECT entry_id FROM log_fts WHERE entry_id IS NOT NULL)
""")
- conn.commit()
- conn.close()
+ raw.commit()
+ raw.close()
def _sanitize_fts_query(raw: str, or_mode: bool = False) -> str:
@@ -198,14 +206,44 @@ def _bm25_search(
include_repeats: bool = False,
or_mode: bool = False,
) -> list[SearchResult]:
- """Pure BM25 FTS5 search — internal helper used by both search() and _hybrid_search()."""
- conn = sqlite3.connect(str(db_path), timeout=30.0)
- conn.execute("PRAGMA journal_mode=WAL")
- conn.row_factory = sqlite3.Row
+ """FTS search — BM25 via FTS5 (SQLite) or tsvector (Postgres)."""
+ tid = resolve_tenant_id()
+ if BACKEND == Backend.POSTGRES:
+ return _pg_fts_search(
+ db_path, query, tid,
+ severity=severity, source_filter=source_filter,
+ pattern_filter=pattern_filter, since=since, until=until,
+ limit=limit, include_repeats=include_repeats,
+ )
+
+ return _sqlite_fts_search(
+ db_path, query, tid,
+ severity=severity, source_filter=source_filter,
+ pattern_filter=pattern_filter, since=since, until=until,
+ limit=limit, include_repeats=include_repeats, or_mode=or_mode,
+ )
+
+
+def _sqlite_fts_search(
+ db_path: Path,
+ query: str,
+ tid: str,
+ severity: str | None,
+ source_filter: str | None,
+ pattern_filter: str | None,
+ since: str | None,
+ until: str | None,
+ limit: int,
+ include_repeats: bool,
+ or_mode: bool,
+) -> list[SearchResult]:
fts_query = _sanitize_fts_query(query, or_mode=or_mode)
- conditions = ["log_fts MATCH ?"]
- params: list = [fts_query]
+ conditions = [
+ "log_fts MATCH ?",
+ "(e.tenant_id = ? OR e.tenant_id = '')",
+ ]
+ params: list = [fts_query, tid]
if severity:
conditions.append("severity = ?")
@@ -223,29 +261,33 @@ def _bm25_search(
conditions.append("timestamp_iso <= ?")
params.append(until)
if not include_repeats:
- conditions.append("repeat_count = 1")
+ conditions.append("f.repeat_count = 1")
where = " AND ".join(conditions)
params.append(limit)
+ raw = sqlite3.connect(str(db_path), timeout=30.0)
+ raw.row_factory = sqlite3.Row
try:
- rows = conn.execute(
+ rows = raw.execute(
f"""
- SELECT entry_id, source_id, sequence, timestamp_iso, severity,
- repeat_count, out_of_order, matched_patterns, text, rank
- FROM log_fts
+ SELECT f.entry_id, f.source_id, f.sequence, f.timestamp_iso, f.severity,
+ f.repeat_count, f.out_of_order, f.matched_patterns, f.text, f.rank
+ FROM log_fts f
+ JOIN log_entries e ON e.id = f.entry_id
WHERE {where}
- ORDER BY rank
+ ORDER BY f.rank
LIMIT ?
""",
params,
).fetchall()
- except sqlite3.OperationalError as e:
- logger.warning("FTS query failed (%s) — index may not be built yet", e)
- conn.close()
+ except sqlite3.OperationalError as exc:
+ logger.warning("FTS query failed (%s) — index may not be built yet", exc)
return []
+ finally:
+ raw.close()
- results = [
+ return [
SearchResult(
entry_id=r["entry_id"],
source_id=r["source_id"],
@@ -256,12 +298,83 @@ def _bm25_search(
out_of_order=bool(r["out_of_order"]),
matched_patterns=json.loads(r["matched_patterns"] or "[]"),
text=r["text"],
- rank=r["rank"],
+ rank=float(r["rank"]),
+ )
+ for r in rows
+ ]
+
+
+def _pg_fts_search(
+ db_path: Path,
+ query: str,
+ tid: str,
+ severity: str | None,
+ source_filter: str | None,
+ pattern_filter: str | None,
+ since: str | None,
+ until: str | None,
+ limit: int,
+ include_repeats: bool,
+) -> list[SearchResult]:
+ """Postgres FTS via tsvector column and websearch_to_tsquery."""
+ tsq = "websearch_to_tsquery('english', %s)"
+ conditions = [
+ f"text_tsv @@ {tsq}",
+ "(tenant_id = %s OR tenant_id = '')",
+ ]
+ params: list = [query, tid]
+
+ if severity:
+ conditions.append("severity = %s")
+ params.append(severity.upper())
+ if source_filter:
+ conditions.append("source_id LIKE %s")
+ params.append(f"%{source_filter}%")
+ if pattern_filter:
+ conditions.append("matched_patterns LIKE %s")
+ params.append(f'%"{pattern_filter}"%')
+ if since:
+ conditions.append("timestamp_iso >= %s")
+ params.append(since)
+ if until:
+ conditions.append("timestamp_iso <= %s")
+ params.append(until)
+ if not include_repeats:
+ conditions.append("repeat_count = 1")
+
+ where = " AND ".join(conditions)
+ # ts_rank needs the tsquery again — append it then the limit
+ params.extend([query, limit])
+
+ with get_conn(db_path) as conn:
+ rows = conn.execute(
+ f"""
+ SELECT id AS entry_id, source_id, sequence, timestamp_iso, severity,
+ repeat_count, out_of_order, matched_patterns, text,
+ ts_rank(text_tsv, {tsq}) AS rank
+ FROM log_entries
+ WHERE {where}
+ ORDER BY rank DESC
+ LIMIT %s
+ """,
+ params,
+ ).fetchall()
+
+ return [
+ SearchResult(
+ entry_id=r["entry_id"],
+ source_id=r["source_id"],
+ sequence=r["sequence"],
+ timestamp_iso=r["timestamp_iso"],
+ severity=r["severity"],
+ repeat_count=r["repeat_count"],
+ out_of_order=bool(r["out_of_order"]),
+ matched_patterns=json.loads(r["matched_patterns"] or "[]"),
+ text=r["text"],
+ rank=float(r["rank"]),
)
for r in rows
]
- conn.close()
- return results
def entries_in_window(
@@ -282,12 +395,12 @@ def entries_in_window(
(e.g. network-syslog) don't crowd out lower-volume but more interesting ones.
Errors/warnings are ranked first within each source partition.
"""
- conn = sqlite3.connect(str(db_path), timeout=30.0)
- conn.execute("PRAGMA journal_mode=WAL")
- conn.row_factory = sqlite3.Row
-
- conditions: list[str] = ["repeat_count = 1"]
- params: list = []
+ tid = resolve_tenant_id()
+ conditions: list[str] = [
+ "repeat_count = 1",
+ "(tenant_id = ? OR tenant_id = '')",
+ ]
+ params: list = [tid]
if since:
conditions.append("timestamp_iso >= ?")
@@ -305,8 +418,7 @@ def entries_in_window(
where = " AND ".join(conditions)
if per_source_cap is not None:
- # Use a window function to cap rows per source, errors/warnings first.
- query = f"""
+ sql = f"""
WITH ranked AS (
SELECT id as entry_id, source_id, sequence, timestamp_iso, severity,
repeat_count, out_of_order, matched_patterns, text, 0.0 as rank,
@@ -333,7 +445,7 @@ def entries_in_window(
"""
params.extend([per_source_cap, limit])
else:
- query = f"""
+ sql = f"""
SELECT id as entry_id, source_id, sequence, timestamp_iso, severity,
repeat_count, out_of_order, matched_patterns, text, 0.0 as rank
FROM log_entries
@@ -343,8 +455,8 @@ def entries_in_window(
"""
params.append(limit)
- rows = conn.execute(query, params).fetchall()
- conn.close()
+ with get_conn(db_path) as conn:
+ rows = conn.execute(sql, params).fetchall()
return [
SearchResult(
@@ -357,7 +469,7 @@ def entries_in_window(
out_of_order=bool(r["out_of_order"]),
matched_patterns=json.loads(r["matched_patterns"] or "[]"),
text=r["text"],
- rank=r["rank"],
+ rank=float(r["rank"]),
)
for r in rows
]
@@ -376,16 +488,14 @@ def recent_source_errors(
Bypasses FTS ranking so text content doesn't affect which errors surface.
Used by diagnose when FTS keyword search returns nothing for a known source.
"""
- conn = sqlite3.connect(str(db_path), timeout=30.0)
- conn.execute("PRAGMA journal_mode=WAL")
- conn.row_factory = sqlite3.Row
-
+ tid = resolve_tenant_id()
conditions = [
"source_id LIKE ?",
"severity = ?",
"repeat_count = 1",
+ "(tenant_id = ? OR tenant_id = '')",
]
- params: list = [f"%{source_filter}%", severity.upper()]
+ params: list = [f"%{source_filter}%", severity.upper(), tid]
if since:
conditions.append("timestamp_iso >= ?")
@@ -397,18 +507,18 @@ def recent_source_errors(
params.append(limit)
where = " AND ".join(conditions)
- rows = conn.execute(
- f"""
- SELECT id as entry_id, source_id, sequence, timestamp_iso, severity,
- repeat_count, out_of_order, matched_patterns, text, 0.0 as rank
- FROM log_entries
- WHERE {where}
- ORDER BY timestamp_iso DESC
- LIMIT ?
- """,
- params,
- ).fetchall()
- conn.close()
+ with get_conn(db_path) as conn:
+ rows = conn.execute(
+ f"""
+ SELECT id as entry_id, source_id, sequence, timestamp_iso, severity,
+ repeat_count, out_of_order, matched_patterns, text, 0.0 as rank
+ FROM log_entries
+ WHERE {where}
+ ORDER BY timestamp_iso DESC
+ LIMIT ?
+ """,
+ params,
+ ).fetchall()
return [
SearchResult(
@@ -421,7 +531,7 @@ def recent_source_errors(
out_of_order=bool(r["out_of_order"]),
matched_patterns=json.loads(r["matched_patterns"] or "[]"),
text=r["text"],
- rank=r["rank"],
+ rank=float(r["rank"]),
)
for r in rows
]
@@ -436,37 +546,34 @@ def list_sources(db_path: Path) -> list[dict]:
returned as-is. ``unit_count`` reports how many distinct sub-units were
merged into each row.
"""
- conn = sqlite3.connect(str(db_path), timeout=30.0)
- conn.execute("PRAGMA journal_mode=WAL")
- rows = conn.execute("""
- SELECT
- CASE
- WHEN INSTR(SUBSTR(source_id, INSTR(source_id, ':')+1), ':') > 0
- THEN SUBSTR(source_id, 1,
- INSTR(source_id, ':')
- + INSTR(SUBSTR(source_id, INSTR(source_id, ':')+1), ':')
- - 1)
- ELSE source_id
- END AS group_id,
- COUNT(DISTINCT source_id) AS unit_count,
- COUNT(*) AS entry_count,
- MIN(timestamp_iso) AS earliest,
- MAX(timestamp_iso) AS latest,
- SUM(CASE WHEN severity IN ('ERROR','CRITICAL','EMERGENCY','ALERT')
- THEN 1 ELSE 0 END) AS error_count
- FROM log_entries
- GROUP BY group_id
- ORDER BY entry_count DESC
- """).fetchall()
- conn.close()
+ tid = resolve_tenant_id()
+ group_expr = frag.source_group_expr("source_id")
+ with get_conn(db_path) as conn:
+ rows = conn.execute(
+ f"""
+ SELECT
+ {group_expr} AS group_id,
+ COUNT(DISTINCT source_id) AS unit_count,
+ COUNT(*) AS entry_count,
+ MIN(timestamp_iso) AS earliest,
+ MAX(timestamp_iso) AS latest,
+ SUM(CASE WHEN severity IN ('ERROR','CRITICAL','EMERGENCY','ALERT')
+ THEN 1 ELSE 0 END) AS error_count
+ FROM log_entries
+ WHERE (tenant_id = ? OR tenant_id = '')
+ GROUP BY group_id
+ ORDER BY entry_count DESC
+ """,
+ (tid,),
+ ).fetchall()
return [
{
- "source_id": r[0],
- "unit_count": r[1],
- "entry_count": r[2],
- "earliest": r[3],
- "latest": r[4],
- "error_count": r[5],
+ "source_id": r["group_id"],
+ "unit_count": r["unit_count"],
+ "entry_count": r["entry_count"],
+ "earliest": r["earliest"],
+ "latest": r["latest"],
+ "error_count": r["error_count"],
}
for r in rows
]
@@ -498,47 +605,65 @@ def stats_summary(db_path: Path, window_hours: int = 24, severity_overrides: lis
Queries plain log_entries (not FTS) so it works even before the index is built.
"""
rules = _compile_overrides(severity_overrides or [])
+ tid = resolve_tenant_id()
+ group_expr = frag.source_group_expr("source_id")
+ since_iso = (
+ datetime.now(timezone.utc) - timedelta(hours=window_hours)
+ ).strftime("%Y-%m-%dT%H:%M:%S")
- conn = sqlite3.connect(str(db_path), timeout=30.0)
- conn.execute("PRAGMA journal_mode=WAL")
- conn.row_factory = sqlite3.Row
+ with get_conn(db_path) as conn:
+ row = conn.execute(
+ """
+ SELECT
+ COUNT(*) AS total,
+ SUM(CASE WHEN severity = 'CRITICAL' THEN 1 ELSE 0 END) AS criticals,
+ SUM(CASE WHEN severity IN ('ERROR','CRITICAL','EMERGENCY','ALERT') THEN 1 ELSE 0 END) AS errors
+ FROM log_entries
+ WHERE timestamp_iso >= ?
+ AND repeat_count = 1
+ AND (tenant_id = ? OR tenant_id = '')
+ """,
+ (since_iso, tid),
+ ).fetchone()
+ total_24h = int(row["total"] or 0)
+ criticals_24h = int(row["criticals"] or 0)
+ errors_24h = int(row["errors"] or 0)
- since_expr = f"strftime('%Y-%m-%dT%H:%M:%S', 'now', '-{window_hours} hours')"
+ source_rows = conn.execute(
+ f"""
+ SELECT
+ {group_expr} AS group_id,
+ COUNT(*) AS entry_count,
+ SUM(CASE WHEN severity IN ('ERROR','CRITICAL','EMERGENCY','ALERT') THEN 1 ELSE 0 END) AS error_count,
+ MAX(timestamp_iso) AS latest
+ FROM log_entries
+ WHERE timestamp_iso >= ?
+ AND repeat_count = 1
+ AND (tenant_id = ? OR tenant_id = '')
+ GROUP BY group_id
+ ORDER BY error_count DESC, entry_count DESC
+ """,
+ (since_iso, tid),
+ ).fetchall()
- # Overall counts in window
- row = conn.execute(f"""
- SELECT
- COUNT(*) AS total,
- SUM(CASE WHEN severity = 'CRITICAL' THEN 1 ELSE 0 END) AS criticals,
- SUM(CASE WHEN severity IN ('ERROR','CRITICAL','EMERGENCY','ALERT') THEN 1 ELSE 0 END) AS errors
- FROM log_entries
- WHERE timestamp_iso >= {since_expr}
- AND repeat_count = 1
- """).fetchone()
- total_24h = int(row["total"] or 0)
- criticals_24h = int(row["criticals"] or 0)
- errors_24h = int(row["errors"] or 0)
+ crit_rows = conn.execute(
+ """
+ SELECT id as entry_id, source_id, timestamp_iso, severity, text
+ FROM log_entries
+ WHERE severity = 'CRITICAL'
+ AND repeat_count = 1
+ AND (tenant_id = ? OR tenant_id = '')
+ ORDER BY timestamp_iso DESC
+ LIMIT 25
+ """,
+ (tid,),
+ ).fetchall()
+
+ last_row = conn.execute(
+ "SELECT MAX(ingest_time) AS t FROM log_entries WHERE (tenant_id = ? OR tenant_id = '')",
+ (tid,),
+ ).fetchone()
- # Per-source breakdown — grouped by prefix:host stem (same logic as list_sources).
- source_rows = conn.execute(f"""
- SELECT
- CASE
- WHEN INSTR(SUBSTR(source_id, INSTR(source_id, ':')+1), ':') > 0
- THEN SUBSTR(source_id, 1,
- INSTR(source_id, ':')
- + INSTR(SUBSTR(source_id, INSTR(source_id, ':')+1), ':')
- - 1)
- ELSE source_id
- END AS group_id,
- COUNT(*) AS entry_count,
- SUM(CASE WHEN severity IN ('ERROR','CRITICAL','EMERGENCY','ALERT') THEN 1 ELSE 0 END) AS error_count,
- MAX(timestamp_iso) AS latest
- FROM log_entries
- WHERE timestamp_iso >= {since_expr}
- AND repeat_count = 1
- GROUP BY group_id
- ORDER BY error_count DESC, entry_count DESC
- """).fetchall()
source_health = [
{
"source_id": r["group_id"],
@@ -549,16 +674,6 @@ def stats_summary(db_path: Path, window_hours: int = 24, severity_overrides: lis
for r in source_rows
]
- # Fetch candidate criticals (fetch more so filtering doesn't leave us with too few)
- crit_rows = conn.execute("""
- SELECT id as entry_id, source_id, timestamp_iso, severity, text
- FROM log_entries
- WHERE severity = 'CRITICAL' AND repeat_count = 1
- ORDER BY timestamp_iso DESC
- LIMIT 25
- """).fetchall()
-
- # Apply overrides: skip entries whose effective severity is no longer CRITICAL
suppressed = 0
recent_criticals = []
for r in crit_rows:
@@ -576,11 +691,8 @@ def stats_summary(db_path: Path, window_hours: int = 24, severity_overrides: lis
else:
suppressed += 1
- last_row = conn.execute("SELECT MAX(ingest_time) AS t FROM log_entries").fetchone()
last_gleaned: str | None = last_row["t"] if last_row else None
- conn.close()
-
return {
"window_hours": window_hours,
"total_24h": total_24h,
diff --git a/app/tasks/glean_scheduler.py b/app/tasks/glean_scheduler.py
index 02c6567..ba4e501 100644
--- a/app/tasks/glean_scheduler.py
+++ b/app/tasks/glean_scheduler.py
@@ -11,7 +11,7 @@ from __future__ import annotations
import asyncio
import json
import logging
-import sqlite3
+from app.db import get_conn, resolve_tenant_id
from dataclasses import dataclass, field
from datetime import datetime, timedelta, timezone
from pathlib import Path
@@ -49,9 +49,8 @@ def get_state() -> IngestState:
def _query_matched_since(db_path: Path, since: str | None) -> list[dict]:
"""Return entries with non-empty matched_patterns, optionally filtered by ingest_time."""
- conn = sqlite3.connect(str(db_path), timeout=30.0)
- conn.row_factory = sqlite3.Row
- try:
+ tid = resolve_tenant_id()
+ with get_conn(db_path) as conn:
if since:
rows = conn.execute(
"""
@@ -59,11 +58,13 @@ def _query_matched_since(db_path: Path, since: str | None) -> list[dict]:
ingest_time, severity, repeat_count, out_of_order,
matched_patterns, text
FROM log_entries
- WHERE matched_patterns != '[]' AND ingest_time > ?
+ WHERE matched_patterns != '[]'
+ AND ingest_time > ?
+ AND (tenant_id = ? OR tenant_id = '')
ORDER BY ingest_time
LIMIT 5000
""",
- (since,),
+ (since, tid),
).fetchall()
else:
rows = conn.execute(
@@ -73,13 +74,13 @@ def _query_matched_since(db_path: Path, since: str | None) -> list[dict]:
matched_patterns, text
FROM log_entries
WHERE matched_patterns != '[]'
+ AND (tenant_id = ? OR tenant_id = '')
ORDER BY ingest_time DESC
LIMIT 5000
""",
+ (tid,),
).fetchall()
- return [dict(r) for r in rows]
- finally:
- conn.close()
+ return [dict(r) for r in rows]
async def submit_matched(
diff --git a/app/watch/watcher.py b/app/watch/watcher.py
index 1108087..dda8ad2 100644
--- a/app/watch/watcher.py
+++ b/app/watch/watcher.py
@@ -8,7 +8,6 @@ from __future__ import annotations
import json
import logging
-import sqlite3
import subprocess
import threading
from dataclasses import dataclass, field
@@ -21,9 +20,10 @@ import yaml
from app.glean import journald as journald_parser, syslog as syslog_parser
from app.glean import plaintext as plaintext_parser, servarr as servarr_parser, plex as plex_parser
from app.glean import qbittorrent as qbit_parser, caddy as caddy_parser
-from app.glean.pipeline import _detect_format
+from app.db import get_conn
+from app.db.schema import ensure_schema
+from app.glean.pipeline import _detect_format, _write_batch
from app.glean.base import _compile, load_patterns, now_iso
-from app.glean.pipeline import _write_batch, _SCHEMA
from app.services.search import build_fts_index
from app.services.models import RetrievedEntry
@@ -111,28 +111,24 @@ class WatchSource:
patterns = load_patterns(self.pattern_file)
compiled = _compile(patterns)
- conn = sqlite3.connect(str(self.db_path), timeout=30.0)
- conn.execute("PRAGMA journal_mode=WAL")
- conn.executescript(_SCHEMA)
- conn.commit()
+ ensure_schema(self.db_path)
- try:
- cmd = self._build_command()
- if not cmd:
- return
- self._proc = subprocess.Popen(
- cmd,
- stdout=subprocess.PIPE,
- stderr=subprocess.PIPE,
- text=True,
- bufsize=1,
- )
- self._drain(conn, compiled)
- except Exception as exc:
- self._error = str(exc)
- logger.error("Watch source %r crashed: %s", self.config.source_id, exc)
- finally:
- conn.close()
+ with get_conn(self.db_path) as conn:
+ try:
+ cmd = self._build_command()
+ if not cmd:
+ return
+ self._proc = subprocess.Popen(
+ cmd,
+ stdout=subprocess.PIPE,
+ stderr=subprocess.PIPE,
+ text=True,
+ bufsize=1,
+ )
+ self._drain(conn, compiled)
+ except Exception as exc:
+ self._error = str(exc)
+ logger.error("Watch source %r crashed: %s", self.config.source_id, exc)
def _build_command(self) -> list[str] | None:
t = self.config.source_type
@@ -193,7 +189,7 @@ class WatchSource:
return []
- def _drain(self, conn: sqlite3.Connection, compiled) -> None:
+ def _drain(self, conn, compiled) -> None:
"""Read lines from the subprocess and flush to DB periodically."""
assert self._proc is not None
buffer: list[str] = []
@@ -229,7 +225,7 @@ class WatchSource:
if buffer:
self._flush(conn, buffer, compiled, flush_count)
- def _flush(self, conn: sqlite3.Connection, lines: list[str], compiled, flush_count: int) -> int:
+ def _flush(self, conn, lines: list[str], compiled, flush_count: int) -> int:
ingest_time = now_iso()
try:
entries = self._parse_lines(lines, ingest_time, compiled)
diff --git a/docker-compose.yml b/docker-compose.yml
new file mode 100644
index 0000000..8c9bf29
--- /dev/null
+++ b/docker-compose.yml
@@ -0,0 +1,50 @@
+version: "3.9"
+
+# Turnstone with external Postgres DB.
+# Data lives in the named volume `turnstone_pgdata` — survives image rebuilds.
+# To adopt an EXISTING Postgres install, set DATABASE_URL to point at it and
+# remove the `db` service and `depends_on` blocks.
+#
+# Quick start:
+# docker compose up -d
+# # Then open http://localhost:8520
+
+services:
+ db:
+ image: postgres:16-alpine
+ restart: unless-stopped
+ environment:
+ POSTGRES_DB: turnstone
+ POSTGRES_USER: turnstone
+ POSTGRES_PASSWORD: ${POSTGRES_PASSWORD:-turnstone_dev}
+ volumes:
+ - turnstone_pgdata:/var/lib/postgresql/data
+ healthcheck:
+ test: ["CMD-SHELL", "pg_isready -U turnstone -d turnstone"]
+ interval: 5s
+ timeout: 5s
+ retries: 5
+
+ turnstone:
+ build: .
+ restart: unless-stopped
+ ports:
+ - "${TURNSTONE_PORT:-8520}:8520"
+ depends_on:
+ db:
+ condition: service_healthy
+ environment:
+ # Backend selection — comment out DATABASE_URL to fall back to SQLite
+ DATABASE_URL: postgresql://turnstone:${POSTGRES_PASSWORD:-turnstone_dev}@db:5432/turnstone
+ TURNSTONE_TENANT_ID: ${TURNSTONE_TENANT_ID:-}
+ TURNSTONE_API_KEY: ${TURNSTONE_API_KEY:-}
+ TURNSTONE_GLEAN_INTERVAL: ${TURNSTONE_GLEAN_INTERVAL:-900}
+ TURNSTONE_SOURCE_HOST: ${TURNSTONE_SOURCE_HOST:-}
+ TURNSTONE_SUBMIT_ENDPOINT: ${TURNSTONE_SUBMIT_ENDPOINT:-}
+ volumes:
+ - ./patterns:/app/patterns:ro
+ - ./data:/app/data # optional: persists SQLite files if DATABASE_URL unset
+
+volumes:
+ turnstone_pgdata:
+ name: turnstone_pgdata
diff --git a/requirements.txt b/requirements.txt
index f91b900..21b3c6c 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -1,5 +1,7 @@
fastapi>=0.110.0
uvicorn[standard]>=0.27.0
+# Postgres backend — optional; SQLite is used when DATABASE_URL is unset
+psycopg[binary,pool]>=3.1.0
pydantic>=2.0.0
pyyaml>=6.0
aiofiles>=23.0.0
diff --git a/scripts/migrate_sqlite_to_postgres.py b/scripts/migrate_sqlite_to_postgres.py
new file mode 100644
index 0000000..4402353
--- /dev/null
+++ b/scripts/migrate_sqlite_to_postgres.py
@@ -0,0 +1,204 @@
+#!/usr/bin/env python3
+"""One-shot migration: copy data from existing SQLite DBs into Postgres.
+
+Usage:
+ DATABASE_URL=postgresql://... python scripts/migrate_sqlite_to_postgres.py \
+ --main-db data/turnstone.db \
+ --context-db data/turnstone-context.db \
+ --incidents-db data/turnstone-incidents.db \
+ [--tenant-id heimdall]
+
+The script is idempotent: rows already present in Postgres (same id) are skipped.
+It must be run ONCE per node after deploying the shared Postgres backend.
+
+Prerequisites:
+ pip install 'psycopg[binary,pool]'
+ Set DATABASE_URL to the target Postgres connection string.
+"""
+from __future__ import annotations
+
+import argparse
+import os
+import sqlite3
+import sys
+from pathlib import Path
+
+# Allow running from the project root without installing the package
+sys.path.insert(0, str(Path(__file__).parent.parent))
+
+
+def _pg_connect():
+ import psycopg # type: ignore[import]
+ url = os.environ.get("DATABASE_URL")
+ if not url:
+ print("ERROR: DATABASE_URL not set", file=sys.stderr)
+ sys.exit(1)
+ return psycopg.connect(url, autocommit=False)
+
+
+def _ensure_schema_pg() -> None:
+ from app.db.schema import ensure_schema, ensure_context_schema, ensure_incidents_schema
+ from pathlib import Path
+ ensure_schema(Path("/dev/null")) # db_path ignored for Postgres
+ ensure_context_schema(Path("/dev/null"))
+ ensure_incidents_schema(Path("/dev/null"))
+ print("Postgres schema verified")
+
+
+def _migrate_table(
+ src_conn: sqlite3.Connection,
+ dst_conn,
+ table: str,
+ tenant_id: str,
+ columns: list[str],
+ conflict_cols: list[str],
+) -> int:
+ """Copy rows from SQLite table to Postgres. Returns rows inserted."""
+ # Check if source table exists
+ try:
+ rows = src_conn.execute(f"SELECT * FROM {table} LIMIT 0").fetchall() # noqa: S608
+ except sqlite3.OperationalError:
+ print(f" {table}: not found in SQLite — skipping")
+ return 0
+
+ # Fetch all rows
+ src_conn.row_factory = sqlite3.Row
+ rows = src_conn.execute(f"SELECT * FROM {table}").fetchall() # noqa: S608
+ if not rows:
+ print(f" {table}: empty — skipping")
+ return 0
+
+ # Build INSERT ... ON CONFLICT DO NOTHING
+ col_list = ", ".join(columns)
+ placeholders = ", ".join("%s" for _ in columns)
+ conflict = ", ".join(conflict_cols)
+ sql = (
+ f"INSERT INTO {table} ({col_list}) VALUES ({placeholders}) " # noqa: S608
+ f"ON CONFLICT ({conflict}) DO NOTHING"
+ )
+
+ inserted = 0
+ with dst_conn.cursor() as cur:
+ for row in rows:
+ # Build values: inject tenant_id if not present in source row
+ vals = []
+ for col in columns:
+ if col == "tenant_id":
+ try:
+ val = row["tenant_id"] or tenant_id
+ except (IndexError, KeyError):
+ val = tenant_id
+ else:
+ try:
+ vals.append(row[col])
+ except (IndexError, KeyError):
+ vals.append(None)
+ continue
+ vals.append(val)
+ cur.execute(sql, vals)
+ inserted += cur.rowcount
+
+ dst_conn.commit()
+ print(f" {table}: {inserted}/{len(rows)} rows inserted ({len(rows) - inserted} skipped)")
+ return inserted
+
+
+def main() -> None:
+ parser = argparse.ArgumentParser(description="Migrate Turnstone SQLite → Postgres")
+ parser.add_argument("--main-db", default="data/turnstone.db")
+ parser.add_argument("--context-db", default="data/turnstone-context.db")
+ parser.add_argument("--incidents-db", default="data/turnstone-incidents.db")
+ parser.add_argument("--tenant-id", default=None, help="Override tenant ID (default: socket.gethostname())")
+ args = parser.parse_args()
+
+ if args.tenant_id:
+ os.environ["TURNSTONE_TENANT_ID"] = args.tenant_id
+
+ import socket
+ tenant_id = os.environ.get("TURNSTONE_TENANT_ID") or socket.gethostname()
+ print(f"Migrating as tenant_id={tenant_id!r}")
+
+ # Ensure Postgres schema exists first
+ os.environ.setdefault("DATABASE_URL", "") # schema functions check this
+ _ensure_schema_pg()
+
+ pg = _pg_connect()
+ total = 0
+
+ # ── Main DB ───────────────────────────────────────────────────────────────
+ main_path = Path(args.main_db)
+ if main_path.exists():
+ print(f"\nMigrating main DB: {main_path}")
+ src = sqlite3.connect(str(main_path))
+ src.row_factory = sqlite3.Row
+
+ total += _migrate_table(src, pg, "log_entries", tenant_id,
+ columns=["tenant_id", "id", "source_id", "sequence", "timestamp_raw",
+ "timestamp_iso", "ingest_time", "severity", "repeat_count",
+ "out_of_order", "matched_patterns", "text"],
+ conflict_cols=["tenant_id", "id"])
+
+ total += _migrate_table(src, pg, "glean_fingerprints", tenant_id,
+ columns=["tenant_id", "path", "mtime", "size", "gleaned_at"],
+ conflict_cols=["tenant_id", "path"])
+
+ total += _migrate_table(src, pg, "blocklist_candidates", tenant_id,
+ columns=["id", "tenant_id", "domain_or_ip", "source_device_ip", "source_device_name",
+ "first_seen", "last_seen", "hit_count", "status", "pushed_at",
+ "log_evidence", "matched_rule", "llm_score", "llm_reason"],
+ conflict_cols=["id"])
+ src.close()
+ else:
+ print(f"Main DB not found at {main_path} — skipping")
+
+ # ── Context DB ────────────────────────────────────────────────────────────
+ ctx_path = Path(args.context_db)
+ if ctx_path.exists():
+ print(f"\nMigrating context DB: {ctx_path}")
+ src = sqlite3.connect(str(ctx_path))
+
+ total += _migrate_table(src, pg, "context_facts", tenant_id,
+ columns=["id", "tenant_id", "category", "key", "value", "source", "created_at"],
+ conflict_cols=["id"])
+
+ total += _migrate_table(src, pg, "context_documents", tenant_id,
+ columns=["id", "tenant_id", "filename", "doc_type", "full_text", "file_size", "uploaded_at"],
+ conflict_cols=["id"])
+
+ total += _migrate_table(src, pg, "context_chunks", tenant_id,
+ columns=["id", "tenant_id", "document_id", "chunk_index", "text"],
+ conflict_cols=["id"])
+ src.close()
+ else:
+ print(f"Context DB not found at {ctx_path} — skipping")
+
+ # ── Incidents DB ──────────────────────────────────────────────────────────
+ inc_path = Path(args.incidents_db)
+ if inc_path.exists():
+ print(f"\nMigrating incidents DB: {inc_path}")
+ src = sqlite3.connect(str(inc_path))
+
+ total += _migrate_table(src, pg, "incidents", tenant_id,
+ columns=["id", "tenant_id", "label", "issue_type", "started_at", "ended_at",
+ "notes", "created_at", "severity"],
+ conflict_cols=["id"])
+
+ total += _migrate_table(src, pg, "received_bundles", tenant_id,
+ columns=["id", "tenant_id", "source_host", "issue_type", "label", "severity",
+ "started_at", "bundled_at", "entry_count", "bundle_json"],
+ conflict_cols=["id"])
+
+ total += _migrate_table(src, pg, "sent_bundles", tenant_id,
+ columns=["id", "tenant_id", "incident_id", "exported_at", "sanitized",
+ "entry_count", "bundle_json"],
+ conflict_cols=["id"])
+ src.close()
+ else:
+ print(f"Incidents DB not found at {inc_path} — skipping")
+
+ pg.close()
+ print(f"\nDone. Total rows inserted: {total}")
+
+
+if __name__ == "__main__":
+ main()
diff --git a/tests/context/test_diagnose_context.py b/tests/context/test_diagnose_context.py
index f34da5f..1a8a6e2 100644
--- a/tests/context/test_diagnose_context.py
+++ b/tests/context/test_diagnose_context.py
@@ -4,6 +4,7 @@ import sqlite3
from pathlib import Path
from unittest.mock import patch
import pytest
+from app.db.schema import ensure_schema, ensure_context_schema
from app.services.llm import summarize
from app.services.search import SearchResult
@@ -64,36 +65,14 @@ def test_summarize_without_context_block_unchanged():
@pytest.fixture
def db_with_facts(tmp_path):
db_path = tmp_path / "t.db"
+ ensure_schema(db_path)
+ ensure_context_schema(db_path)
conn = sqlite3.connect(str(db_path))
- conn.executescript("""
- CREATE TABLE log_entries (
- id TEXT PRIMARY KEY, source_id TEXT NOT NULL, sequence INTEGER NOT NULL,
- timestamp_raw TEXT, timestamp_iso TEXT, ingest_time TEXT NOT NULL,
- severity TEXT, repeat_count INTEGER DEFAULT 1, out_of_order INTEGER DEFAULT 0,
- matched_patterns TEXT DEFAULT '[]', text TEXT NOT NULL
- );
- CREATE VIRTUAL TABLE IF NOT EXISTS log_fts USING fts5(
- text, entry_id UNINDEXED, source_id UNINDEXED, sequence UNINDEXED,
- severity UNINDEXED, timestamp_iso UNINDEXED, matched_patterns UNINDEXED,
- repeat_count UNINDEXED, out_of_order UNINDEXED, tokenize='porter ascii'
- );
- CREATE TABLE context_facts (
- id TEXT PRIMARY KEY, category TEXT NOT NULL, key TEXT NOT NULL,
- value TEXT NOT NULL, source TEXT, created_at TEXT NOT NULL
- );
- CREATE TABLE context_documents (
- id TEXT PRIMARY KEY, filename TEXT NOT NULL, doc_type TEXT NOT NULL,
- full_text TEXT NOT NULL, file_size INTEGER, uploaded_at TEXT NOT NULL
- );
- CREATE TABLE context_chunks (
- id TEXT PRIMARY KEY, document_id TEXT NOT NULL
- REFERENCES context_documents(id) ON DELETE CASCADE,
- chunk_index INTEGER NOT NULL, text TEXT NOT NULL, embedding BLOB
- );
- INSERT INTO context_facts VALUES (
- 'f1','service','plex','port:32400','wizard','2026-05-13T00:00:00+00:00'
- );
- """)
+ conn.execute(
+ "INSERT INTO context_facts(id, tenant_id, category, key, value, source, created_at) "
+ "VALUES (?,?,?,?,?,?,?)",
+ ("f1", "", "service", "plex", "port:32400", "wizard", "2026-05-13T00:00:00+00:00"),
+ )
conn.commit()
conn.close()
return db_path
diff --git a/tests/context/test_doc_upload.py b/tests/context/test_doc_upload.py
index 162f6f5..12e1fa0 100644
--- a/tests/context/test_doc_upload.py
+++ b/tests/context/test_doc_upload.py
@@ -1,8 +1,8 @@
"""End-to-end upload pipeline: file bytes → DB rows."""
-import sqlite3
import pytest
from pathlib import Path
+from app.db.schema import ensure_context_schema
from app.glean.doc_upload import glean_upload
from app.context.store import list_facts, list_documents
from app.context.chunker import UnsupportedDocType
@@ -11,24 +11,7 @@ from app.context.chunker import UnsupportedDocType
@pytest.fixture
def db(tmp_path):
db_path = tmp_path / "t.db"
- conn = sqlite3.connect(str(db_path))
- conn.executescript("""
- CREATE TABLE context_facts (
- id TEXT PRIMARY KEY, category TEXT NOT NULL, key TEXT NOT NULL,
- value TEXT NOT NULL, source TEXT, created_at TEXT NOT NULL
- );
- CREATE TABLE context_documents (
- id TEXT PRIMARY KEY, filename TEXT NOT NULL, doc_type TEXT NOT NULL,
- full_text TEXT NOT NULL, file_size INTEGER, uploaded_at TEXT NOT NULL
- );
- CREATE TABLE context_chunks (
- id TEXT PRIMARY KEY, document_id TEXT NOT NULL
- REFERENCES context_documents(id) ON DELETE CASCADE,
- chunk_index INTEGER NOT NULL, text TEXT NOT NULL, embedding BLOB
- );
- """)
- conn.commit()
- conn.close()
+ ensure_context_schema(db_path)
return db_path
diff --git a/tests/context/test_schema.py b/tests/context/test_schema.py
index ea71812..4943b79 100644
--- a/tests/context/test_schema.py
+++ b/tests/context/test_schema.py
@@ -1,13 +1,13 @@
-"""Verify the three new context tables are created by ensure_schema."""
+"""Verify the three context tables are created by ensure_context_schema."""
import sqlite3
from pathlib import Path
import pytest
-from app.glean.pipeline import ensure_schema
+from app.db.schema import ensure_context_schema
def test_context_tables_created(tmp_path):
db = tmp_path / "t.db"
- ensure_schema(db)
+ ensure_context_schema(db)
conn = sqlite3.connect(str(db))
tables = {r[0] for r in conn.execute(
"SELECT name FROM sqlite_master WHERE type='table'"
@@ -20,5 +20,5 @@ def test_context_tables_created(tmp_path):
def test_context_schema_idempotent(tmp_path):
db = tmp_path / "t.db"
- ensure_schema(db)
- ensure_schema(db) # second call must not raise
+ ensure_context_schema(db)
+ ensure_context_schema(db) # second call must not raise
diff --git a/tests/context/test_store.py b/tests/context/test_store.py
index 8c6edea..7197579 100644
--- a/tests/context/test_store.py
+++ b/tests/context/test_store.py
@@ -2,6 +2,7 @@
import sqlite3
import pytest
from pathlib import Path
+from app.db.schema import ensure_context_schema
from app.context.store import (
add_fact, list_facts, delete_fact,
add_document, list_documents, delete_document,
@@ -12,24 +13,7 @@ from app.context.store import (
@pytest.fixture
def db(tmp_path):
db_path = tmp_path / "t.db"
- conn = sqlite3.connect(str(db_path))
- conn.executescript("""
- CREATE TABLE context_facts (
- id TEXT PRIMARY KEY, category TEXT NOT NULL, key TEXT NOT NULL,
- value TEXT NOT NULL, source TEXT, created_at TEXT NOT NULL
- );
- CREATE TABLE context_documents (
- id TEXT PRIMARY KEY, filename TEXT NOT NULL, doc_type TEXT NOT NULL,
- full_text TEXT NOT NULL, file_size INTEGER, uploaded_at TEXT NOT NULL
- );
- CREATE TABLE context_chunks (
- id TEXT PRIMARY KEY, document_id TEXT NOT NULL
- REFERENCES context_documents(id) ON DELETE CASCADE,
- chunk_index INTEGER NOT NULL, text TEXT NOT NULL, embedding BLOB
- );
- """)
- conn.commit()
- conn.close()
+ ensure_context_schema(db_path)
return db_path
diff --git a/tests/context/test_wizard.py b/tests/context/test_wizard.py
index e10682e..8d76f81 100644
--- a/tests/context/test_wizard.py
+++ b/tests/context/test_wizard.py
@@ -2,21 +2,14 @@
import sqlite3
import pytest
from pathlib import Path
+from app.db.schema import ensure_context_schema
from app.context.wizard import get_schema, advance_step, is_complete, apply_session, TOTAL_STEPS
@pytest.fixture
def db(tmp_path):
db_path = tmp_path / "t.db"
- conn = sqlite3.connect(str(db_path))
- conn.executescript("""
- CREATE TABLE context_facts (
- id TEXT PRIMARY KEY, category TEXT NOT NULL, key TEXT NOT NULL,
- value TEXT NOT NULL, source TEXT, created_at TEXT NOT NULL
- );
- """)
- conn.commit()
- conn.close()
+ ensure_context_schema(db_path)
return db_path
diff --git a/tests/test_glean_fingerprint.py b/tests/test_glean_fingerprint.py
index 96aca23..827838b 100644
--- a/tests/test_glean_fingerprint.py
+++ b/tests/test_glean_fingerprint.py
@@ -51,12 +51,14 @@ class TestFingerprintHelpers:
def test_fp_unchanged_returns_false_when_no_record(self, db_path: Path, log_file: Path) -> None:
conn = sqlite3.connect(str(db_path))
+ conn.row_factory = sqlite3.Row
mtime, size = _fingerprint(log_file)
assert _fp_unchanged(conn, log_file, mtime, size) is False
conn.close()
def test_fp_unchanged_returns_true_after_save(self, db_path: Path, log_file: Path) -> None:
conn = sqlite3.connect(str(db_path))
+ conn.row_factory = sqlite3.Row
mtime, size = _fingerprint(log_file)
_save_fingerprint(conn, log_file, mtime, size, now_iso())
conn.commit()
@@ -65,6 +67,7 @@ class TestFingerprintHelpers:
def test_fp_unchanged_returns_false_on_size_change(self, db_path: Path, log_file: Path) -> None:
conn = sqlite3.connect(str(db_path))
+ conn.row_factory = sqlite3.Row
mtime, size = _fingerprint(log_file)
_save_fingerprint(conn, log_file, mtime, size, now_iso())
conn.commit()
@@ -74,6 +77,7 @@ class TestFingerprintHelpers:
def test_fp_unchanged_returns_false_on_mtime_change(self, db_path: Path, log_file: Path) -> None:
conn = sqlite3.connect(str(db_path))
+ conn.row_factory = sqlite3.Row
mtime, size = _fingerprint(log_file)
_save_fingerprint(conn, log_file, mtime, size, now_iso())
conn.commit()
diff --git a/tests/test_hybrid_search.py b/tests/test_hybrid_search.py
index 1e3101e..631c5fb 100644
--- a/tests/test_hybrid_search.py
+++ b/tests/test_hybrid_search.py
@@ -33,12 +33,11 @@ def db(tmp_path: Path) -> Path:
("database connection refused backend gone away", "ERROR"),
("mDNS avahi heartbeat ok", "INFO"),
]):
- # Columns: id, source_id, sequence, timestamp_raw, timestamp_iso,
- # ingest_time, severity, repeat_count, out_of_order,
- # matched_patterns, text
conn.execute(
- "INSERT INTO log_entries VALUES (?,?,?,?,?,?,?,?,?,?,?)",
- (str(uuid.uuid4()), "src", i, None, None, "2026-01-01T00:00:00", sev, 1, 0, "[]", text),
+ "INSERT INTO log_entries(id, tenant_id, source_id, sequence, timestamp_raw, "
+ "timestamp_iso, ingest_time, severity, repeat_count, out_of_order, "
+ "matched_patterns, text) VALUES (?,?,?,?,?,?,?,?,?,?,?,?)",
+ (str(uuid.uuid4()), "", "src", i, None, None, "2026-01-01T00:00:00", sev, 1, 0, "[]", text),
)
conn.commit()
conn.close()
From 0693e1fd5459de5f0f7d169e6ac5b7f04e42e927 Mon Sep 17 00:00:00 2001
From: pyr0ball
Date: Tue, 9 Jun 2026 11:15:13 -0700
Subject: [PATCH 02/17] feat: anomaly scoring pipeline (#10)
- Add app/services/anomaly.py: batch scorer using HF text-classification
pipeline; rewrites anomaly_score/anomaly_label/anomaly_scored_at on
log_entries; inserts high-confidence hits into detections table
- Add app/tasks/anomaly_scorer.py: background task (same shape as
glean_scheduler); triggered after each glean cycle when
TURNSTONE_ANOMALY_MODEL is set
- DB schema: add anomaly_score/anomaly_label/anomaly_scored_at columns to
log_entries (idempotent ALTER TABLE migration); add detections table
- Wire scorer into scheduler_loop and glean_scheduler.run_once; no-op when
model env var is empty (safe to leave unconfigured)
- REST endpoints: GET/POST /api/anomaly/status, /api/anomaly/run,
GET /api/anomaly/detections, POST /api/anomaly/detections/{id}/acknowledge
- Reuses Hybrid-BERT label map from diagnose/classifier.py; works with any
HF text-classification model
- 12 new tests; 406/406 passing
Closes: https://git.opensourcesolarpunk.com/Circuit-Forge/turnstone/issues/10
---
app/db/schema.py | 74 +++++++--
app/rest.py | 68 ++++++++
app/services/anomaly.py | 291 +++++++++++++++++++++++++++++++++++
app/tasks/anomaly_scorer.py | 114 ++++++++++++++
app/tasks/glean_scheduler.py | 21 ++-
tests/test_anomaly.py | 220 ++++++++++++++++++++++++++
6 files changed, 775 insertions(+), 13 deletions(-)
create mode 100644 app/services/anomaly.py
create mode 100644 app/tasks/anomaly_scorer.py
create mode 100644 tests/test_anomaly.py
diff --git a/app/db/schema.py b/app/db/schema.py
index 7cc8d97..0e9ad2f 100644
--- a/app/db/schema.py
+++ b/app/db/schema.py
@@ -23,18 +23,21 @@ logger = logging.getLogger(__name__)
_MAIN_SCHEMA_SQLITE = """
CREATE TABLE IF NOT EXISTS log_entries (
- id TEXT NOT NULL,
- tenant_id TEXT NOT NULL DEFAULT '',
- source_id TEXT NOT NULL,
- sequence INTEGER NOT NULL,
- timestamp_raw TEXT,
- timestamp_iso TEXT,
- ingest_time TEXT NOT NULL,
- severity TEXT,
- repeat_count INTEGER DEFAULT 1,
- out_of_order INTEGER DEFAULT 0,
+ id TEXT NOT NULL,
+ tenant_id TEXT NOT NULL DEFAULT '',
+ source_id TEXT NOT NULL,
+ sequence INTEGER NOT NULL,
+ timestamp_raw TEXT,
+ timestamp_iso TEXT,
+ ingest_time TEXT NOT NULL,
+ severity TEXT,
+ repeat_count INTEGER DEFAULT 1,
+ out_of_order INTEGER DEFAULT 0,
matched_patterns TEXT DEFAULT '[]',
- text TEXT NOT NULL,
+ text TEXT NOT NULL,
+ anomaly_score REAL,
+ anomaly_label TEXT,
+ anomaly_scored_at TEXT,
PRIMARY KEY (tenant_id, id)
);
CREATE INDEX IF NOT EXISTS idx_source ON log_entries(source_id);
@@ -43,6 +46,27 @@ CREATE INDEX IF NOT EXISTS idx_timestamp ON log_entries(timestamp_iso);
CREATE INDEX IF NOT EXISTS idx_ts_repeat ON log_entries(timestamp_iso, repeat_count);
CREATE INDEX IF NOT EXISTS idx_severity ON log_entries(tenant_id, severity);
CREATE INDEX IF NOT EXISTS idx_patterns ON log_entries(matched_patterns);
+CREATE INDEX IF NOT EXISTS idx_anomaly ON log_entries(tenant_id, anomaly_score);
+
+CREATE TABLE IF NOT EXISTS detections (
+ id TEXT PRIMARY KEY,
+ tenant_id TEXT NOT NULL DEFAULT '',
+ entry_id TEXT NOT NULL,
+ source_id TEXT NOT NULL,
+ anomaly_label TEXT NOT NULL,
+ anomaly_score REAL NOT NULL,
+ severity TEXT NOT NULL,
+ text TEXT NOT NULL,
+ timestamp_iso TEXT,
+ detected_at TEXT NOT NULL,
+ acknowledged INTEGER NOT NULL DEFAULT 0,
+ acknowledged_at TEXT,
+ notes TEXT NOT NULL DEFAULT ''
+);
+CREATE INDEX IF NOT EXISTS idx_detections_tenant ON detections(tenant_id, detected_at);
+CREATE INDEX IF NOT EXISTS idx_detections_ack ON detections(acknowledged);
+CREATE INDEX IF NOT EXISTS idx_detections_label ON detections(anomaly_label);
+CREATE INDEX IF NOT EXISTS idx_detections_entry ON detections(entry_id);
CREATE TABLE IF NOT EXISTS glean_fingerprints (
tenant_id TEXT NOT NULL DEFAULT '',
@@ -174,6 +198,9 @@ _MAIN_SCHEMA_PG_STMTS = [
matched_patterns TEXT DEFAULT '[]',
text TEXT NOT NULL,
text_tsv tsvector,
+ anomaly_score DOUBLE PRECISION,
+ anomaly_label TEXT,
+ anomaly_scored_at TEXT,
PRIMARY KEY (tenant_id, id)
)
""",
@@ -182,6 +209,28 @@ _MAIN_SCHEMA_PG_STMTS = [
"CREATE INDEX IF NOT EXISTS idx_severity ON log_entries(tenant_id, severity)",
"CREATE INDEX IF NOT EXISTS idx_patterns ON log_entries(matched_patterns)",
"CREATE INDEX IF NOT EXISTS idx_fts_gin ON log_entries USING GIN(text_tsv)",
+ "CREATE INDEX IF NOT EXISTS idx_anomaly ON log_entries(tenant_id, anomaly_score)",
+ """
+ CREATE TABLE IF NOT EXISTS detections (
+ id TEXT PRIMARY KEY,
+ tenant_id TEXT NOT NULL DEFAULT '',
+ entry_id TEXT NOT NULL,
+ source_id TEXT NOT NULL,
+ anomaly_label TEXT NOT NULL,
+ anomaly_score DOUBLE PRECISION NOT NULL,
+ severity TEXT NOT NULL,
+ text TEXT NOT NULL,
+ timestamp_iso TEXT,
+ detected_at TEXT NOT NULL,
+ acknowledged INTEGER NOT NULL DEFAULT 0,
+ acknowledged_at TEXT,
+ notes TEXT NOT NULL DEFAULT ''
+ )
+ """,
+ "CREATE INDEX IF NOT EXISTS idx_detections_tenant ON detections(tenant_id, detected_at)",
+ "CREATE INDEX IF NOT EXISTS idx_detections_ack ON detections(acknowledged)",
+ "CREATE INDEX IF NOT EXISTS idx_detections_label ON detections(anomaly_label)",
+ "CREATE INDEX IF NOT EXISTS idx_detections_entry ON detections(entry_id)",
"""
CREATE OR REPLACE FUNCTION _ts_update_text_tsv() RETURNS trigger AS $$
BEGIN
@@ -336,6 +385,9 @@ _MAIN_MIGRATIONS_SQLITE = [
"ALTER TABLE glean_fingerprints ADD COLUMN mtime REAL",
"ALTER TABLE glean_fingerprints ADD COLUMN size INTEGER",
"ALTER TABLE glean_fingerprints ADD COLUMN gleaned_at TEXT",
+ "ALTER TABLE log_entries ADD COLUMN anomaly_score REAL",
+ "ALTER TABLE log_entries ADD COLUMN anomaly_label TEXT",
+ "ALTER TABLE log_entries ADD COLUMN anomaly_scored_at TEXT",
]
_CONTEXT_MIGRATIONS_SQLITE = [
diff --git a/app/rest.py b/app/rest.py
index cc87254..d187979 100644
--- a/app/rest.py
+++ b/app/rest.py
@@ -88,6 +88,8 @@ from app.glean.doc_upload import glean_upload as _glean_upload
from app.context.wizard import get_schema as _wizard_schema, advance_step, is_complete, apply_session
from app.context.chunker import UnsupportedDocType, FileTooLarge
from app.tasks.glean_scheduler import get_state as _glean_state, run_once as _run_glean, scheduler_loop as _scheduler_loop, submit_matched as _submit_matched
+from app.tasks.anomaly_scorer import get_state as _scorer_state, run_once as _run_scorer
+from app.services.anomaly import list_detections as _list_detections, acknowledge_detection as _ack_detection
from app.glean.mqtt_subscriber import run_mqtt_subscribers as _run_mqtt_subscribers
DB_PATH = Path(os.environ.get("TURNSTONE_DB", Path(__file__).parent.parent / "data" / "turnstone.db"))
@@ -109,6 +111,9 @@ PATTERN_DIR = Path(os.environ.get("TURNSTONE_PATTERNS", Path(__file__).parent.pa
PATTERN_FILE = PATTERN_DIR / "default.yaml"
GLEAN_INTERVAL = int(os.environ.get("TURNSTONE_GLEAN_INTERVAL", "900"))
SUBMIT_ENDPOINT = os.environ.get("TURNSTONE_SUBMIT_ENDPOINT", "").rstrip("/")
+ANOMALY_MODEL = os.environ.get("TURNSTONE_ANOMALY_MODEL", "")
+ANOMALY_DEVICE = os.environ.get("TURNSTONE_ANOMALY_DEVICE", "cpu")
+ANOMALY_THRESHOLD = float(os.environ.get("TURNSTONE_ANOMALY_THRESHOLD", "0.75"))
# When set, all /api/ routes require Authorization: Bearer .
# Unset (default) means no authentication — suitable for local-only deployments.
_API_KEY: str | None = os.environ.get("TURNSTONE_API_KEY") or None
@@ -165,6 +170,9 @@ async def _lifespan(app: FastAPI):
sources_file, DB_PATH, PATTERN_FILE, GLEAN_INTERVAL,
submit_endpoint=SUBMIT_ENDPOINT or None,
source_host=SOURCE_HOST,
+ anomaly_model=ANOMALY_MODEL,
+ anomaly_device=ANOMALY_DEVICE,
+ anomaly_threshold=ANOMALY_THRESHOLD,
),
name="glean-scheduler",
)
@@ -1318,6 +1326,66 @@ async def debug_search(q: str):
app.include_router(_ctx)
+# ---------------------------------------------------------------------------
+# Anomaly scoring endpoints
+# ---------------------------------------------------------------------------
+
+_anomaly = APIRouter(prefix="/turnstone/api/anomaly", dependencies=[Depends(_check_api_key)])
+
+
+@_anomaly.get("/status")
+async def anomaly_status():
+ """Return scorer state and configuration."""
+ state = _scorer_state()
+ return {
+ "model": ANOMALY_MODEL or None,
+ "threshold": ANOMALY_THRESHOLD,
+ "device": ANOMALY_DEVICE,
+ "enabled": bool(ANOMALY_MODEL),
+ **vars(state),
+ }
+
+
+@_anomaly.post("/run")
+async def anomaly_run(background_tasks: BackgroundTasks):
+ """Trigger a manual anomaly scoring pass (runs in background)."""
+ if not ANOMALY_MODEL:
+ raise HTTPException(status_code=400, detail="TURNSTONE_ANOMALY_MODEL not configured")
+ background_tasks.add_task(
+ _run_scorer, DB_PATH, ANOMALY_MODEL, ANOMALY_DEVICE, 256, ANOMALY_THRESHOLD
+ )
+ return {"ok": True, "message": "scorer triggered"}
+
+
+@_anomaly.get("/detections")
+async def anomaly_detections(
+ limit: int = Query(100, ge=1, le=1000),
+ unacked_only: bool = Query(False),
+ label: str | None = Query(None),
+):
+ """List anomaly detections ordered by detected_at DESC."""
+ loop = asyncio.get_running_loop()
+ rows = await loop.run_in_executor(
+ None, lambda: _list_detections(DB_PATH, limit=limit, unacked_only=unacked_only, label=label)
+ )
+ return {"detections": rows, "total": len(rows)}
+
+
+@_anomaly.post("/detections/{detection_id}/acknowledge")
+async def acknowledge_detection(detection_id: str, notes: str = ""):
+ """Acknowledge a detection (mark as reviewed)."""
+ loop = asyncio.get_running_loop()
+ updated = await loop.run_in_executor(
+ None, lambda: _ack_detection(DB_PATH, detection_id, notes)
+ )
+ if not updated:
+ raise HTTPException(status_code=404, detail="Detection not found")
+ return {"ok": True}
+
+
+app.include_router(_anomaly)
+
+
# Root redirect → /turnstone/
@app.get("/")
def root_redirect() -> RedirectResponse:
diff --git a/app/services/anomaly.py b/app/services/anomaly.py
new file mode 100644
index 0000000..85e7317
--- /dev/null
+++ b/app/services/anomaly.py
@@ -0,0 +1,291 @@
+"""Anomaly scoring pipeline — batch-score log_entries with a HF classifier.
+
+Designed to run after each glean cycle (or standalone). When no model is
+configured the scorer is a no-op and returns immediately, so it is always
+safe to wire into the glean pipeline.
+
+Model: any HuggingFace text-classification model. The existing Hybrid-BERT
+label map (from diagnose/classifier.py) is reused when the model produces
+NORMAL/SECURITY_ANOMALY/… outputs; other models get a generic severity map.
+
+Scoring strategy
+----------------
+- Query unscored rows in batches (WHERE anomaly_scored_at IS NULL)
+- Run each entry text through the HF pipeline
+- Write anomaly_score + anomaly_label + anomaly_scored_at back
+- INSERT high-confidence hits (score >= threshold) into detections table,
+ skipping duplicates so the scorer is safe to re-run
+"""
+from __future__ import annotations
+
+import logging
+import os
+import uuid
+from dataclasses import dataclass
+from datetime import datetime, timezone
+from pathlib import Path
+from typing import Any
+
+from app.db import get_conn, resolve_tenant_id
+from app.db.dialect import q
+
+logger = logging.getLogger(__name__)
+
+# ---------------------------------------------------------------------------
+# Label maps — reuse Hybrid-BERT vocabulary from diagnose/classifier.py
+# ---------------------------------------------------------------------------
+
+_HYBRID_BERT_SEVERITY: dict[str, str] = {
+ "NORMAL": "INFO",
+ "SECURITY_ANOMALY": "ERROR",
+ "SYSTEM_FAILURE": "CRITICAL",
+ "PERFORMANCE_ISSUE": "WARN",
+ "NETWORK_ANOMALY": "WARN",
+ "CONFIG_ERROR": "ERROR",
+ "HARDWARE_ISSUE": "CRITICAL",
+}
+
+_GENERIC_SEVERITY: dict[str, str] = {
+ "CRITICAL": "CRITICAL",
+ "ERROR": "ERROR",
+ "WARNING": "WARN",
+ "WARN": "WARN",
+ "INFO": "INFO",
+ "DEBUG": "DEBUG",
+}
+
+_ANOMALOUS_LABELS: frozenset[str] = frozenset(
+ {
+ "SECURITY_ANOMALY",
+ "SYSTEM_FAILURE",
+ "PERFORMANCE_ISSUE",
+ "NETWORK_ANOMALY",
+ "CONFIG_ERROR",
+ "HARDWARE_ISSUE",
+ "CRITICAL",
+ "ERROR",
+ }
+)
+
+_DEFAULT_THRESHOLD = float(os.environ.get("TURNSTONE_ANOMALY_THRESHOLD", "0.75"))
+_DEFAULT_MODEL = os.environ.get("TURNSTONE_ANOMALY_MODEL", "")
+_DEFAULT_DEVICE = os.environ.get("TURNSTONE_ANOMALY_DEVICE", "cpu")
+_DEFAULT_BATCH = int(os.environ.get("TURNSTONE_ANOMALY_BATCH", "256"))
+
+# ---------------------------------------------------------------------------
+# ML singleton
+# ---------------------------------------------------------------------------
+
+_pipeline: Any | None = None
+
+
+def _get_pipeline(model_id: str, device: str) -> Any:
+ global _pipeline # noqa: PLW0603
+ if _pipeline is None:
+ from transformers import pipeline as hf_pipeline # type: ignore[import-untyped]
+ _pipeline = hf_pipeline("text-classification", model=model_id, device=device)
+ return _pipeline
+
+
+def reset_pipeline() -> None:
+ """Reset the cached pipeline singleton (test helper)."""
+ global _pipeline # noqa: PLW0603
+ _pipeline = None
+
+
+# ---------------------------------------------------------------------------
+# Result types
+# ---------------------------------------------------------------------------
+
+
+@dataclass
+class ScoringResult:
+ scored: int = 0
+ detections: int = 0
+ skipped: bool = False
+ error: str | None = None
+
+
+# ---------------------------------------------------------------------------
+# Internal helpers
+# ---------------------------------------------------------------------------
+
+
+def _map_label(raw_label: str, score: float) -> tuple[str, str]:
+ """Return (normalised_label, severity) for a raw model output label."""
+ upper = raw_label.upper()
+ if upper in _HYBRID_BERT_SEVERITY:
+ return upper, _HYBRID_BERT_SEVERITY[upper]
+ sev = _GENERIC_SEVERITY.get(upper, "WARN")
+ return upper, sev
+
+
+def _fetch_unscored(conn: Any, tenant_id: str, limit: int) -> list[dict]:
+ rows = conn.execute(
+ q("""
+ SELECT id, source_id, text, timestamp_iso, severity
+ FROM log_entries
+ WHERE anomaly_scored_at IS NULL
+ AND (tenant_id = ? OR tenant_id = '')
+ ORDER BY ingest_time DESC
+ LIMIT ?
+ """),
+ (tenant_id, limit),
+ ).fetchall()
+ return [dict(r) for r in rows]
+
+
+def _write_scores(
+ conn: Any,
+ rows: list[dict],
+ scored_at: str,
+) -> None:
+ conn.executemany(
+ q("UPDATE log_entries SET anomaly_score = ?, anomaly_label = ?, anomaly_scored_at = ? WHERE id = ?"),
+ [(r["anomaly_score"], r["anomaly_label"], scored_at, r["id"]) for r in rows],
+ )
+
+
+def _insert_detections(conn: Any, rows: list[dict], tenant_id: str, detected_at: str) -> int:
+ inserted = 0
+ for r in rows:
+ try:
+ conn.execute(
+ q("""
+ INSERT INTO detections
+ (id, tenant_id, entry_id, source_id, anomaly_label, anomaly_score,
+ severity, text, timestamp_iso, detected_at)
+ VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
+ """),
+ (
+ str(uuid.uuid4()),
+ tenant_id,
+ r["id"],
+ r["source_id"],
+ r["anomaly_label"],
+ r["anomaly_score"],
+ r["severity"],
+ r["text"][:2000],
+ r.get("timestamp_iso"),
+ detected_at,
+ ),
+ )
+ inserted += 1
+ except Exception: # noqa: BLE001
+ pass # duplicate entry_id or constraint violation — skip
+ return inserted
+
+
+# ---------------------------------------------------------------------------
+# Public API
+# ---------------------------------------------------------------------------
+
+
+def score_unscored(
+ db_path: Path,
+ model_id: str = _DEFAULT_MODEL,
+ device: str = _DEFAULT_DEVICE,
+ batch_size: int = _DEFAULT_BATCH,
+ threshold: float = _DEFAULT_THRESHOLD,
+) -> ScoringResult:
+ """Score all unscored log_entries in batches.
+
+ Returns immediately (skipped=True) when model_id is empty — allows
+ unconditional wiring without requiring the model to be configured.
+ """
+ if not model_id:
+ return ScoringResult(skipped=True)
+
+ try:
+ pipe = _get_pipeline(model_id, device)
+ except Exception as exc:
+ logger.error("Failed to load anomaly model %r: %s", model_id, exc)
+ return ScoringResult(error=str(exc))
+
+ tenant_id = resolve_tenant_id()
+ total_scored = 0
+ total_detections = 0
+
+ while True:
+ with get_conn(db_path) as conn:
+ batch = _fetch_unscored(conn, tenant_id, batch_size)
+ if not batch:
+ break
+
+ texts = [r["text"][:512] for r in batch]
+ try:
+ predictions = pipe(texts, truncation=True, max_length=512)
+ except Exception as exc:
+ logger.error("Inference error on batch of %d: %s", len(batch), exc)
+ return ScoringResult(scored=total_scored, detections=total_detections, error=str(exc))
+
+ scored_at = datetime.now(tz=timezone.utc).isoformat()
+ scored_rows: list[dict] = []
+ detection_rows: list[dict] = []
+
+ for row, pred in zip(batch, predictions):
+ label, severity = _map_label(pred["label"], pred["score"])
+ enriched = {**row, "anomaly_score": pred["score"], "anomaly_label": label, "severity": severity}
+ scored_rows.append(enriched)
+ if label in _ANOMALOUS_LABELS and pred["score"] >= threshold:
+ detection_rows.append(enriched)
+
+ with get_conn(db_path) as conn:
+ _write_scores(conn, scored_rows, scored_at)
+ det_count = _insert_detections(conn, detection_rows, tenant_id, scored_at)
+ conn.commit()
+
+ total_scored += len(scored_rows)
+ total_detections += det_count
+ logger.info(
+ "Scored %d entries, %d detections (threshold=%.2f)",
+ len(scored_rows), det_count, threshold,
+ )
+
+ if len(batch) < batch_size:
+ break
+
+ return ScoringResult(scored=total_scored, detections=total_detections)
+
+
+def list_detections(
+ db_path: Path,
+ limit: int = 100,
+ unacked_only: bool = False,
+ label: str | None = None,
+) -> list[dict]:
+ """Return detections ordered by detected_at DESC."""
+ tenant_id = resolve_tenant_id()
+ conditions = ["(tenant_id = ? OR tenant_id = '')"]
+ params: list[Any] = [tenant_id]
+
+ if unacked_only:
+ conditions.append("acknowledged = 0")
+ if label:
+ conditions.append(q("anomaly_label = ?"))
+ params.append(label.upper())
+
+ where = " AND ".join(conditions)
+ with get_conn(db_path) as conn:
+ rows = conn.execute(
+ q(f"SELECT * FROM detections WHERE {where} ORDER BY detected_at DESC LIMIT ?"), # noqa: S608
+ (*params, limit),
+ ).fetchall()
+ return [dict(r) for r in rows]
+
+
+def acknowledge_detection(db_path: Path, detection_id: str, notes: str = "") -> bool:
+ """Mark a detection as acknowledged. Returns True if a row was updated."""
+ tenant_id = resolve_tenant_id()
+ acked_at = datetime.now(tz=timezone.utc).isoformat()
+ with get_conn(db_path) as conn:
+ cur = conn.execute(
+ q("""
+ UPDATE detections
+ SET acknowledged = 1, acknowledged_at = ?, notes = ?
+ WHERE id = ? AND (tenant_id = ? OR tenant_id = '')
+ """),
+ (acked_at, notes, detection_id, tenant_id),
+ )
+ conn.commit()
+ return cur.rowcount > 0
diff --git a/app/tasks/anomaly_scorer.py b/app/tasks/anomaly_scorer.py
new file mode 100644
index 0000000..e952b62
--- /dev/null
+++ b/app/tasks/anomaly_scorer.py
@@ -0,0 +1,114 @@
+"""Background anomaly scoring task.
+
+Runs score_unscored() after each glean cycle (triggered by glean_scheduler)
+or on its own interval when TURNSTONE_ANOMALY_INTERVAL is set.
+
+Set TURNSTONE_ANOMALY_MODEL to a HuggingFace model ID to activate.
+When the env var is empty (default) the scorer is a no-op.
+"""
+from __future__ import annotations
+
+import asyncio
+import logging
+import os
+from dataclasses import dataclass, field
+from datetime import datetime, timedelta, timezone
+from pathlib import Path
+
+from app.services.anomaly import ScoringResult, score_unscored
+
+logger = logging.getLogger(__name__)
+
+_DEFAULT_INTERVAL = int(os.environ.get("TURNSTONE_ANOMALY_INTERVAL", "0"))
+
+_lock = asyncio.Lock()
+
+
+@dataclass
+class ScorerState:
+ last_run_at: str | None = None
+ last_duration_s: float | None = None
+ last_scored: int = 0
+ last_detections: int = 0
+ last_error: str | None = None
+ run_count: int = 0
+ next_run_at: str | None = None
+ running: bool = False
+ total_scored: int = 0
+ total_detections: int = 0
+
+
+_state = ScorerState()
+
+
+def get_state() -> ScorerState:
+ return _state
+
+
+async def run_once(
+ db_path: Path,
+ model_id: str = "",
+ device: str = "cpu",
+ batch_size: int = 256,
+ threshold: float = 0.75,
+) -> ScoringResult:
+ """Score unscored entries once. Skips if already running or model not configured."""
+ if _lock.locked():
+ return ScoringResult(skipped=True, error="scorer already running")
+
+ async with _lock:
+ _state.running = True
+ started = datetime.now(tz=timezone.utc)
+ try:
+ loop = asyncio.get_running_loop()
+ result: ScoringResult = await loop.run_in_executor(
+ None,
+ lambda: score_unscored(db_path, model_id, device, batch_size, threshold),
+ )
+ duration = (datetime.now(tz=timezone.utc) - started).total_seconds()
+ _state.last_run_at = started.isoformat()
+ _state.last_duration_s = round(duration, 2)
+ _state.last_scored = result.scored
+ _state.last_detections = result.detections
+ _state.last_error = result.error
+ _state.run_count += 1
+ _state.total_scored += result.scored
+ _state.total_detections += result.detections
+ if not result.skipped:
+ logger.info(
+ "Anomaly scorer: %d scored, %d detections in %.1fs",
+ result.scored, result.detections, duration,
+ )
+ return result
+ except Exception as exc:
+ duration = (datetime.now(tz=timezone.utc) - started).total_seconds()
+ _state.last_run_at = started.isoformat()
+ _state.last_duration_s = round(duration, 2)
+ _state.last_error = str(exc)
+ _state.run_count += 1
+ logger.error("Anomaly scorer failed: %s", exc)
+ return ScoringResult(error=str(exc))
+ finally:
+ _state.running = False
+
+
+async def scorer_loop(
+ db_path: Path,
+ model_id: str,
+ device: str,
+ interval_s: int,
+ batch_size: int = 256,
+ threshold: float = 0.75,
+) -> None:
+ """Score unscored entries every interval_s seconds until cancelled."""
+ logger.info("Anomaly scorer loop started — interval %ds, model: %s", interval_s, model_id)
+ while True:
+ await run_once(db_path, model_id, device, batch_size, threshold)
+ next_run = datetime.now(tz=timezone.utc) + timedelta(seconds=interval_s)
+ _state.next_run_at = next_run.isoformat()
+ try:
+ await asyncio.sleep(interval_s)
+ except asyncio.CancelledError:
+ logger.info("Anomaly scorer loop cancelled")
+ _state.next_run_at = None
+ raise
diff --git a/app/tasks/glean_scheduler.py b/app/tasks/glean_scheduler.py
index ba4e501..7322158 100644
--- a/app/tasks/glean_scheduler.py
+++ b/app/tasks/glean_scheduler.py
@@ -20,6 +20,7 @@ from typing import Any
import httpx
from app.glean.pipeline import glean_sources
+from app.tasks.anomaly_scorer import run_once as _run_scorer
logger = logging.getLogger(__name__)
@@ -123,6 +124,9 @@ async def run_once(
submit_endpoint: str | None = None,
source_host: str = "unknown",
force: bool = False,
+ anomaly_model: str = "",
+ anomaly_device: str = "cpu",
+ anomaly_threshold: float = 0.75,
) -> dict[str, Any]:
"""Ingest all sources once, then submit matched entries if configured.
@@ -163,6 +167,9 @@ async def run_once(
if submit_endpoint:
await submit_matched(db_path, submit_endpoint, source_host, since=_state.last_submitted_at)
+ if anomaly_model:
+ await _run_scorer(db_path, anomaly_model, anomaly_device, threshold=anomaly_threshold)
+
return {"ok": True, "stats": _state.last_stats, "duration_s": _state.last_duration_s}
@@ -173,13 +180,23 @@ async def scheduler_loop(
interval_s: int,
submit_endpoint: str | None = None,
source_host: str = "unknown",
+ anomaly_model: str = "",
+ anomaly_device: str = "cpu",
+ anomaly_threshold: float = 0.75,
) -> None:
- """Run glean + optional submission every interval_s seconds until cancelled."""
+ """Run glean + optional submission + optional anomaly scoring every interval_s seconds."""
logger.info("Ingest scheduler started — interval %ds, sources: %s", interval_s, sources_file)
if submit_endpoint:
logger.info("Submission enabled — endpoint: %s", submit_endpoint)
+ if anomaly_model:
+ logger.info("Anomaly scoring enabled — model: %s", anomaly_model)
while True:
- await run_once(sources_file, db_path, pattern_file, submit_endpoint, source_host)
+ await run_once(
+ sources_file, db_path, pattern_file, submit_endpoint, source_host,
+ anomaly_model=anomaly_model,
+ anomaly_device=anomaly_device,
+ anomaly_threshold=anomaly_threshold,
+ )
next_run = datetime.now(tz=timezone.utc) + timedelta(seconds=interval_s)
_state.next_run_at = next_run.isoformat()
try:
diff --git a/tests/test_anomaly.py b/tests/test_anomaly.py
new file mode 100644
index 0000000..31bbe98
--- /dev/null
+++ b/tests/test_anomaly.py
@@ -0,0 +1,220 @@
+"""Tests for app/services/anomaly.py — anomaly scoring pipeline."""
+from __future__ import annotations
+
+import sqlite3
+import uuid
+from pathlib import Path
+from unittest.mock import MagicMock
+
+import pytest
+
+import app.services.anomaly as anomaly_mod
+from app.db.schema import ensure_schema
+from app.services.anomaly import (
+ ScoringResult,
+ acknowledge_detection,
+ list_detections,
+ reset_pipeline,
+ score_unscored,
+)
+
+
+# ---------------------------------------------------------------------------
+# Fixtures
+# ---------------------------------------------------------------------------
+
+
+@pytest.fixture(autouse=True)
+def _reset_pipeline():
+ """Ensure the ML singleton is cleared between tests."""
+ reset_pipeline()
+ yield
+ reset_pipeline()
+
+
+@pytest.fixture
+def db(tmp_path: Path) -> Path:
+ db_path = tmp_path / "t.db"
+ ensure_schema(db_path)
+ return db_path
+
+
+def _insert_entry(db_path: Path, text: str, entry_id: str | None = None) -> str:
+ eid = entry_id or str(uuid.uuid4())
+ conn = sqlite3.connect(str(db_path))
+ conn.execute(
+ "INSERT INTO log_entries(id, tenant_id, source_id, sequence, ingest_time, text) "
+ "VALUES (?,?,?,?,?,?)",
+ (eid, "", "src", 1, "2026-01-01T00:00:00", text),
+ )
+ conn.commit()
+ conn.close()
+ return eid
+
+
+# ---------------------------------------------------------------------------
+# score_unscored
+# ---------------------------------------------------------------------------
+
+
+def test_score_unscored_no_model_returns_skipped(db: Path):
+ result = score_unscored(db, model_id="")
+ assert result.skipped is True
+ assert result.scored == 0
+
+
+def test_score_unscored_scores_entries(db: Path, monkeypatch):
+ _insert_entry(db, "kernel panic — OOM killer invoked")
+ _insert_entry(db, "user login successful")
+
+ mock_pipe = MagicMock(return_value=[
+ {"label": "SYSTEM_FAILURE", "score": 0.92},
+ {"label": "NORMAL", "score": 0.88},
+ ])
+ monkeypatch.setattr(anomaly_mod, "_pipeline", mock_pipe)
+
+ result = score_unscored(db, model_id="fake-model", batch_size=10)
+ assert result.skipped is False
+ assert result.scored == 2
+
+
+def test_score_unscored_creates_detection_above_threshold(db: Path, monkeypatch):
+ _insert_entry(db, "segfault in service")
+
+ mock_pipe = MagicMock(return_value=[
+ {"label": "SYSTEM_FAILURE", "score": 0.95},
+ ])
+ monkeypatch.setattr(anomaly_mod, "_pipeline", mock_pipe)
+
+ result = score_unscored(db, model_id="fake-model", threshold=0.80)
+ assert result.detections == 1
+
+ detections = list_detections(db)
+ assert len(detections) == 1
+ assert detections[0]["anomaly_label"] == "SYSTEM_FAILURE"
+ assert detections[0]["anomaly_score"] == pytest.approx(0.95)
+
+
+def test_score_unscored_no_detection_below_threshold(db: Path, monkeypatch):
+ _insert_entry(db, "warning: disk at 80%")
+
+ mock_pipe = MagicMock(return_value=[
+ {"label": "PERFORMANCE_ISSUE", "score": 0.60},
+ ])
+ monkeypatch.setattr(anomaly_mod, "_pipeline", mock_pipe)
+
+ result = score_unscored(db, model_id="fake-model", threshold=0.80)
+ assert result.detections == 0
+ assert result.scored == 1
+
+
+def test_score_unscored_normal_label_never_detection(db: Path, monkeypatch):
+ _insert_entry(db, "service started successfully")
+
+ mock_pipe = MagicMock(return_value=[
+ {"label": "NORMAL", "score": 0.99},
+ ])
+ monkeypatch.setattr(anomaly_mod, "_pipeline", mock_pipe)
+
+ result = score_unscored(db, model_id="fake-model", threshold=0.50)
+ assert result.detections == 0
+
+
+def test_score_unscored_idempotent(db: Path, monkeypatch):
+ """Entries already scored are not re-scored on subsequent runs."""
+ _insert_entry(db, "first entry")
+
+ call_count = 0
+
+ def _side_effect(texts, **_kwargs):
+ nonlocal call_count
+ call_count += 1
+ return [{"label": "NORMAL", "score": 0.90} for _ in texts]
+
+ mock_pipe = MagicMock(side_effect=_side_effect)
+ monkeypatch.setattr(anomaly_mod, "_pipeline", mock_pipe)
+
+ score_unscored(db, model_id="fake-model")
+ score_unscored(db, model_id="fake-model")
+
+ assert call_count == 1 # second run finds no unscored rows
+
+
+def test_score_unscored_pipeline_error_returns_error(db: Path, monkeypatch):
+ _insert_entry(db, "some log line")
+
+ mock_pipe = MagicMock(side_effect=RuntimeError("CUDA OOM"))
+ monkeypatch.setattr(anomaly_mod, "_pipeline", mock_pipe)
+
+ result = score_unscored(db, model_id="fake-model")
+ assert result.error is not None
+ assert "CUDA OOM" in result.error
+
+
+# ---------------------------------------------------------------------------
+# list_detections / acknowledge_detection
+# ---------------------------------------------------------------------------
+
+
+def test_list_detections_empty(db: Path):
+ assert list_detections(db) == []
+
+
+def test_list_detections_filters_unacked(db: Path, monkeypatch):
+ _insert_entry(db, "crash")
+
+ mock_pipe = MagicMock(return_value=[{"label": "SYSTEM_FAILURE", "score": 0.91}])
+ monkeypatch.setattr(anomaly_mod, "_pipeline", mock_pipe)
+ score_unscored(db, model_id="fake-model", threshold=0.80)
+
+ all_dets = list_detections(db)
+ assert len(all_dets) == 1
+ unacked = list_detections(db, unacked_only=True)
+ assert len(unacked) == 1
+
+
+def test_acknowledge_detection(db: Path, monkeypatch):
+ _insert_entry(db, "network anomaly")
+
+ mock_pipe = MagicMock(return_value=[{"label": "NETWORK_ANOMALY", "score": 0.88}])
+ monkeypatch.setattr(anomaly_mod, "_pipeline", mock_pipe)
+ score_unscored(db, model_id="fake-model", threshold=0.80)
+
+ dets = list_detections(db)
+ assert len(dets) == 1
+ det_id = dets[0]["id"]
+
+ updated = acknowledge_detection(db, det_id, notes="benign test traffic")
+ assert updated is True
+
+ unacked = list_detections(db, unacked_only=True)
+ assert len(unacked) == 0
+
+ all_dets = list_detections(db)
+ assert all_dets[0]["acknowledged"] == 1
+ assert all_dets[0]["notes"] == "benign test traffic"
+
+
+def test_acknowledge_detection_unknown_id(db: Path):
+ updated = acknowledge_detection(db, "nonexistent-id")
+ assert updated is False
+
+
+def test_list_detections_label_filter(db: Path, monkeypatch):
+ _insert_entry(db, "OOM kill")
+ _insert_entry(db, "network timeout")
+
+ mock_pipe = MagicMock(side_effect=[
+ [{"label": "SYSTEM_FAILURE", "score": 0.93}],
+ [{"label": "NETWORK_ANOMALY", "score": 0.85}],
+ ])
+ monkeypatch.setattr(anomaly_mod, "_pipeline", mock_pipe)
+
+ score_unscored(db, model_id="fake-model", batch_size=1, threshold=0.80)
+ score_unscored(db, model_id="fake-model", batch_size=1, threshold=0.80)
+
+ sys_dets = list_detections(db, label="SYSTEM_FAILURE")
+ assert all(d["anomaly_label"] == "SYSTEM_FAILURE" for d in sys_dets)
+
+ net_dets = list_detections(db, label="NETWORK_ANOMALY")
+ assert all(d["anomaly_label"] == "NETWORK_ANOMALY" for d in net_dets)
From 40694a30e5928df8e0bf99f0b10f832ff4d710ae Mon Sep 17 00:00:00 2001
From: pyr0ball
Date: Tue, 9 Jun 2026 23:01:48 -0700
Subject: [PATCH 03/17] chore: wire anomaly scoring pipeline into deployment
config
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit
Add TURNSTONE_ANOMALY_* env vars to docker-compose.yml, docker-standalone.sh,
and .env.example. Mount shared HF model cache (/Library/Assets/LLM on Heimdall)
as read-only bind in both compose and standalone — avoids re-downloading models
that are already cached by the diagnose pipeline.
Heimdall: byviz/bylastic_classification_logs already cached, threshold 0.80,
glean-triggered only (TURNSTONE_ANOMALY_INTERVAL=0).
---
.env.example | 17 +++++++++++++++++
docker-compose.yml | 14 ++++++++++++++
docker-standalone.sh | 19 ++++++++++++++++++-
3 files changed, 49 insertions(+), 1 deletion(-)
diff --git a/.env.example b/.env.example
index fff4a27..97b21e6 100644
--- a/.env.example
+++ b/.env.example
@@ -42,6 +42,23 @@
# TURNSTONE_EMBED_MODEL=BAAI/bge-small-en-v1.5
# TURNSTONE_EMBED_DEVICE=cpu
+# --- Anomaly scoring pipeline (IDS / watchdog) ---
+# Batch-scores every ingested log entry after each glean cycle.
+# Any HuggingFace text-classification model works; the byviz classifier (already
+# required by the diagnose pipeline) is the recommended starting point.
+# Detections above the threshold are inserted into the detections table and
+# surfaced in the Security Alerts tab.
+#
+# Set TURNSTONE_ANOMALY_MODEL to enable; leave unset to disable (safe default).
+# TURNSTONE_ANOMALY_MODEL=byviz/bylastic_classification_logs
+# TURNSTONE_ANOMALY_DEVICE=cpu # or "cuda" / "mps" for GPU inference
+# TURNSTONE_ANOMALY_THRESHOLD=0.80 # confidence floor for detection insertion
+# TURNSTONE_ANOMALY_INTERVAL=0 # standalone loop (0 = glean-triggered only)
+#
+# HuggingFace model cache — share with the host to avoid re-downloading models.
+# HF_HOME=/hf_cache # inside container (set in docker-compose)
+# HF_CACHE_PATH=/Library/Assets/LLM # host bind-mount source (docker-compose only)
+
# --- Air-gapped / offline deployment ---
# Set to 1 to block all HuggingFace hub network access at runtime.
# Pre-download models to ~/.cache/huggingface/ before deploying — see docs/air-gapped-deployment.md.
diff --git a/docker-compose.yml b/docker-compose.yml
index 8c9bf29..d197bc1 100644
--- a/docker-compose.yml
+++ b/docker-compose.yml
@@ -41,9 +41,23 @@ services:
TURNSTONE_GLEAN_INTERVAL: ${TURNSTONE_GLEAN_INTERVAL:-900}
TURNSTONE_SOURCE_HOST: ${TURNSTONE_SOURCE_HOST:-}
TURNSTONE_SUBMIT_ENDPOINT: ${TURNSTONE_SUBMIT_ENDPOINT:-}
+ # --- Multi-agent diagnose pipeline ---
+ TURNSTONE_MULTI_AGENT_DIAGNOSE: ${TURNSTONE_MULTI_AGENT_DIAGNOSE:-false}
+ TURNSTONE_CLASSIFIER_MODEL: ${TURNSTONE_CLASSIFIER_MODEL:-}
+ TURNSTONE_EMBED_BACKEND: ${TURNSTONE_EMBED_BACKEND:-}
+ TURNSTONE_EMBED_MODEL: ${TURNSTONE_EMBED_MODEL:-}
+ TURNSTONE_EMBED_DEVICE: ${TURNSTONE_EMBED_DEVICE:-cpu}
+ # --- Anomaly scoring pipeline ---
+ TURNSTONE_ANOMALY_MODEL: ${TURNSTONE_ANOMALY_MODEL:-}
+ TURNSTONE_ANOMALY_DEVICE: ${TURNSTONE_ANOMALY_DEVICE:-cpu}
+ TURNSTONE_ANOMALY_THRESHOLD: ${TURNSTONE_ANOMALY_THRESHOLD:-0.75}
+ TURNSTONE_ANOMALY_INTERVAL: ${TURNSTONE_ANOMALY_INTERVAL:-0}
+ # --- HuggingFace model cache ---
+ HF_HOME: /hf_cache
volumes:
- ./patterns:/app/patterns:ro
- ./data:/app/data # optional: persists SQLite files if DATABASE_URL unset
+ - ${HF_CACHE_PATH:-/Library/Assets/LLM}:/hf_cache:ro # shared model cache
volumes:
turnstone_pgdata:
diff --git a/docker-standalone.sh b/docker-standalone.sh
index 7098fa8..8d45406 100755
--- a/docker-standalone.sh
+++ b/docker-standalone.sh
@@ -62,7 +62,10 @@ set -euo pipefail
REPO_DIR="${HOME}/turnstone"
DATA_DIR="${REPO_DIR}/data"
PATTERNS_DIR="${REPO_DIR}/patterns"
-HF_CACHE_DIR="${REPO_DIR}/hf-cache" # persists downloaded ML models across restarts
+# HF_CACHE_DIR: override to a shared cache directory to avoid re-downloading models.
+# Example (Heimdall, where byviz/bylastic_classification_logs is already cached):
+# export HF_CACHE_DIR=/Library/Assets/LLM
+HF_CACHE_DIR="${HF_CACHE_DIR:-${REPO_DIR}/hf-cache}"
TZ="${TZ:-America/Los_Angeles}"
@@ -83,6 +86,16 @@ TZ="${TZ:-America/Los_Angeles}"
# bash ~/turnstone/docker-standalone.sh
#
+# ── Anomaly scoring pipeline (IDS / watchdog) ────────────────────────────────
+# Set TURNSTONE_ANOMALY_MODEL to enable automatic anomaly scoring after each
+# glean run. The byviz classifier (already used by the diagnose pipeline) is
+# a good default — it's cached alongside the other models.
+#
+# export TURNSTONE_ANOMALY_MODEL=byviz/bylastic_classification_logs
+# export TURNSTONE_ANOMALY_THRESHOLD=0.80 # confidence floor (default 0.75)
+# bash ~/turnstone/docker-standalone.sh
+#
+
# ── Multi-agent diagnose pipeline ────────────────────────────────────────────
# Enable the 5-stage ML pipeline to get smarter diagnose results.
#
@@ -134,6 +147,10 @@ docker run -d \
-e TURNSTONE_EMBED_BACKEND="${TURNSTONE_EMBED_BACKEND:-sentence_transformers}" \
-e TURNSTONE_EMBED_MODEL="${TURNSTONE_EMBED_MODEL:-sentence-transformers/all-MiniLM-L6-v2}" \
-e TURNSTONE_EMBED_DEVICE="${TURNSTONE_EMBED_DEVICE:-cpu}" \
+ -e TURNSTONE_ANOMALY_MODEL="${TURNSTONE_ANOMALY_MODEL:-}" \
+ -e TURNSTONE_ANOMALY_DEVICE="${TURNSTONE_ANOMALY_DEVICE:-cpu}" \
+ -e TURNSTONE_ANOMALY_THRESHOLD="${TURNSTONE_ANOMALY_THRESHOLD:-0.75}" \
+ -e TURNSTONE_ANOMALY_INTERVAL="${TURNSTONE_ANOMALY_INTERVAL:-0}" \
localhost/turnstone:latest
echo ""
From 6e228fe0bfae8eef688f72994bed41a1b6c5f9f3 Mon Sep 17 00:00:00 2001
From: pyr0ball
Date: Wed, 10 Jun 2026 00:28:15 -0700
Subject: [PATCH 04/17] =?UTF-8?q?feat:=20security=20alerts=20tab=20?=
=?UTF-8?q?=E2=80=94=20UI=20view=20for=20anomaly=20detections=20(#11)?=
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit
New SecurityAlertsView (/alerts route) surfaces the detections table built
in #10. Features:
- All / Unacknowledged tab filter with live counts
- Label dropdown (SECURITY_ANOMALY, SYSTEM_FAILURE, NETWORK_ANOMALY, etc.)
- Score confidence bar per detection (colour-coded by threshold)
- Acknowledge drawer: full log text, optional notes, in-place row dim on save
- Scorer status badge + manual "Run scorer" button
- Config warning when TURNSTONE_ANOMALY_MODEL is unset
Dashboard: new "Unreviewed Alerts" stat card (red border when > 0) links
to /alerts so alerts surface on the landing page without navigating away.
Closes: https://git.opensourcesolarpunk.com/Circuit-Forge/turnstone/issues/11
---
web/src/App.vue | 1 +
web/src/router/index.ts | 2 +
web/src/views/DashboardView.vue | 22 +-
web/src/views/SecurityAlertsView.vue | 458 +++++++++++++++++++++++++++
4 files changed, 482 insertions(+), 1 deletion(-)
create mode 100644 web/src/views/SecurityAlertsView.vue
diff --git a/web/src/App.vue b/web/src/App.vue
index 914984d..f6a1b48 100644
--- a/web/src/App.vue
+++ b/web/src/App.vue
@@ -76,6 +76,7 @@ const navLinks = [
{ to: '/search', label: 'Search' },
{ to: '/diagnose', label: 'Diagnose' },
{ to: '/incidents', label: 'Incidents' },
+ { to: '/alerts', label: 'Alerts' },
{ to: '/bundles', label: 'Bundles' },
{ to: '/sources', label: 'Sources' },
{ to: '/context', label: 'Context' },
diff --git a/web/src/router/index.ts b/web/src/router/index.ts
index b2c7f97..e5bba57 100644
--- a/web/src/router/index.ts
+++ b/web/src/router/index.ts
@@ -8,6 +8,7 @@ import BundlesView from '@/views/BundlesView.vue'
import SettingsView from '@/views/SettingsView.vue'
import ContextView from '@/views/ContextView.vue'
import BlocklistView from '@/views/BlocklistView.vue'
+import SecurityAlertsView from '@/views/SecurityAlertsView.vue'
export default createRouter({
history: createWebHistory(import.meta.env.BASE_URL),
@@ -17,6 +18,7 @@ export default createRouter({
{ path: '/search', component: LogSearchView },
{ path: '/diagnose', component: DiagnoseView },
{ path: '/incidents', component: IncidentsView },
+ { path: '/alerts', component: SecurityAlertsView },
{ path: '/bundles', component: BundlesView },
{ path: '/sources', component: SourcesView },
{ path: '/context', component: ContextView },
diff --git a/web/src/views/DashboardView.vue b/web/src/views/DashboardView.vue
index 98a9c4f..3d6a73a 100644
--- a/web/src/views/DashboardView.vue
+++ b/web/src/views/DashboardView.vue
@@ -52,6 +52,16 @@
{{ incidentsLoading ? '…' : activeIncidents }}
+
+ Unreviewed Alerts
+
+ {{ alertsLoading ? '…' : unackedAlerts }}
+
+
@@ -201,6 +211,8 @@ const loading = ref(true)
const incidents = ref([])
const incidentsLoading = ref(true)
const watchSources = ref([])
+const unackedAlerts = ref(0)
+const alertsLoading = ref(true)
const activeIncidents = computed(() =>
incidents.value.filter(i => !i.ended_at).length
@@ -217,7 +229,7 @@ const isStale = computed(() => {
})
onMounted(async () => {
- await Promise.all([loadStats(), loadIncidents(), loadWatchStatus()])
+ await Promise.all([loadStats(), loadIncidents(), loadWatchStatus(), loadAlertCount()])
})
async function loadStats() {
@@ -245,6 +257,14 @@ async function loadWatchStatus() {
} catch { /* non-critical */ }
}
+async function loadAlertCount() {
+ try {
+ const res = await fetch(`${BASE}/turnstone/api/anomaly/detections?unacked_only=true&limit=1000`)
+ if (res.ok) unackedAlerts.value = (await res.json()).total ?? 0
+ } catch { /* non-critical — scorer may be disabled */ }
+ finally { alertsLoading.value = false }
+}
+
function healthDot(errors: number, total: number): string {
if (errors === 0) return 'bg-green-500'
const ratio = errors / Math.max(total, 1)
diff --git a/web/src/views/SecurityAlertsView.vue b/web/src/views/SecurityAlertsView.vue
new file mode 100644
index 0000000..7ac5361
--- /dev/null
+++ b/web/src/views/SecurityAlertsView.vue
@@ -0,0 +1,458 @@
+
+
+
+
+
+
+
Security Alerts
+
+ Anomaly detections from the scoring pipeline.
+ Acknowledge entries after review to track your triage state.
+
+
+
+
+
+
+
+ {{ scorerStatus.running ? 'scoring…' : scorerStatus.enabled ? 'scorer ready' : 'scorer off' }}
+
+
+
+
+
+
+
+
+ Anomaly scoring is disabled — set TURNSTONE_ANOMALY_MODEL
+ in your .env and restart Turnstone.
+
+
+
+
+ Total scored: {{ scorerStatus.total_scored ?? '—' }}
+ Total detections: {{ scorerStatus.total_detections ?? '—' }}
+
+ Last run: {{ formatTs(scorerStatus.last_run_at) }}
+
+
+ Last error: {{ scorerStatus.last_error }}
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
Loading…
+
+
+
+
No unacknowledged detections — all clear.
+
Enable anomaly scoring to start detecting.
+
No detections yet. Run the scorer after gleaning to populate this list.
+
+
+
+
+
+
+
+
+ | Sev |
+ Label |
+ Score |
+ Source |
+ Log entry |
+ Detected |
+ |
+
+
+
+
+ |
+
+ {{ det.severity }}
+
+ |
+
+
+ {{ det.anomaly_label }}
+
+ |
+
+
+
+ {{ Math.round(det.anomaly_score * 100) }}%
+
+ |
+ {{ det.source_id }} |
+ {{ det.text }} |
+ {{ formatTs(det.detected_at) }} |
+
+ reviewed
+
+ |
+
+
+
+
+
+
+
+
+
+
+
+
+ {{ drawer.severity }}
+
+ {{ drawer.anomaly_label }}
+
+ {{ Math.round(drawer.anomaly_score * 100) }}% confidence
+
+
+ source: {{ drawer.source_id }}
+ · {{ formatTs(drawer.timestamp_iso) }}
+
+
+
+
+
+
+
+ {{ drawer.text }}
+
+
+
+
+
Acknowledged {{ formatTs(drawer.acknowledged_at) }}
+
{{ drawer.notes }}
+
+
+
+
+
+
+
+
+
+ {{ ackError }}
+
+
+
+
+
+
+
+
+
+
+
From cffe6bcd3104093c034bfbbbd50ba097ebbd3ea2 Mon Sep 17 00:00:00 2001
From: pyr0ball
Date: Wed, 10 Jun 2026 01:03:25 -0700
Subject: [PATCH 05/17] feat: cybersec zero-shot scoring pipeline (#9)
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit
Second-pass cybersec classifier using DeBERTa-v3-base-mnli (already
cached — no download required). Runs after each anomaly scoring pass on
entries flagged by the anomaly scorer or with pattern matches.
Architecture:
- app/services/cybersec.py: zero-shot-classification pipeline with 5
cybersec candidate labels (auth failure, privilege escalation, network
intrusion, malware, data exfiltration). Writes ml_score/ml_label/
ml_scored_at to log_entries; inserts high-confidence hits into
detections with scorer='cybersec'.
- app/tasks/cybersec_scorer.py: async background task (same shape as
anomaly_scorer.py).
- REST: GET/POST /turnstone/api/cybersec/status|run|detections.
GET /turnstone/api/anomaly/detections now accepts scorer= filter.
Schema: ml_score, ml_label, ml_scored_at added to log_entries; scorer
column added to detections (idempotent migrations + DDL for both SQLite
and Postgres).
UI: Security Alerts view gains Source dropdown (All / Anomaly / Cybersec)
and cybersec scorer status badge. Label dropdown split into optgroups.
Deployment: TURNSTONE_CYBERSEC_MODEL/DEVICE/THRESHOLD vars added to
.env.example, docker-compose.yml, docker-standalone.sh.
Tests: 10 new tests — no model, no eligible entries, scoring, detection
creation, normal label suppression, threshold filtering, pattern-tag
filtering, idempotency, list filtering, scorer column filter.
416/416 passing.
Closes: https://git.opensourcesolarpunk.com/Circuit-Forge/turnstone/issues/9
---
.env.example | 9 +
app/db/schema.py | 20 ++-
app/rest.py | 61 ++++++-
app/services/anomaly.py | 4 +
app/services/cybersec.py | 241 +++++++++++++++++++++++++++
app/tasks/cybersec_scorer.py | 84 ++++++++++
app/tasks/glean_scheduler.py | 17 +-
docker-compose.yml | 4 +
docker-standalone.sh | 3 +
tests/test_cybersec.py | 233 ++++++++++++++++++++++++++
web/src/views/SecurityAlertsView.vue | 65 +++++++-
11 files changed, 730 insertions(+), 11 deletions(-)
create mode 100644 app/services/cybersec.py
create mode 100644 app/tasks/cybersec_scorer.py
create mode 100644 tests/test_cybersec.py
diff --git a/.env.example b/.env.example
index 97b21e6..2c1da08 100644
--- a/.env.example
+++ b/.env.example
@@ -42,6 +42,15 @@
# TURNSTONE_EMBED_MODEL=BAAI/bge-small-en-v1.5
# TURNSTONE_EMBED_DEVICE=cpu
+# --- Cybersec scoring pipeline (zero-shot, second-pass on flagged entries) ---
+# Runs a zero-shot classifier on entries already flagged by the anomaly scorer
+# or that have pattern matches — a focused second opinion using cybersec vocabulary.
+# The DeBERTa-v3-base-mnli model (required by the diagnose pipeline) is the recommended
+# zero-shot classifier — it produces human-readable cybersec labels with no fine-tuning.
+# TURNSTONE_CYBERSEC_MODEL=MoritzLaurer/DeBERTa-v3-base-mnli-fever-anli
+# TURNSTONE_CYBERSEC_DEVICE=cpu
+# TURNSTONE_CYBERSEC_THRESHOLD=0.60 # lower than anomaly threshold (zero-shot is calibrated differently)
+
# --- Anomaly scoring pipeline (IDS / watchdog) ---
# Batch-scores every ingested log entry after each glean cycle.
# Any HuggingFace text-classification model works; the byviz classifier (already
diff --git a/app/db/schema.py b/app/db/schema.py
index 0e9ad2f..311a321 100644
--- a/app/db/schema.py
+++ b/app/db/schema.py
@@ -38,6 +38,9 @@ CREATE TABLE IF NOT EXISTS log_entries (
anomaly_score REAL,
anomaly_label TEXT,
anomaly_scored_at TEXT,
+ ml_score REAL,
+ ml_label TEXT,
+ ml_scored_at TEXT,
PRIMARY KEY (tenant_id, id)
);
CREATE INDEX IF NOT EXISTS idx_source ON log_entries(source_id);
@@ -47,6 +50,7 @@ CREATE INDEX IF NOT EXISTS idx_ts_repeat ON log_entries(timestamp_iso, repeat_
CREATE INDEX IF NOT EXISTS idx_severity ON log_entries(tenant_id, severity);
CREATE INDEX IF NOT EXISTS idx_patterns ON log_entries(matched_patterns);
CREATE INDEX IF NOT EXISTS idx_anomaly ON log_entries(tenant_id, anomaly_score);
+CREATE INDEX IF NOT EXISTS idx_ml_scored ON log_entries(tenant_id, ml_scored_at);
CREATE TABLE IF NOT EXISTS detections (
id TEXT PRIMARY KEY,
@@ -61,12 +65,14 @@ CREATE TABLE IF NOT EXISTS detections (
detected_at TEXT NOT NULL,
acknowledged INTEGER NOT NULL DEFAULT 0,
acknowledged_at TEXT,
- notes TEXT NOT NULL DEFAULT ''
+ notes TEXT NOT NULL DEFAULT '',
+ scorer TEXT NOT NULL DEFAULT 'anomaly'
);
CREATE INDEX IF NOT EXISTS idx_detections_tenant ON detections(tenant_id, detected_at);
CREATE INDEX IF NOT EXISTS idx_detections_ack ON detections(acknowledged);
CREATE INDEX IF NOT EXISTS idx_detections_label ON detections(anomaly_label);
CREATE INDEX IF NOT EXISTS idx_detections_entry ON detections(entry_id);
+CREATE INDEX IF NOT EXISTS idx_detections_scorer ON detections(scorer);
CREATE TABLE IF NOT EXISTS glean_fingerprints (
tenant_id TEXT NOT NULL DEFAULT '',
@@ -201,6 +207,9 @@ _MAIN_SCHEMA_PG_STMTS = [
anomaly_score DOUBLE PRECISION,
anomaly_label TEXT,
anomaly_scored_at TEXT,
+ ml_score DOUBLE PRECISION,
+ ml_label TEXT,
+ ml_scored_at TEXT,
PRIMARY KEY (tenant_id, id)
)
""",
@@ -210,6 +219,7 @@ _MAIN_SCHEMA_PG_STMTS = [
"CREATE INDEX IF NOT EXISTS idx_patterns ON log_entries(matched_patterns)",
"CREATE INDEX IF NOT EXISTS idx_fts_gin ON log_entries USING GIN(text_tsv)",
"CREATE INDEX IF NOT EXISTS idx_anomaly ON log_entries(tenant_id, anomaly_score)",
+ "CREATE INDEX IF NOT EXISTS idx_ml_scored ON log_entries(tenant_id, ml_scored_at)",
"""
CREATE TABLE IF NOT EXISTS detections (
id TEXT PRIMARY KEY,
@@ -224,13 +234,15 @@ _MAIN_SCHEMA_PG_STMTS = [
detected_at TEXT NOT NULL,
acknowledged INTEGER NOT NULL DEFAULT 0,
acknowledged_at TEXT,
- notes TEXT NOT NULL DEFAULT ''
+ notes TEXT NOT NULL DEFAULT '',
+ scorer TEXT NOT NULL DEFAULT 'anomaly'
)
""",
"CREATE INDEX IF NOT EXISTS idx_detections_tenant ON detections(tenant_id, detected_at)",
"CREATE INDEX IF NOT EXISTS idx_detections_ack ON detections(acknowledged)",
"CREATE INDEX IF NOT EXISTS idx_detections_label ON detections(anomaly_label)",
"CREATE INDEX IF NOT EXISTS idx_detections_entry ON detections(entry_id)",
+ "CREATE INDEX IF NOT EXISTS idx_detections_scorer ON detections(scorer)",
"""
CREATE OR REPLACE FUNCTION _ts_update_text_tsv() RETURNS trigger AS $$
BEGIN
@@ -388,6 +400,10 @@ _MAIN_MIGRATIONS_SQLITE = [
"ALTER TABLE log_entries ADD COLUMN anomaly_score REAL",
"ALTER TABLE log_entries ADD COLUMN anomaly_label TEXT",
"ALTER TABLE log_entries ADD COLUMN anomaly_scored_at TEXT",
+ "ALTER TABLE log_entries ADD COLUMN ml_score REAL",
+ "ALTER TABLE log_entries ADD COLUMN ml_label TEXT",
+ "ALTER TABLE log_entries ADD COLUMN ml_scored_at TEXT",
+ "ALTER TABLE detections ADD COLUMN scorer TEXT NOT NULL DEFAULT 'anomaly'",
]
_CONTEXT_MIGRATIONS_SQLITE = [
diff --git a/app/rest.py b/app/rest.py
index d187979..a59ede9 100644
--- a/app/rest.py
+++ b/app/rest.py
@@ -89,7 +89,9 @@ from app.context.wizard import get_schema as _wizard_schema, advance_step, is_co
from app.context.chunker import UnsupportedDocType, FileTooLarge
from app.tasks.glean_scheduler import get_state as _glean_state, run_once as _run_glean, scheduler_loop as _scheduler_loop, submit_matched as _submit_matched
from app.tasks.anomaly_scorer import get_state as _scorer_state, run_once as _run_scorer
+from app.tasks.cybersec_scorer import get_state as _cybersec_state, run_once as _run_cybersec
from app.services.anomaly import list_detections as _list_detections, acknowledge_detection as _ack_detection
+from app.services.cybersec import list_cybersec_detections as _list_cybersec, CYBERSEC_LABELS
from app.glean.mqtt_subscriber import run_mqtt_subscribers as _run_mqtt_subscribers
DB_PATH = Path(os.environ.get("TURNSTONE_DB", Path(__file__).parent.parent / "data" / "turnstone.db"))
@@ -114,6 +116,9 @@ SUBMIT_ENDPOINT = os.environ.get("TURNSTONE_SUBMIT_ENDPOINT", "").rstrip("/")
ANOMALY_MODEL = os.environ.get("TURNSTONE_ANOMALY_MODEL", "")
ANOMALY_DEVICE = os.environ.get("TURNSTONE_ANOMALY_DEVICE", "cpu")
ANOMALY_THRESHOLD = float(os.environ.get("TURNSTONE_ANOMALY_THRESHOLD", "0.75"))
+CYBERSEC_MODEL = os.environ.get("TURNSTONE_CYBERSEC_MODEL", "")
+CYBERSEC_DEVICE = os.environ.get("TURNSTONE_CYBERSEC_DEVICE", "cpu")
+CYBERSEC_THRESHOLD = float(os.environ.get("TURNSTONE_CYBERSEC_THRESHOLD", "0.60"))
# When set, all /api/ routes require Authorization: Bearer .
# Unset (default) means no authentication — suitable for local-only deployments.
_API_KEY: str | None = os.environ.get("TURNSTONE_API_KEY") or None
@@ -173,6 +178,9 @@ async def _lifespan(app: FastAPI):
anomaly_model=ANOMALY_MODEL,
anomaly_device=ANOMALY_DEVICE,
anomaly_threshold=ANOMALY_THRESHOLD,
+ cybersec_model=CYBERSEC_MODEL,
+ cybersec_device=CYBERSEC_DEVICE,
+ cybersec_threshold=CYBERSEC_THRESHOLD,
),
name="glean-scheduler",
)
@@ -1362,11 +1370,12 @@ async def anomaly_detections(
limit: int = Query(100, ge=1, le=1000),
unacked_only: bool = Query(False),
label: str | None = Query(None),
+ scorer: str | None = Query(None),
):
- """List anomaly detections ordered by detected_at DESC."""
+ """List detections ordered by detected_at DESC. Optionally filter by scorer ('anomaly'|'cybersec')."""
loop = asyncio.get_running_loop()
rows = await loop.run_in_executor(
- None, lambda: _list_detections(DB_PATH, limit=limit, unacked_only=unacked_only, label=label)
+ None, lambda: _list_detections(DB_PATH, limit=limit, unacked_only=unacked_only, label=label, scorer=scorer)
)
return {"detections": rows, "total": len(rows)}
@@ -1386,6 +1395,54 @@ async def acknowledge_detection(detection_id: str, notes: str = ""):
app.include_router(_anomaly)
+# ---------------------------------------------------------------------------
+# Cybersec scoring endpoints
+# ---------------------------------------------------------------------------
+
+_cybersec_router = APIRouter(prefix="/turnstone/api/cybersec", dependencies=[Depends(_check_api_key)])
+
+
+@_cybersec_router.get("/status")
+async def cybersec_status():
+ """Return cybersec scorer state and configuration."""
+ return {
+ "model": CYBERSEC_MODEL or None,
+ "threshold": CYBERSEC_THRESHOLD,
+ "device": CYBERSEC_DEVICE,
+ "enabled": bool(CYBERSEC_MODEL),
+ "candidate_labels": CYBERSEC_LABELS,
+ **_cybersec_state(),
+ }
+
+
+@_cybersec_router.post("/run")
+async def cybersec_run(background_tasks: BackgroundTasks):
+ """Trigger a manual cybersec scoring pass (runs in background)."""
+ if not CYBERSEC_MODEL:
+ raise HTTPException(status_code=400, detail="TURNSTONE_CYBERSEC_MODEL not configured")
+ background_tasks.add_task(
+ _run_cybersec, DB_PATH, CYBERSEC_MODEL, CYBERSEC_DEVICE, 32, CYBERSEC_THRESHOLD
+ )
+ return {"ok": True, "message": "cybersec scorer triggered"}
+
+
+@_cybersec_router.get("/detections")
+async def cybersec_detections(
+ limit: int = Query(100, ge=1, le=1000),
+ unacked_only: bool = Query(False),
+ label: str | None = Query(None),
+):
+ """List cybersec detections ordered by detected_at DESC."""
+ loop = asyncio.get_running_loop()
+ rows = await loop.run_in_executor(
+ None, lambda: _list_cybersec(DB_PATH, limit=limit, unacked_only=unacked_only, label=label)
+ )
+ return {"detections": rows, "total": len(rows)}
+
+
+app.include_router(_cybersec_router)
+
+
# Root redirect → /turnstone/
@app.get("/")
def root_redirect() -> RedirectResponse:
diff --git a/app/services/anomaly.py b/app/services/anomaly.py
index 85e7317..4e525fe 100644
--- a/app/services/anomaly.py
+++ b/app/services/anomaly.py
@@ -253,6 +253,7 @@ def list_detections(
limit: int = 100,
unacked_only: bool = False,
label: str | None = None,
+ scorer: str | None = None,
) -> list[dict]:
"""Return detections ordered by detected_at DESC."""
tenant_id = resolve_tenant_id()
@@ -264,6 +265,9 @@ def list_detections(
if label:
conditions.append(q("anomaly_label = ?"))
params.append(label.upper())
+ if scorer:
+ conditions.append(q("scorer = ?"))
+ params.append(scorer.lower())
where = " AND ".join(conditions)
with get_conn(db_path) as conn:
diff --git a/app/services/cybersec.py b/app/services/cybersec.py
new file mode 100644
index 0000000..66fd893
--- /dev/null
+++ b/app/services/cybersec.py
@@ -0,0 +1,241 @@
+"""Cybersecurity-focused scoring pipeline using zero-shot classification.
+
+Runs a second-pass analysis on entries that were already flagged by the
+anomaly scorer or that have pattern matches. Uses a zero-shot classification
+model (DeBERTa-v3-base-mnli is cached locally) so no fine-tuning is needed.
+
+The scorer writes ml_score / ml_label / ml_scored_at to log_entries and
+inserts high-confidence non-normal hits into the detections table tagged
+with scorer='cybersec'.
+
+Env vars
+--------
+TURNSTONE_CYBERSEC_MODEL — HF model id for zero-shot classification.
+ Recommended: MoritzLaurer/DeBERTa-v3-base-mnli-fever-anli
+ (already cached from the diagnose pipeline).
+ Set to empty string to disable (safe default).
+TURNSTONE_CYBERSEC_DEVICE — 'cpu' (default) or 'cuda'
+TURNSTONE_CYBERSEC_THRESHOLD — float confidence floor for detection insertion (default 0.60)
+"""
+from __future__ import annotations
+
+import logging
+import uuid
+from dataclasses import dataclass, field
+from datetime import datetime, timezone
+from pathlib import Path
+from typing import Any
+
+from app.db import get_conn, resolve_tenant_id
+from app.db.dialect import q
+
+logger = logging.getLogger(__name__)
+
+# ---------------------------------------------------------------------------
+# Candidate labels — cybersec vocabulary for zero-shot inference
+# ---------------------------------------------------------------------------
+
+CYBERSEC_LABELS: list[str] = [
+ "authentication failure or brute force attack",
+ "privilege escalation or unauthorized access",
+ "network intrusion or port scan",
+ "malware or suspicious process activity",
+ "data exfiltration or unusual outbound traffic",
+ "normal system operation",
+]
+
+_NORMAL_LABEL = "normal system operation"
+
+_LABEL_SEVERITY: dict[str, str] = {
+ "authentication failure or brute force attack": "ERROR",
+ "privilege escalation or unauthorized access": "CRITICAL",
+ "network intrusion or port scan": "ERROR",
+ "malware or suspicious process activity": "CRITICAL",
+ "data exfiltration or unusual outbound traffic":"CRITICAL",
+ "normal system operation": "INFO",
+}
+
+# ---------------------------------------------------------------------------
+# Pipeline singleton
+# ---------------------------------------------------------------------------
+
+_pipeline: Any = None
+
+
+def _get_pipeline(model_id: str, device: str) -> Any:
+ global _pipeline # noqa: PLW0603
+ if _pipeline is None:
+ from transformers import pipeline # type: ignore[import-untyped]
+ logger.info("loading cybersec zero-shot pipeline: %s on %s", model_id, device)
+ _pipeline = pipeline(
+ "zero-shot-classification",
+ model=model_id,
+ device=0 if device == "cuda" else -1,
+ )
+ logger.info("cybersec pipeline ready")
+ return _pipeline
+
+
+def reset_pipeline() -> None:
+ """Clear the cached pipeline — for testing only."""
+ global _pipeline # noqa: PLW0603
+ _pipeline = None
+
+
+# ---------------------------------------------------------------------------
+# Result type
+# ---------------------------------------------------------------------------
+
+@dataclass
+class CybersecResult:
+ scored: int = 0
+ detections: int = 0
+ skipped: bool = False
+ error: str | None = None
+
+
+# ---------------------------------------------------------------------------
+# Core scoring function
+# ---------------------------------------------------------------------------
+
+def score_security_entries(
+ db_path: Path,
+ model_id: str,
+ device: str = "cpu",
+ batch_size: int = 32,
+ threshold: float = 0.60,
+) -> CybersecResult:
+ """Score entries that were anomaly-flagged or pattern-matched.
+
+ Only entries with ml_scored_at IS NULL are processed (idempotent).
+ Writes ml_score / ml_label / ml_scored_at and inserts high-confidence
+ hits into detections with scorer='cybersec'.
+ """
+ if not model_id:
+ return CybersecResult(skipped=True)
+
+ tenant_id = resolve_tenant_id()
+ try:
+ pipe = _get_pipeline(model_id, device)
+ except Exception as exc:
+ logger.error("failed to load cybersec pipeline: %s", exc)
+ return CybersecResult(error=str(exc))
+
+ total_scored = 0
+ total_detections = 0
+
+ try:
+ with get_conn(db_path) as conn:
+ # Only score entries that are worth a second look:
+ # anomaly-flagged (non-normal) OR have at least one pattern match.
+ rows = conn.execute(
+ q("""
+ SELECT id, source_id, text, timestamp_iso
+ FROM log_entries
+ WHERE (tenant_id = ? OR tenant_id = '')
+ AND ml_scored_at IS NULL
+ AND (
+ (anomaly_label IS NOT NULL AND anomaly_label != 'NORMAL')
+ OR (matched_patterns IS NOT NULL AND matched_patterns != '[]' AND matched_patterns != '')
+ )
+ LIMIT ?
+ """),
+ (tenant_id, batch_size * 10),
+ ).fetchall()
+
+ if not rows:
+ return CybersecResult(skipped=True)
+
+ # Process in chunks to avoid OOM on large backlogs
+ for i in range(0, len(rows), batch_size):
+ chunk = rows[i : i + batch_size]
+ texts = [r["text"] for r in chunk]
+
+ try:
+ results = pipe(texts, candidate_labels=CYBERSEC_LABELS, multi_label=False)
+ except Exception as exc:
+ logger.warning("zero-shot inference error on chunk %d: %s", i, exc)
+ continue
+
+ now = datetime.now(tz=timezone.utc).isoformat()
+
+ with get_conn(db_path) as conn:
+ for row, result in zip(chunk, results):
+ top_label: str = result["labels"][0]
+ top_score: float = result["scores"][0]
+
+ conn.execute(
+ q("""
+ UPDATE log_entries
+ SET ml_score = ?, ml_label = ?, ml_scored_at = ?
+ WHERE id = ? AND (tenant_id = ? OR tenant_id = '')
+ """),
+ (top_score, top_label, now, row["id"], tenant_id),
+ )
+ total_scored += 1
+
+ if top_score >= threshold and top_label != _NORMAL_LABEL:
+ severity = _LABEL_SEVERITY.get(top_label, "WARN")
+ try:
+ conn.execute(
+ q("""
+ INSERT INTO detections
+ (id, tenant_id, entry_id, source_id, anomaly_label,
+ anomaly_score, severity, text, timestamp_iso,
+ detected_at, scorer)
+ VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, 'cybersec')
+ """),
+ (
+ str(uuid.uuid4()),
+ tenant_id,
+ row["id"],
+ row["source_id"],
+ top_label,
+ top_score,
+ severity,
+ row["text"],
+ row["timestamp_iso"],
+ now,
+ ),
+ )
+ total_detections += 1
+ except Exception:
+ pass # entry may already have a detection — skip
+
+ conn.commit()
+
+ except Exception as exc:
+ logger.error("cybersec scoring failed: %s", exc)
+ return CybersecResult(scored=total_scored, detections=total_detections, error=str(exc))
+
+ return CybersecResult(scored=total_scored, detections=total_detections)
+
+
+# ---------------------------------------------------------------------------
+# Query helpers (used by REST layer)
+# ---------------------------------------------------------------------------
+
+def list_cybersec_detections(
+ db_path: Path,
+ limit: int = 100,
+ unacked_only: bool = False,
+ label: str | None = None,
+) -> list[dict]:
+ """Return cybersec detections ordered by detected_at DESC."""
+ tenant_id = resolve_tenant_id()
+ conditions = ["(tenant_id = ? OR tenant_id = '')", "scorer = 'cybersec'"]
+ params: list[Any] = [tenant_id]
+
+ if unacked_only:
+ conditions.append("acknowledged = 0")
+ if label:
+ conditions.append(q("anomaly_label = ?"))
+ params.append(label)
+
+ where = " AND ".join(conditions)
+ with get_conn(db_path) as conn:
+ rows = conn.execute(
+ q(f"SELECT * FROM detections WHERE {where} ORDER BY detected_at DESC LIMIT ?"), # noqa: S608
+ (*params, limit),
+ ).fetchall()
+ return [dict(r) for r in rows]
diff --git a/app/tasks/cybersec_scorer.py b/app/tasks/cybersec_scorer.py
new file mode 100644
index 0000000..6b3ca4c
--- /dev/null
+++ b/app/tasks/cybersec_scorer.py
@@ -0,0 +1,84 @@
+"""Background task wrapper for the cybersec zero-shot scoring pipeline."""
+from __future__ import annotations
+
+import asyncio
+import logging
+from dataclasses import dataclass, field
+from datetime import datetime, timezone
+from pathlib import Path
+
+from app.services.cybersec import score_security_entries
+
+logger = logging.getLogger(__name__)
+
+_lock = asyncio.Lock()
+
+
+@dataclass
+class CybersecState:
+ last_run_at: str | None = None
+ last_duration_s: float | None = None
+ last_scored: int = 0
+ last_detections: int = 0
+ last_error: str | None = None
+ run_count: int = 0
+ running: bool = False
+ total_scored: int = 0
+ total_detections: int = 0
+
+
+_state = CybersecState()
+
+
+def get_state() -> dict:
+ return {
+ "last_run_at": _state.last_run_at,
+ "last_duration_s":_state.last_duration_s,
+ "last_scored": _state.last_scored,
+ "last_detections":_state.last_detections,
+ "last_error": _state.last_error,
+ "run_count": _state.run_count,
+ "running": _state.running,
+ "total_scored": _state.total_scored,
+ "total_detections": _state.total_detections,
+ }
+
+
+async def run_once(
+ db_path: Path,
+ model_id: str,
+ device: str = "cpu",
+ batch_size: int = 32,
+ threshold: float = 0.60,
+) -> None:
+ """Single cybersec scoring pass — no-op if already running or no model set."""
+ if not model_id or _lock.locked():
+ return
+
+ async with _lock:
+ _state.running = True
+ started = datetime.now(tz=timezone.utc)
+ try:
+ loop = asyncio.get_running_loop()
+ result = await loop.run_in_executor(
+ None,
+ lambda: score_security_entries(db_path, model_id, device, batch_size, threshold),
+ )
+ elapsed = (datetime.now(tz=timezone.utc) - started).total_seconds()
+ _state.last_run_at = started.isoformat()
+ _state.last_duration_s = elapsed
+ _state.last_scored = result.scored
+ _state.last_detections = result.detections
+ _state.last_error = result.error
+ _state.run_count += 1
+ _state.total_scored += result.scored
+ _state.total_detections += result.detections
+ if result.error:
+ logger.error("cybersec scorer error: %s", result.error)
+ elif not result.skipped:
+ logger.info(
+ "cybersec scorer: scored=%d detections=%d in %.1fs",
+ result.scored, result.detections, elapsed,
+ )
+ finally:
+ _state.running = False
diff --git a/app/tasks/glean_scheduler.py b/app/tasks/glean_scheduler.py
index 7322158..fa05040 100644
--- a/app/tasks/glean_scheduler.py
+++ b/app/tasks/glean_scheduler.py
@@ -21,6 +21,7 @@ import httpx
from app.glean.pipeline import glean_sources
from app.tasks.anomaly_scorer import run_once as _run_scorer
+from app.tasks.cybersec_scorer import run_once as _run_cybersec
logger = logging.getLogger(__name__)
@@ -127,6 +128,9 @@ async def run_once(
anomaly_model: str = "",
anomaly_device: str = "cpu",
anomaly_threshold: float = 0.75,
+ cybersec_model: str = "",
+ cybersec_device: str = "cpu",
+ cybersec_threshold: float = 0.60,
) -> dict[str, Any]:
"""Ingest all sources once, then submit matched entries if configured.
@@ -170,6 +174,9 @@ async def run_once(
if anomaly_model:
await _run_scorer(db_path, anomaly_model, anomaly_device, threshold=anomaly_threshold)
+ if cybersec_model:
+ await _run_cybersec(db_path, cybersec_model, cybersec_device, threshold=cybersec_threshold)
+
return {"ok": True, "stats": _state.last_stats, "duration_s": _state.last_duration_s}
@@ -183,19 +190,27 @@ async def scheduler_loop(
anomaly_model: str = "",
anomaly_device: str = "cpu",
anomaly_threshold: float = 0.75,
+ cybersec_model: str = "",
+ cybersec_device: str = "cpu",
+ cybersec_threshold: float = 0.60,
) -> None:
- """Run glean + optional submission + optional anomaly scoring every interval_s seconds."""
+ """Run glean + optional submission + optional anomaly/cybersec scoring every interval_s seconds."""
logger.info("Ingest scheduler started — interval %ds, sources: %s", interval_s, sources_file)
if submit_endpoint:
logger.info("Submission enabled — endpoint: %s", submit_endpoint)
if anomaly_model:
logger.info("Anomaly scoring enabled — model: %s", anomaly_model)
+ if cybersec_model:
+ logger.info("Cybersec scoring enabled — model: %s", cybersec_model)
while True:
await run_once(
sources_file, db_path, pattern_file, submit_endpoint, source_host,
anomaly_model=anomaly_model,
anomaly_device=anomaly_device,
anomaly_threshold=anomaly_threshold,
+ cybersec_model=cybersec_model,
+ cybersec_device=cybersec_device,
+ cybersec_threshold=cybersec_threshold,
)
next_run = datetime.now(tz=timezone.utc) + timedelta(seconds=interval_s)
_state.next_run_at = next_run.isoformat()
diff --git a/docker-compose.yml b/docker-compose.yml
index d197bc1..2e064a4 100644
--- a/docker-compose.yml
+++ b/docker-compose.yml
@@ -47,6 +47,10 @@ services:
TURNSTONE_EMBED_BACKEND: ${TURNSTONE_EMBED_BACKEND:-}
TURNSTONE_EMBED_MODEL: ${TURNSTONE_EMBED_MODEL:-}
TURNSTONE_EMBED_DEVICE: ${TURNSTONE_EMBED_DEVICE:-cpu}
+ # --- Cybersec scoring pipeline ---
+ TURNSTONE_CYBERSEC_MODEL: ${TURNSTONE_CYBERSEC_MODEL:-}
+ TURNSTONE_CYBERSEC_DEVICE: ${TURNSTONE_CYBERSEC_DEVICE:-cpu}
+ TURNSTONE_CYBERSEC_THRESHOLD: ${TURNSTONE_CYBERSEC_THRESHOLD:-0.60}
# --- Anomaly scoring pipeline ---
TURNSTONE_ANOMALY_MODEL: ${TURNSTONE_ANOMALY_MODEL:-}
TURNSTONE_ANOMALY_DEVICE: ${TURNSTONE_ANOMALY_DEVICE:-cpu}
diff --git a/docker-standalone.sh b/docker-standalone.sh
index 8d45406..fece648 100755
--- a/docker-standalone.sh
+++ b/docker-standalone.sh
@@ -147,6 +147,9 @@ docker run -d \
-e TURNSTONE_EMBED_BACKEND="${TURNSTONE_EMBED_BACKEND:-sentence_transformers}" \
-e TURNSTONE_EMBED_MODEL="${TURNSTONE_EMBED_MODEL:-sentence-transformers/all-MiniLM-L6-v2}" \
-e TURNSTONE_EMBED_DEVICE="${TURNSTONE_EMBED_DEVICE:-cpu}" \
+ -e TURNSTONE_CYBERSEC_MODEL="${TURNSTONE_CYBERSEC_MODEL:-}" \
+ -e TURNSTONE_CYBERSEC_DEVICE="${TURNSTONE_CYBERSEC_DEVICE:-cpu}" \
+ -e TURNSTONE_CYBERSEC_THRESHOLD="${TURNSTONE_CYBERSEC_THRESHOLD:-0.60}" \
-e TURNSTONE_ANOMALY_MODEL="${TURNSTONE_ANOMALY_MODEL:-}" \
-e TURNSTONE_ANOMALY_DEVICE="${TURNSTONE_ANOMALY_DEVICE:-cpu}" \
-e TURNSTONE_ANOMALY_THRESHOLD="${TURNSTONE_ANOMALY_THRESHOLD:-0.75}" \
diff --git a/tests/test_cybersec.py b/tests/test_cybersec.py
new file mode 100644
index 0000000..8f4f99a
--- /dev/null
+++ b/tests/test_cybersec.py
@@ -0,0 +1,233 @@
+"""Tests for the cybersec zero-shot scoring pipeline."""
+from __future__ import annotations
+
+import sqlite3
+import tempfile
+from pathlib import Path
+from unittest.mock import MagicMock
+
+import pytest
+
+from app.db.schema import ensure_schema
+from app.services.cybersec import (
+ CybersecResult,
+ CYBERSEC_LABELS,
+ _NORMAL_LABEL,
+ reset_pipeline,
+ score_security_entries,
+ list_cybersec_detections,
+)
+import app.services.cybersec as cybersec_mod
+
+
+@pytest.fixture(autouse=True)
+def _reset(tmp_path):
+ reset_pipeline()
+ yield
+ reset_pipeline()
+
+
+@pytest.fixture
+def db(tmp_path) -> Path:
+ path = tmp_path / "test.db"
+ ensure_schema(path)
+ return path
+
+
+def _insert_entry(db: Path, entry_id: str, text: str,
+ anomaly_label: str | None = None,
+ matched_patterns: str = "[]") -> None:
+ with sqlite3.connect(db) as conn:
+ conn.execute(
+ """INSERT OR IGNORE INTO log_entries
+ (id, tenant_id, source_id, sequence, ingest_time, text,
+ anomaly_label, matched_patterns)
+ VALUES (?, '', 'test-src', 1, '2026-01-01T00:00:00Z', ?, ?, ?)""",
+ (entry_id, text, anomaly_label, matched_patterns),
+ )
+ conn.commit()
+
+
+# ---------------------------------------------------------------------------
+# No model configured → skipped
+# ---------------------------------------------------------------------------
+
+def test_no_model_returns_skipped(db):
+ result = score_security_entries(db, model_id="")
+ assert result.skipped is True
+ assert result.scored == 0
+
+
+# ---------------------------------------------------------------------------
+# No eligible entries → skipped
+# ---------------------------------------------------------------------------
+
+def test_no_eligible_entries_skipped(db):
+ _insert_entry(db, "e1", "Started nginx.service", anomaly_label=None, matched_patterns="[]")
+ mock_pipe = MagicMock(return_value=[{"labels": [_NORMAL_LABEL], "scores": [0.99]}])
+ monkeypatch = pytest.MonkeyPatch()
+ monkeypatch.setattr(cybersec_mod, "_pipeline", mock_pipe)
+ result = score_security_entries(db, model_id="fake-model")
+ assert result.skipped is True
+ monkeypatch.undo()
+
+
+# ---------------------------------------------------------------------------
+# Security entry gets scored
+# ---------------------------------------------------------------------------
+
+def test_security_entry_scored(db, monkeypatch):
+ _insert_entry(db, "e1",
+ "Failed password for root from 192.168.1.1 port 22 ssh2",
+ anomaly_label="SECURITY_ANOMALY")
+
+ mock_pipe = MagicMock(return_value=[{
+ "labels": ["authentication failure or brute force attack", _NORMAL_LABEL],
+ "scores": [0.85, 0.15],
+ }])
+ monkeypatch.setattr(cybersec_mod, "_pipeline", mock_pipe)
+
+ result = score_security_entries(db, model_id="fake-model", threshold=0.70)
+ assert result.scored == 1
+ assert result.detections == 1
+ assert result.error is None
+
+ with sqlite3.connect(db) as conn:
+ conn.row_factory = sqlite3.Row
+ row = conn.execute("SELECT ml_score, ml_label, ml_scored_at FROM log_entries WHERE id='e1'").fetchone()
+ assert row["ml_score"] == pytest.approx(0.85)
+ assert row["ml_label"] == "authentication failure or brute force attack"
+ assert row["ml_scored_at"] is not None
+
+
+# ---------------------------------------------------------------------------
+# Detection created above threshold
+# ---------------------------------------------------------------------------
+
+def test_detection_inserted_above_threshold(db, monkeypatch):
+ _insert_entry(db, "e1", "sudo: authentication failure", anomaly_label="ERROR")
+
+ monkeypatch.setattr(cybersec_mod, "_pipeline", MagicMock(return_value=[{
+ "labels": ["privilege escalation or unauthorized access", _NORMAL_LABEL],
+ "scores": [0.75, 0.25],
+ }]))
+
+ score_security_entries(db, model_id="fake-model", threshold=0.60)
+
+ with sqlite3.connect(db) as conn:
+ conn.row_factory = sqlite3.Row
+ dets = conn.execute("SELECT * FROM detections WHERE scorer='cybersec'").fetchall()
+ assert len(dets) == 1
+ assert dets[0]["anomaly_label"] == "privilege escalation or unauthorized access"
+ assert dets[0]["severity"] == "CRITICAL"
+
+
+# ---------------------------------------------------------------------------
+# Normal label → no detection even above score threshold
+# ---------------------------------------------------------------------------
+
+def test_normal_label_no_detection(db, monkeypatch):
+ _insert_entry(db, "e1", "Started nginx.service", anomaly_label="INFO",
+ matched_patterns='["service_start"]')
+
+ monkeypatch.setattr(cybersec_mod, "_pipeline", MagicMock(return_value=[{
+ "labels": [_NORMAL_LABEL, "network intrusion or port scan"],
+ "scores": [0.95, 0.05],
+ }]))
+
+ result = score_security_entries(db, model_id="fake-model", threshold=0.60)
+ assert result.detections == 0
+
+
+# ---------------------------------------------------------------------------
+# Below threshold → scored but no detection
+# ---------------------------------------------------------------------------
+
+def test_below_threshold_no_detection(db, monkeypatch):
+ _insert_entry(db, "e1", "Some suspicious text", anomaly_label="WARN")
+
+ monkeypatch.setattr(cybersec_mod, "_pipeline", MagicMock(return_value=[{
+ "labels": ["network intrusion or port scan", _NORMAL_LABEL],
+ "scores": [0.45, 0.55],
+ }]))
+
+ result = score_security_entries(db, model_id="fake-model", threshold=0.60)
+ assert result.scored == 1
+ assert result.detections == 0
+
+
+# ---------------------------------------------------------------------------
+# Pattern-matched entry (not anomaly-flagged) still gets scored
+# ---------------------------------------------------------------------------
+
+def test_pattern_matched_entry_scored(db, monkeypatch):
+ _insert_entry(db, "e1", "SSH port forwarding conflict detected",
+ anomaly_label=None,
+ matched_patterns='["ssh_forward_conflict"]')
+
+ monkeypatch.setattr(cybersec_mod, "_pipeline", MagicMock(return_value=[{
+ "labels": ["network intrusion or port scan", _NORMAL_LABEL],
+ "scores": [0.70, 0.30],
+ }]))
+
+ result = score_security_entries(db, model_id="fake-model", threshold=0.60)
+ assert result.scored == 1
+ assert result.detections == 1
+
+
+# ---------------------------------------------------------------------------
+# Idempotency — re-run finds nothing unscored
+# ---------------------------------------------------------------------------
+
+def test_idempotent_rerun(db, monkeypatch):
+ _insert_entry(db, "e1", "Failed login", anomaly_label="ERROR")
+
+ monkeypatch.setattr(cybersec_mod, "_pipeline", MagicMock(return_value=[{
+ "labels": ["authentication failure or brute force attack"],
+ "scores": [0.80],
+ }]))
+
+ score_security_entries(db, model_id="fake-model", threshold=0.60)
+ result2 = score_security_entries(db, model_id="fake-model", threshold=0.60)
+ assert result2.skipped is True
+
+
+# ---------------------------------------------------------------------------
+# list_cybersec_detections filters to scorer='cybersec'
+# ---------------------------------------------------------------------------
+
+def test_list_cybersec_detections(db, monkeypatch):
+ _insert_entry(db, "e1", "Failed login", anomaly_label="ERROR")
+
+ monkeypatch.setattr(cybersec_mod, "_pipeline", MagicMock(return_value=[{
+ "labels": ["authentication failure or brute force attack"],
+ "scores": [0.90],
+ }]))
+ score_security_entries(db, model_id="fake-model", threshold=0.60)
+
+ rows = list_cybersec_detections(db)
+ assert len(rows) == 1
+ assert rows[0]["scorer"] == "cybersec"
+
+
+# ---------------------------------------------------------------------------
+# list_detections scorer filter (anomaly service)
+# ---------------------------------------------------------------------------
+
+def test_list_detections_scorer_filter(db, monkeypatch):
+ from app.services.anomaly import list_detections
+ _insert_entry(db, "e1", "Failed login", anomaly_label="ERROR")
+
+ monkeypatch.setattr(cybersec_mod, "_pipeline", MagicMock(return_value=[{
+ "labels": ["authentication failure or brute force attack"],
+ "scores": [0.90],
+ }]))
+ score_security_entries(db, model_id="fake-model", threshold=0.60)
+
+ all_dets = list_detections(db)
+ cybersec_dets = list_detections(db, scorer="cybersec")
+ anomaly_dets = list_detections(db, scorer="anomaly")
+
+ assert len(cybersec_dets) == 1
+ assert len(anomaly_dets) == 0
+ assert len(all_dets) >= 1
diff --git a/web/src/views/SecurityAlertsView.vue b/web/src/views/SecurityAlertsView.vue
index 7ac5361..5f71189 100644
--- a/web/src/views/SecurityAlertsView.vue
+++ b/web/src/views/SecurityAlertsView.vue
@@ -29,6 +29,20 @@
{{ scorerStatus.running ? 'scoring…' : scorerStatus.enabled ? 'scorer ready' : 'scorer off' }}
+
+
+ {{ cybersecStatus.enabled ? 'cybersec on' : 'cybersec off' }}
+
+