diff --git a/api/cloud_session.py b/api/cloud_session.py index 99d9de6..54ee454 100644 --- a/api/cloud_session.py +++ b/api/cloud_session.py @@ -1,11 +1,9 @@ """Cloud session resolution for Snipe FastAPI. -In local mode (CLOUD_MODE unset/false): all functions return a local CloudUser -with no auth checks, full tier access, and both DB paths pointing to SNIPE_DB. - -In cloud mode (CLOUD_MODE=true): validates the cf_session JWT injected by Caddy -as X-CF-Session, resolves user_id, auto-provisions a free Heimdall license on -first visit, fetches the tier, and returns per-user DB paths. +Delegates JWT validation, Heimdall provisioning, tier resolution, and guest +session management to circuitforge_core.CloudSessionFactory. Snipe-specific +CloudUser (shared_db + user_db paths), SessionFeatures, and DB helpers are +kept here. FastAPI usage: @app.get("/api/search") @@ -18,15 +16,12 @@ from __future__ import annotations import logging import os -import re -import time from dataclasses import dataclass from pathlib import Path from typing import Optional -import jwt as pyjwt -import requests -from fastapi import Depends, HTTPException, Request +from circuitforge_core.cloud_session import CloudSessionFactory as _CoreFactory +from fastapi import Depends, HTTPException, Request, Response log = logging.getLogger(__name__) @@ -34,20 +29,13 @@ log = logging.getLogger(__name__) CLOUD_MODE: bool = os.environ.get("CLOUD_MODE", "").lower() in ("1", "true", "yes") CLOUD_DATA_ROOT: Path = Path(os.environ.get("CLOUD_DATA_ROOT", "/devl/snipe-cloud-data")) -DIRECTUS_JWT_SECRET: str = os.environ.get("DIRECTUS_JWT_SECRET", "") -CF_SERVER_SECRET: str = os.environ.get("CF_SERVER_SECRET", "") -HEIMDALL_URL: str = os.environ.get("HEIMDALL_URL", "https://license.circuitforge.tech") -HEIMDALL_ADMIN_TOKEN: str = os.environ.get("HEIMDALL_ADMIN_TOKEN", "") -# Local-mode DB paths (ignored in cloud mode) _LOCAL_SNIPE_DB: Path = Path(os.environ.get("SNIPE_DB", "data/snipe.db")) -# Tier cache: user_id → (tier, fetched_at_epoch) -_TIER_CACHE: dict[str, tuple[str, float]] = {} -_TIER_CACHE_TTL = 300 # 5 minutes - TIERS = ["free", "paid", "premium", "ultra"] +_core = _CoreFactory(product="snipe") + # ── Domain ──────────────────────────────────────────────────────────────────── @@ -90,97 +78,6 @@ def compute_features(tier: str) -> SessionFeatures: ) -# ── JWT validation ──────────────────────────────────────────────────────────── - -def _extract_session_token(header_value: str) -> str: - """Extract cf_session value from a Cookie or X-CF-Session header string. - - Returns the JWT token string, or "" if no valid session token is found. - Cookie strings like "snipe_guest=abc123" (no cf_session key) return "" - so the caller falls through to the guest/anonymous path rather than - passing a non-JWT string to validate_session_jwt(). - """ - m = re.search(r'(?:^|;)\s*cf_session=([^;]+)', header_value) - if m: - return m.group(1).strip() - # Only treat as a raw JWT if it has exactly three base64url segments (header.payload.sig). - # Cookie strings like "snipe_guest=abc123" must NOT be forwarded to JWT validation. - stripped = header_value.strip() - if re.match(r'^[A-Za-z0-9\-_]+\.[A-Za-z0-9\-_]+\.[A-Za-z0-9\-_=]+$', stripped): - return stripped # bare JWT forwarded directly by Caddy - return "" # not a JWT and no cf_session cookie — treat as unauthenticated - - -def _extract_guest_token(cookie_header: str) -> str | None: - """Extract snipe_guest UUID from the Cookie header, if present.""" - m = re.search(r'(?:^|;)\s*snipe_guest=([^;]+)', cookie_header) - return m.group(1).strip() if m else None - - -def validate_session_jwt(token: str) -> str: - """Validate a cf_session JWT and return the Directus user_id. - - Uses HMAC-SHA256 verification against DIRECTUS_JWT_SECRET (same secret - cf-directus uses to sign session tokens). Returns user_id on success, - raises HTTPException(401) on failure. - - Directus 11+ uses 'id' (not 'sub') for the user UUID in its JWT payload. - """ - try: - payload = pyjwt.decode( - token, - DIRECTUS_JWT_SECRET, - algorithms=["HS256"], - options={"require": ["id", "exp"]}, - ) - return payload["id"] - except Exception as exc: - log.debug("JWT validation failed: %s", exc) - raise HTTPException(status_code=401, detail="Session invalid or expired") - - -# ── Heimdall integration ────────────────────────────────────────────────────── - -def _ensure_provisioned(user_id: str) -> None: - """Idempotent: create a free Heimdall license for this user if none exists.""" - if not HEIMDALL_ADMIN_TOKEN: - return - try: - requests.post( - f"{HEIMDALL_URL}/admin/provision", - json={"directus_user_id": user_id, "product": "snipe", "tier": "free"}, - headers={"Authorization": f"Bearer {HEIMDALL_ADMIN_TOKEN}"}, - timeout=5, - ) - except Exception as exc: - log.warning("Heimdall provision failed for user %s: %s", user_id, exc) - - -def _fetch_cloud_tier(user_id: str) -> str: - """Resolve tier from Heimdall with a 5-minute in-process cache.""" - now = time.monotonic() - cached = _TIER_CACHE.get(user_id) - if cached and (now - cached[1]) < _TIER_CACHE_TTL: - return cached[0] - - if not HEIMDALL_ADMIN_TOKEN: - return "free" - try: - resp = requests.post( - f"{HEIMDALL_URL}/admin/cloud/resolve", - json={"directus_user_id": user_id, "product": "snipe"}, - headers={"Authorization": f"Bearer {HEIMDALL_ADMIN_TOKEN}"}, - timeout=5, - ) - tier = resp.json().get("tier", "free") if resp.ok else "free" - except Exception as exc: - log.warning("Heimdall tier resolve failed for user %s: %s", user_id, exc) - tier = "free" - - _TIER_CACHE[user_id] = (tier, now) - return tier - - # ── DB path helpers ─────────────────────────────────────────────────────────── def _shared_db_path() -> Path: @@ -209,58 +106,25 @@ def _anon_db_path() -> Path: # ── FastAPI dependency ──────────────────────────────────────────────────────── -def get_session(request: Request) -> CloudUser: +def get_session(request: Request, response: Response) -> CloudUser: """FastAPI dependency — resolves the current user from the request. - Local mode: returns a fully-privileged "local" user pointing at SNIPE_DB. + Delegates auth/tier resolution to cf-core CloudSessionFactory, then maps + the result to Snipe's CloudUser with shared_db + user_db paths. + + Local mode: fully-privileged "local" user pointing at SNIPE_DB. Cloud mode: validates X-CF-Session JWT, provisions Heimdall license, resolves tier, returns per-user DB paths. - Unauthenticated cloud visitors: returns a free-tier anonymous user so - search and scoring work without an account. + Anonymous: guest session with free-tier access to shared scammer corpus. """ - if not CLOUD_MODE: - return CloudUser( - user_id="local", - tier="local", - shared_db=_LOCAL_SNIPE_DB, - user_db=_LOCAL_SNIPE_DB, - ) + core_user = _core.resolve(request, response) + uid, tier = core_user.user_id, core_user.tier - cookie_header = request.headers.get("cookie", "") - raw_header = request.headers.get("x-cf-session", "") or cookie_header - - if not raw_header: - # No session at all — check for a guest UUID cookie set by /api/session - guest_uuid = _extract_guest_token(cookie_header) - user_id = f"guest:{guest_uuid}" if guest_uuid else "anonymous" - return CloudUser( - user_id=user_id, - tier="free", - shared_db=_shared_db_path(), - user_db=_anon_db_path(), - ) - - token = _extract_session_token(raw_header) - if not token: - guest_uuid = _extract_guest_token(cookie_header) - user_id = f"guest:{guest_uuid}" if guest_uuid else "anonymous" - return CloudUser( - user_id=user_id, - tier="free", - shared_db=_shared_db_path(), - user_db=_anon_db_path(), - ) - - user_id = validate_session_jwt(token) - _ensure_provisioned(user_id) - tier = _fetch_cloud_tier(user_id) - - return CloudUser( - user_id=user_id, - tier=tier, - shared_db=_shared_db_path(), - user_db=_user_db_path(user_id), - ) + if not CLOUD_MODE or uid in ("local", "local-dev"): + return CloudUser(user_id=uid, tier=tier, shared_db=_LOCAL_SNIPE_DB, user_db=_LOCAL_SNIPE_DB) + if uid.startswith("anon-"): + return CloudUser(user_id=uid, tier=tier, shared_db=_shared_db_path(), user_db=_anon_db_path()) + return CloudUser(user_id=uid, tier=tier, shared_db=_shared_db_path(), user_db=_user_db_path(uid)) def require_tier(min_tier: str):