Compare commits

..

1 commit

6 changed files with 17 additions and 82 deletions

View file

@ -76,7 +76,7 @@ def _is_bypass_ip(ip: str) -> bool:
_LOCAL_KIWI_DB: Path = Path(os.environ.get("KIWI_DB", "data/kiwi.db")) _LOCAL_KIWI_DB: Path = Path(os.environ.get("KIWI_DB", "data/kiwi.db"))
_TIER_CACHE: dict[str, tuple[dict, float]] = {} _TIER_CACHE: dict[str, tuple[str, float]] = {}
_TIER_CACHE_TTL = 300 # 5 minutes _TIER_CACHE_TTL = 300 # 5 minutes
TIERS = ["free", "paid", "premium", "ultra"] TIERS = ["free", "paid", "premium", "ultra"]
@ -90,8 +90,6 @@ class CloudUser:
tier: str # free | paid | premium | ultra | local tier: str # free | paid | premium | ultra | local
db: Path # per-user SQLite DB path db: Path # per-user SQLite DB path
has_byok: bool # True if a configured LLM backend is present in llm.yaml has_byok: bool # True if a configured LLM backend is present in llm.yaml
household_id: str | None = None
is_household_owner: bool = False
# ── JWT validation ───────────────────────────────────────────────────────────── # ── JWT validation ─────────────────────────────────────────────────────────────
@ -132,16 +130,14 @@ def _ensure_provisioned(user_id: str) -> None:
log.warning("Heimdall provision failed for user %s: %s", user_id, exc) log.warning("Heimdall provision failed for user %s: %s", user_id, exc)
def _fetch_cloud_tier(user_id: str) -> tuple[str, str | None, bool]: def _fetch_cloud_tier(user_id: str) -> str:
"""Returns (tier, household_id | None, is_household_owner)."""
now = time.monotonic() now = time.monotonic()
cached = _TIER_CACHE.get(user_id) cached = _TIER_CACHE.get(user_id)
if cached and (now - cached[1]) < _TIER_CACHE_TTL: if cached and (now - cached[1]) < _TIER_CACHE_TTL:
entry = cached[0] return cached[0]
return entry["tier"], entry.get("household_id"), entry.get("is_household_owner", False)
if not HEIMDALL_ADMIN_TOKEN: if not HEIMDALL_ADMIN_TOKEN:
return "free", None, False return "free"
try: try:
resp = requests.post( resp = requests.post(
f"{HEIMDALL_URL}/admin/cloud/resolve", f"{HEIMDALL_URL}/admin/cloud/resolve",
@ -149,23 +145,17 @@ def _fetch_cloud_tier(user_id: str) -> tuple[str, str | None, bool]:
headers={"Authorization": f"Bearer {HEIMDALL_ADMIN_TOKEN}"}, headers={"Authorization": f"Bearer {HEIMDALL_ADMIN_TOKEN}"},
timeout=5, timeout=5,
) )
data = resp.json() if resp.ok else {} tier = resp.json().get("tier", "free") if resp.ok else "free"
tier = data.get("tier", "free")
household_id = data.get("household_id")
is_owner = data.get("is_household_owner", False)
except Exception as exc: except Exception as exc:
log.warning("Heimdall tier resolve failed for user %s: %s", user_id, exc) log.warning("Heimdall tier resolve failed for user %s: %s", user_id, exc)
tier, household_id, is_owner = "free", None, False tier = "free"
_TIER_CACHE[user_id] = ({"tier": tier, "household_id": household_id, "is_household_owner": is_owner}, now) _TIER_CACHE[user_id] = (tier, now)
return tier, household_id, is_owner return tier
def _user_db_path(user_id: str, household_id: str | None = None) -> Path: def _user_db_path(user_id: str) -> Path:
if household_id: path = CLOUD_DATA_ROOT / user_id / "kiwi.db"
path = CLOUD_DATA_ROOT / f"household_{household_id}" / "kiwi.db"
else:
path = CLOUD_DATA_ROOT / user_id / "kiwi.db"
path.parent.mkdir(parents=True, exist_ok=True) path.parent.mkdir(parents=True, exist_ok=True)
return path return path
@ -208,6 +198,8 @@ def get_session(request: Request) -> CloudUser:
if not CLOUD_MODE: if not CLOUD_MODE:
return CloudUser(user_id="local", tier="local", db=_LOCAL_KIWI_DB, has_byok=has_byok) return CloudUser(user_id="local", tier="local", db=_LOCAL_KIWI_DB, has_byok=has_byok)
# Prefer X-Real-IP (set by nginx from the actual client address) over the
# TCP peer address (which is nginx's container IP when behind the proxy).
# Prefer X-Real-IP (set by nginx from the actual client address) over the # Prefer X-Real-IP (set by nginx from the actual client address) over the
# TCP peer address (which is nginx's container IP when behind the proxy). # TCP peer address (which is nginx's container IP when behind the proxy).
client_ip = ( client_ip = (
@ -233,15 +225,8 @@ def get_session(request: Request) -> CloudUser:
user_id = validate_session_jwt(token) user_id = validate_session_jwt(token)
_ensure_provisioned(user_id) _ensure_provisioned(user_id)
tier, household_id, is_household_owner = _fetch_cloud_tier(user_id) tier = _fetch_cloud_tier(user_id)
return CloudUser( return CloudUser(user_id=user_id, tier=tier, db=_user_db_path(user_id), has_byok=has_byok)
user_id=user_id,
tier=tier,
db=_user_db_path(user_id, household_id=household_id),
has_byok=has_byok,
household_id=household_id,
is_household_owner=is_household_owner,
)
def require_tier(min_tier: str): def require_tier(min_tier: str):

View file

@ -1,10 +0,0 @@
-- 017_household_invites.sql
CREATE TABLE IF NOT EXISTS household_invites (
token TEXT PRIMARY KEY,
household_id TEXT NOT NULL,
created_by TEXT NOT NULL,
created_at TEXT NOT NULL DEFAULT (datetime('now')),
expires_at TEXT NOT NULL,
used_at TEXT,
used_by TEXT
);

View file

@ -33,7 +33,7 @@ def _try_docuvision(image_path: str | Path) -> str | None:
if not cf_orch_url: if not cf_orch_url:
return None return None
try: try:
from circuitforge_orch.client import CFOrchClient from circuitforge_core.resources import CFOrchClient
from app.services.ocr.docuvision_client import DocuvisionClient from app.services.ocr.docuvision_client import DocuvisionClient
client = CFOrchClient(cf_orch_url) client = CFOrchClient(cf_orch_url)

View file

@ -143,7 +143,7 @@ class LLMRecipeGenerator:
cf_orch_url = os.environ.get("CF_ORCH_URL") cf_orch_url = os.environ.get("CF_ORCH_URL")
if cf_orch_url: if cf_orch_url:
try: try:
from circuitforge_orch.client import CFOrchClient from circuitforge_core.resources import CFOrchClient
client = CFOrchClient(cf_orch_url) client = CFOrchClient(cf_orch_url)
return client.allocate( return client.allocate(
service="vllm", service="vllm",

View file

@ -23,7 +23,7 @@ dependencies = [
"httpx>=0.27", "httpx>=0.27",
"requests>=2.31", "requests>=2.31",
# CircuitForge shared scaffold # CircuitForge shared scaffold
"circuitforge-core>=0.8.0", "circuitforge-core>=0.6.0",
] ]
[tool.setuptools.packages.find] [tool.setuptools.packages.find]

View file

@ -1,40 +0,0 @@
"""Tests for household session resolution in cloud_session.py."""
import os
from pathlib import Path
from unittest.mock import patch, MagicMock
import pytest
os.environ.setdefault("CLOUD_MODE", "false")
import app.cloud_session as cs
from app.cloud_session import (
CloudUser,
_user_db_path,
)
def test_clouduser_has_household_fields():
u = CloudUser(
user_id="u1", tier="premium", db=Path("/tmp/u1.db"),
has_byok=False, household_id="hh-1", is_household_owner=True
)
assert u.household_id == "hh-1"
assert u.is_household_owner is True
def test_clouduser_household_defaults_none():
u = CloudUser(user_id="u1", tier="free", db=Path("/tmp/u1.db"), has_byok=False)
assert u.household_id is None
assert u.is_household_owner is False
def test_user_db_path_personal(tmp_path, monkeypatch):
monkeypatch.setattr(cs, "CLOUD_DATA_ROOT", tmp_path)
result = cs._user_db_path("abc123")
assert result == tmp_path / "abc123" / "kiwi.db"
def test_user_db_path_household(tmp_path, monkeypatch):
monkeypatch.setattr(cs, "CLOUD_DATA_ROOT", tmp_path)
result = cs._user_db_path("abc123", household_id="hh-xyz")
assert result == tmp_path / "household_hh-xyz" / "kiwi.db"