diff --git a/app/cloud_session.py b/app/cloud_session.py index ba35bbb..8aa642b 100644 --- a/app/cloud_session.py +++ b/app/cloud_session.py @@ -76,7 +76,7 @@ def _is_bypass_ip(ip: str) -> bool: _LOCAL_KIWI_DB: Path = Path(os.environ.get("KIWI_DB", "data/kiwi.db")) -_TIER_CACHE: dict[str, tuple[str, float]] = {} +_TIER_CACHE: dict[str, tuple[dict, float]] = {} _TIER_CACHE_TTL = 300 # 5 minutes TIERS = ["free", "paid", "premium", "ultra"] @@ -90,6 +90,8 @@ class CloudUser: tier: str # free | paid | premium | ultra | local db: Path # per-user SQLite DB path has_byok: bool # True if a configured LLM backend is present in llm.yaml + household_id: str | None = None + is_household_owner: bool = False # ── JWT validation ───────────────────────────────────────────────────────────── @@ -130,14 +132,16 @@ def _ensure_provisioned(user_id: str) -> None: log.warning("Heimdall provision failed for user %s: %s", user_id, exc) -def _fetch_cloud_tier(user_id: str) -> str: +def _fetch_cloud_tier(user_id: str) -> tuple[str, str | None, bool]: + """Returns (tier, household_id | None, is_household_owner).""" now = time.monotonic() cached = _TIER_CACHE.get(user_id) if cached and (now - cached[1]) < _TIER_CACHE_TTL: - return cached[0] + entry = cached[0] + return entry["tier"], entry.get("household_id"), entry.get("is_household_owner", False) if not HEIMDALL_ADMIN_TOKEN: - return "free" + return "free", None, False try: resp = requests.post( f"{HEIMDALL_URL}/admin/cloud/resolve", @@ -145,17 +149,23 @@ def _fetch_cloud_tier(user_id: str) -> str: headers={"Authorization": f"Bearer {HEIMDALL_ADMIN_TOKEN}"}, timeout=5, ) - tier = resp.json().get("tier", "free") if resp.ok else "free" + data = resp.json() if resp.ok else {} + tier = data.get("tier", "free") + household_id = data.get("household_id") + is_owner = data.get("is_household_owner", False) except Exception as exc: log.warning("Heimdall tier resolve failed for user %s: %s", user_id, exc) - tier = "free" + tier, household_id, is_owner = "free", None, False - _TIER_CACHE[user_id] = (tier, now) - return tier + _TIER_CACHE[user_id] = ({"tier": tier, "household_id": household_id, "is_household_owner": is_owner}, now) + return tier, household_id, is_owner -def _user_db_path(user_id: str) -> Path: - path = CLOUD_DATA_ROOT / user_id / "kiwi.db" +def _user_db_path(user_id: str, household_id: str | None = None) -> Path: + if household_id: + path = CLOUD_DATA_ROOT / f"household_{household_id}" / "kiwi.db" + else: + path = CLOUD_DATA_ROOT / user_id / "kiwi.db" path.parent.mkdir(parents=True, exist_ok=True) return path @@ -225,8 +235,15 @@ def get_session(request: Request) -> CloudUser: user_id = validate_session_jwt(token) _ensure_provisioned(user_id) - tier = _fetch_cloud_tier(user_id) - return CloudUser(user_id=user_id, tier=tier, db=_user_db_path(user_id), has_byok=has_byok) + tier, household_id, is_household_owner = _fetch_cloud_tier(user_id) + return CloudUser( + user_id=user_id, + tier=tier, + db=_user_db_path(user_id, household_id=household_id), + has_byok=has_byok, + household_id=household_id, + is_household_owner=is_household_owner, + ) def require_tier(min_tier: str): diff --git a/tests/test_household.py b/tests/test_household.py new file mode 100644 index 0000000..40a3f9a --- /dev/null +++ b/tests/test_household.py @@ -0,0 +1,38 @@ +"""Tests for household session resolution in cloud_session.py.""" +import os +from pathlib import Path +from unittest.mock import patch, MagicMock +import pytest + +os.environ.setdefault("CLOUD_MODE", "false") + +from app.cloud_session import ( + CloudUser, + _user_db_path, + CLOUD_DATA_ROOT, +) + + +def test_clouduser_has_household_fields(): + u = CloudUser( + user_id="u1", tier="premium", db=Path("/tmp/u1.db"), + has_byok=False, household_id="hh-1", is_household_owner=True + ) + assert u.household_id == "hh-1" + assert u.is_household_owner is True + + +def test_clouduser_household_defaults_none(): + u = CloudUser(user_id="u1", tier="free", db=Path("/tmp/u1.db"), has_byok=False) + assert u.household_id is None + assert u.is_household_owner is False + + +def test_user_db_path_personal(): + path = _user_db_path("abc123", household_id=None) + assert path == CLOUD_DATA_ROOT / "abc123" / "kiwi.db" + + +def test_user_db_path_household(): + path = _user_db_path("abc123", household_id="hh-xyz") + assert path == CLOUD_DATA_ROOT / "household_hh-xyz" / "kiwi.db"