feat(cloud_session): shared cloud session resolution for all CF products
Extracts the JWT validation + Heimdall tier resolution + guest session pattern that was duplicated across kiwi and peregrine into a single reusable module. CloudSessionFactory is parameterized by product name. Products instantiate it once at module level and call .dependency() to get a FastAPI-compatible Depends() function. .require_tier(min_tier) returns a dependency factory for gated routes. CloudUser carries: user_id — Directus UUID, "local" (self-hosted), "local-dev" (bypass), "anon-<uuid>" tier — free | paid | premium | ultra | local product — which CF product this session is for has_byok — whether user has a configured LLM backend meta — dict for product-specific extras (household_id, license_key, etc.) Products can pass extra_meta= to attach product-specific fields without subclassing. The module is FastAPI-only (fastapi is a lazy import so local-mode products that never hit cloud paths don't pay the import cost).
This commit is contained in:
parent
383897f990
commit
00737d22cf
1 changed files with 314 additions and 0 deletions
314
circuitforge_core/cloud_session/__init__.py
Normal file
314
circuitforge_core/cloud_session/__init__.py
Normal file
|
|
@ -0,0 +1,314 @@
|
|||
"""
|
||||
circuitforge_core.cloud_session — shared cloud session resolution for all CF products.
|
||||
|
||||
Usage (FastAPI product):
|
||||
|
||||
from circuitforge_core.cloud_session import CloudSessionFactory
|
||||
from pathlib import Path
|
||||
|
||||
_sessions = CloudSessionFactory(
|
||||
product="avocet",
|
||||
local_db=Path("data/avocet.db"),
|
||||
)
|
||||
get_session = _sessions.dependency()
|
||||
require_tier = _sessions.require_tier
|
||||
|
||||
@router.get("/api/imitate")
|
||||
def imitate(session: CloudUser = Depends(get_session)):
|
||||
# session.user_id is the Directus UUID for cloud users, "local" for self-hosted
|
||||
...
|
||||
|
||||
Environment variables (set per-product via .env / compose):
|
||||
CLOUD_MODE 1/true/yes to enable cloud auth (default: off)
|
||||
CLOUD_DATA_ROOT Root directory for per-user data (default: /devl/<product>-cloud-data)
|
||||
DIRECTUS_JWT_SECRET HS256 secret used to sign cf_session JWTs (required in cloud mode)
|
||||
HEIMDALL_URL License server base URL (default: https://license.circuitforge.tech)
|
||||
HEIMDALL_ADMIN_TOKEN Heimdall admin bearer token (required for tier resolution)
|
||||
CF_SERVER_SECRET Server-side secret for deriving per-user encryption keys
|
||||
CLOUD_AUTH_BYPASS_IPS Comma-separated IPs/CIDRs to skip JWT auth (dev LAN only)
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import ipaddress
|
||||
import logging
|
||||
import os
|
||||
import re
|
||||
import time
|
||||
import uuid
|
||||
from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
from typing import Any, Callable
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
TIERS: list[str] = ["free", "paid", "premium", "ultra"]
|
||||
|
||||
# ── CloudUser ─────────────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class CloudUser:
|
||||
"""Resolved user identity for one HTTP request.
|
||||
|
||||
user_id: Directus UUID for authenticated cloud users.
|
||||
"local" for self-hosted / CLOUD_MODE=false.
|
||||
"local-dev" for dev-bypass-IP sessions.
|
||||
"anon-<uuid>" for unauthenticated guest visitors.
|
||||
tier: free | paid | premium | ultra | local
|
||||
product: Which CF product this session belongs to (e.g. "avocet").
|
||||
meta: Product-specific extras (e.g. household_id for Kiwi).
|
||||
Access via session.meta.get("household_id").
|
||||
"""
|
||||
user_id: str
|
||||
tier: str
|
||||
product: str
|
||||
has_byok: bool = False
|
||||
meta: dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
|
||||
# ── Helpers ───────────────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
def _parse_bypass_nets(raw: str) -> tuple[list[ipaddress.IPv4Network | ipaddress.IPv6Network], frozenset[str]]:
|
||||
nets: list[ipaddress.IPv4Network | ipaddress.IPv6Network] = []
|
||||
ips: set[str] = set()
|
||||
for entry in (e.strip() for e in raw.split(",") if e.strip()):
|
||||
try:
|
||||
nets.append(ipaddress.ip_network(entry, strict=False))
|
||||
except ValueError:
|
||||
ips.add(entry)
|
||||
return nets, frozenset(ips)
|
||||
|
||||
|
||||
def _is_bypass_ip(
|
||||
ip: str,
|
||||
nets: list[ipaddress.IPv4Network | ipaddress.IPv6Network],
|
||||
ips: frozenset[str],
|
||||
) -> bool:
|
||||
if not ip or (not nets and not ips):
|
||||
return False
|
||||
if ip in ips:
|
||||
return True
|
||||
try:
|
||||
addr = ipaddress.ip_address(ip)
|
||||
return any(addr in net for net in nets)
|
||||
except ValueError:
|
||||
return False
|
||||
|
||||
|
||||
def _extract_session_token(header_value: str) -> str:
|
||||
"""Pull cf_session value out of a raw Cookie header or return the value as-is."""
|
||||
m = re.search(r'(?:^|;)\s*cf_session=([^;]+)', header_value)
|
||||
return m.group(1).strip() if m else header_value.strip()
|
||||
|
||||
|
||||
# ── CloudSessionFactory ───────────────────────────────────────────────────────
|
||||
|
||||
|
||||
class CloudSessionFactory:
|
||||
"""Per-product session factory. Instantiate once at module level.
|
||||
|
||||
Args:
|
||||
product: Product code string (e.g. "avocet", "kiwi").
|
||||
extra_meta: Optional async-or-sync callable that receives
|
||||
(user_id: str, tier: str) and returns a dict merged
|
||||
into CloudUser.meta. Use for product-specific fields
|
||||
like household_id.
|
||||
byok_detector: Callable() → bool. Override to detect BYOK for this
|
||||
product's config path. Default: always False.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
product: str,
|
||||
extra_meta: Callable[[str, str], dict[str, Any]] | None = None,
|
||||
byok_detector: Callable[[], bool] | None = None,
|
||||
) -> None:
|
||||
self.product = product
|
||||
self._extra_meta = extra_meta
|
||||
self._byok_detector = byok_detector or (lambda: False)
|
||||
|
||||
# Config — read from environment at construction time so tests can patch env
|
||||
self._cloud_mode: bool = os.environ.get("CLOUD_MODE", "").lower() in ("1", "true", "yes")
|
||||
self._directus_secret: str = os.environ.get("DIRECTUS_JWT_SECRET", "")
|
||||
self._heimdall_url: str = os.environ.get("HEIMDALL_URL", "https://license.circuitforge.tech")
|
||||
self._heimdall_token: str = os.environ.get("HEIMDALL_ADMIN_TOKEN", "")
|
||||
self._cloud_data_root: Path = Path(
|
||||
os.environ.get("CLOUD_DATA_ROOT", f"/devl/{product}-cloud-data")
|
||||
)
|
||||
|
||||
_bypass_raw = os.environ.get("CLOUD_AUTH_BYPASS_IPS", "")
|
||||
self._bypass_nets, self._bypass_ips = _parse_bypass_nets(_bypass_raw)
|
||||
|
||||
# Tier resolution cache: {user_id: (result_dict, timestamp)}
|
||||
self._tier_cache: dict[str, tuple[dict, float]] = {}
|
||||
self._tier_cache_ttl: float = 300.0 # 5 minutes
|
||||
|
||||
# ── JWT ───────────────────────────────────────────────────────────────────
|
||||
|
||||
def validate_jwt(self, token: str) -> str:
|
||||
"""Validate a cf_session JWT and return the Directus user_id. Raises HTTPException on failure."""
|
||||
try:
|
||||
import jwt as pyjwt # lazy — not needed in local mode
|
||||
from fastapi import HTTPException
|
||||
payload = pyjwt.decode(
|
||||
token,
|
||||
self._directus_secret,
|
||||
algorithms=["HS256"],
|
||||
options={"require": ["id", "exp"]},
|
||||
)
|
||||
return payload["id"]
|
||||
except Exception as exc:
|
||||
log.debug("JWT validation failed: %s", exc)
|
||||
from fastapi import HTTPException
|
||||
raise HTTPException(status_code=401, detail="Session invalid or expired")
|
||||
|
||||
# ── Heimdall ──────────────────────────────────────────────────────────────
|
||||
|
||||
def _ensure_provisioned(self, user_id: str) -> None:
|
||||
if not self._heimdall_token:
|
||||
return
|
||||
try:
|
||||
import requests
|
||||
requests.post(
|
||||
f"{self._heimdall_url}/admin/provision",
|
||||
json={"directus_user_id": user_id, "product": self.product, "tier": "free"},
|
||||
headers={"Authorization": f"Bearer {self._heimdall_token}"},
|
||||
timeout=5,
|
||||
)
|
||||
except Exception as exc:
|
||||
log.warning("Heimdall provision failed for user %s: %s", user_id, exc)
|
||||
|
||||
def _resolve_tier(self, user_id: str) -> dict[str, Any]:
|
||||
"""Returns dict with keys: tier, license_key (and any product extras)."""
|
||||
now = time.monotonic()
|
||||
cached = self._tier_cache.get(user_id)
|
||||
if cached and (now - cached[1]) < self._tier_cache_ttl:
|
||||
return cached[0]
|
||||
|
||||
result: dict[str, Any] = {"tier": "free", "license_key": None}
|
||||
if self._heimdall_token:
|
||||
try:
|
||||
import requests
|
||||
resp = requests.post(
|
||||
f"{self._heimdall_url}/admin/cloud/resolve",
|
||||
json={"directus_user_id": user_id, "product": self.product},
|
||||
headers={"Authorization": f"Bearer {self._heimdall_token}"},
|
||||
timeout=5,
|
||||
)
|
||||
if resp.ok:
|
||||
data = resp.json()
|
||||
result["tier"] = data.get("tier", "free")
|
||||
result["license_key"] = data.get("key_display")
|
||||
# Forward any extra fields Heimdall returns (household_id etc.)
|
||||
result.update({k: v for k, v in data.items() if k not in result})
|
||||
except Exception as exc:
|
||||
log.warning("Heimdall tier resolve failed for %s: %s", user_id, exc)
|
||||
else:
|
||||
log.debug("HEIMDALL_ADMIN_TOKEN not set — defaulting tier to free")
|
||||
|
||||
self._tier_cache[user_id] = (result, now)
|
||||
return result
|
||||
|
||||
# ── Guest sessions ────────────────────────────────────────────────────────
|
||||
|
||||
_GUEST_COOKIE = "cf_guest_id"
|
||||
_GUEST_COOKIE_MAX_AGE = 60 * 60 * 24 * 90 # 90 days
|
||||
|
||||
def _resolve_guest(self, request: Any, response: Any) -> CloudUser:
|
||||
guest_id = (request.cookies.get(self._GUEST_COOKIE) or "").strip()
|
||||
if not guest_id:
|
||||
guest_id = str(uuid.uuid4())
|
||||
is_https = request.headers.get("x-forwarded-proto", "http").lower() == "https"
|
||||
response.set_cookie(
|
||||
key=self._GUEST_COOKIE,
|
||||
value=guest_id,
|
||||
max_age=self._GUEST_COOKIE_MAX_AGE,
|
||||
httponly=True,
|
||||
samesite="lax",
|
||||
secure=is_https,
|
||||
)
|
||||
return CloudUser(
|
||||
user_id=f"anon-{guest_id}",
|
||||
tier="free",
|
||||
product=self.product,
|
||||
has_byok=self._byok_detector(),
|
||||
)
|
||||
|
||||
# ── Core resolver ─────────────────────────────────────────────────────────
|
||||
|
||||
def resolve(self, request: Any, response: Any) -> CloudUser:
|
||||
"""Resolve the CloudUser for a FastAPI request. Suitable as a Depends() target."""
|
||||
has_byok = self._byok_detector()
|
||||
|
||||
if not self._cloud_mode:
|
||||
return CloudUser(user_id="local", tier="local", product=self.product, has_byok=has_byok)
|
||||
|
||||
client_ip = (
|
||||
request.headers.get("x-real-ip", "")
|
||||
or (request.client.host if request.client else "")
|
||||
)
|
||||
if _is_bypass_ip(client_ip, self._bypass_nets, self._bypass_ips):
|
||||
log.debug("Bypass IP %s — returning local-dev session for product %s", client_ip, self.product)
|
||||
return CloudUser(user_id="local-dev", tier="local", product=self.product, has_byok=has_byok)
|
||||
|
||||
raw_session = (
|
||||
request.headers.get("x-cf-session", "").strip()
|
||||
or request.cookies.get("cf_session", "").strip()
|
||||
)
|
||||
if not raw_session:
|
||||
return self._resolve_guest(request, response)
|
||||
|
||||
token = _extract_session_token(raw_session)
|
||||
if not token:
|
||||
return self._resolve_guest(request, response)
|
||||
|
||||
user_id = self.validate_jwt(token)
|
||||
self._ensure_provisioned(user_id)
|
||||
tier_data = self._resolve_tier(user_id)
|
||||
tier = tier_data.get("tier", "free")
|
||||
|
||||
meta: dict[str, Any] = {}
|
||||
if self._extra_meta:
|
||||
meta = self._extra_meta(user_id, tier) or {}
|
||||
# Merge any extra fields from Heimdall response (e.g. household_id)
|
||||
meta.update({k: v for k, v in tier_data.items() if k not in ("tier", "license_key")})
|
||||
meta["license_key"] = tier_data.get("license_key")
|
||||
|
||||
return CloudUser(
|
||||
user_id=user_id,
|
||||
tier=tier,
|
||||
product=self.product,
|
||||
has_byok=has_byok,
|
||||
meta=meta,
|
||||
)
|
||||
|
||||
def dependency(self) -> Callable[[Any, Any], CloudUser]:
|
||||
"""Return a FastAPI-compatible dependency function (use with Depends())."""
|
||||
factory = self
|
||||
|
||||
def _get_session(request: Any, response: Any) -> CloudUser:
|
||||
return factory.resolve(request, response)
|
||||
|
||||
return _get_session
|
||||
|
||||
def require_tier(self, min_tier: str) -> Callable:
|
||||
"""Dependency factory — raises 403 if the session tier is below min_tier."""
|
||||
from fastapi import Depends, HTTPException
|
||||
min_idx = TIERS.index(min_tier)
|
||||
get_session = self.dependency()
|
||||
|
||||
def _check(session: CloudUser = Depends(get_session)) -> CloudUser:
|
||||
if session.tier in ("local", "local-dev"):
|
||||
return session
|
||||
try:
|
||||
if TIERS.index(session.tier) < min_idx:
|
||||
raise HTTPException(
|
||||
status_code=403,
|
||||
detail=f"This feature requires {min_tier} tier or above.",
|
||||
)
|
||||
except ValueError:
|
||||
raise HTTPException(status_code=403, detail="Unknown tier.")
|
||||
return session
|
||||
|
||||
return _check
|
||||
Loading…
Reference in a new issue