feat: extend CloudUser with household_id + update session resolution

Add household_id and is_household_owner fields to CloudUser dataclass.
Update _user_db_path to route household members to a shared DB path.
Update _fetch_cloud_tier to return a 3-tuple and cache a dict.
Update get_session to unpack and propagate household fields.
This commit is contained in:
pyr0ball 2026-04-04 22:30:07 -07:00
parent 9602f84e62
commit 9985d12156
2 changed files with 67 additions and 12 deletions

View file

@ -76,7 +76,7 @@ def _is_bypass_ip(ip: str) -> bool:
_LOCAL_KIWI_DB: Path = Path(os.environ.get("KIWI_DB", "data/kiwi.db"))
_TIER_CACHE: dict[str, tuple[str, float]] = {}
_TIER_CACHE: dict[str, tuple[dict, float]] = {}
_TIER_CACHE_TTL = 300 # 5 minutes
TIERS = ["free", "paid", "premium", "ultra"]
@ -90,6 +90,8 @@ class CloudUser:
tier: str # free | paid | premium | ultra | local
db: Path # per-user SQLite DB path
has_byok: bool # True if a configured LLM backend is present in llm.yaml
household_id: str | None = None
is_household_owner: bool = False
# ── JWT validation ─────────────────────────────────────────────────────────────
@ -130,14 +132,16 @@ def _ensure_provisioned(user_id: str) -> None:
log.warning("Heimdall provision failed for user %s: %s", user_id, exc)
def _fetch_cloud_tier(user_id: str) -> str:
def _fetch_cloud_tier(user_id: str) -> tuple[str, str | None, bool]:
"""Returns (tier, household_id | None, is_household_owner)."""
now = time.monotonic()
cached = _TIER_CACHE.get(user_id)
if cached and (now - cached[1]) < _TIER_CACHE_TTL:
return cached[0]
entry = cached[0]
return entry["tier"], entry.get("household_id"), entry.get("is_household_owner", False)
if not HEIMDALL_ADMIN_TOKEN:
return "free"
return "free", None, False
try:
resp = requests.post(
f"{HEIMDALL_URL}/admin/cloud/resolve",
@ -145,17 +149,23 @@ def _fetch_cloud_tier(user_id: str) -> str:
headers={"Authorization": f"Bearer {HEIMDALL_ADMIN_TOKEN}"},
timeout=5,
)
tier = resp.json().get("tier", "free") if resp.ok else "free"
data = resp.json() if resp.ok else {}
tier = data.get("tier", "free")
household_id = data.get("household_id")
is_owner = data.get("is_household_owner", False)
except Exception as exc:
log.warning("Heimdall tier resolve failed for user %s: %s", user_id, exc)
tier = "free"
tier, household_id, is_owner = "free", None, False
_TIER_CACHE[user_id] = (tier, now)
return tier
_TIER_CACHE[user_id] = ({"tier": tier, "household_id": household_id, "is_household_owner": is_owner}, now)
return tier, household_id, is_owner
def _user_db_path(user_id: str) -> Path:
path = CLOUD_DATA_ROOT / user_id / "kiwi.db"
def _user_db_path(user_id: str, household_id: str | None = None) -> Path:
if household_id:
path = CLOUD_DATA_ROOT / f"household_{household_id}" / "kiwi.db"
else:
path = CLOUD_DATA_ROOT / user_id / "kiwi.db"
path.parent.mkdir(parents=True, exist_ok=True)
return path
@ -225,8 +235,15 @@ def get_session(request: Request) -> CloudUser:
user_id = validate_session_jwt(token)
_ensure_provisioned(user_id)
tier = _fetch_cloud_tier(user_id)
return CloudUser(user_id=user_id, tier=tier, db=_user_db_path(user_id), has_byok=has_byok)
tier, household_id, is_household_owner = _fetch_cloud_tier(user_id)
return CloudUser(
user_id=user_id,
tier=tier,
db=_user_db_path(user_id, household_id=household_id),
has_byok=has_byok,
household_id=household_id,
is_household_owner=is_household_owner,
)
def require_tier(min_tier: str):

38
tests/test_household.py Normal file
View file

@ -0,0 +1,38 @@
"""Tests for household session resolution in cloud_session.py."""
import os
from pathlib import Path
from unittest.mock import patch, MagicMock
import pytest
os.environ.setdefault("CLOUD_MODE", "false")
from app.cloud_session import (
CloudUser,
_user_db_path,
CLOUD_DATA_ROOT,
)
def test_clouduser_has_household_fields():
u = CloudUser(
user_id="u1", tier="premium", db=Path("/tmp/u1.db"),
has_byok=False, household_id="hh-1", is_household_owner=True
)
assert u.household_id == "hh-1"
assert u.is_household_owner is True
def test_clouduser_household_defaults_none():
u = CloudUser(user_id="u1", tier="free", db=Path("/tmp/u1.db"), has_byok=False)
assert u.household_id is None
assert u.is_household_owner is False
def test_user_db_path_personal():
path = _user_db_path("abc123", household_id=None)
assert path == CLOUD_DATA_ROOT / "abc123" / "kiwi.db"
def test_user_db_path_household():
path = _user_db_path("abc123", household_id="hh-xyz")
assert path == CLOUD_DATA_ROOT / "household_hh-xyz" / "kiwi.db"