turnstone/app/rest.py
pyr0ball aafb4e2cad fix: separate context KB into own SQLite file to eliminate write-lock contention
context_facts, context_documents, and context_chunks now live in
turnstone-context.db (sibling of turnstone.db).  The glean scheduler
held write locks on the main DB long enough to cause 5-second timeout
failures on context fact inserts; separate files have independent WAL
write locks so they never contend.

Changes:
- pipeline.py: extract _CONTEXT_SCHEMA + ensure_context_schema()
- rest.py: CONTEXT_DB_PATH (TURNSTONE_CONTEXT_DB env var, defaults to
  sibling file); init via ensure_context_schema(); all context routes
  pass CONTEXT_DB_PATH; diagnose_stream receives context_db_path kwarg
- diagnose/__init__.py: diagnose_stream() accepts context_db_path
  (falls back to db_path for backward compat); retrieve_context uses it
- store.py: sqlite3.connect() timeout=30.0 — Python driver retry loop
  is independent of PRAGMA busy_timeout; needed for any remaining
  contention during test or single-file deployments

Closes: #42
2026-05-25 21:19:32 -07:00

1120 lines
39 KiB
Python

"""Turnstone REST API — serves REST API and Vue SPA under the /turnstone prefix.
All routes (API + static files) are mounted at /turnstone so the app works
identically whether accessed directly (http://host:8534/turnstone/) or through
Caddy (menagerie.circuitforge.tech/turnstone) without prefix stripping.
"""
from __future__ import annotations
import asyncio
import dataclasses
import hmac
import json
import os
import sqlite3
import tempfile
import urllib.error
import urllib.request
from contextlib import asynccontextmanager
from pathlib import Path
from typing import Annotated
import yaml
from fastapi import APIRouter, BackgroundTasks, FastAPI, HTTPException, Query, Request, UploadFile
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import FileResponse, RedirectResponse, StreamingResponse
from fastapi.staticfiles import StaticFiles
from pydantic import BaseModel
from app.glean.pipeline import ensure_schema, ensure_context_schema, glean_file as _glean_file, glean_ssh_source as _glean_ssh_source
from app.glean.base import load_compiled_patterns, now_iso
from app.glean.tautulli import parse_webhook as _parse_tautulli
from app.glean.wazuh import is_wazuh_alert as _is_wazuh_alert, parse as _parse_wazuh
from app.services.blocklist import (
BlocklistCandidate,
get_candidate,
list_candidates,
load_telemetry_rules,
mark_pushed,
mark_unblocked,
run_scan,
update_candidate_status,
)
from app.services.pihole import PiholeClient
from app.services.incidents import (
build_bundle,
create_incident,
delete_incident,
get_bundle,
get_incident,
get_incident_entries,
list_bundles,
list_incidents,
store_bundle,
)
from app.services.search import (
search as _search,
list_sources as _list_sources,
recent_source_errors as _source_errors,
stats_summary as _stats,
format_results,
build_fts_index,
)
from app.services.diagnose import diagnose as _diagnose, diagnose_stream as _diagnose_stream
from app.watch.watcher import Watcher, load_watch_config
from app.context.store import (
add_fact as _add_fact,
list_facts as _list_facts,
delete_fact as _delete_fact,
list_documents as _list_documents,
delete_document as _delete_document,
)
from app.context.retriever import retrieve_context as _retrieve_context, format_context_block
from app.glean.doc_upload import glean_upload as _glean_upload
from app.context.wizard import get_schema as _wizard_schema, advance_step, is_complete, apply_session
from app.context.chunker import UnsupportedDocType, FileTooLarge
from app.tasks.glean_scheduler import get_state as _glean_state, run_once as _run_glean, scheduler_loop as _scheduler_loop, submit_matched as _submit_matched
from app.glean.mqtt_subscriber import run_mqtt_subscribers as _run_mqtt_subscribers
DB_PATH = Path(os.environ.get("TURNSTONE_DB", Path(__file__).parent.parent / "data" / "turnstone.db"))
# Context KB gets its own file so context fact writes never contend with the
# high-throughput glean scheduler. Defaults to a sibling file next to the main DB.
CONTEXT_DB_PATH = Path(
os.environ.get("TURNSTONE_CONTEXT_DB", DB_PATH.parent / "turnstone-context.db")
)
PREFS_PATH = DB_PATH.parent / "preferences.json"
DIST_DIR = Path(__file__).parent.parent / "web" / "dist"
SOURCE_HOST = os.environ.get("TURNSTONE_SOURCE_HOST", "unknown")
BUNDLE_ENDPOINT = os.environ.get("TURNSTONE_BUNDLE_ENDPOINT", "")
PATTERN_DIR = Path(os.environ.get("TURNSTONE_PATTERNS", Path(__file__).parent.parent / "patterns"))
PATTERN_FILE = PATTERN_DIR / "default.yaml"
GLEAN_INTERVAL = int(os.environ.get("TURNSTONE_GLEAN_INTERVAL", "900"))
SUBMIT_ENDPOINT = os.environ.get("TURNSTONE_SUBMIT_ENDPOINT", "").rstrip("/")
# GPU inference server URL.
# Priority: GPU_SERVER_URL → CF_ORCH_URL (backward compat) → orch.circuitforge.tech (Paid+).
# Resolved value is written back to CF_ORCH_URL so cf-core callers see it automatically.
GPU_SERVER_URL: str | None = (
os.environ.get("GPU_SERVER_URL")
or os.environ.get("CF_ORCH_URL")
or (
"https://orch.circuitforge.tech"
if os.environ.get("CF_LICENSE_KEY")
else None
)
)
if GPU_SERVER_URL:
os.environ["CF_ORCH_URL"] = GPU_SERVER_URL
_watcher = Watcher(DB_PATH, PATTERN_FILE)
_compiled_patterns: list = []
@asynccontextmanager
async def _lifespan(app: FastAPI):
global _compiled_patterns
ensure_schema(DB_PATH)
ensure_context_schema(CONTEXT_DB_PATH)
_compiled_patterns = load_compiled_patterns(PATTERN_FILE)
watch_cfg_path = PATTERN_DIR / "watch.yaml"
configs = load_watch_config(watch_cfg_path)
if configs:
_watcher.configure(configs)
_watcher.start()
sources_file = PATTERN_DIR / "sources.yaml"
_scheduler_task: asyncio.Task | None = None
if GLEAN_INTERVAL > 0 and sources_file.exists():
_scheduler_task = asyncio.create_task(
_scheduler_loop(
sources_file, DB_PATH, PATTERN_FILE, GLEAN_INTERVAL,
submit_endpoint=SUBMIT_ENDPOINT or None,
source_host=SOURCE_HOST,
),
name="glean-scheduler",
)
_mqtt_task: asyncio.Task | None = None
if sources_file.exists():
_mqtt_task = asyncio.create_task(
_run_mqtt_subscribers(sources_file, DB_PATH),
name="mqtt-subscribers",
)
yield
_watcher.stop()
for task in (_scheduler_task, _mqtt_task):
if task:
task.cancel()
try:
await task
except asyncio.CancelledError:
pass
app = FastAPI(title="Turnstone API", version="0.5.0", docs_url="/turnstone/docs", redoc_url=None, lifespan=_lifespan)
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_methods=["GET", "POST", "DELETE", "PATCH"],
allow_headers=["*"],
)
_PREFS_DEFAULTS: dict = {
"entry_point_style": "topbar",
"llm_url": GPU_SERVER_URL or "http://localhost:11434",
"llm_model": "llama3.1:8b",
"llm_api_key": "",
"severity_overrides": [
{
"name": "PAM auth noise",
"pattern": r"pam_unix.*auth(?:entication)?\s+fail|auth could not identify",
"override_severity": "WARN",
"enabled": True,
}
],
}
def _load_prefs() -> dict[str, str]:
if PREFS_PATH.exists():
try:
saved = json.loads(PREFS_PATH.read_text())
return {**_PREFS_DEFAULTS, **saved}
except (json.JSONDecodeError, OSError):
pass
return dict(_PREFS_DEFAULTS)
def _save_prefs(data: dict[str, str]) -> None:
PREFS_PATH.write_text(json.dumps(data))
class DiagnoseRequest(BaseModel):
query: str
since: str | None = None
until: str | None = None
source: str | None = None
class SeverityOverride(BaseModel):
name: str
pattern: str
override_severity: str
enabled: bool = True
class SettingsBody(BaseModel):
entry_point_style: str | None = None
llm_url: str | None = None
llm_model: str | None = None
llm_api_key: str | None = None
tautulli_token: str | None = None
severity_overrides: list[SeverityOverride] | None = None
pihole_url: str | None = None
pihole_version: str | None = None
pihole_api_key: str | None = None
router_source_ids: str | None = None
device_names: str | None = None
class IncidentCreate(BaseModel):
label: str
issue_type: str = ""
started_at: str | None = None
ended_at: str | None = None
notes: str = ""
severity: str = "medium"
class FactBody(BaseModel):
category: str
key: str
value: str
source: str | None = None
class WizardStepBody(BaseModel):
session: dict
step_id: str
answer: str | list[str] | None = None
class WizardApplyBody(BaseModel):
session: dict
# Serve built Vue assets at the path Vite embeds in index.html.
if (DIST_DIR / "assets").exists():
app.mount("/turnstone/assets", StaticFiles(directory=str(DIST_DIR / "assets")), name="assets")
# API router — all routes accessible at /turnstone/api/* and /turnstone/health.
router = APIRouter(prefix="/turnstone")
@router.get("/health")
def health() -> dict:
return {"status": "ok", "db": str(DB_PATH)}
@router.get("/api/search")
def search_logs(
q: Annotated[str, Query(description="Search query")] = "",
source: Annotated[str | None, Query(description="Filter by log source ID (partial match)")] = None,
severity: Annotated[str | None, Query(description="Filter by severity (DEBUG/INFO/WARN/ERROR/CRITICAL)")] = None,
since: Annotated[str | None, Query(description="ISO timestamp lower bound")] = None,
until: Annotated[str | None, Query(description="ISO timestamp upper bound")] = None,
limit: Annotated[int, Query(ge=1, le=500)] = 50,
) -> dict:
if not q:
return {"count": 0, "results": []}
results = _search(
DB_PATH,
query=q,
source_filter=source,
severity=severity,
since=since,
until=until,
limit=limit,
)
return {"count": len(results), "results": [dataclasses.asdict(r) for r in results]}
@router.get("/api/diagnose")
def diagnose(
q: Annotated[str, Query(description="Service name or problem description")] = "",
source: Annotated[str | None, Query(description="Limit to a specific source ID (partial match)")] = None,
since: Annotated[str | None, Query(description="ISO timestamp lower bound")] = None,
until: Annotated[str | None, Query(description="ISO timestamp upper bound")] = None,
) -> dict:
if not q:
return {"count": 0, "results": [], "formatted": ""}
# Auto-detect source hints: if a query token matches part of a known source_id,
# use that token as the source_filter so all matching sources (e.g. all
# rotated plex logs) are included — not just the first matched rotation.
detected_source = source
if not detected_source:
known_sources = [s["source_id"] for s in _list_sources(DB_PATH)]
q_lower = q.lower()
for src in known_sources:
parts = [p for seg in src.split(":") for p in seg.replace("-", " ").replace("_", " ").split()]
for p in parts:
if len(p) > 3 and p in q_lower:
detected_source = p # use matched token, not full source_id
break
if detected_source:
break
common: dict = dict(source_filter=detected_source, since=since, until=until, include_repeats=False)
# Broad pass uses OR so any symptom keyword surfaces evidence
broad = _search(DB_PATH, query=q, limit=15, or_mode=True, **common)
critical = _search(DB_PATH, query=q, severity="CRITICAL", limit=5, **common)
errors = _search(DB_PATH, query=q, severity="ERROR", limit=10, **common)
# When a source was auto-detected, also pull its most recent errors via plain SQL —
# FTS ranking can bury real errors from the named service if their text doesn't
# match the symptom keywords. Plain-SQL scan returns actual recent errors regardless.
source_errors: list = []
if detected_source and not source and not errors:
source_errors = _source_errors(
DB_PATH, source_filter=detected_source, severity="ERROR",
limit=10, since=since, until=until,
)
if not source_errors:
source_errors = _source_errors(
DB_PATH, source_filter=detected_source, severity="CRITICAL",
limit=5, since=since, until=until,
)
seen: set[str] = set()
combined = []
for r in broad + critical + errors + source_errors:
if r.entry_id not in seen:
seen.add(r.entry_id)
combined.append(r)
combined.sort(key=lambda r: (r.timestamp_iso or "\xff", r.sequence))
combined = combined[:20]
return {
"count": len(combined),
"results": [dataclasses.asdict(r) for r in combined],
"formatted": format_results(combined),
}
@router.post("/api/diagnose")
def diagnose_post(body: DiagnoseRequest) -> dict:
if not body.query.strip():
return {
"summary": {
"total": 0, "window_start": None, "window_end": None,
"time_detected": False, "by_severity": {}, "by_source": {},
},
"entries": [],
}
prefs = _load_prefs()
result = _diagnose(
DB_PATH,
query=body.query,
since=body.since,
until=body.until,
source_filter=body.source or None,
llm_url=prefs.get("llm_url") or None,
llm_model=prefs.get("llm_model") or None,
llm_api_key=prefs.get("llm_api_key") or None,
)
return {
"summary": result["summary"],
"reasoning": result.get("reasoning"),
"entries": [dataclasses.asdict(r) for r in result["entries"]],
}
@router.post("/api/diagnose/stream")
async def diagnose_post_stream(body: DiagnoseRequest) -> StreamingResponse:
prefs = _load_prefs()
async def sse_gen():
async for event in _diagnose_stream(
DB_PATH,
query=body.query,
since=body.since,
until=body.until,
source_filter=body.source or None,
llm_url=prefs.get("llm_url") or None,
llm_model=prefs.get("llm_model") or None,
llm_api_key=prefs.get("llm_api_key") or None,
context_db_path=CONTEXT_DB_PATH,
):
yield f"data: {json.dumps(event)}\n\n"
return StreamingResponse(
sse_gen(),
media_type="text/event-stream",
headers={"Cache-Control": "no-cache", "X-Accel-Buffering": "no"},
)
@router.get("/api/settings")
def get_settings() -> dict:
return _load_prefs()
@router.patch("/api/settings")
def patch_settings(body: SettingsBody) -> dict:
prefs = _load_prefs()
if body.entry_point_style is not None:
if body.entry_point_style not in ("topbar", "fab"):
raise HTTPException(status_code=422, detail="entry_point_style must be 'topbar' or 'fab'")
prefs["entry_point_style"] = body.entry_point_style
if body.llm_url is not None:
prefs["llm_url"] = body.llm_url
if body.llm_model is not None:
prefs["llm_model"] = body.llm_model
if body.llm_api_key is not None:
prefs["llm_api_key"] = body.llm_api_key
if body.tautulli_token is not None:
prefs["tautulli_token"] = body.tautulli_token
if body.severity_overrides is not None:
prefs["severity_overrides"] = [o.model_dump() for o in body.severity_overrides]
if body.pihole_url is not None:
prefs["pihole_url"] = body.pihole_url
if body.pihole_version is not None:
prefs["pihole_version"] = body.pihole_version
if body.pihole_api_key is not None:
prefs["pihole_api_key"] = body.pihole_api_key
if body.router_source_ids is not None:
prefs["router_source_ids"] = body.router_source_ids
if body.device_names is not None:
prefs["device_names"] = body.device_names
_save_prefs(prefs)
return prefs
@router.get("/api/sources")
def list_sources() -> dict:
return {"sources": _list_sources(DB_PATH)}
@router.get("/api/sources/configured")
def list_configured_sources() -> dict:
"""Return every source in sources.yaml, enriched with DB stats.
Unlike ``/api/sources`` (which is DB-only), this endpoint reads sources.yaml
so SSH sources appear even before their first successful glean. DB entry
counts, error counts, and timestamps are aggregated and merged in.
For SSH sources, sub-source IDs (e.g. ``rack01/journald``) are summed to
produce a single aggregate stat row for the top-level host entry.
"""
sources_file = PATTERN_DIR / "sources.yaml"
if not sources_file.exists():
return {"sources": []}
with open(sources_file) as f:
config = yaml.safe_load(f) or {}
# Fetch all DB source stats once; key by source_id for O(1) lookup.
db_stats: dict[str, dict] = {}
try:
for row in _list_sources(DB_PATH):
db_stats[row["source_id"]] = row
except Exception:
pass # DB may not exist on first run
result = []
for src in config.get("sources", []):
transport = src.get("transport", "local")
src_id = src.get("id", "")
entry: dict = {"id": src_id, "transport": transport}
if transport != "ssh":
entry["path"] = src.get("path", "")
db = db_stats.get(src_id, {})
entry["entry_count"] = db.get("entry_count", 0)
entry["error_count"] = db.get("error_count", 0)
entry["earliest"] = db.get("earliest")
entry["latest"] = db.get("latest")
else:
entry["host"] = src.get("host", "")
entry["user"] = src.get("user", "")
glean_items: list[dict] = src.get("glean", [])
entry["glean_types"] = sorted({item.get("type", "plaintext") for item in glean_items})
entry["glean_items"] = glean_items
# Aggregate sub-source DB rows that belong to this SSH host.
# Sub-sources use IDs like "{host_id}/{type}" or "{host_id}/{type}/{container}".
prefix = src_id + "/"
matching_rows = [
v for k, v in db_stats.items()
if k.startswith(prefix) or k == src_id
]
entry["entry_count"] = sum(r.get("entry_count", 0) for r in matching_rows)
entry["error_count"] = sum(r.get("error_count", 0) for r in matching_rows)
earliests = [r["earliest"] for r in matching_rows if r.get("earliest")]
latests = [r["latest"] for r in matching_rows if r.get("latest")]
entry["earliest"] = min(earliests) if earliests else None
entry["latest"] = max(latests) if latests else None
result.append(entry)
return {"sources": result}
@router.delete("/api/sources/{source_id}")
def delete_source(source_id: str) -> dict:
"""Delete all log entries (and FTS index rows) for a given source."""
conn = sqlite3.connect(str(DB_PATH))
conn.execute("PRAGMA journal_mode=WAL")
try:
conn.execute("DELETE FROM log_fts WHERE source_id = ?", (source_id,))
cur = conn.execute("DELETE FROM log_entries WHERE source_id = ?", (source_id,))
deleted = cur.rowcount
conn.commit()
finally:
conn.close()
return {"deleted": deleted, "source_id": source_id}
@router.post("/api/sources/{source_id}/glean")
def reglean_source(
source_id: str,
background_tasks: BackgroundTasks,
force: Annotated[bool, Query(description="Bypass fingerprint check and re-glean even if file is unchanged")] = False,
) -> dict:
"""Trigger a re-glean for a configured source from sources.yaml.
Handles both local file sources and SSH remote sources. For SSH sources,
the glean runs in the foreground and rebuilds the FTS index before returning
(same behaviour as local sources — callers can rely on the count being final
when the response arrives).
Use ``?force=true`` to bypass the fingerprint cache and re-glean the file
even if mtime and size appear unchanged since the last run.
"""
sources_file = PATTERN_DIR / "sources.yaml"
if not sources_file.exists():
raise HTTPException(status_code=404, detail="sources.yaml not found")
with open(sources_file) as f:
config = yaml.safe_load(f) or {}
matching = [s for s in config.get("sources", []) if s.get("id") == source_id]
if not matching:
raise HTTPException(status_code=404, detail=f"Source {source_id!r} not in sources.yaml")
src = matching[0]
if src.get("transport") == "ssh":
# SSH sources: open connection, glean all items, rebuild FTS inline.
# Fingerprint skipping applies only to local file sources.
stats = _glean_ssh_source(src, DB_PATH, PATTERN_FILE)
return {"source_id": source_id, "gleaned": sum(stats.values())}
# Local file source.
src_path = Path(src["path"])
if not src_path.exists():
raise HTTPException(status_code=422, detail=f"Path does not exist: {src_path}")
stats = _glean_file(src_path, DB_PATH, PATTERN_FILE, force=force)
background_tasks.add_task(build_fts_index, DB_PATH)
return {"source_id": source_id, "gleaned": stats.get(source_id, sum(stats.values()))}
@router.post("/api/glean/upload")
async def glean_upload(
file: UploadFile,
source_id: Annotated[str | None, Query(description="Override source ID (defaults to filename)")] = None,
background_tasks: BackgroundTasks = None,
) -> dict:
"""Accept a multipart log file, auto-detect format, glean into DB."""
sid = source_id or Path(file.filename or "upload").stem
content = await file.read()
with tempfile.NamedTemporaryFile(
suffix=Path(file.filename or "log.txt").suffix or ".log",
delete=False,
) as tmp:
tmp.write(content)
tmp_path = Path(tmp.name)
try:
stats = _glean_file(tmp_path, DB_PATH, PATTERN_FILE)
finally:
tmp_path.unlink(missing_ok=True)
if background_tasks is not None:
background_tasks.add_task(build_fts_index, DB_PATH)
total = sum(stats.values())
return {"source_id": sid, "gleaned": total, "stats": stats}
class BatchEntry(BaseModel):
id: str
source_id: str
sequence: int
timestamp_raw: str | None = None
timestamp_iso: str | None = None
ingest_time: str
severity: str | None = None
repeat_count: int = 1
out_of_order: int = 0
matched_patterns: list[str] = []
text: str
class BatchGleanRequest(BaseModel):
source_host: str = "unknown"
entries: list[BatchEntry]
@router.post("/api/glean/batch")
def glean_batch(payload: BatchGleanRequest, background_tasks: BackgroundTasks) -> dict:
"""Accept pre-parsed log entries from a remote Turnstone instance (submission protocol).
Used by nodes with TURNSTONE_SUBMIT_ENDPOINT configured to push their
pattern-matched entries to a central receiving instance.
"""
if not payload.entries:
return {"gleaned": 0}
conn = sqlite3.connect(str(DB_PATH))
conn.execute("PRAGMA journal_mode=WAL")
conn.executemany(
"""
INSERT OR IGNORE INTO log_entries
(id, source_id, sequence, timestamp_raw, timestamp_iso,
ingest_time, severity, repeat_count, out_of_order,
matched_patterns, text)
VALUES (?,?,?,?,?,?,?,?,?,?,?)
""",
[
(
e.id,
f"{payload.source_host}/{e.source_id}",
e.sequence,
e.timestamp_raw,
e.timestamp_iso,
e.ingest_time,
e.severity,
e.repeat_count,
e.out_of_order,
json.dumps(e.matched_patterns),
e.text,
)
for e in payload.entries
],
)
conn.commit()
conn.close()
background_tasks.add_task(build_fts_index, DB_PATH)
return {"gleaned": len(payload.entries), "source_host": payload.source_host}
@router.get("/api/tasks/glean/status")
def glean_task_status() -> dict:
"""Return the current state of the periodic glean scheduler."""
s = _glean_state()
return {
"running": s.running,
"run_count": s.run_count,
"last_run_at": s.last_run_at,
"last_duration_s": s.last_duration_s,
"last_stats": s.last_stats,
"last_error": s.last_error,
"next_run_at": s.next_run_at,
"interval_s": GLEAN_INTERVAL,
"scheduler_active": GLEAN_INTERVAL > 0 and (PATTERN_DIR / "sources.yaml").exists(),
"submit_endpoint": SUBMIT_ENDPOINT or None,
"last_submitted_at": s.last_submitted_at,
"last_submit_count": s.last_submit_count,
"last_submit_error": s.last_submit_error,
}
@router.post("/api/tasks/glean")
async def trigger_glean(
force: Annotated[bool, Query(description="Bypass fingerprint check and re-glean all sources")] = False,
) -> dict:
"""Manually trigger a glean of all configured sources. No-ops if already running.
Use ``?force=true`` to bypass the fingerprint cache and re-glean every local
file source even when mtime and size are unchanged since the last run.
"""
sources_file = PATTERN_DIR / "sources.yaml"
if not sources_file.exists():
raise HTTPException(status_code=404, detail="sources.yaml not found — configure log sources first")
return await _run_glean(
sources_file, DB_PATH, PATTERN_FILE,
submit_endpoint=SUBMIT_ENDPOINT or None,
source_host=SOURCE_HOST,
force=force,
)
@router.post("/api/glean/wazuh/alert")
async def glean_wazuh_alert(
alert: dict,
source_id: Annotated[str | None, Query(description="Source label (defaults to 'wazuh')")] = None,
background_tasks: BackgroundTasks = None,
) -> dict:
"""Accept a single Wazuh alert JSON object pushed by a Wazuh custom integration.
Configure in Wazuh: ossec.conf → <integration><name>custom-turnstone</name>
pointing to a script that POSTs the alert JSON to this endpoint.
"""
if not _is_wazuh_alert(alert):
from fastapi import HTTPException
raise HTTPException(status_code=422, detail="Not a valid Wazuh alert object")
sid = source_id or "wazuh"
ingest_time = now_iso()
compiled = load_compiled_patterns(PATTERN_FILE)
entries = list(_parse_wazuh(iter([json.dumps(alert)]), sid, compiled, ingest_time))
if entries:
conn = sqlite3.connect(str(DB_PATH))
conn.execute("PRAGMA journal_mode=WAL")
conn.executemany(
"""
INSERT OR IGNORE INTO log_entries
(id, source_id, sequence, timestamp_raw, timestamp_iso,
ingest_time, severity, repeat_count, out_of_order,
matched_patterns, text)
VALUES (?,?,?,?,?,?,?,?,?,?,?)
""",
[
(
e.entry_id, e.source_id, e.sequence,
e.timestamp_raw, e.timestamp_iso, e.ingest_time,
e.severity, e.repeat_count, int(e.out_of_order),
json.dumps(list(e.matched_patterns)), e.text,
)
for e in entries
],
)
conn.commit()
conn.close()
if background_tasks is not None:
background_tasks.add_task(build_fts_index, DB_PATH)
return {"ingested": len(entries), "source_id": sid}
@router.get("/api/watch/status")
def watch_status() -> dict:
return {"active": _watcher.is_active(), "sources": _watcher.status}
@router.post("/api/watch/reload")
def watch_reload() -> dict:
"""Stop all watch sources and restart with current watch.yaml."""
global _compiled_patterns
_watcher.stop()
_compiled_patterns = load_compiled_patterns(PATTERN_FILE)
watch_cfg_path = PATTERN_DIR / "watch.yaml"
configs = load_watch_config(watch_cfg_path)
if configs:
_watcher.configure(configs)
_watcher.start()
return {"reloaded": True, "source_count": len(configs)}
@router.get("/api/stats")
def get_stats(
window: Annotated[int, Query(ge=1, le=168, description="Hours to look back")] = 24,
) -> dict:
prefs = _load_prefs()
return _stats(DB_PATH, window_hours=window, severity_overrides=prefs.get("severity_overrides", []))
@router.post("/api/incidents")
def create_incident_endpoint(body: IncidentCreate) -> dict:
incident = create_incident(
DB_PATH,
label=body.label,
issue_type=body.issue_type,
started_at=body.started_at,
ended_at=body.ended_at,
notes=body.notes,
severity=body.severity,
)
return dataclasses.asdict(incident)
@router.get("/api/incidents")
def list_incidents_endpoint() -> dict:
return {"incidents": [dataclasses.asdict(i) for i in list_incidents(DB_PATH)]}
@router.get("/api/incidents/{incident_id}")
def get_incident_endpoint(incident_id: str) -> dict:
incident = get_incident(DB_PATH, incident_id)
if not incident:
raise HTTPException(status_code=404, detail="Incident not found")
entries = get_incident_entries(DB_PATH, incident)
return {
**dataclasses.asdict(incident),
"entries": [dataclasses.asdict(e) for e in entries],
}
@router.delete("/api/incidents/{incident_id}")
def delete_incident_endpoint(incident_id: str) -> dict:
if not delete_incident(DB_PATH, incident_id):
raise HTTPException(status_code=404, detail="Incident not found")
return {"deleted": incident_id}
@router.get("/api/incidents/{incident_id}/bundle")
def get_incident_bundle(incident_id: str) -> dict:
incident = get_incident(DB_PATH, incident_id)
if not incident:
raise HTTPException(status_code=404, detail="Incident not found")
return build_bundle(DB_PATH, incident, source_host=SOURCE_HOST)
@router.post("/api/incidents/{incident_id}/send")
def send_incident_bundle(incident_id: str) -> dict:
if not BUNDLE_ENDPOINT:
raise HTTPException(status_code=503, detail="TURNSTONE_BUNDLE_ENDPOINT not configured")
incident = get_incident(DB_PATH, incident_id)
if not incident:
raise HTTPException(status_code=404, detail="Incident not found")
bundle = build_bundle(DB_PATH, incident, source_host=SOURCE_HOST)
payload = json.dumps(bundle).encode()
req = urllib.request.Request(
BUNDLE_ENDPOINT,
data=payload,
headers={"Content-Type": "application/json"},
method="POST",
)
try:
with urllib.request.urlopen(req, timeout=15) as resp:
return {"sent": True, "status": resp.status, "entry_count": len(bundle["log_entries"])}
except urllib.error.HTTPError as exc:
raise HTTPException(status_code=502, detail=f"Receiver returned {exc.code}") from exc
except OSError as exc:
raise HTTPException(status_code=502, detail=f"Send failed: {exc}") from exc
@router.post("/api/bundles")
def receive_bundle(bundle: dict) -> dict:
record = store_bundle(DB_PATH, bundle)
return {"id": record.id, "entry_count": record.entry_count}
@router.get("/api/bundles")
def list_bundles_endpoint() -> dict:
bundles = list_bundles(DB_PATH)
return {"bundles": [dataclasses.asdict(b) for b in bundles]}
@router.get("/api/bundles/{bundle_id}")
def get_bundle_endpoint(bundle_id: str) -> dict:
bundle = get_bundle(DB_PATH, bundle_id)
if not bundle:
raise HTTPException(status_code=404, detail="Bundle not found")
return dataclasses.asdict(bundle)
def _tautulli_write_entry(conn: sqlite3.Connection, entry) -> None:
conn.execute(
"""
INSERT OR IGNORE INTO log_entries
(id, source_id, sequence, timestamp_raw, timestamp_iso,
ingest_time, severity, repeat_count, out_of_order,
matched_patterns, text)
VALUES (?,?,?,?,?,?,?,?,?,?,?)
""",
(
entry.entry_id, entry.source_id, entry.sequence,
entry.timestamp_raw, entry.timestamp_iso, entry.ingest_time,
entry.severity, entry.repeat_count, int(entry.out_of_order),
json.dumps(list(entry.matched_patterns)), entry.text,
),
)
@router.post("/api/glean/tautulli")
def glean_tautulli(
payload: dict,
request: Request,
background_tasks: BackgroundTasks,
) -> dict:
"""Accept a Tautulli webhook POST and store the event as a log entry."""
prefs = _load_prefs()
token = prefs.get("tautulli_token", "")
if token:
header_token = request.headers.get("X-Tautulli-Token", "")
if not hmac.compare_digest(header_token, token):
raise HTTPException(status_code=403, detail="Invalid Tautulli token")
if "action" not in payload:
raise HTTPException(status_code=400, detail="Missing required field: action")
compiled = _compiled_patterns
entry = _parse_tautulli(payload, compiled)
conn = sqlite3.connect(str(DB_PATH))
conn.execute("PRAGMA journal_mode=WAL")
try:
_tautulli_write_entry(conn, entry)
conn.commit()
finally:
conn.close()
background_tasks.add_task(build_fts_index, DB_PATH)
return {"stored": 1, "entry_id": entry.entry_id, "action": payload.get("action")}
class BlocklistStatusBody(BaseModel):
status: str
def _make_pihole_client() -> PiholeClient:
"""Build PiholeClient from prefs. Raises 503 if not configured.
The 503 is raised by catching ValueError from PiholeClient.__post_init__,
which validates that url and api_key are non-empty. When PiholeClient is
mocked in tests, __post_init__ is never called and no 503 is raised.
"""
prefs = _load_prefs()
url = prefs.get("pihole_url", "")
key = prefs.get("pihole_api_key", "")
version = prefs.get("pihole_version", "v6")
try:
return PiholeClient(url=url, api_key=key, version=version)
except ValueError as exc:
raise HTTPException(
status_code=503,
detail="Pi-hole not configured — set pihole_url and pihole_api_key in Settings",
) from exc
@router.get("/api/blocklist/candidates")
def list_blocklist_candidates(
status: Annotated[str | None, Query()] = None,
device_ip: Annotated[str | None, Query()] = None,
) -> dict:
candidates = list_candidates(DB_PATH, status=status, device_ip=device_ip)
return {"candidates": [dataclasses.asdict(c) for c in candidates], "total": len(candidates)}
@router.post("/api/blocklist/scan")
def scan_blocklist(background_tasks: BackgroundTasks) -> dict:
prefs = _load_prefs()
source_ids = [s.strip() for s in prefs.get("router_source_ids", "").split(",") if s.strip()]
device_map: dict[str, str] = {}
raw_devices = prefs.get("device_names", "")
if raw_devices:
try:
device_map = json.loads(raw_devices)
except (ValueError, TypeError):
raise HTTPException(status_code=400, detail="device_names is not valid JSON — update it in Settings")
telemetry_path = PATTERN_DIR / "telemetry.yaml"
telemetry_rules = load_telemetry_rules(telemetry_path) if telemetry_path.exists() else []
background_tasks.add_task(run_scan, DB_PATH, source_ids, device_map, telemetry_rules)
return {"started": True}
@router.patch("/api/blocklist/candidates/{candidate_id}")
def update_blocklist_status(candidate_id: str, body: BlocklistStatusBody) -> dict:
try:
candidate = update_candidate_status(DB_PATH, candidate_id, body.status)
return dataclasses.asdict(candidate)
except ValueError as exc:
raise HTTPException(status_code=400, detail=str(exc))
except KeyError:
raise HTTPException(status_code=404, detail="Candidate not found")
@router.post("/api/blocklist/push/{candidate_id}")
def push_to_pihole(candidate_id: str) -> dict:
try:
candidate = get_candidate(DB_PATH, candidate_id)
except KeyError:
raise HTTPException(status_code=404, detail="Candidate not found")
if candidate.status != "approved":
raise HTTPException(
status_code=400,
detail=f"Candidate must be approved before pushing (current status: {candidate.status!r})",
)
pihole = _make_pihole_client()
pihole.block(candidate.domain_or_ip)
mark_pushed(DB_PATH, candidate_id)
return {"pushed": True, "domain": candidate.domain_or_ip}
@router.delete("/api/blocklist/push/{candidate_id}")
def unblock_from_pihole(candidate_id: str) -> dict:
try:
candidate = get_candidate(DB_PATH, candidate_id)
except KeyError:
raise HTTPException(status_code=404, detail="Candidate not found")
if candidate.status != "pushed":
raise HTTPException(
status_code=400,
detail=f"Candidate is not currently pushed (status: {candidate.status!r})",
)
pihole = _make_pihole_client()
pihole.unblock(candidate.domain_or_ip)
mark_unblocked(DB_PATH, candidate_id)
return {"unblocked": True, "domain": candidate.domain_or_ip}
@router.post("/api/blocklist/test")
def test_pihole_connection() -> dict:
pihole = _make_pihole_client()
return pihole.test_connection()
app.include_router(router)
_ctx = APIRouter(prefix="/turnstone/api/context")
@_ctx.post("/docs")
async def upload_doc(file: UploadFile):
content = await file.read()
try:
result = await asyncio.to_thread(
lambda: _ingest_upload(CONTEXT_DB_PATH, file.filename or "upload", content)
)
except UnsupportedDocType as e:
raise HTTPException(status_code=415, detail=str(e))
except FileTooLarge as e:
raise HTTPException(status_code=413, detail=str(e))
return result
@_ctx.get("/docs")
async def list_docs():
docs = await asyncio.to_thread(lambda: _list_documents(CONTEXT_DB_PATH))
return [
{
"id": d.id,
"filename": d.filename,
"doc_type": d.doc_type,
"file_size": d.file_size,
"uploaded_at": d.uploaded_at,
}
for d in docs
]
@_ctx.delete("/docs/{doc_id}")
async def delete_doc(doc_id: str):
deleted = await asyncio.to_thread(lambda: _delete_document(CONTEXT_DB_PATH, doc_id))
if not deleted:
raise HTTPException(status_code=404, detail="Document not found")
return {"deleted": doc_id}
@_ctx.post("/facts")
async def create_fact(body: FactBody):
fact = await asyncio.to_thread(
lambda: _add_fact(CONTEXT_DB_PATH, body.category, body.key, body.value, body.source)
)
return {"id": fact.id, "category": fact.category, "key": fact.key,
"value": fact.value, "source": fact.source, "created_at": fact.created_at}
@_ctx.get("/facts")
async def list_facts_endpoint(category: str | None = None):
facts = await asyncio.to_thread(lambda: _list_facts(CONTEXT_DB_PATH, category))
return [
{"id": f.id, "category": f.category, "key": f.key,
"value": f.value, "source": f.source, "created_at": f.created_at}
for f in facts
]
@_ctx.delete("/facts/{fact_id}")
async def delete_fact_endpoint(fact_id: str):
deleted = await asyncio.to_thread(lambda: _delete_fact(CONTEXT_DB_PATH, fact_id))
if not deleted:
raise HTTPException(status_code=404, detail="Fact not found")
return {"deleted": fact_id}
@_ctx.get("/wizard/schema")
async def wizard_schema():
return _wizard_schema()
@_ctx.post("/wizard/step")
async def wizard_step(body: WizardStepBody):
updated = advance_step(body.session, body.step_id, body.answer)
return {"session": updated, "complete": is_complete(updated)}
@_ctx.post("/wizard/apply")
async def wizard_apply(body: WizardApplyBody):
if not is_complete(body.session):
raise HTTPException(status_code=400, detail="Wizard session is not complete")
result = await asyncio.to_thread(lambda: apply_session(CONTEXT_DB_PATH, body.session))
return result
@_ctx.get("/debug/search")
async def debug_search(q: str):
ctx = await asyncio.to_thread(lambda: _retrieve_context(CONTEXT_DB_PATH, q))
return {"facts": ctx.facts, "chunks": ctx.chunks, "block": format_context_block(ctx)}
app.include_router(_ctx)
# Root redirect → /turnstone/
@app.get("/")
def root_redirect() -> RedirectResponse:
return RedirectResponse(url="/turnstone/")
# SPA catch-all — serves index.html for any /turnstone/* path that isn't a
# static asset or API route. Must be registered after include_router.
@app.get("/turnstone/{path:path}")
def spa_fallback(path: str) -> FileResponse:
if DIST_DIR.exists():
candidate = DIST_DIR / path
if candidate.is_file():
return FileResponse(str(candidate))
return FileResponse(str(DIST_DIR / "index.html"))
return FileResponse("/dev/null", status_code=503)