diff --git a/Dockerfile b/Dockerfile index adc363b..f8cac14 100644 --- a/Dockerfile +++ b/Dockerfile @@ -4,8 +4,9 @@ FROM python:3.11-slim WORKDIR /app # System deps for companyScraper (beautifulsoup4, fake-useragent, lxml) and PDF gen +# libsqlcipher-dev: required to build pysqlcipher3 (SQLCipher AES-256 encryption for cloud mode) RUN apt-get update && apt-get install -y --no-install-recommends \ - gcc libffi-dev curl \ + gcc libffi-dev curl libsqlcipher-dev \ && rm -rf /var/lib/apt/lists/* COPY requirements.txt . diff --git a/app/cloud_session.py b/app/cloud_session.py new file mode 100644 index 0000000..14a8b85 --- /dev/null +++ b/app/cloud_session.py @@ -0,0 +1,94 @@ +# peregrine/app/cloud_session.py +""" +Cloud session middleware for multi-tenant Peregrine deployment. + +In local-first mode (CLOUD_MODE unset or false), all functions are no-ops. +In cloud mode (CLOUD_MODE=true), resolves the Directus session JWT from the +X-CF-Session header, validates it, and injects user_id + db_path into +st.session_state. + +All Peregrine pages call get_db_path() instead of DEFAULT_DB directly to +transparently support both local and cloud deployments. +""" +import os +import hmac +import hashlib +from pathlib import Path + +import streamlit as st + +from scripts.db import DEFAULT_DB + +CLOUD_MODE: bool = os.environ.get("CLOUD_MODE", "").lower() in ("1", "true", "yes") +CLOUD_DATA_ROOT: Path = Path(os.environ.get("CLOUD_DATA_ROOT", "/devl/menagerie-data")) +DIRECTUS_JWT_SECRET: str = os.environ.get("DIRECTUS_JWT_SECRET", "") +SERVER_SECRET: str = os.environ.get("CF_SERVER_SECRET", "") + + +def validate_session_jwt(token: str) -> str: + """Validate a Directus session JWT and return the user UUID. Raises on failure.""" + import jwt # PyJWT — lazy import so local mode never needs it + payload = jwt.decode(token, DIRECTUS_JWT_SECRET, algorithms=["HS256"]) + user_id = payload.get("id") or payload.get("sub") + if not user_id: + raise ValueError("JWT missing user id claim") + return user_id + + +def _user_data_path(user_id: str, app: str) -> Path: + return CLOUD_DATA_ROOT / user_id / app + + +def derive_db_key(user_id: str) -> str: + """Derive a per-user SQLCipher encryption key from the server secret.""" + return hmac.new( + SERVER_SECRET.encode(), + user_id.encode(), + hashlib.sha256, + ).hexdigest() + + +def resolve_session(app: str = "peregrine") -> None: + """ + Call at the top of each Streamlit page. + In local mode: no-op. + In cloud mode: reads X-CF-Session header, validates JWT, creates user + data directory on first visit, and sets st.session_state keys: + - user_id: str + - db_path: Path + - db_key: str (SQLCipher key for this user) + Idempotent — skips if user_id already in session_state. + """ + if not CLOUD_MODE: + return + if st.session_state.get("user_id"): + return + + token = st.context.headers.get("x-cf-session", "") + if not token: + st.error("Session token missing. Please log in at circuitforge.tech.") + st.stop() + + try: + user_id = validate_session_jwt(token) + except Exception as exc: + st.error(f"Invalid session — please log in again. ({exc})") + st.stop() + + user_path = _user_data_path(user_id, app) + user_path.mkdir(parents=True, exist_ok=True) + (user_path / "config").mkdir(exist_ok=True) + (user_path / "data").mkdir(exist_ok=True) + + st.session_state["user_id"] = user_id + st.session_state["db_path"] = user_path / "staging.db" + st.session_state["db_key"] = derive_db_key(user_id) + + +def get_db_path() -> Path: + """ + Return the active db_path for this session. + Cloud: user-scoped path from session_state. + Local: DEFAULT_DB (from STAGING_DB env var or repo default). + """ + return st.session_state.get("db_path", DEFAULT_DB) diff --git a/requirements.txt b/requirements.txt index 81e8237..b48998c 100644 --- a/requirements.txt +++ b/requirements.txt @@ -54,6 +54,7 @@ python-dotenv # ── Auth / licensing ────────────────────────────────────────────────────── PyJWT>=2.8 +pysqlcipher3 # ── Utilities ───────────────────────────────────────────────────────────── sqlalchemy diff --git a/scripts/db.py b/scripts/db.py index a091a87..0bc5515 100644 --- a/scripts/db.py +++ b/scripts/db.py @@ -11,6 +11,30 @@ from typing import Optional DEFAULT_DB = Path(os.environ.get("STAGING_DB", Path(__file__).parent.parent / "staging.db")) + +def get_connection(db_path: Path = DEFAULT_DB, key: str = "") -> "sqlite3.Connection": + """ + Open a database connection. + + In cloud mode with a key: uses SQLCipher (AES-256 encrypted, API-identical to sqlite3). + Otherwise: vanilla sqlite3. + + Args: + db_path: Path to the SQLite/SQLCipher database file. + key: SQLCipher encryption key (hex string). Empty = unencrypted. + """ + import os as _os + cloud_mode = _os.environ.get("CLOUD_MODE", "").lower() in ("1", "true", "yes") + if cloud_mode and key: + from pysqlcipher3 import dbapi2 as _sqlcipher + conn = _sqlcipher.connect(str(db_path)) + conn.execute(f"PRAGMA key='{key}'") + return conn + else: + import sqlite3 as _sqlite3 + return _sqlite3.connect(str(db_path)) + + CREATE_JOBS = """ CREATE TABLE IF NOT EXISTS jobs ( id INTEGER PRIMARY KEY AUTOINCREMENT, diff --git a/tests/test_cloud_session.py b/tests/test_cloud_session.py new file mode 100644 index 0000000..8d637a4 --- /dev/null +++ b/tests/test_cloud_session.py @@ -0,0 +1,96 @@ +import pytest +import os +from unittest.mock import patch, MagicMock +from pathlib import Path + + +def test_resolve_session_is_noop_in_local_mode(monkeypatch): + """resolve_session() does nothing when CLOUD_MODE is not set.""" + monkeypatch.delenv("CLOUD_MODE", raising=False) + # Must reimport after env change + import importlib + import app.cloud_session as cs + importlib.reload(cs) + # Should return without touching st + cs.resolve_session("peregrine") # no error = pass + + +def test_resolve_session_sets_db_path(tmp_path, monkeypatch): + """resolve_session() sets st.session_state.db_path from a valid JWT.""" + monkeypatch.setenv("CLOUD_MODE", "true") + import importlib + import app.cloud_session as cs + importlib.reload(cs) + + mock_state = {} + with patch.object(cs, "validate_session_jwt", return_value="user-uuid-123"), \ + patch.object(cs, "st") as mock_st, \ + patch.object(cs, "CLOUD_DATA_ROOT", tmp_path): + mock_st.session_state = mock_state + mock_st.context.headers = {"x-cf-session": "valid.jwt.token"} + cs.resolve_session("peregrine") + + assert mock_state["user_id"] == "user-uuid-123" + assert mock_state["db_path"] == tmp_path / "user-uuid-123" / "peregrine" / "staging.db" + + +def test_resolve_session_creates_user_dir(tmp_path, monkeypatch): + """resolve_session() creates the user data directory on first login.""" + monkeypatch.setenv("CLOUD_MODE", "true") + import importlib + import app.cloud_session as cs + importlib.reload(cs) + + mock_state = {} + with patch.object(cs, "validate_session_jwt", return_value="new-user"), \ + patch.object(cs, "st") as mock_st, \ + patch.object(cs, "CLOUD_DATA_ROOT", tmp_path): + mock_st.session_state = mock_state + mock_st.context.headers = {"x-cf-session": "valid.jwt.token"} + cs.resolve_session("peregrine") + + assert (tmp_path / "new-user" / "peregrine").is_dir() + assert (tmp_path / "new-user" / "peregrine" / "config").is_dir() + assert (tmp_path / "new-user" / "peregrine" / "data").is_dir() + + +def test_resolve_session_idempotent(monkeypatch): + """resolve_session() skips if user_id already in session state.""" + monkeypatch.setenv("CLOUD_MODE", "true") + import importlib + import app.cloud_session as cs + importlib.reload(cs) + + with patch.object(cs, "st") as mock_st: + mock_st.session_state = {"user_id": "existing-user"} + # Should not try to read headers or validate JWT + cs.resolve_session("peregrine") + # context.headers should never be accessed + mock_st.context.headers.__getitem__.assert_not_called() if hasattr(mock_st.context, 'headers') else None + + +def test_get_db_path_returns_session_path(tmp_path, monkeypatch): + """get_db_path() returns session-scoped path when set.""" + import importlib + import app.cloud_session as cs + importlib.reload(cs) + + session_db = tmp_path / "staging.db" + with patch.object(cs, "st") as mock_st: + mock_st.session_state = {"db_path": session_db} + result = cs.get_db_path() + assert result == session_db + + +def test_get_db_path_falls_back_to_default(monkeypatch): + """get_db_path() returns DEFAULT_DB when no session path set.""" + monkeypatch.delenv("CLOUD_MODE", raising=False) + import importlib + import app.cloud_session as cs + importlib.reload(cs) + from scripts.db import DEFAULT_DB + + with patch.object(cs, "st") as mock_st: + mock_st.session_state = {} + result = cs.get_db_path() + assert result == DEFAULT_DB diff --git a/tests/test_db.py b/tests/test_db.py index 9b0148c..b8b1331 100644 --- a/tests/test_db.py +++ b/tests/test_db.py @@ -576,3 +576,17 @@ def test_insert_task_with_params(tmp_path): params2 = json.dumps({"section": "job_titles"}) task_id3, is_new3 = insert_task(db, "wizard_generate", 0, params=params2) assert is_new3 is True + + +def test_get_connection_local_mode(tmp_path): + """get_connection() returns a working sqlite3 connection in local mode (no key).""" + from scripts.db import get_connection + db = tmp_path / "test_conn.db" + conn = get_connection(db) + conn.execute("CREATE TABLE t (x INTEGER)") + conn.execute("INSERT INTO t VALUES (42)") + conn.commit() + result = conn.execute("SELECT x FROM t").fetchone() + conn.close() + assert result[0] == 42 + assert db.exists()