refactor: replace hand-rolled JWT+Heimdall with cf-core CloudSessionFactory
Delegates JWT validation, Heimdall provision/tier-resolve, bypass-IP handling, and guest session management to circuitforge_core. Snipe keeps its own CloudUser (shared_db + user_db), SessionFeatures, compute_features, and DB path helpers. Removes ~158 lines of duplicated auth code. Note: get_session() now takes (Request, Response) — FastAPI auto-injects both, no call-site changes needed.
This commit is contained in:
parent
ec0af07905
commit
0354234f86
1 changed files with 21 additions and 157 deletions
|
|
@ -1,11 +1,9 @@
|
||||||
"""Cloud session resolution for Snipe FastAPI.
|
"""Cloud session resolution for Snipe FastAPI.
|
||||||
|
|
||||||
In local mode (CLOUD_MODE unset/false): all functions return a local CloudUser
|
Delegates JWT validation, Heimdall provisioning, tier resolution, and guest
|
||||||
with no auth checks, full tier access, and both DB paths pointing to SNIPE_DB.
|
session management to circuitforge_core.CloudSessionFactory. Snipe-specific
|
||||||
|
CloudUser (shared_db + user_db paths), SessionFeatures, and DB helpers are
|
||||||
In cloud mode (CLOUD_MODE=true): validates the cf_session JWT injected by Caddy
|
kept here.
|
||||||
as X-CF-Session, resolves user_id, auto-provisions a free Heimdall license on
|
|
||||||
first visit, fetches the tier, and returns per-user DB paths.
|
|
||||||
|
|
||||||
FastAPI usage:
|
FastAPI usage:
|
||||||
@app.get("/api/search")
|
@app.get("/api/search")
|
||||||
|
|
@ -18,15 +16,12 @@ from __future__ import annotations
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
import re
|
|
||||||
import time
|
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
import jwt as pyjwt
|
from circuitforge_core.cloud_session import CloudSessionFactory as _CoreFactory
|
||||||
import requests
|
from fastapi import Depends, HTTPException, Request, Response
|
||||||
from fastapi import Depends, HTTPException, Request
|
|
||||||
|
|
||||||
log = logging.getLogger(__name__)
|
log = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
@ -34,20 +29,13 @@ log = logging.getLogger(__name__)
|
||||||
|
|
||||||
CLOUD_MODE: bool = os.environ.get("CLOUD_MODE", "").lower() in ("1", "true", "yes")
|
CLOUD_MODE: bool = os.environ.get("CLOUD_MODE", "").lower() in ("1", "true", "yes")
|
||||||
CLOUD_DATA_ROOT: Path = Path(os.environ.get("CLOUD_DATA_ROOT", "/devl/snipe-cloud-data"))
|
CLOUD_DATA_ROOT: Path = Path(os.environ.get("CLOUD_DATA_ROOT", "/devl/snipe-cloud-data"))
|
||||||
DIRECTUS_JWT_SECRET: str = os.environ.get("DIRECTUS_JWT_SECRET", "")
|
|
||||||
CF_SERVER_SECRET: str = os.environ.get("CF_SERVER_SECRET", "")
|
|
||||||
HEIMDALL_URL: str = os.environ.get("HEIMDALL_URL", "https://license.circuitforge.tech")
|
|
||||||
HEIMDALL_ADMIN_TOKEN: str = os.environ.get("HEIMDALL_ADMIN_TOKEN", "")
|
|
||||||
|
|
||||||
# Local-mode DB paths (ignored in cloud mode)
|
|
||||||
_LOCAL_SNIPE_DB: Path = Path(os.environ.get("SNIPE_DB", "data/snipe.db"))
|
_LOCAL_SNIPE_DB: Path = Path(os.environ.get("SNIPE_DB", "data/snipe.db"))
|
||||||
|
|
||||||
# Tier cache: user_id → (tier, fetched_at_epoch)
|
|
||||||
_TIER_CACHE: dict[str, tuple[str, float]] = {}
|
|
||||||
_TIER_CACHE_TTL = 300 # 5 minutes
|
|
||||||
|
|
||||||
TIERS = ["free", "paid", "premium", "ultra"]
|
TIERS = ["free", "paid", "premium", "ultra"]
|
||||||
|
|
||||||
|
_core = _CoreFactory(product="snipe")
|
||||||
|
|
||||||
|
|
||||||
# ── Domain ────────────────────────────────────────────────────────────────────
|
# ── Domain ────────────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
@ -90,97 +78,6 @@ def compute_features(tier: str) -> SessionFeatures:
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
# ── JWT validation ────────────────────────────────────────────────────────────
|
|
||||||
|
|
||||||
def _extract_session_token(header_value: str) -> str:
|
|
||||||
"""Extract cf_session value from a Cookie or X-CF-Session header string.
|
|
||||||
|
|
||||||
Returns the JWT token string, or "" if no valid session token is found.
|
|
||||||
Cookie strings like "snipe_guest=abc123" (no cf_session key) return ""
|
|
||||||
so the caller falls through to the guest/anonymous path rather than
|
|
||||||
passing a non-JWT string to validate_session_jwt().
|
|
||||||
"""
|
|
||||||
m = re.search(r'(?:^|;)\s*cf_session=([^;]+)', header_value)
|
|
||||||
if m:
|
|
||||||
return m.group(1).strip()
|
|
||||||
# Only treat as a raw JWT if it has exactly three base64url segments (header.payload.sig).
|
|
||||||
# Cookie strings like "snipe_guest=abc123" must NOT be forwarded to JWT validation.
|
|
||||||
stripped = header_value.strip()
|
|
||||||
if re.match(r'^[A-Za-z0-9\-_]+\.[A-Za-z0-9\-_]+\.[A-Za-z0-9\-_=]+$', stripped):
|
|
||||||
return stripped # bare JWT forwarded directly by Caddy
|
|
||||||
return "" # not a JWT and no cf_session cookie — treat as unauthenticated
|
|
||||||
|
|
||||||
|
|
||||||
def _extract_guest_token(cookie_header: str) -> str | None:
|
|
||||||
"""Extract snipe_guest UUID from the Cookie header, if present."""
|
|
||||||
m = re.search(r'(?:^|;)\s*snipe_guest=([^;]+)', cookie_header)
|
|
||||||
return m.group(1).strip() if m else None
|
|
||||||
|
|
||||||
|
|
||||||
def validate_session_jwt(token: str) -> str:
|
|
||||||
"""Validate a cf_session JWT and return the Directus user_id.
|
|
||||||
|
|
||||||
Uses HMAC-SHA256 verification against DIRECTUS_JWT_SECRET (same secret
|
|
||||||
cf-directus uses to sign session tokens). Returns user_id on success,
|
|
||||||
raises HTTPException(401) on failure.
|
|
||||||
|
|
||||||
Directus 11+ uses 'id' (not 'sub') for the user UUID in its JWT payload.
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
payload = pyjwt.decode(
|
|
||||||
token,
|
|
||||||
DIRECTUS_JWT_SECRET,
|
|
||||||
algorithms=["HS256"],
|
|
||||||
options={"require": ["id", "exp"]},
|
|
||||||
)
|
|
||||||
return payload["id"]
|
|
||||||
except Exception as exc:
|
|
||||||
log.debug("JWT validation failed: %s", exc)
|
|
||||||
raise HTTPException(status_code=401, detail="Session invalid or expired")
|
|
||||||
|
|
||||||
|
|
||||||
# ── Heimdall integration ──────────────────────────────────────────────────────
|
|
||||||
|
|
||||||
def _ensure_provisioned(user_id: str) -> None:
|
|
||||||
"""Idempotent: create a free Heimdall license for this user if none exists."""
|
|
||||||
if not HEIMDALL_ADMIN_TOKEN:
|
|
||||||
return
|
|
||||||
try:
|
|
||||||
requests.post(
|
|
||||||
f"{HEIMDALL_URL}/admin/provision",
|
|
||||||
json={"directus_user_id": user_id, "product": "snipe", "tier": "free"},
|
|
||||||
headers={"Authorization": f"Bearer {HEIMDALL_ADMIN_TOKEN}"},
|
|
||||||
timeout=5,
|
|
||||||
)
|
|
||||||
except Exception as exc:
|
|
||||||
log.warning("Heimdall provision failed for user %s: %s", user_id, exc)
|
|
||||||
|
|
||||||
|
|
||||||
def _fetch_cloud_tier(user_id: str) -> str:
|
|
||||||
"""Resolve tier from Heimdall with a 5-minute in-process cache."""
|
|
||||||
now = time.monotonic()
|
|
||||||
cached = _TIER_CACHE.get(user_id)
|
|
||||||
if cached and (now - cached[1]) < _TIER_CACHE_TTL:
|
|
||||||
return cached[0]
|
|
||||||
|
|
||||||
if not HEIMDALL_ADMIN_TOKEN:
|
|
||||||
return "free"
|
|
||||||
try:
|
|
||||||
resp = requests.post(
|
|
||||||
f"{HEIMDALL_URL}/admin/cloud/resolve",
|
|
||||||
json={"directus_user_id": user_id, "product": "snipe"},
|
|
||||||
headers={"Authorization": f"Bearer {HEIMDALL_ADMIN_TOKEN}"},
|
|
||||||
timeout=5,
|
|
||||||
)
|
|
||||||
tier = resp.json().get("tier", "free") if resp.ok else "free"
|
|
||||||
except Exception as exc:
|
|
||||||
log.warning("Heimdall tier resolve failed for user %s: %s", user_id, exc)
|
|
||||||
tier = "free"
|
|
||||||
|
|
||||||
_TIER_CACHE[user_id] = (tier, now)
|
|
||||||
return tier
|
|
||||||
|
|
||||||
|
|
||||||
# ── DB path helpers ───────────────────────────────────────────────────────────
|
# ── DB path helpers ───────────────────────────────────────────────────────────
|
||||||
|
|
||||||
def _shared_db_path() -> Path:
|
def _shared_db_path() -> Path:
|
||||||
|
|
@ -209,58 +106,25 @@ def _anon_db_path() -> Path:
|
||||||
|
|
||||||
# ── FastAPI dependency ────────────────────────────────────────────────────────
|
# ── FastAPI dependency ────────────────────────────────────────────────────────
|
||||||
|
|
||||||
def get_session(request: Request) -> CloudUser:
|
def get_session(request: Request, response: Response) -> CloudUser:
|
||||||
"""FastAPI dependency — resolves the current user from the request.
|
"""FastAPI dependency — resolves the current user from the request.
|
||||||
|
|
||||||
Local mode: returns a fully-privileged "local" user pointing at SNIPE_DB.
|
Delegates auth/tier resolution to cf-core CloudSessionFactory, then maps
|
||||||
|
the result to Snipe's CloudUser with shared_db + user_db paths.
|
||||||
|
|
||||||
|
Local mode: fully-privileged "local" user pointing at SNIPE_DB.
|
||||||
Cloud mode: validates X-CF-Session JWT, provisions Heimdall license,
|
Cloud mode: validates X-CF-Session JWT, provisions Heimdall license,
|
||||||
resolves tier, returns per-user DB paths.
|
resolves tier, returns per-user DB paths.
|
||||||
Unauthenticated cloud visitors: returns a free-tier anonymous user so
|
Anonymous: guest session with free-tier access to shared scammer corpus.
|
||||||
search and scoring work without an account.
|
|
||||||
"""
|
"""
|
||||||
if not CLOUD_MODE:
|
core_user = _core.resolve(request, response)
|
||||||
return CloudUser(
|
uid, tier = core_user.user_id, core_user.tier
|
||||||
user_id="local",
|
|
||||||
tier="local",
|
|
||||||
shared_db=_LOCAL_SNIPE_DB,
|
|
||||||
user_db=_LOCAL_SNIPE_DB,
|
|
||||||
)
|
|
||||||
|
|
||||||
cookie_header = request.headers.get("cookie", "")
|
if not CLOUD_MODE or uid in ("local", "local-dev"):
|
||||||
raw_header = request.headers.get("x-cf-session", "") or cookie_header
|
return CloudUser(user_id=uid, tier=tier, shared_db=_LOCAL_SNIPE_DB, user_db=_LOCAL_SNIPE_DB)
|
||||||
|
if uid.startswith("anon-"):
|
||||||
if not raw_header:
|
return CloudUser(user_id=uid, tier=tier, shared_db=_shared_db_path(), user_db=_anon_db_path())
|
||||||
# No session at all — check for a guest UUID cookie set by /api/session
|
return CloudUser(user_id=uid, tier=tier, shared_db=_shared_db_path(), user_db=_user_db_path(uid))
|
||||||
guest_uuid = _extract_guest_token(cookie_header)
|
|
||||||
user_id = f"guest:{guest_uuid}" if guest_uuid else "anonymous"
|
|
||||||
return CloudUser(
|
|
||||||
user_id=user_id,
|
|
||||||
tier="free",
|
|
||||||
shared_db=_shared_db_path(),
|
|
||||||
user_db=_anon_db_path(),
|
|
||||||
)
|
|
||||||
|
|
||||||
token = _extract_session_token(raw_header)
|
|
||||||
if not token:
|
|
||||||
guest_uuid = _extract_guest_token(cookie_header)
|
|
||||||
user_id = f"guest:{guest_uuid}" if guest_uuid else "anonymous"
|
|
||||||
return CloudUser(
|
|
||||||
user_id=user_id,
|
|
||||||
tier="free",
|
|
||||||
shared_db=_shared_db_path(),
|
|
||||||
user_db=_anon_db_path(),
|
|
||||||
)
|
|
||||||
|
|
||||||
user_id = validate_session_jwt(token)
|
|
||||||
_ensure_provisioned(user_id)
|
|
||||||
tier = _fetch_cloud_tier(user_id)
|
|
||||||
|
|
||||||
return CloudUser(
|
|
||||||
user_id=user_id,
|
|
||||||
tier=tier,
|
|
||||||
shared_db=_shared_db_path(),
|
|
||||||
user_db=_user_db_path(user_id),
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def require_tier(min_tier: str):
|
def require_tier(min_tier: str):
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue