Compare commits
4 commits
main
...
feature/or
| Author | SHA1 | Date | |
|---|---|---|---|
| e605954254 | |||
| ed6813713e | |||
| 9985d12156 | |||
| 9602f84e62 |
6 changed files with 82 additions and 17 deletions
|
|
@ -76,7 +76,7 @@ def _is_bypass_ip(ip: str) -> bool:
|
|||
|
||||
_LOCAL_KIWI_DB: Path = Path(os.environ.get("KIWI_DB", "data/kiwi.db"))
|
||||
|
||||
_TIER_CACHE: dict[str, tuple[str, float]] = {}
|
||||
_TIER_CACHE: dict[str, tuple[dict, float]] = {}
|
||||
_TIER_CACHE_TTL = 300 # 5 minutes
|
||||
|
||||
TIERS = ["free", "paid", "premium", "ultra"]
|
||||
|
|
@ -90,6 +90,8 @@ class CloudUser:
|
|||
tier: str # free | paid | premium | ultra | local
|
||||
db: Path # per-user SQLite DB path
|
||||
has_byok: bool # True if a configured LLM backend is present in llm.yaml
|
||||
household_id: str | None = None
|
||||
is_household_owner: bool = False
|
||||
|
||||
|
||||
# ── JWT validation ─────────────────────────────────────────────────────────────
|
||||
|
|
@ -130,14 +132,16 @@ def _ensure_provisioned(user_id: str) -> None:
|
|||
log.warning("Heimdall provision failed for user %s: %s", user_id, exc)
|
||||
|
||||
|
||||
def _fetch_cloud_tier(user_id: str) -> str:
|
||||
def _fetch_cloud_tier(user_id: str) -> tuple[str, str | None, bool]:
|
||||
"""Returns (tier, household_id | None, is_household_owner)."""
|
||||
now = time.monotonic()
|
||||
cached = _TIER_CACHE.get(user_id)
|
||||
if cached and (now - cached[1]) < _TIER_CACHE_TTL:
|
||||
return cached[0]
|
||||
entry = cached[0]
|
||||
return entry["tier"], entry.get("household_id"), entry.get("is_household_owner", False)
|
||||
|
||||
if not HEIMDALL_ADMIN_TOKEN:
|
||||
return "free"
|
||||
return "free", None, False
|
||||
try:
|
||||
resp = requests.post(
|
||||
f"{HEIMDALL_URL}/admin/cloud/resolve",
|
||||
|
|
@ -145,17 +149,23 @@ def _fetch_cloud_tier(user_id: str) -> str:
|
|||
headers={"Authorization": f"Bearer {HEIMDALL_ADMIN_TOKEN}"},
|
||||
timeout=5,
|
||||
)
|
||||
tier = resp.json().get("tier", "free") if resp.ok else "free"
|
||||
data = resp.json() if resp.ok else {}
|
||||
tier = data.get("tier", "free")
|
||||
household_id = data.get("household_id")
|
||||
is_owner = data.get("is_household_owner", False)
|
||||
except Exception as exc:
|
||||
log.warning("Heimdall tier resolve failed for user %s: %s", user_id, exc)
|
||||
tier = "free"
|
||||
tier, household_id, is_owner = "free", None, False
|
||||
|
||||
_TIER_CACHE[user_id] = (tier, now)
|
||||
return tier
|
||||
_TIER_CACHE[user_id] = ({"tier": tier, "household_id": household_id, "is_household_owner": is_owner}, now)
|
||||
return tier, household_id, is_owner
|
||||
|
||||
|
||||
def _user_db_path(user_id: str) -> Path:
|
||||
path = CLOUD_DATA_ROOT / user_id / "kiwi.db"
|
||||
def _user_db_path(user_id: str, household_id: str | None = None) -> Path:
|
||||
if household_id:
|
||||
path = CLOUD_DATA_ROOT / f"household_{household_id}" / "kiwi.db"
|
||||
else:
|
||||
path = CLOUD_DATA_ROOT / user_id / "kiwi.db"
|
||||
path.parent.mkdir(parents=True, exist_ok=True)
|
||||
return path
|
||||
|
||||
|
|
@ -198,8 +208,6 @@ def get_session(request: Request) -> CloudUser:
|
|||
if not CLOUD_MODE:
|
||||
return CloudUser(user_id="local", tier="local", db=_LOCAL_KIWI_DB, has_byok=has_byok)
|
||||
|
||||
# Prefer X-Real-IP (set by nginx from the actual client address) over the
|
||||
# TCP peer address (which is nginx's container IP when behind the proxy).
|
||||
# Prefer X-Real-IP (set by nginx from the actual client address) over the
|
||||
# TCP peer address (which is nginx's container IP when behind the proxy).
|
||||
client_ip = (
|
||||
|
|
@ -225,8 +233,15 @@ def get_session(request: Request) -> CloudUser:
|
|||
|
||||
user_id = validate_session_jwt(token)
|
||||
_ensure_provisioned(user_id)
|
||||
tier = _fetch_cloud_tier(user_id)
|
||||
return CloudUser(user_id=user_id, tier=tier, db=_user_db_path(user_id), has_byok=has_byok)
|
||||
tier, household_id, is_household_owner = _fetch_cloud_tier(user_id)
|
||||
return CloudUser(
|
||||
user_id=user_id,
|
||||
tier=tier,
|
||||
db=_user_db_path(user_id, household_id=household_id),
|
||||
has_byok=has_byok,
|
||||
household_id=household_id,
|
||||
is_household_owner=is_household_owner,
|
||||
)
|
||||
|
||||
|
||||
def require_tier(min_tier: str):
|
||||
|
|
|
|||
10
app/db/migrations/017_household_invites.sql
Normal file
10
app/db/migrations/017_household_invites.sql
Normal file
|
|
@ -0,0 +1,10 @@
|
|||
-- 017_household_invites.sql
|
||||
CREATE TABLE IF NOT EXISTS household_invites (
|
||||
token TEXT PRIMARY KEY,
|
||||
household_id TEXT NOT NULL,
|
||||
created_by TEXT NOT NULL,
|
||||
created_at TEXT NOT NULL DEFAULT (datetime('now')),
|
||||
expires_at TEXT NOT NULL,
|
||||
used_at TEXT,
|
||||
used_by TEXT
|
||||
);
|
||||
|
|
@ -33,7 +33,7 @@ def _try_docuvision(image_path: str | Path) -> str | None:
|
|||
if not cf_orch_url:
|
||||
return None
|
||||
try:
|
||||
from circuitforge_core.resources import CFOrchClient
|
||||
from circuitforge_orch.client import CFOrchClient
|
||||
from app.services.ocr.docuvision_client import DocuvisionClient
|
||||
|
||||
client = CFOrchClient(cf_orch_url)
|
||||
|
|
|
|||
|
|
@ -143,7 +143,7 @@ class LLMRecipeGenerator:
|
|||
cf_orch_url = os.environ.get("CF_ORCH_URL")
|
||||
if cf_orch_url:
|
||||
try:
|
||||
from circuitforge_core.resources import CFOrchClient
|
||||
from circuitforge_orch.client import CFOrchClient
|
||||
client = CFOrchClient(cf_orch_url)
|
||||
return client.allocate(
|
||||
service="vllm",
|
||||
|
|
|
|||
|
|
@ -23,7 +23,7 @@ dependencies = [
|
|||
"httpx>=0.27",
|
||||
"requests>=2.31",
|
||||
# CircuitForge shared scaffold
|
||||
"circuitforge-core>=0.6.0",
|
||||
"circuitforge-core>=0.8.0",
|
||||
]
|
||||
|
||||
[tool.setuptools.packages.find]
|
||||
|
|
|
|||
40
tests/test_household.py
Normal file
40
tests/test_household.py
Normal file
|
|
@ -0,0 +1,40 @@
|
|||
"""Tests for household session resolution in cloud_session.py."""
|
||||
import os
|
||||
from pathlib import Path
|
||||
from unittest.mock import patch, MagicMock
|
||||
import pytest
|
||||
|
||||
os.environ.setdefault("CLOUD_MODE", "false")
|
||||
|
||||
import app.cloud_session as cs
|
||||
from app.cloud_session import (
|
||||
CloudUser,
|
||||
_user_db_path,
|
||||
)
|
||||
|
||||
|
||||
def test_clouduser_has_household_fields():
|
||||
u = CloudUser(
|
||||
user_id="u1", tier="premium", db=Path("/tmp/u1.db"),
|
||||
has_byok=False, household_id="hh-1", is_household_owner=True
|
||||
)
|
||||
assert u.household_id == "hh-1"
|
||||
assert u.is_household_owner is True
|
||||
|
||||
|
||||
def test_clouduser_household_defaults_none():
|
||||
u = CloudUser(user_id="u1", tier="free", db=Path("/tmp/u1.db"), has_byok=False)
|
||||
assert u.household_id is None
|
||||
assert u.is_household_owner is False
|
||||
|
||||
|
||||
def test_user_db_path_personal(tmp_path, monkeypatch):
|
||||
monkeypatch.setattr(cs, "CLOUD_DATA_ROOT", tmp_path)
|
||||
result = cs._user_db_path("abc123")
|
||||
assert result == tmp_path / "abc123" / "kiwi.db"
|
||||
|
||||
|
||||
def test_user_db_path_household(tmp_path, monkeypatch):
|
||||
monkeypatch.setattr(cs, "CLOUD_DATA_ROOT", tmp_path)
|
||||
result = cs._user_db_path("abc123", household_id="hh-xyz")
|
||||
assert result == tmp_path / "household_hh-xyz" / "kiwi.db"
|
||||
Loading…
Reference in a new issue