Calls /admin/cloud/resolve after JWT validation to inject the user's current subscription tier (free/paid/premium/ultra) into session_state as cloud_tier. Cached 5 minutes via st.cache_data to avoid Heimdall spam on every Streamlit rerun. Degrades gracefully to free on timeout or missing token. New env vars: HEIMDALL_URL, HEIMDALL_ADMIN_TOKEN (added to .env.example and compose.cloud.yml). HEIMDALL_URL defaults to http://cf-license:8000 for internal Docker network access. New helper: get_cloud_tier() — returns tier string in cloud mode, "local" in local-first mode, so pages can distinguish self-hosted from cloud.
152 lines
5.4 KiB
Python
152 lines
5.4 KiB
Python
# peregrine/app/cloud_session.py
|
|
"""
|
|
Cloud session middleware for multi-tenant Peregrine deployment.
|
|
|
|
In local-first mode (CLOUD_MODE unset or false), all functions are no-ops.
|
|
In cloud mode (CLOUD_MODE=true), resolves the Directus session JWT from the
|
|
X-CF-Session header, validates it, and injects user_id + db_path into
|
|
st.session_state.
|
|
|
|
All Peregrine pages call get_db_path() instead of DEFAULT_DB directly to
|
|
transparently support both local and cloud deployments.
|
|
"""
|
|
import logging
|
|
import os
|
|
import re
|
|
import hmac
|
|
import hashlib
|
|
from pathlib import Path
|
|
|
|
import requests
|
|
import streamlit as st
|
|
|
|
from scripts.db import DEFAULT_DB
|
|
|
|
log = logging.getLogger(__name__)
|
|
|
|
CLOUD_MODE: bool = os.environ.get("CLOUD_MODE", "").lower() in ("1", "true", "yes")
|
|
CLOUD_DATA_ROOT: Path = Path(os.environ.get("CLOUD_DATA_ROOT", "/devl/menagerie-data"))
|
|
DIRECTUS_JWT_SECRET: str = os.environ.get("DIRECTUS_JWT_SECRET", "")
|
|
SERVER_SECRET: str = os.environ.get("CF_SERVER_SECRET", "")
|
|
|
|
# Heimdall license server — internal URL preferred when running on the same host
|
|
HEIMDALL_URL: str = os.environ.get("HEIMDALL_URL", "https://license.circuitforge.tech")
|
|
HEIMDALL_ADMIN_TOKEN: str = os.environ.get("HEIMDALL_ADMIN_TOKEN", "")
|
|
|
|
|
|
def _extract_session_token(cookie_header: str) -> str:
|
|
"""Extract cf_session value from a Cookie header string."""
|
|
m = re.search(r'(?:^|;)\s*cf_session=([^;]+)', cookie_header)
|
|
return m.group(1).strip() if m else ""
|
|
|
|
|
|
@st.cache_data(ttl=300, show_spinner=False)
|
|
def _fetch_cloud_tier(user_id: str, product: str) -> str:
|
|
"""Call Heimdall to resolve the current cloud tier for this user.
|
|
|
|
Cached per (user_id, product) for 5 minutes to avoid hammering Heimdall
|
|
on every Streamlit rerun. Returns "free" on any error so the app degrades
|
|
gracefully rather than blocking the user.
|
|
"""
|
|
if not HEIMDALL_ADMIN_TOKEN:
|
|
log.warning("HEIMDALL_ADMIN_TOKEN not set — defaulting tier to free")
|
|
return "free"
|
|
try:
|
|
resp = requests.post(
|
|
f"{HEIMDALL_URL}/admin/cloud/resolve",
|
|
json={"user_id": user_id, "product": product},
|
|
headers={"Authorization": f"Bearer {HEIMDALL_ADMIN_TOKEN}"},
|
|
timeout=5,
|
|
)
|
|
if resp.status_code == 200:
|
|
return resp.json().get("tier", "free")
|
|
if resp.status_code == 404:
|
|
# No cloud key yet — user signed up before provision ran; return free.
|
|
return "free"
|
|
log.warning("Heimdall resolve returned %s — defaulting tier to free", resp.status_code)
|
|
except Exception as exc:
|
|
log.warning("Heimdall tier resolve failed: %s — defaulting to free", exc)
|
|
return "free"
|
|
|
|
|
|
def validate_session_jwt(token: str) -> str:
|
|
"""Validate a Directus session JWT and return the user UUID. Raises on failure."""
|
|
import jwt # PyJWT — lazy import so local mode never needs it
|
|
payload = jwt.decode(token, DIRECTUS_JWT_SECRET, algorithms=["HS256"])
|
|
user_id = payload.get("id") or payload.get("sub")
|
|
if not user_id:
|
|
raise ValueError("JWT missing user id claim")
|
|
return user_id
|
|
|
|
|
|
def _user_data_path(user_id: str, app: str) -> Path:
|
|
return CLOUD_DATA_ROOT / user_id / app
|
|
|
|
|
|
def derive_db_key(user_id: str) -> str:
|
|
"""Derive a per-user SQLCipher encryption key from the server secret."""
|
|
return hmac.new(
|
|
SERVER_SECRET.encode(),
|
|
user_id.encode(),
|
|
hashlib.sha256,
|
|
).hexdigest()
|
|
|
|
|
|
def resolve_session(app: str = "peregrine") -> None:
|
|
"""
|
|
Call at the top of each Streamlit page.
|
|
In local mode: no-op.
|
|
In cloud mode: reads X-CF-Session header, validates JWT, creates user
|
|
data directory on first visit, and sets st.session_state keys:
|
|
- user_id: str
|
|
- db_path: Path
|
|
- db_key: str (SQLCipher key for this user)
|
|
- cloud_tier: str (free | paid | premium | ultra — resolved from Heimdall)
|
|
Idempotent — skips if user_id already in session_state.
|
|
"""
|
|
if not CLOUD_MODE:
|
|
return
|
|
if st.session_state.get("user_id"):
|
|
return
|
|
|
|
cookie_header = st.context.headers.get("x-cf-session", "")
|
|
session_jwt = _extract_session_token(cookie_header)
|
|
if not session_jwt:
|
|
st.error("Session token missing. Please log in at circuitforge.tech.")
|
|
st.stop()
|
|
|
|
try:
|
|
user_id = validate_session_jwt(session_jwt)
|
|
except Exception as exc:
|
|
st.error(f"Invalid session — please log in again. ({exc})")
|
|
st.stop()
|
|
|
|
user_path = _user_data_path(user_id, app)
|
|
user_path.mkdir(parents=True, exist_ok=True)
|
|
(user_path / "config").mkdir(exist_ok=True)
|
|
(user_path / "data").mkdir(exist_ok=True)
|
|
|
|
st.session_state["user_id"] = user_id
|
|
st.session_state["db_path"] = user_path / "staging.db"
|
|
st.session_state["db_key"] = derive_db_key(user_id)
|
|
st.session_state["cloud_tier"] = _fetch_cloud_tier(user_id, app)
|
|
|
|
|
|
def get_db_path() -> Path:
|
|
"""
|
|
Return the active db_path for this session.
|
|
Cloud: user-scoped path from session_state.
|
|
Local: DEFAULT_DB (from STAGING_DB env var or repo default).
|
|
"""
|
|
return st.session_state.get("db_path", DEFAULT_DB)
|
|
|
|
|
|
def get_cloud_tier() -> str:
|
|
"""
|
|
Return the current user's cloud tier.
|
|
Cloud mode: resolved from Heimdall at session start (cached 5 min).
|
|
Local mode: always returns "local" so pages can distinguish self-hosted from cloud.
|
|
"""
|
|
if not CLOUD_MODE:
|
|
return "local"
|
|
return st.session_state.get("cloud_tier", "free")
|