refactor: replace hand-rolled JWT+Heimdall with cf-core CloudSessionFactory
Delegates JWT validation, Heimdall provision/tier-resolve, bypass-IP handling, and guest session management to circuitforge_core. Snipe keeps its own CloudUser (shared_db + user_db), SessionFeatures, compute_features, and DB path helpers. Removes ~158 lines of duplicated auth code. Note: get_session() now takes (Request, Response) — FastAPI auto-injects both, no call-site changes needed.
This commit is contained in:
parent
ec0af07905
commit
0354234f86
1 changed files with 21 additions and 157 deletions
|
|
@ -1,11 +1,9 @@
|
|||
"""Cloud session resolution for Snipe FastAPI.
|
||||
|
||||
In local mode (CLOUD_MODE unset/false): all functions return a local CloudUser
|
||||
with no auth checks, full tier access, and both DB paths pointing to SNIPE_DB.
|
||||
|
||||
In cloud mode (CLOUD_MODE=true): validates the cf_session JWT injected by Caddy
|
||||
as X-CF-Session, resolves user_id, auto-provisions a free Heimdall license on
|
||||
first visit, fetches the tier, and returns per-user DB paths.
|
||||
Delegates JWT validation, Heimdall provisioning, tier resolution, and guest
|
||||
session management to circuitforge_core.CloudSessionFactory. Snipe-specific
|
||||
CloudUser (shared_db + user_db paths), SessionFeatures, and DB helpers are
|
||||
kept here.
|
||||
|
||||
FastAPI usage:
|
||||
@app.get("/api/search")
|
||||
|
|
@ -18,15 +16,12 @@ from __future__ import annotations
|
|||
|
||||
import logging
|
||||
import os
|
||||
import re
|
||||
import time
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
import jwt as pyjwt
|
||||
import requests
|
||||
from fastapi import Depends, HTTPException, Request
|
||||
from circuitforge_core.cloud_session import CloudSessionFactory as _CoreFactory
|
||||
from fastapi import Depends, HTTPException, Request, Response
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
|
@ -34,20 +29,13 @@ log = logging.getLogger(__name__)
|
|||
|
||||
CLOUD_MODE: bool = os.environ.get("CLOUD_MODE", "").lower() in ("1", "true", "yes")
|
||||
CLOUD_DATA_ROOT: Path = Path(os.environ.get("CLOUD_DATA_ROOT", "/devl/snipe-cloud-data"))
|
||||
DIRECTUS_JWT_SECRET: str = os.environ.get("DIRECTUS_JWT_SECRET", "")
|
||||
CF_SERVER_SECRET: str = os.environ.get("CF_SERVER_SECRET", "")
|
||||
HEIMDALL_URL: str = os.environ.get("HEIMDALL_URL", "https://license.circuitforge.tech")
|
||||
HEIMDALL_ADMIN_TOKEN: str = os.environ.get("HEIMDALL_ADMIN_TOKEN", "")
|
||||
|
||||
# Local-mode DB paths (ignored in cloud mode)
|
||||
_LOCAL_SNIPE_DB: Path = Path(os.environ.get("SNIPE_DB", "data/snipe.db"))
|
||||
|
||||
# Tier cache: user_id → (tier, fetched_at_epoch)
|
||||
_TIER_CACHE: dict[str, tuple[str, float]] = {}
|
||||
_TIER_CACHE_TTL = 300 # 5 minutes
|
||||
|
||||
TIERS = ["free", "paid", "premium", "ultra"]
|
||||
|
||||
_core = _CoreFactory(product="snipe")
|
||||
|
||||
|
||||
# ── Domain ────────────────────────────────────────────────────────────────────
|
||||
|
||||
|
|
@ -90,97 +78,6 @@ def compute_features(tier: str) -> SessionFeatures:
|
|||
)
|
||||
|
||||
|
||||
# ── JWT validation ────────────────────────────────────────────────────────────
|
||||
|
||||
def _extract_session_token(header_value: str) -> str:
|
||||
"""Extract cf_session value from a Cookie or X-CF-Session header string.
|
||||
|
||||
Returns the JWT token string, or "" if no valid session token is found.
|
||||
Cookie strings like "snipe_guest=abc123" (no cf_session key) return ""
|
||||
so the caller falls through to the guest/anonymous path rather than
|
||||
passing a non-JWT string to validate_session_jwt().
|
||||
"""
|
||||
m = re.search(r'(?:^|;)\s*cf_session=([^;]+)', header_value)
|
||||
if m:
|
||||
return m.group(1).strip()
|
||||
# Only treat as a raw JWT if it has exactly three base64url segments (header.payload.sig).
|
||||
# Cookie strings like "snipe_guest=abc123" must NOT be forwarded to JWT validation.
|
||||
stripped = header_value.strip()
|
||||
if re.match(r'^[A-Za-z0-9\-_]+\.[A-Za-z0-9\-_]+\.[A-Za-z0-9\-_=]+$', stripped):
|
||||
return stripped # bare JWT forwarded directly by Caddy
|
||||
return "" # not a JWT and no cf_session cookie — treat as unauthenticated
|
||||
|
||||
|
||||
def _extract_guest_token(cookie_header: str) -> str | None:
|
||||
"""Extract snipe_guest UUID from the Cookie header, if present."""
|
||||
m = re.search(r'(?:^|;)\s*snipe_guest=([^;]+)', cookie_header)
|
||||
return m.group(1).strip() if m else None
|
||||
|
||||
|
||||
def validate_session_jwt(token: str) -> str:
|
||||
"""Validate a cf_session JWT and return the Directus user_id.
|
||||
|
||||
Uses HMAC-SHA256 verification against DIRECTUS_JWT_SECRET (same secret
|
||||
cf-directus uses to sign session tokens). Returns user_id on success,
|
||||
raises HTTPException(401) on failure.
|
||||
|
||||
Directus 11+ uses 'id' (not 'sub') for the user UUID in its JWT payload.
|
||||
"""
|
||||
try:
|
||||
payload = pyjwt.decode(
|
||||
token,
|
||||
DIRECTUS_JWT_SECRET,
|
||||
algorithms=["HS256"],
|
||||
options={"require": ["id", "exp"]},
|
||||
)
|
||||
return payload["id"]
|
||||
except Exception as exc:
|
||||
log.debug("JWT validation failed: %s", exc)
|
||||
raise HTTPException(status_code=401, detail="Session invalid or expired")
|
||||
|
||||
|
||||
# ── Heimdall integration ──────────────────────────────────────────────────────
|
||||
|
||||
def _ensure_provisioned(user_id: str) -> None:
|
||||
"""Idempotent: create a free Heimdall license for this user if none exists."""
|
||||
if not HEIMDALL_ADMIN_TOKEN:
|
||||
return
|
||||
try:
|
||||
requests.post(
|
||||
f"{HEIMDALL_URL}/admin/provision",
|
||||
json={"directus_user_id": user_id, "product": "snipe", "tier": "free"},
|
||||
headers={"Authorization": f"Bearer {HEIMDALL_ADMIN_TOKEN}"},
|
||||
timeout=5,
|
||||
)
|
||||
except Exception as exc:
|
||||
log.warning("Heimdall provision failed for user %s: %s", user_id, exc)
|
||||
|
||||
|
||||
def _fetch_cloud_tier(user_id: str) -> str:
|
||||
"""Resolve tier from Heimdall with a 5-minute in-process cache."""
|
||||
now = time.monotonic()
|
||||
cached = _TIER_CACHE.get(user_id)
|
||||
if cached and (now - cached[1]) < _TIER_CACHE_TTL:
|
||||
return cached[0]
|
||||
|
||||
if not HEIMDALL_ADMIN_TOKEN:
|
||||
return "free"
|
||||
try:
|
||||
resp = requests.post(
|
||||
f"{HEIMDALL_URL}/admin/cloud/resolve",
|
||||
json={"directus_user_id": user_id, "product": "snipe"},
|
||||
headers={"Authorization": f"Bearer {HEIMDALL_ADMIN_TOKEN}"},
|
||||
timeout=5,
|
||||
)
|
||||
tier = resp.json().get("tier", "free") if resp.ok else "free"
|
||||
except Exception as exc:
|
||||
log.warning("Heimdall tier resolve failed for user %s: %s", user_id, exc)
|
||||
tier = "free"
|
||||
|
||||
_TIER_CACHE[user_id] = (tier, now)
|
||||
return tier
|
||||
|
||||
|
||||
# ── DB path helpers ───────────────────────────────────────────────────────────
|
||||
|
||||
def _shared_db_path() -> Path:
|
||||
|
|
@ -209,58 +106,25 @@ def _anon_db_path() -> Path:
|
|||
|
||||
# ── FastAPI dependency ────────────────────────────────────────────────────────
|
||||
|
||||
def get_session(request: Request) -> CloudUser:
|
||||
def get_session(request: Request, response: Response) -> CloudUser:
|
||||
"""FastAPI dependency — resolves the current user from the request.
|
||||
|
||||
Local mode: returns a fully-privileged "local" user pointing at SNIPE_DB.
|
||||
Delegates auth/tier resolution to cf-core CloudSessionFactory, then maps
|
||||
the result to Snipe's CloudUser with shared_db + user_db paths.
|
||||
|
||||
Local mode: fully-privileged "local" user pointing at SNIPE_DB.
|
||||
Cloud mode: validates X-CF-Session JWT, provisions Heimdall license,
|
||||
resolves tier, returns per-user DB paths.
|
||||
Unauthenticated cloud visitors: returns a free-tier anonymous user so
|
||||
search and scoring work without an account.
|
||||
Anonymous: guest session with free-tier access to shared scammer corpus.
|
||||
"""
|
||||
if not CLOUD_MODE:
|
||||
return CloudUser(
|
||||
user_id="local",
|
||||
tier="local",
|
||||
shared_db=_LOCAL_SNIPE_DB,
|
||||
user_db=_LOCAL_SNIPE_DB,
|
||||
)
|
||||
core_user = _core.resolve(request, response)
|
||||
uid, tier = core_user.user_id, core_user.tier
|
||||
|
||||
cookie_header = request.headers.get("cookie", "")
|
||||
raw_header = request.headers.get("x-cf-session", "") or cookie_header
|
||||
|
||||
if not raw_header:
|
||||
# No session at all — check for a guest UUID cookie set by /api/session
|
||||
guest_uuid = _extract_guest_token(cookie_header)
|
||||
user_id = f"guest:{guest_uuid}" if guest_uuid else "anonymous"
|
||||
return CloudUser(
|
||||
user_id=user_id,
|
||||
tier="free",
|
||||
shared_db=_shared_db_path(),
|
||||
user_db=_anon_db_path(),
|
||||
)
|
||||
|
||||
token = _extract_session_token(raw_header)
|
||||
if not token:
|
||||
guest_uuid = _extract_guest_token(cookie_header)
|
||||
user_id = f"guest:{guest_uuid}" if guest_uuid else "anonymous"
|
||||
return CloudUser(
|
||||
user_id=user_id,
|
||||
tier="free",
|
||||
shared_db=_shared_db_path(),
|
||||
user_db=_anon_db_path(),
|
||||
)
|
||||
|
||||
user_id = validate_session_jwt(token)
|
||||
_ensure_provisioned(user_id)
|
||||
tier = _fetch_cloud_tier(user_id)
|
||||
|
||||
return CloudUser(
|
||||
user_id=user_id,
|
||||
tier=tier,
|
||||
shared_db=_shared_db_path(),
|
||||
user_db=_user_db_path(user_id),
|
||||
)
|
||||
if not CLOUD_MODE or uid in ("local", "local-dev"):
|
||||
return CloudUser(user_id=uid, tier=tier, shared_db=_LOCAL_SNIPE_DB, user_db=_LOCAL_SNIPE_DB)
|
||||
if uid.startswith("anon-"):
|
||||
return CloudUser(user_id=uid, tier=tier, shared_db=_shared_db_path(), user_db=_anon_db_path())
|
||||
return CloudUser(user_id=uid, tier=tier, shared_db=_shared_db_path(), user_db=_user_db_path(uid))
|
||||
|
||||
|
||||
def require_tier(min_tier: str):
|
||||
|
|
|
|||
Loading…
Reference in a new issue