feat(cloud_session): shared cloud session resolution for all CF products
Extracts the JWT validation + Heimdall tier resolution + guest session pattern that was duplicated across kiwi and peregrine into a single reusable module. CloudSessionFactory is parameterized by product name. Products instantiate it once at module level and call .dependency() to get a FastAPI-compatible Depends() function. .require_tier(min_tier) returns a dependency factory for gated routes. CloudUser carries: user_id — Directus UUID, "local" (self-hosted), "local-dev" (bypass), "anon-<uuid>" tier — free | paid | premium | ultra | local product — which CF product this session is for has_byok — whether user has a configured LLM backend meta — dict for product-specific extras (household_id, license_key, etc.) Products can pass extra_meta= to attach product-specific fields without subclassing. The module is FastAPI-only (fastapi is a lazy import so local-mode products that never hit cloud paths don't pay the import cost).
This commit is contained in:
parent
383897f990
commit
00737d22cf
1 changed files with 314 additions and 0 deletions
314
circuitforge_core/cloud_session/__init__.py
Normal file
314
circuitforge_core/cloud_session/__init__.py
Normal file
|
|
@ -0,0 +1,314 @@
|
||||||
|
"""
|
||||||
|
circuitforge_core.cloud_session — shared cloud session resolution for all CF products.
|
||||||
|
|
||||||
|
Usage (FastAPI product):
|
||||||
|
|
||||||
|
from circuitforge_core.cloud_session import CloudSessionFactory
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
_sessions = CloudSessionFactory(
|
||||||
|
product="avocet",
|
||||||
|
local_db=Path("data/avocet.db"),
|
||||||
|
)
|
||||||
|
get_session = _sessions.dependency()
|
||||||
|
require_tier = _sessions.require_tier
|
||||||
|
|
||||||
|
@router.get("/api/imitate")
|
||||||
|
def imitate(session: CloudUser = Depends(get_session)):
|
||||||
|
# session.user_id is the Directus UUID for cloud users, "local" for self-hosted
|
||||||
|
...
|
||||||
|
|
||||||
|
Environment variables (set per-product via .env / compose):
|
||||||
|
CLOUD_MODE 1/true/yes to enable cloud auth (default: off)
|
||||||
|
CLOUD_DATA_ROOT Root directory for per-user data (default: /devl/<product>-cloud-data)
|
||||||
|
DIRECTUS_JWT_SECRET HS256 secret used to sign cf_session JWTs (required in cloud mode)
|
||||||
|
HEIMDALL_URL License server base URL (default: https://license.circuitforge.tech)
|
||||||
|
HEIMDALL_ADMIN_TOKEN Heimdall admin bearer token (required for tier resolution)
|
||||||
|
CF_SERVER_SECRET Server-side secret for deriving per-user encryption keys
|
||||||
|
CLOUD_AUTH_BYPASS_IPS Comma-separated IPs/CIDRs to skip JWT auth (dev LAN only)
|
||||||
|
"""
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import ipaddress
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
import re
|
||||||
|
import time
|
||||||
|
import uuid
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Any, Callable
|
||||||
|
|
||||||
|
log = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
TIERS: list[str] = ["free", "paid", "premium", "ultra"]
|
||||||
|
|
||||||
|
# ── CloudUser ─────────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True)
|
||||||
|
class CloudUser:
|
||||||
|
"""Resolved user identity for one HTTP request.
|
||||||
|
|
||||||
|
user_id: Directus UUID for authenticated cloud users.
|
||||||
|
"local" for self-hosted / CLOUD_MODE=false.
|
||||||
|
"local-dev" for dev-bypass-IP sessions.
|
||||||
|
"anon-<uuid>" for unauthenticated guest visitors.
|
||||||
|
tier: free | paid | premium | ultra | local
|
||||||
|
product: Which CF product this session belongs to (e.g. "avocet").
|
||||||
|
meta: Product-specific extras (e.g. household_id for Kiwi).
|
||||||
|
Access via session.meta.get("household_id").
|
||||||
|
"""
|
||||||
|
user_id: str
|
||||||
|
tier: str
|
||||||
|
product: str
|
||||||
|
has_byok: bool = False
|
||||||
|
meta: dict[str, Any] = field(default_factory=dict)
|
||||||
|
|
||||||
|
|
||||||
|
# ── Helpers ───────────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
def _parse_bypass_nets(raw: str) -> tuple[list[ipaddress.IPv4Network | ipaddress.IPv6Network], frozenset[str]]:
|
||||||
|
nets: list[ipaddress.IPv4Network | ipaddress.IPv6Network] = []
|
||||||
|
ips: set[str] = set()
|
||||||
|
for entry in (e.strip() for e in raw.split(",") if e.strip()):
|
||||||
|
try:
|
||||||
|
nets.append(ipaddress.ip_network(entry, strict=False))
|
||||||
|
except ValueError:
|
||||||
|
ips.add(entry)
|
||||||
|
return nets, frozenset(ips)
|
||||||
|
|
||||||
|
|
||||||
|
def _is_bypass_ip(
|
||||||
|
ip: str,
|
||||||
|
nets: list[ipaddress.IPv4Network | ipaddress.IPv6Network],
|
||||||
|
ips: frozenset[str],
|
||||||
|
) -> bool:
|
||||||
|
if not ip or (not nets and not ips):
|
||||||
|
return False
|
||||||
|
if ip in ips:
|
||||||
|
return True
|
||||||
|
try:
|
||||||
|
addr = ipaddress.ip_address(ip)
|
||||||
|
return any(addr in net for net in nets)
|
||||||
|
except ValueError:
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
def _extract_session_token(header_value: str) -> str:
|
||||||
|
"""Pull cf_session value out of a raw Cookie header or return the value as-is."""
|
||||||
|
m = re.search(r'(?:^|;)\s*cf_session=([^;]+)', header_value)
|
||||||
|
return m.group(1).strip() if m else header_value.strip()
|
||||||
|
|
||||||
|
|
||||||
|
# ── CloudSessionFactory ───────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
class CloudSessionFactory:
|
||||||
|
"""Per-product session factory. Instantiate once at module level.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
product: Product code string (e.g. "avocet", "kiwi").
|
||||||
|
extra_meta: Optional async-or-sync callable that receives
|
||||||
|
(user_id: str, tier: str) and returns a dict merged
|
||||||
|
into CloudUser.meta. Use for product-specific fields
|
||||||
|
like household_id.
|
||||||
|
byok_detector: Callable() → bool. Override to detect BYOK for this
|
||||||
|
product's config path. Default: always False.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
product: str,
|
||||||
|
extra_meta: Callable[[str, str], dict[str, Any]] | None = None,
|
||||||
|
byok_detector: Callable[[], bool] | None = None,
|
||||||
|
) -> None:
|
||||||
|
self.product = product
|
||||||
|
self._extra_meta = extra_meta
|
||||||
|
self._byok_detector = byok_detector or (lambda: False)
|
||||||
|
|
||||||
|
# Config — read from environment at construction time so tests can patch env
|
||||||
|
self._cloud_mode: bool = os.environ.get("CLOUD_MODE", "").lower() in ("1", "true", "yes")
|
||||||
|
self._directus_secret: str = os.environ.get("DIRECTUS_JWT_SECRET", "")
|
||||||
|
self._heimdall_url: str = os.environ.get("HEIMDALL_URL", "https://license.circuitforge.tech")
|
||||||
|
self._heimdall_token: str = os.environ.get("HEIMDALL_ADMIN_TOKEN", "")
|
||||||
|
self._cloud_data_root: Path = Path(
|
||||||
|
os.environ.get("CLOUD_DATA_ROOT", f"/devl/{product}-cloud-data")
|
||||||
|
)
|
||||||
|
|
||||||
|
_bypass_raw = os.environ.get("CLOUD_AUTH_BYPASS_IPS", "")
|
||||||
|
self._bypass_nets, self._bypass_ips = _parse_bypass_nets(_bypass_raw)
|
||||||
|
|
||||||
|
# Tier resolution cache: {user_id: (result_dict, timestamp)}
|
||||||
|
self._tier_cache: dict[str, tuple[dict, float]] = {}
|
||||||
|
self._tier_cache_ttl: float = 300.0 # 5 minutes
|
||||||
|
|
||||||
|
# ── JWT ───────────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
def validate_jwt(self, token: str) -> str:
|
||||||
|
"""Validate a cf_session JWT and return the Directus user_id. Raises HTTPException on failure."""
|
||||||
|
try:
|
||||||
|
import jwt as pyjwt # lazy — not needed in local mode
|
||||||
|
from fastapi import HTTPException
|
||||||
|
payload = pyjwt.decode(
|
||||||
|
token,
|
||||||
|
self._directus_secret,
|
||||||
|
algorithms=["HS256"],
|
||||||
|
options={"require": ["id", "exp"]},
|
||||||
|
)
|
||||||
|
return payload["id"]
|
||||||
|
except Exception as exc:
|
||||||
|
log.debug("JWT validation failed: %s", exc)
|
||||||
|
from fastapi import HTTPException
|
||||||
|
raise HTTPException(status_code=401, detail="Session invalid or expired")
|
||||||
|
|
||||||
|
# ── Heimdall ──────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
def _ensure_provisioned(self, user_id: str) -> None:
|
||||||
|
if not self._heimdall_token:
|
||||||
|
return
|
||||||
|
try:
|
||||||
|
import requests
|
||||||
|
requests.post(
|
||||||
|
f"{self._heimdall_url}/admin/provision",
|
||||||
|
json={"directus_user_id": user_id, "product": self.product, "tier": "free"},
|
||||||
|
headers={"Authorization": f"Bearer {self._heimdall_token}"},
|
||||||
|
timeout=5,
|
||||||
|
)
|
||||||
|
except Exception as exc:
|
||||||
|
log.warning("Heimdall provision failed for user %s: %s", user_id, exc)
|
||||||
|
|
||||||
|
def _resolve_tier(self, user_id: str) -> dict[str, Any]:
|
||||||
|
"""Returns dict with keys: tier, license_key (and any product extras)."""
|
||||||
|
now = time.monotonic()
|
||||||
|
cached = self._tier_cache.get(user_id)
|
||||||
|
if cached and (now - cached[1]) < self._tier_cache_ttl:
|
||||||
|
return cached[0]
|
||||||
|
|
||||||
|
result: dict[str, Any] = {"tier": "free", "license_key": None}
|
||||||
|
if self._heimdall_token:
|
||||||
|
try:
|
||||||
|
import requests
|
||||||
|
resp = requests.post(
|
||||||
|
f"{self._heimdall_url}/admin/cloud/resolve",
|
||||||
|
json={"directus_user_id": user_id, "product": self.product},
|
||||||
|
headers={"Authorization": f"Bearer {self._heimdall_token}"},
|
||||||
|
timeout=5,
|
||||||
|
)
|
||||||
|
if resp.ok:
|
||||||
|
data = resp.json()
|
||||||
|
result["tier"] = data.get("tier", "free")
|
||||||
|
result["license_key"] = data.get("key_display")
|
||||||
|
# Forward any extra fields Heimdall returns (household_id etc.)
|
||||||
|
result.update({k: v for k, v in data.items() if k not in result})
|
||||||
|
except Exception as exc:
|
||||||
|
log.warning("Heimdall tier resolve failed for %s: %s", user_id, exc)
|
||||||
|
else:
|
||||||
|
log.debug("HEIMDALL_ADMIN_TOKEN not set — defaulting tier to free")
|
||||||
|
|
||||||
|
self._tier_cache[user_id] = (result, now)
|
||||||
|
return result
|
||||||
|
|
||||||
|
# ── Guest sessions ────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
_GUEST_COOKIE = "cf_guest_id"
|
||||||
|
_GUEST_COOKIE_MAX_AGE = 60 * 60 * 24 * 90 # 90 days
|
||||||
|
|
||||||
|
def _resolve_guest(self, request: Any, response: Any) -> CloudUser:
|
||||||
|
guest_id = (request.cookies.get(self._GUEST_COOKIE) or "").strip()
|
||||||
|
if not guest_id:
|
||||||
|
guest_id = str(uuid.uuid4())
|
||||||
|
is_https = request.headers.get("x-forwarded-proto", "http").lower() == "https"
|
||||||
|
response.set_cookie(
|
||||||
|
key=self._GUEST_COOKIE,
|
||||||
|
value=guest_id,
|
||||||
|
max_age=self._GUEST_COOKIE_MAX_AGE,
|
||||||
|
httponly=True,
|
||||||
|
samesite="lax",
|
||||||
|
secure=is_https,
|
||||||
|
)
|
||||||
|
return CloudUser(
|
||||||
|
user_id=f"anon-{guest_id}",
|
||||||
|
tier="free",
|
||||||
|
product=self.product,
|
||||||
|
has_byok=self._byok_detector(),
|
||||||
|
)
|
||||||
|
|
||||||
|
# ── Core resolver ─────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
def resolve(self, request: Any, response: Any) -> CloudUser:
|
||||||
|
"""Resolve the CloudUser for a FastAPI request. Suitable as a Depends() target."""
|
||||||
|
has_byok = self._byok_detector()
|
||||||
|
|
||||||
|
if not self._cloud_mode:
|
||||||
|
return CloudUser(user_id="local", tier="local", product=self.product, has_byok=has_byok)
|
||||||
|
|
||||||
|
client_ip = (
|
||||||
|
request.headers.get("x-real-ip", "")
|
||||||
|
or (request.client.host if request.client else "")
|
||||||
|
)
|
||||||
|
if _is_bypass_ip(client_ip, self._bypass_nets, self._bypass_ips):
|
||||||
|
log.debug("Bypass IP %s — returning local-dev session for product %s", client_ip, self.product)
|
||||||
|
return CloudUser(user_id="local-dev", tier="local", product=self.product, has_byok=has_byok)
|
||||||
|
|
||||||
|
raw_session = (
|
||||||
|
request.headers.get("x-cf-session", "").strip()
|
||||||
|
or request.cookies.get("cf_session", "").strip()
|
||||||
|
)
|
||||||
|
if not raw_session:
|
||||||
|
return self._resolve_guest(request, response)
|
||||||
|
|
||||||
|
token = _extract_session_token(raw_session)
|
||||||
|
if not token:
|
||||||
|
return self._resolve_guest(request, response)
|
||||||
|
|
||||||
|
user_id = self.validate_jwt(token)
|
||||||
|
self._ensure_provisioned(user_id)
|
||||||
|
tier_data = self._resolve_tier(user_id)
|
||||||
|
tier = tier_data.get("tier", "free")
|
||||||
|
|
||||||
|
meta: dict[str, Any] = {}
|
||||||
|
if self._extra_meta:
|
||||||
|
meta = self._extra_meta(user_id, tier) or {}
|
||||||
|
# Merge any extra fields from Heimdall response (e.g. household_id)
|
||||||
|
meta.update({k: v for k, v in tier_data.items() if k not in ("tier", "license_key")})
|
||||||
|
meta["license_key"] = tier_data.get("license_key")
|
||||||
|
|
||||||
|
return CloudUser(
|
||||||
|
user_id=user_id,
|
||||||
|
tier=tier,
|
||||||
|
product=self.product,
|
||||||
|
has_byok=has_byok,
|
||||||
|
meta=meta,
|
||||||
|
)
|
||||||
|
|
||||||
|
def dependency(self) -> Callable[[Any, Any], CloudUser]:
|
||||||
|
"""Return a FastAPI-compatible dependency function (use with Depends())."""
|
||||||
|
factory = self
|
||||||
|
|
||||||
|
def _get_session(request: Any, response: Any) -> CloudUser:
|
||||||
|
return factory.resolve(request, response)
|
||||||
|
|
||||||
|
return _get_session
|
||||||
|
|
||||||
|
def require_tier(self, min_tier: str) -> Callable:
|
||||||
|
"""Dependency factory — raises 403 if the session tier is below min_tier."""
|
||||||
|
from fastapi import Depends, HTTPException
|
||||||
|
min_idx = TIERS.index(min_tier)
|
||||||
|
get_session = self.dependency()
|
||||||
|
|
||||||
|
def _check(session: CloudUser = Depends(get_session)) -> CloudUser:
|
||||||
|
if session.tier in ("local", "local-dev"):
|
||||||
|
return session
|
||||||
|
try:
|
||||||
|
if TIERS.index(session.tier) < min_idx:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=403,
|
||||||
|
detail=f"This feature requires {min_tier} tier or above.",
|
||||||
|
)
|
||||||
|
except ValueError:
|
||||||
|
raise HTTPException(status_code=403, detail="Unknown tier.")
|
||||||
|
return session
|
||||||
|
|
||||||
|
return _check
|
||||||
Loading…
Reference in a new issue