feat: local-first LLM config + hosted coordinator auth
LLMRouter env-var auto-config: - No llm.yaml required — auto-configures from ANTHROPIC_API_KEY, OPENAI_API_KEY, or OLLAMA_HOST on first use - Bare-metal self-hosters can run any CF product with just env vars - Falls back to FileNotFoundError with actionable message only when no env vars are set either CFOrchClient auth: - Reads CF_LICENSE_KEY env var (or explicit api_key param) - Sends Authorization: Bearer <key> on all allocation/release requests - Required for the hosted public coordinator; no-op for local deployments HeimdallAuthMiddleware (new): - FastAPI middleware for cf-orch coordinator - Enabled by HEIMDALL_URL env var; self-hosted deployments skip it - 5-min TTL cache (matching Kiwi cloud session) keeps Heimdall off the per-allocation hot path - /api/health exempt; free-tier keys rejected with 403 + reason - 13 tests covering cache TTL, tier ranking, and middleware gating
This commit is contained in:
parent
9544f695e6
commit
3deae056de
5 changed files with 444 additions and 7 deletions
|
|
@ -17,13 +17,80 @@ CONFIG_PATH = Path.home() / ".config" / "circuitforge" / "llm.yaml"
|
||||||
|
|
||||||
class LLMRouter:
|
class LLMRouter:
|
||||||
def __init__(self, config_path: Path = CONFIG_PATH):
|
def __init__(self, config_path: Path = CONFIG_PATH):
|
||||||
if not config_path.exists():
|
if config_path.exists():
|
||||||
raise FileNotFoundError(
|
|
||||||
f"{config_path} not found. "
|
|
||||||
"Copy the llm.yaml.example to ~/.config/circuitforge/llm.yaml and configure your LLM backends."
|
|
||||||
)
|
|
||||||
with open(config_path) as f:
|
with open(config_path) as f:
|
||||||
self.config = yaml.safe_load(f)
|
self.config = yaml.safe_load(f)
|
||||||
|
else:
|
||||||
|
env_config = self._auto_config_from_env()
|
||||||
|
if env_config is None:
|
||||||
|
raise FileNotFoundError(
|
||||||
|
f"{config_path} not found and no LLM env vars detected. "
|
||||||
|
"Either copy llm.yaml.example to ~/.config/circuitforge/llm.yaml, "
|
||||||
|
"or set ANTHROPIC_API_KEY, OPENAI_API_KEY, or OLLAMA_HOST."
|
||||||
|
)
|
||||||
|
logger.info(
|
||||||
|
"[LLMRouter] No llm.yaml found — using env-var auto-config "
|
||||||
|
"(backends: %s)", ", ".join(env_config["fallback_order"])
|
||||||
|
)
|
||||||
|
self.config = env_config
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _auto_config_from_env() -> dict | None:
|
||||||
|
"""Build a minimal LLM config from well-known environment variables.
|
||||||
|
|
||||||
|
Priority order (highest to lowest):
|
||||||
|
1. ANTHROPIC_API_KEY → anthropic backend
|
||||||
|
2. OPENAI_API_KEY → openai-compat → api.openai.com (or OPENAI_BASE_URL)
|
||||||
|
3. OLLAMA_HOST → openai-compat → local Ollama (always included as last resort)
|
||||||
|
|
||||||
|
Returns None only when none of these are set and Ollama is not configured,
|
||||||
|
so the caller can decide whether to raise or surface a user-facing message.
|
||||||
|
"""
|
||||||
|
backends: dict = {}
|
||||||
|
fallback_order: list[str] = []
|
||||||
|
|
||||||
|
if os.environ.get("ANTHROPIC_API_KEY"):
|
||||||
|
backends["anthropic"] = {
|
||||||
|
"type": "anthropic",
|
||||||
|
"enabled": True,
|
||||||
|
"model": os.environ.get("ANTHROPIC_MODEL", "claude-haiku-4-5-20251001"),
|
||||||
|
"api_key_env": "ANTHROPIC_API_KEY",
|
||||||
|
"supports_images": True,
|
||||||
|
}
|
||||||
|
fallback_order.append("anthropic")
|
||||||
|
|
||||||
|
if os.environ.get("OPENAI_API_KEY"):
|
||||||
|
backends["openai"] = {
|
||||||
|
"type": "openai_compat",
|
||||||
|
"enabled": True,
|
||||||
|
"base_url": os.environ.get("OPENAI_BASE_URL", "https://api.openai.com/v1"),
|
||||||
|
"model": os.environ.get("OPENAI_MODEL", "gpt-4o-mini"),
|
||||||
|
"api_key": os.environ.get("OPENAI_API_KEY"),
|
||||||
|
"supports_images": True,
|
||||||
|
}
|
||||||
|
fallback_order.append("openai")
|
||||||
|
|
||||||
|
# Ollama — always added when any config exists, as the lowest-cost local fallback.
|
||||||
|
# Unreachable Ollama is harmless — _is_reachable() skips it gracefully.
|
||||||
|
ollama_host = os.environ.get("OLLAMA_HOST", "http://localhost:11434")
|
||||||
|
if not ollama_host.startswith("http"):
|
||||||
|
ollama_host = f"http://{ollama_host}"
|
||||||
|
backends["ollama"] = {
|
||||||
|
"type": "openai_compat",
|
||||||
|
"enabled": True,
|
||||||
|
"base_url": ollama_host.rstrip("/") + "/v1",
|
||||||
|
"model": os.environ.get("OLLAMA_MODEL", "llama3.2:3b"),
|
||||||
|
"api_key": "any",
|
||||||
|
"supports_images": False,
|
||||||
|
}
|
||||||
|
fallback_order.append("ollama")
|
||||||
|
|
||||||
|
# Return None if only ollama is in the list AND no explicit host was set —
|
||||||
|
# that means the user set nothing at all, not even OLLAMA_HOST.
|
||||||
|
if fallback_order == ["ollama"] and "OLLAMA_HOST" not in os.environ:
|
||||||
|
return None
|
||||||
|
|
||||||
|
return {"backends": backends, "fallback_order": fallback_order}
|
||||||
|
|
||||||
def _is_reachable(self, base_url: str) -> bool:
|
def _is_reachable(self, base_url: str) -> bool:
|
||||||
"""Quick health-check ping. Returns True if backend is up."""
|
"""Quick health-check ping. Returns True if backend is up."""
|
||||||
|
|
|
||||||
|
|
@ -1,6 +1,7 @@
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
|
import os
|
||||||
from contextlib import contextmanager, asynccontextmanager
|
from contextlib import contextmanager, asynccontextmanager
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
|
|
||||||
|
|
@ -34,13 +35,25 @@ class CFOrchClient:
|
||||||
async with client.allocate_async("vllm", model_candidates=["Ouro-1.4B"]) as alloc:
|
async with client.allocate_async("vllm", model_candidates=["Ouro-1.4B"]) as alloc:
|
||||||
...
|
...
|
||||||
|
|
||||||
|
Authentication:
|
||||||
|
Pass api_key explicitly, or set CF_LICENSE_KEY env var. When set, every
|
||||||
|
request carries Authorization: Bearer <key>. Required for the hosted
|
||||||
|
CircuitForge coordinator (orch.circuitforge.tech); optional for local
|
||||||
|
self-hosted coordinators.
|
||||||
|
|
||||||
Raises ValueError immediately if coordinator_url is empty.
|
Raises ValueError immediately if coordinator_url is empty.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, coordinator_url: str) -> None:
|
def __init__(self, coordinator_url: str, api_key: str | None = None) -> None:
|
||||||
if not coordinator_url:
|
if not coordinator_url:
|
||||||
raise ValueError("coordinator_url is empty — cf-orch not configured")
|
raise ValueError("coordinator_url is empty — cf-orch not configured")
|
||||||
self._url = coordinator_url.rstrip("/")
|
self._url = coordinator_url.rstrip("/")
|
||||||
|
self._api_key = api_key or os.environ.get("CF_LICENSE_KEY", "")
|
||||||
|
|
||||||
|
def _headers(self) -> dict[str, str]:
|
||||||
|
if self._api_key:
|
||||||
|
return {"Authorization": f"Bearer {self._api_key}"}
|
||||||
|
return {}
|
||||||
|
|
||||||
def _build_body(self, model_candidates: list[str] | None, ttl_s: float, caller: str) -> dict:
|
def _build_body(self, model_candidates: list[str] | None, ttl_s: float, caller: str) -> dict:
|
||||||
return {
|
return {
|
||||||
|
|
@ -74,6 +87,7 @@ class CFOrchClient:
|
||||||
resp = httpx.post(
|
resp = httpx.post(
|
||||||
f"{self._url}/api/services/{service}/allocate",
|
f"{self._url}/api/services/{service}/allocate",
|
||||||
json=self._build_body(model_candidates, ttl_s, caller),
|
json=self._build_body(model_candidates, ttl_s, caller),
|
||||||
|
headers=self._headers(),
|
||||||
timeout=120.0,
|
timeout=120.0,
|
||||||
)
|
)
|
||||||
if not resp.is_success:
|
if not resp.is_success:
|
||||||
|
|
@ -88,6 +102,7 @@ class CFOrchClient:
|
||||||
try:
|
try:
|
||||||
httpx.delete(
|
httpx.delete(
|
||||||
f"{self._url}/api/services/{service}/allocations/{alloc.allocation_id}",
|
f"{self._url}/api/services/{service}/allocations/{alloc.allocation_id}",
|
||||||
|
headers=self._headers(),
|
||||||
timeout=10.0,
|
timeout=10.0,
|
||||||
)
|
)
|
||||||
except Exception as exc:
|
except Exception as exc:
|
||||||
|
|
@ -107,6 +122,7 @@ class CFOrchClient:
|
||||||
resp = await client.post(
|
resp = await client.post(
|
||||||
f"{self._url}/api/services/{service}/allocate",
|
f"{self._url}/api/services/{service}/allocate",
|
||||||
json=self._build_body(model_candidates, ttl_s, caller),
|
json=self._build_body(model_candidates, ttl_s, caller),
|
||||||
|
headers=self._headers(),
|
||||||
)
|
)
|
||||||
if not resp.is_success:
|
if not resp.is_success:
|
||||||
raise RuntimeError(
|
raise RuntimeError(
|
||||||
|
|
@ -120,6 +136,7 @@ class CFOrchClient:
|
||||||
try:
|
try:
|
||||||
await client.delete(
|
await client.delete(
|
||||||
f"{self._url}/api/services/{service}/allocations/{alloc.allocation_id}",
|
f"{self._url}/api/services/{service}/allocations/{alloc.allocation_id}",
|
||||||
|
headers=self._headers(),
|
||||||
timeout=10.0,
|
timeout=10.0,
|
||||||
)
|
)
|
||||||
except Exception as exc:
|
except Exception as exc:
|
||||||
|
|
|
||||||
|
|
@ -133,6 +133,14 @@ def create_coordinator_app(
|
||||||
|
|
||||||
app = FastAPI(title="cf-orch-coordinator", lifespan=_lifespan)
|
app = FastAPI(title="cf-orch-coordinator", lifespan=_lifespan)
|
||||||
|
|
||||||
|
# Optional Heimdall auth — enabled when HEIMDALL_URL env var is set.
|
||||||
|
# Self-hosted coordinators skip this entirely; the CF-hosted public endpoint
|
||||||
|
# (orch.circuitforge.tech) sets HEIMDALL_URL to gate paid+ access.
|
||||||
|
from circuitforge_core.resources.coordinator.auth import HeimdallAuthMiddleware
|
||||||
|
_auth = HeimdallAuthMiddleware.from_env()
|
||||||
|
if _auth is not None:
|
||||||
|
app.middleware("http")(_auth)
|
||||||
|
|
||||||
@app.get("/", response_class=HTMLResponse, include_in_schema=False)
|
@app.get("/", response_class=HTMLResponse, include_in_schema=False)
|
||||||
def dashboard() -> HTMLResponse:
|
def dashboard() -> HTMLResponse:
|
||||||
return HTMLResponse(content=_DASHBOARD_HTML)
|
return HTMLResponse(content=_DASHBOARD_HTML)
|
||||||
|
|
|
||||||
197
circuitforge_core/resources/coordinator/auth.py
Normal file
197
circuitforge_core/resources/coordinator/auth.py
Normal file
|
|
@ -0,0 +1,197 @@
|
||||||
|
"""
|
||||||
|
cf-orch coordinator auth middleware.
|
||||||
|
|
||||||
|
When HEIMDALL_URL is set, all /api/* requests (except /api/health) must carry:
|
||||||
|
Authorization: Bearer <CF license key>
|
||||||
|
|
||||||
|
The key is validated against Heimdall and the result cached for
|
||||||
|
CACHE_TTL_S seconds (default 300 / 5 min). This keeps Heimdall out of the
|
||||||
|
per-allocation hot path while keeping revocation latency bounded.
|
||||||
|
|
||||||
|
When HEIMDALL_URL is not set, auth is disabled — self-hosted deployments work
|
||||||
|
with no configuration change.
|
||||||
|
|
||||||
|
Environment variables
|
||||||
|
---------------------
|
||||||
|
HEIMDALL_URL Heimdall base URL, e.g. https://license.circuitforge.tech
|
||||||
|
When absent, auth is skipped entirely.
|
||||||
|
HEIMDALL_MIN_TIER Minimum tier required (default: "paid").
|
||||||
|
Accepted values: free, paid, premium, ultra.
|
||||||
|
CF_ORCH_AUTH_SECRET Shared secret sent to Heimdall so it can distinguish
|
||||||
|
coordinator service calls from end-user requests.
|
||||||
|
Must match the COORDINATOR_SECRET env var on Heimdall.
|
||||||
|
"""
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
import time
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from threading import Lock
|
||||||
|
|
||||||
|
import httpx
|
||||||
|
from fastapi import Request
|
||||||
|
from fastapi.responses import JSONResponse
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
# Unauthenticated paths — health check must always be accessible for monitoring.
|
||||||
|
_EXEMPT_PATHS: frozenset[str] = frozenset({"/api/health", "/", "/openapi.json", "/docs", "/redoc"})
|
||||||
|
|
||||||
|
_TIER_ORDER: dict[str, int] = {"free": 0, "paid": 1, "premium": 2, "ultra": 3}
|
||||||
|
|
||||||
|
CACHE_TTL_S: float = 300.0 # 5 minutes — matches Kiwi cloud session TTL
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class _CacheEntry:
|
||||||
|
valid: bool
|
||||||
|
tier: str
|
||||||
|
user_id: str
|
||||||
|
expires_at: float
|
||||||
|
|
||||||
|
|
||||||
|
class _ValidationCache:
|
||||||
|
"""Thread-safe TTL cache for Heimdall validation results."""
|
||||||
|
|
||||||
|
def __init__(self, ttl_s: float = CACHE_TTL_S) -> None:
|
||||||
|
self._ttl = ttl_s
|
||||||
|
self._store: dict[str, _CacheEntry] = {}
|
||||||
|
self._lock = Lock()
|
||||||
|
|
||||||
|
def get(self, key: str) -> _CacheEntry | None:
|
||||||
|
with self._lock:
|
||||||
|
entry = self._store.get(key)
|
||||||
|
if entry is None or time.monotonic() > entry.expires_at:
|
||||||
|
return None
|
||||||
|
return entry
|
||||||
|
|
||||||
|
def set(self, key: str, valid: bool, tier: str, user_id: str) -> None:
|
||||||
|
with self._lock:
|
||||||
|
self._store[key] = _CacheEntry(
|
||||||
|
valid=valid,
|
||||||
|
tier=tier,
|
||||||
|
user_id=user_id,
|
||||||
|
expires_at=time.monotonic() + self._ttl,
|
||||||
|
)
|
||||||
|
|
||||||
|
def evict(self, key: str) -> None:
|
||||||
|
with self._lock:
|
||||||
|
self._store.pop(key, None)
|
||||||
|
|
||||||
|
def prune(self) -> int:
|
||||||
|
"""Remove expired entries. Returns count removed."""
|
||||||
|
now = time.monotonic()
|
||||||
|
with self._lock:
|
||||||
|
expired = [k for k, e in self._store.items() if now > e.expires_at]
|
||||||
|
for k in expired:
|
||||||
|
del self._store[k]
|
||||||
|
return len(expired)
|
||||||
|
|
||||||
|
|
||||||
|
class HeimdallAuthMiddleware:
|
||||||
|
"""
|
||||||
|
ASGI middleware that validates CF license keys against Heimdall.
|
||||||
|
|
||||||
|
Attach to a FastAPI app via app.middleware("http"):
|
||||||
|
|
||||||
|
middleware = HeimdallAuthMiddleware.from_env()
|
||||||
|
if middleware:
|
||||||
|
app.middleware("http")(middleware)
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
heimdall_url: str,
|
||||||
|
min_tier: str = "paid",
|
||||||
|
auth_secret: str = "",
|
||||||
|
cache_ttl_s: float = CACHE_TTL_S,
|
||||||
|
) -> None:
|
||||||
|
self._heimdall = heimdall_url.rstrip("/")
|
||||||
|
self._min_tier_rank = _TIER_ORDER.get(min_tier, 1)
|
||||||
|
self._min_tier = min_tier
|
||||||
|
self._auth_secret = auth_secret
|
||||||
|
self._cache = _ValidationCache(ttl_s=cache_ttl_s)
|
||||||
|
logger.info(
|
||||||
|
"[cf-orch auth] Heimdall auth enabled — url=%s min_tier=%s ttl=%ss",
|
||||||
|
self._heimdall, min_tier, cache_ttl_s,
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_env(cls) -> "HeimdallAuthMiddleware | None":
|
||||||
|
"""Return a configured middleware instance, or None if HEIMDALL_URL is not set."""
|
||||||
|
url = os.environ.get("HEIMDALL_URL", "")
|
||||||
|
if not url:
|
||||||
|
logger.info("[cf-orch auth] HEIMDALL_URL not set — auth disabled (self-hosted mode)")
|
||||||
|
return None
|
||||||
|
return cls(
|
||||||
|
heimdall_url=url,
|
||||||
|
min_tier=os.environ.get("HEIMDALL_MIN_TIER", "paid"),
|
||||||
|
auth_secret=os.environ.get("CF_ORCH_AUTH_SECRET", ""),
|
||||||
|
)
|
||||||
|
|
||||||
|
def _validate_against_heimdall(self, license_key: str) -> tuple[bool, str, str]:
|
||||||
|
"""
|
||||||
|
Call Heimdall's /licenses/verify endpoint.
|
||||||
|
|
||||||
|
Returns (valid, tier, user_id).
|
||||||
|
On any network or parse error, returns (False, "", "") — fail closed.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
headers: dict[str, str] = {"Content-Type": "application/json"}
|
||||||
|
if self._auth_secret:
|
||||||
|
headers["X-Coordinator-Secret"] = self._auth_secret
|
||||||
|
resp = httpx.post(
|
||||||
|
f"{self._heimdall}/licenses/verify",
|
||||||
|
json={"key": license_key, "min_tier": self._min_tier},
|
||||||
|
headers=headers,
|
||||||
|
timeout=5.0,
|
||||||
|
)
|
||||||
|
if resp.status_code == 200:
|
||||||
|
data = resp.json()
|
||||||
|
return data.get("valid", False), data.get("tier", ""), data.get("user_id", "")
|
||||||
|
# 401/403 from Heimdall = key invalid/insufficient tier
|
||||||
|
logger.debug("[cf-orch auth] Heimdall returned %s for key ...%s", resp.status_code, license_key[-6:])
|
||||||
|
return False, "", ""
|
||||||
|
except Exception as exc:
|
||||||
|
logger.warning("[cf-orch auth] Heimdall unreachable — failing closed: %s", exc)
|
||||||
|
return False, "", ""
|
||||||
|
|
||||||
|
def _check_key(self, license_key: str) -> tuple[bool, str]:
|
||||||
|
"""
|
||||||
|
Validate key (cache-first). Returns (authorized, reason_if_denied).
|
||||||
|
"""
|
||||||
|
cached = self._cache.get(license_key)
|
||||||
|
if cached is not None:
|
||||||
|
if not cached.valid:
|
||||||
|
return False, "license key invalid or expired"
|
||||||
|
if _TIER_ORDER.get(cached.tier, -1) < self._min_tier_rank:
|
||||||
|
return False, f"feature requires {self._min_tier} tier (have: {cached.tier})"
|
||||||
|
return True, ""
|
||||||
|
|
||||||
|
valid, tier, user_id = self._validate_against_heimdall(license_key)
|
||||||
|
self._cache.set(license_key, valid=valid, tier=tier, user_id=user_id)
|
||||||
|
|
||||||
|
if not valid:
|
||||||
|
return False, "license key invalid or expired"
|
||||||
|
if _TIER_ORDER.get(tier, -1) < self._min_tier_rank:
|
||||||
|
return False, f"feature requires {self._min_tier} tier (have: {tier})"
|
||||||
|
return True, ""
|
||||||
|
|
||||||
|
async def __call__(self, request: Request, call_next): # type: ignore[no-untyped-def]
|
||||||
|
if request.url.path in _EXEMPT_PATHS:
|
||||||
|
return await call_next(request)
|
||||||
|
|
||||||
|
auth_header = request.headers.get("Authorization", "")
|
||||||
|
if not auth_header.startswith("Bearer "):
|
||||||
|
return JSONResponse(
|
||||||
|
status_code=401,
|
||||||
|
content={"detail": "Authorization: Bearer <license_key> required"},
|
||||||
|
)
|
||||||
|
|
||||||
|
license_key = auth_header.removeprefix("Bearer ").strip()
|
||||||
|
authorized, reason = self._check_key(license_key)
|
||||||
|
if not authorized:
|
||||||
|
return JSONResponse(status_code=403, content={"detail": reason})
|
||||||
|
|
||||||
|
return await call_next(request)
|
||||||
148
tests/test_resources/test_coordinator_auth.py
Normal file
148
tests/test_resources/test_coordinator_auth.py
Normal file
|
|
@ -0,0 +1,148 @@
|
||||||
|
"""Tests for HeimdallAuthMiddleware — TTL cache and request gating."""
|
||||||
|
import time
|
||||||
|
import pytest
|
||||||
|
from unittest.mock import patch, MagicMock
|
||||||
|
from fastapi import FastAPI
|
||||||
|
from fastapi.testclient import TestClient
|
||||||
|
|
||||||
|
from circuitforge_core.resources.coordinator.auth import (
|
||||||
|
HeimdallAuthMiddleware,
|
||||||
|
_ValidationCache,
|
||||||
|
CACHE_TTL_S,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# ── Cache unit tests ──────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
def test_cache_miss_returns_none():
|
||||||
|
cache = _ValidationCache()
|
||||||
|
assert cache.get("nonexistent") is None
|
||||||
|
|
||||||
|
|
||||||
|
def test_cache_stores_and_retrieves():
|
||||||
|
cache = _ValidationCache()
|
||||||
|
cache.set("key1", valid=True, tier="paid", user_id="u1")
|
||||||
|
entry = cache.get("key1")
|
||||||
|
assert entry is not None
|
||||||
|
assert entry.valid is True
|
||||||
|
assert entry.tier == "paid"
|
||||||
|
|
||||||
|
|
||||||
|
def test_cache_entry_expires():
|
||||||
|
cache = _ValidationCache(ttl_s=0.05)
|
||||||
|
cache.set("key1", valid=True, tier="paid", user_id="u1")
|
||||||
|
time.sleep(0.1)
|
||||||
|
assert cache.get("key1") is None
|
||||||
|
|
||||||
|
|
||||||
|
def test_cache_evict_removes_key():
|
||||||
|
cache = _ValidationCache()
|
||||||
|
cache.set("key1", valid=True, tier="paid", user_id="u1")
|
||||||
|
cache.evict("key1")
|
||||||
|
assert cache.get("key1") is None
|
||||||
|
|
||||||
|
|
||||||
|
def test_cache_prune_removes_expired():
|
||||||
|
cache = _ValidationCache(ttl_s=0.05)
|
||||||
|
cache.set("k1", valid=True, tier="paid", user_id="")
|
||||||
|
cache.set("k2", valid=True, tier="paid", user_id="")
|
||||||
|
time.sleep(0.1)
|
||||||
|
removed = cache.prune()
|
||||||
|
assert removed == 2
|
||||||
|
|
||||||
|
|
||||||
|
# ── Middleware integration tests ──────────────────────────────────────────────
|
||||||
|
|
||||||
|
def _make_app_with_auth(middleware: HeimdallAuthMiddleware) -> TestClient:
|
||||||
|
app = FastAPI()
|
||||||
|
app.middleware("http")(middleware)
|
||||||
|
|
||||||
|
@app.get("/api/health")
|
||||||
|
def health():
|
||||||
|
return {"status": "ok"}
|
||||||
|
|
||||||
|
@app.post("/api/services/vllm/allocate")
|
||||||
|
def allocate():
|
||||||
|
return {"allocation_id": "abc", "url": "http://gpu:8000"}
|
||||||
|
|
||||||
|
return TestClient(app, raise_server_exceptions=False)
|
||||||
|
|
||||||
|
|
||||||
|
def _patched_middleware(valid: bool, tier: str = "paid") -> HeimdallAuthMiddleware:
|
||||||
|
"""Return a middleware whose Heimdall call is pre-mocked."""
|
||||||
|
mw = HeimdallAuthMiddleware(
|
||||||
|
heimdall_url="http://heimdall.test",
|
||||||
|
min_tier="paid",
|
||||||
|
)
|
||||||
|
mw._validate_against_heimdall = MagicMock( # type: ignore[method-assign]
|
||||||
|
return_value=(valid, tier, "user-1" if valid else "")
|
||||||
|
)
|
||||||
|
return mw
|
||||||
|
|
||||||
|
|
||||||
|
def test_health_exempt_no_auth_required():
|
||||||
|
mw = _patched_middleware(valid=True)
|
||||||
|
client = _make_app_with_auth(mw)
|
||||||
|
resp = client.get("/api/health")
|
||||||
|
assert resp.status_code == 200
|
||||||
|
|
||||||
|
|
||||||
|
def test_missing_auth_header_returns_401():
|
||||||
|
mw = _patched_middleware(valid=True)
|
||||||
|
client = _make_app_with_auth(mw)
|
||||||
|
resp = client.post("/api/services/vllm/allocate")
|
||||||
|
assert resp.status_code == 401
|
||||||
|
|
||||||
|
|
||||||
|
def test_invalid_key_returns_403():
|
||||||
|
mw = _patched_middleware(valid=False)
|
||||||
|
client = _make_app_with_auth(mw)
|
||||||
|
resp = client.post(
|
||||||
|
"/api/services/vllm/allocate",
|
||||||
|
headers={"Authorization": "Bearer BAD-KEY"},
|
||||||
|
)
|
||||||
|
assert resp.status_code == 403
|
||||||
|
|
||||||
|
|
||||||
|
def test_valid_paid_key_passes():
|
||||||
|
mw = _patched_middleware(valid=True, tier="paid")
|
||||||
|
client = _make_app_with_auth(mw)
|
||||||
|
resp = client.post(
|
||||||
|
"/api/services/vllm/allocate",
|
||||||
|
headers={"Authorization": "Bearer CFG-KIWI-GOOD-GOOD-GOOD"},
|
||||||
|
)
|
||||||
|
assert resp.status_code == 200
|
||||||
|
|
||||||
|
|
||||||
|
def test_free_tier_key_rejected_when_min_is_paid():
|
||||||
|
mw = _patched_middleware(valid=True, tier="free")
|
||||||
|
client = _make_app_with_auth(mw)
|
||||||
|
resp = client.post(
|
||||||
|
"/api/services/vllm/allocate",
|
||||||
|
headers={"Authorization": "Bearer CFG-KIWI-FREE-FREE-FREE"},
|
||||||
|
)
|
||||||
|
assert resp.status_code == 403
|
||||||
|
assert "paid" in resp.json()["detail"]
|
||||||
|
|
||||||
|
|
||||||
|
def test_cache_prevents_second_heimdall_call():
|
||||||
|
mw = _patched_middleware(valid=True, tier="paid")
|
||||||
|
client = _make_app_with_auth(mw)
|
||||||
|
key = "CFG-KIWI-CACHED-KEY-1"
|
||||||
|
headers = {"Authorization": f"Bearer {key}"}
|
||||||
|
client.post("/api/services/vllm/allocate", headers=headers)
|
||||||
|
client.post("/api/services/vllm/allocate", headers=headers)
|
||||||
|
# Heimdall should only have been called once — second hit is from cache
|
||||||
|
assert mw._validate_against_heimdall.call_count == 1 # type: ignore[attr-defined]
|
||||||
|
|
||||||
|
|
||||||
|
def test_from_env_returns_none_without_heimdall_url(monkeypatch):
|
||||||
|
monkeypatch.delenv("HEIMDALL_URL", raising=False)
|
||||||
|
assert HeimdallAuthMiddleware.from_env() is None
|
||||||
|
|
||||||
|
|
||||||
|
def test_from_env_returns_middleware_when_set(monkeypatch):
|
||||||
|
monkeypatch.setenv("HEIMDALL_URL", "http://heimdall.test")
|
||||||
|
mw = HeimdallAuthMiddleware.from_env()
|
||||||
|
assert mw is not None
|
||||||
|
assert mw._heimdall == "http://heimdall.test"
|
||||||
Loading…
Reference in a new issue