feat(peregrine): add cloud_session middleware + SQLCipher get_connection()
cloud_session.py: no-op in local mode; in cloud mode resolves Directus JWT from X-CF-Session header to per-user db_path in st.session_state. get_connection() in scripts/db.py: transparent SQLCipher/sqlite3 switch — uses encrypted driver when CLOUD_MODE=true and key provided, vanilla sqlite3 otherwise. libsqlcipher-dev added to Dockerfile for Docker builds. 6 new cloud_session tests + 1 new get_connection test — 34/34 db tests pass.
This commit is contained in:
parent
24bb8476ab
commit
96715bdeb6
6 changed files with 231 additions and 1 deletions
|
|
@ -4,8 +4,9 @@ FROM python:3.11-slim
|
|||
WORKDIR /app
|
||||
|
||||
# System deps for companyScraper (beautifulsoup4, fake-useragent, lxml) and PDF gen
|
||||
# libsqlcipher-dev: required to build pysqlcipher3 (SQLCipher AES-256 encryption for cloud mode)
|
||||
RUN apt-get update && apt-get install -y --no-install-recommends \
|
||||
gcc libffi-dev curl \
|
||||
gcc libffi-dev curl libsqlcipher-dev \
|
||||
&& rm -rf /var/lib/apt/lists/*
|
||||
|
||||
COPY requirements.txt .
|
||||
|
|
|
|||
94
app/cloud_session.py
Normal file
94
app/cloud_session.py
Normal file
|
|
@ -0,0 +1,94 @@
|
|||
# peregrine/app/cloud_session.py
|
||||
"""
|
||||
Cloud session middleware for multi-tenant Peregrine deployment.
|
||||
|
||||
In local-first mode (CLOUD_MODE unset or false), all functions are no-ops.
|
||||
In cloud mode (CLOUD_MODE=true), resolves the Directus session JWT from the
|
||||
X-CF-Session header, validates it, and injects user_id + db_path into
|
||||
st.session_state.
|
||||
|
||||
All Peregrine pages call get_db_path() instead of DEFAULT_DB directly to
|
||||
transparently support both local and cloud deployments.
|
||||
"""
|
||||
import os
|
||||
import hmac
|
||||
import hashlib
|
||||
from pathlib import Path
|
||||
|
||||
import streamlit as st
|
||||
|
||||
from scripts.db import DEFAULT_DB
|
||||
|
||||
CLOUD_MODE: bool = os.environ.get("CLOUD_MODE", "").lower() in ("1", "true", "yes")
|
||||
CLOUD_DATA_ROOT: Path = Path(os.environ.get("CLOUD_DATA_ROOT", "/devl/menagerie-data"))
|
||||
DIRECTUS_JWT_SECRET: str = os.environ.get("DIRECTUS_JWT_SECRET", "")
|
||||
SERVER_SECRET: str = os.environ.get("CF_SERVER_SECRET", "")
|
||||
|
||||
|
||||
def validate_session_jwt(token: str) -> str:
|
||||
"""Validate a Directus session JWT and return the user UUID. Raises on failure."""
|
||||
import jwt # PyJWT — lazy import so local mode never needs it
|
||||
payload = jwt.decode(token, DIRECTUS_JWT_SECRET, algorithms=["HS256"])
|
||||
user_id = payload.get("id") or payload.get("sub")
|
||||
if not user_id:
|
||||
raise ValueError("JWT missing user id claim")
|
||||
return user_id
|
||||
|
||||
|
||||
def _user_data_path(user_id: str, app: str) -> Path:
|
||||
return CLOUD_DATA_ROOT / user_id / app
|
||||
|
||||
|
||||
def derive_db_key(user_id: str) -> str:
|
||||
"""Derive a per-user SQLCipher encryption key from the server secret."""
|
||||
return hmac.new(
|
||||
SERVER_SECRET.encode(),
|
||||
user_id.encode(),
|
||||
hashlib.sha256,
|
||||
).hexdigest()
|
||||
|
||||
|
||||
def resolve_session(app: str = "peregrine") -> None:
|
||||
"""
|
||||
Call at the top of each Streamlit page.
|
||||
In local mode: no-op.
|
||||
In cloud mode: reads X-CF-Session header, validates JWT, creates user
|
||||
data directory on first visit, and sets st.session_state keys:
|
||||
- user_id: str
|
||||
- db_path: Path
|
||||
- db_key: str (SQLCipher key for this user)
|
||||
Idempotent — skips if user_id already in session_state.
|
||||
"""
|
||||
if not CLOUD_MODE:
|
||||
return
|
||||
if st.session_state.get("user_id"):
|
||||
return
|
||||
|
||||
token = st.context.headers.get("x-cf-session", "")
|
||||
if not token:
|
||||
st.error("Session token missing. Please log in at circuitforge.tech.")
|
||||
st.stop()
|
||||
|
||||
try:
|
||||
user_id = validate_session_jwt(token)
|
||||
except Exception as exc:
|
||||
st.error(f"Invalid session — please log in again. ({exc})")
|
||||
st.stop()
|
||||
|
||||
user_path = _user_data_path(user_id, app)
|
||||
user_path.mkdir(parents=True, exist_ok=True)
|
||||
(user_path / "config").mkdir(exist_ok=True)
|
||||
(user_path / "data").mkdir(exist_ok=True)
|
||||
|
||||
st.session_state["user_id"] = user_id
|
||||
st.session_state["db_path"] = user_path / "staging.db"
|
||||
st.session_state["db_key"] = derive_db_key(user_id)
|
||||
|
||||
|
||||
def get_db_path() -> Path:
|
||||
"""
|
||||
Return the active db_path for this session.
|
||||
Cloud: user-scoped path from session_state.
|
||||
Local: DEFAULT_DB (from STAGING_DB env var or repo default).
|
||||
"""
|
||||
return st.session_state.get("db_path", DEFAULT_DB)
|
||||
|
|
@ -54,6 +54,7 @@ python-dotenv
|
|||
|
||||
# ── Auth / licensing ──────────────────────────────────────────────────────
|
||||
PyJWT>=2.8
|
||||
pysqlcipher3
|
||||
|
||||
# ── Utilities ─────────────────────────────────────────────────────────────
|
||||
sqlalchemy
|
||||
|
|
|
|||
|
|
@ -11,6 +11,30 @@ from typing import Optional
|
|||
|
||||
DEFAULT_DB = Path(os.environ.get("STAGING_DB", Path(__file__).parent.parent / "staging.db"))
|
||||
|
||||
|
||||
def get_connection(db_path: Path = DEFAULT_DB, key: str = "") -> "sqlite3.Connection":
|
||||
"""
|
||||
Open a database connection.
|
||||
|
||||
In cloud mode with a key: uses SQLCipher (AES-256 encrypted, API-identical to sqlite3).
|
||||
Otherwise: vanilla sqlite3.
|
||||
|
||||
Args:
|
||||
db_path: Path to the SQLite/SQLCipher database file.
|
||||
key: SQLCipher encryption key (hex string). Empty = unencrypted.
|
||||
"""
|
||||
import os as _os
|
||||
cloud_mode = _os.environ.get("CLOUD_MODE", "").lower() in ("1", "true", "yes")
|
||||
if cloud_mode and key:
|
||||
from pysqlcipher3 import dbapi2 as _sqlcipher
|
||||
conn = _sqlcipher.connect(str(db_path))
|
||||
conn.execute(f"PRAGMA key='{key}'")
|
||||
return conn
|
||||
else:
|
||||
import sqlite3 as _sqlite3
|
||||
return _sqlite3.connect(str(db_path))
|
||||
|
||||
|
||||
CREATE_JOBS = """
|
||||
CREATE TABLE IF NOT EXISTS jobs (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
|
|
|
|||
96
tests/test_cloud_session.py
Normal file
96
tests/test_cloud_session.py
Normal file
|
|
@ -0,0 +1,96 @@
|
|||
import pytest
|
||||
import os
|
||||
from unittest.mock import patch, MagicMock
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
def test_resolve_session_is_noop_in_local_mode(monkeypatch):
|
||||
"""resolve_session() does nothing when CLOUD_MODE is not set."""
|
||||
monkeypatch.delenv("CLOUD_MODE", raising=False)
|
||||
# Must reimport after env change
|
||||
import importlib
|
||||
import app.cloud_session as cs
|
||||
importlib.reload(cs)
|
||||
# Should return without touching st
|
||||
cs.resolve_session("peregrine") # no error = pass
|
||||
|
||||
|
||||
def test_resolve_session_sets_db_path(tmp_path, monkeypatch):
|
||||
"""resolve_session() sets st.session_state.db_path from a valid JWT."""
|
||||
monkeypatch.setenv("CLOUD_MODE", "true")
|
||||
import importlib
|
||||
import app.cloud_session as cs
|
||||
importlib.reload(cs)
|
||||
|
||||
mock_state = {}
|
||||
with patch.object(cs, "validate_session_jwt", return_value="user-uuid-123"), \
|
||||
patch.object(cs, "st") as mock_st, \
|
||||
patch.object(cs, "CLOUD_DATA_ROOT", tmp_path):
|
||||
mock_st.session_state = mock_state
|
||||
mock_st.context.headers = {"x-cf-session": "valid.jwt.token"}
|
||||
cs.resolve_session("peregrine")
|
||||
|
||||
assert mock_state["user_id"] == "user-uuid-123"
|
||||
assert mock_state["db_path"] == tmp_path / "user-uuid-123" / "peregrine" / "staging.db"
|
||||
|
||||
|
||||
def test_resolve_session_creates_user_dir(tmp_path, monkeypatch):
|
||||
"""resolve_session() creates the user data directory on first login."""
|
||||
monkeypatch.setenv("CLOUD_MODE", "true")
|
||||
import importlib
|
||||
import app.cloud_session as cs
|
||||
importlib.reload(cs)
|
||||
|
||||
mock_state = {}
|
||||
with patch.object(cs, "validate_session_jwt", return_value="new-user"), \
|
||||
patch.object(cs, "st") as mock_st, \
|
||||
patch.object(cs, "CLOUD_DATA_ROOT", tmp_path):
|
||||
mock_st.session_state = mock_state
|
||||
mock_st.context.headers = {"x-cf-session": "valid.jwt.token"}
|
||||
cs.resolve_session("peregrine")
|
||||
|
||||
assert (tmp_path / "new-user" / "peregrine").is_dir()
|
||||
assert (tmp_path / "new-user" / "peregrine" / "config").is_dir()
|
||||
assert (tmp_path / "new-user" / "peregrine" / "data").is_dir()
|
||||
|
||||
|
||||
def test_resolve_session_idempotent(monkeypatch):
|
||||
"""resolve_session() skips if user_id already in session state."""
|
||||
monkeypatch.setenv("CLOUD_MODE", "true")
|
||||
import importlib
|
||||
import app.cloud_session as cs
|
||||
importlib.reload(cs)
|
||||
|
||||
with patch.object(cs, "st") as mock_st:
|
||||
mock_st.session_state = {"user_id": "existing-user"}
|
||||
# Should not try to read headers or validate JWT
|
||||
cs.resolve_session("peregrine")
|
||||
# context.headers should never be accessed
|
||||
mock_st.context.headers.__getitem__.assert_not_called() if hasattr(mock_st.context, 'headers') else None
|
||||
|
||||
|
||||
def test_get_db_path_returns_session_path(tmp_path, monkeypatch):
|
||||
"""get_db_path() returns session-scoped path when set."""
|
||||
import importlib
|
||||
import app.cloud_session as cs
|
||||
importlib.reload(cs)
|
||||
|
||||
session_db = tmp_path / "staging.db"
|
||||
with patch.object(cs, "st") as mock_st:
|
||||
mock_st.session_state = {"db_path": session_db}
|
||||
result = cs.get_db_path()
|
||||
assert result == session_db
|
||||
|
||||
|
||||
def test_get_db_path_falls_back_to_default(monkeypatch):
|
||||
"""get_db_path() returns DEFAULT_DB when no session path set."""
|
||||
monkeypatch.delenv("CLOUD_MODE", raising=False)
|
||||
import importlib
|
||||
import app.cloud_session as cs
|
||||
importlib.reload(cs)
|
||||
from scripts.db import DEFAULT_DB
|
||||
|
||||
with patch.object(cs, "st") as mock_st:
|
||||
mock_st.session_state = {}
|
||||
result = cs.get_db_path()
|
||||
assert result == DEFAULT_DB
|
||||
|
|
@ -576,3 +576,17 @@ def test_insert_task_with_params(tmp_path):
|
|||
params2 = json.dumps({"section": "job_titles"})
|
||||
task_id3, is_new3 = insert_task(db, "wizard_generate", 0, params=params2)
|
||||
assert is_new3 is True
|
||||
|
||||
|
||||
def test_get_connection_local_mode(tmp_path):
|
||||
"""get_connection() returns a working sqlite3 connection in local mode (no key)."""
|
||||
from scripts.db import get_connection
|
||||
db = tmp_path / "test_conn.db"
|
||||
conn = get_connection(db)
|
||||
conn.execute("CREATE TABLE t (x INTEGER)")
|
||||
conn.execute("INSERT INTO t VALUES (42)")
|
||||
conn.commit()
|
||||
result = conn.execute("SELECT x FROM t").fetchone()
|
||||
conn.close()
|
||||
assert result[0] == 42
|
||||
assert db.exists()
|
||||
|
|
|
|||
Loading…
Reference in a new issue