feat(api): per-user LLM rate limiting via slowapi

Add scripts/rate_limit.py with cloud-aware key function:
- In cloud mode, extracts user_id from _request_db ContextVar path (part[-3])
  so each cloud user has their own rate limit bucket
- In demo mode, returns unique per-request key to disable limiting entirely
  (_demo_guard handles write-blocking; rate limiting would block the demo UX)
- Falls back to client IP for local/self-hosted installs

Wire limiter to 4 endpoints with conservative per-user limits:
- POST /generate/cover-letter: 20/hour
- POST /research/run: 10/hour
- POST /qa/suggest: 60/hour
- POST /survey/analyze: 30/hour

Add _demo_guard() to generate_research and suggest_qa_answer (was missing).
Fix pre-existing silent except in suggest_qa_answer: was bare except pass,
now logs warning with exc_info.

Add _RL_WIZARD placeholder constant with TODO to wire to wizard/ai/interview
after feat/77 merges (declared but intentionally not applied yet to avoid
false sense of security — comment makes the gap explicit).

18 tests covering cloud user isolation, demo bypass, IP fallback, all 4
endpoints returning 429 on excess, retry_after header, and demo guard.

Closes: #122
This commit is contained in:
pyr0ball 2026-06-14 12:14:21 -07:00
parent 3048d8e2f4
commit d801650db1
5 changed files with 347 additions and 5 deletions

View file

@ -69,3 +69,10 @@ CF_SERVER_SECRET= # random 64-char hex — generate: openssl rand -
PLATFORM_DB_URL=postgresql://cf_platform:<password>@host.docker.internal:5433/circuitforge_platform
HEIMDALL_URL=http://cf-license:8000 # internal Docker URL; override for external access
HEIMDALL_ADMIN_TOKEN= # must match ADMIN_TOKEN in circuitforge-license .env
# ── Rate limiting (LLM generation endpoints) ─────────────────────────────────
LLM_RATE_COVER_LETTER=20/hour
LLM_RATE_RESEARCH=10/hour
LLM_RATE_QA_SUGGEST=60/hour
LLM_RATE_SURVEY=30/hour
LLM_RATE_WIZARD=60/hour

View file

@ -39,6 +39,8 @@ if str(PEREGRINE_ROOT) not in sys.path:
from circuitforge_core.api import make_feedback_router as _make_feedback_router # noqa: E402
from circuitforge_core.config.settings import load_env as _load_env # noqa: E402
from scripts.credential_store import get_credential, set_credential # noqa: E402
from scripts.rate_limit import limiter, rate_limit_exceeded_handler # noqa: E402
from slowapi.errors import RateLimitExceeded # noqa: E402
DB_PATH = os.environ.get("STAGING_DB", "/devl/job-seeker/staging.db")
@ -47,6 +49,13 @@ _CLOUD_DATA_ROOT = Path(os.environ.get("CLOUD_DATA_ROOT", "/devl/menagerie-data
_DIRECTUS_SECRET = os.environ.get("DIRECTUS_JWT_SECRET", "")
IS_DEMO: bool = os.environ.get("DEMO_MODE", "").lower() in ("1", "true", "yes")
# ── Rate limiting (LLM generation endpoints) ──────────────────────────────────
_RL_COVER_LETTER = os.environ.get("LLM_RATE_COVER_LETTER", "20/hour")
_RL_RESEARCH = os.environ.get("LLM_RATE_RESEARCH", "10/hour")
_RL_QA_SUGGEST = os.environ.get("LLM_RATE_QA_SUGGEST", "60/hour")
_RL_SURVEY = os.environ.get("LLM_RATE_SURVEY", "30/hour")
_RL_WIZARD = os.environ.get("LLM_RATE_WIZARD", "60/hour") # TODO(#122): wire to wizard/ai/interview after feat/77 merges
# Resolve GPU inference server URL.
# Priority: GPU_SERVER_URL → CF_ORCH_URL (backward compat) → cloud default when licensed.
# Result is written back to CF_ORCH_URL so all downstream callers need no changes.
@ -109,6 +118,8 @@ async def lifespan(app: FastAPI):
app = FastAPI(title="Peregrine Dev API", lifespan=lifespan)
app.state.limiter = limiter
app.add_exception_handler(RateLimitExceeded, rate_limit_exceeded_handler)
app.add_middleware(
CORSMiddleware,
@ -512,7 +523,8 @@ def save_cover_letter(job_id: int, body: CoverLetterBody):
# ── POST /api/jobs/:id/cover_letter/generate ─────────────────────────────────
@app.post("/api/jobs/{job_id}/cover_letter/generate")
def generate_cover_letter(job_id: int):
@limiter.limit(_RL_COVER_LETTER)
def generate_cover_letter(job_id: int, request: Request):
_demo_guard()
try:
from scripts.task_runner import submit_task
@ -565,7 +577,9 @@ def get_research_brief(job_id: int):
@app.post("/api/jobs/{job_id}/research/generate")
def generate_research(job_id: int):
@limiter.limit(_RL_RESEARCH)
def generate_research(job_id: int, request: Request):
_demo_guard()
try:
from scripts.task_runner import submit_task
task_id, is_new = submit_task(db_path=Path(_request_db.get() or DB_PATH), task_type="company_research", job_id=job_id)
@ -1520,7 +1534,8 @@ class SurveyAnalyzeBody(BaseModel):
@app.post("/api/jobs/{job_id}/survey/analyze")
def survey_analyze(job_id: int, body: SurveyAnalyzeBody):
@limiter.limit(_RL_SURVEY)
def survey_analyze(job_id: int, body: SurveyAnalyzeBody, request: Request):
if body.mode not in ("quick", "detailed"):
raise HTTPException(400, f"Invalid mode: {body.mode!r}")
import json as _json
@ -1735,8 +1750,10 @@ def save_qa(job_id: int, payload: QAPayload):
@app.post("/api/jobs/{job_id}/qa/suggest")
def suggest_qa_answer(job_id: int, payload: QASuggestPayload):
@limiter.limit(_RL_QA_SUGGEST)
def suggest_qa_answer(job_id: int, payload: QASuggestPayload, request: Request):
"""Synchronously generate an LLM answer for an application Q&A question."""
_demo_guard()
db = _get_db()
job_row = db.execute(
"SELECT title, company, description FROM jobs WHERE id = ?", (job_id,)
@ -1767,7 +1784,7 @@ def suggest_qa_answer(job_id: int, payload: QASuggestPayload):
parts.append(f"Summary: {resume_data['career_summary'][:400]}")
resume_context = "\n".join(parts)
except Exception:
pass
_log.warning("suggest_qa_answer: failed to load resume context", exc_info=True)
prompt = (
f"You are helping a job applicant answer an application question.\n\n"

View file

@ -63,6 +63,9 @@ dependencies:
# ── Auth / licensing ──────────────────────────────────────────────────────
- PyJWT>=2.8
# ── Rate limiting ─────────────────────────────────────────────────────────
- slowapi>=0.1.9 # per-user rate limiting on LLM endpoints
# ── Utilities ─────────────────────────────────────────────────────────────
- sqlalchemy
- tqdm

32
scripts/rate_limit.py Normal file
View file

@ -0,0 +1,32 @@
"""Per-user rate limiting for Peregrine LLM generation endpoints."""
from pathlib import Path
from slowapi import Limiter
from slowapi.errors import RateLimitExceeded
from slowapi.util import get_remote_address
from starlette.requests import Request
from starlette.responses import JSONResponse
def _rate_key(request: Request) -> str:
"""Cloud mode: user_id from DB path. Local mode: client IP. Demo: unique key (no rate limit)."""
from dev_api import IS_DEMO, _CLOUD_MODE, _request_db # lazy import avoids circular
if IS_DEMO:
return f"demo-{id(request)}" # unique per request — effectively no rate limiting
db_path = _request_db.get()
if _CLOUD_MODE and db_path:
return Path(db_path).parts[-3] # user_id segment
return get_remote_address(request)
limiter = Limiter(key_func=_rate_key)
def rate_limit_exceeded_handler(request: Request, exc: RateLimitExceeded) -> JSONResponse:
"""Return 429 with Retry-After header."""
retry_after = getattr(exc, "retry_after", 60)
return JSONResponse(
status_code=429,
content={"error": "rate_limit_exceeded", "retry_after": retry_after},
headers={"Retry-After": str(retry_after)},
)

283
tests/test_rate_limiting.py Normal file
View file

@ -0,0 +1,283 @@
"""Tests for per-user rate limiting on LLM generation endpoints.
Covers:
- _rate_key() in demo mode returns unique per-request key (no rate limiting)
- _rate_key() in cloud mode returns user_id segment from DB path
- _rate_key() in local mode falls back to client IP address
- rate_limit_exceeded_handler() returns 429 with Retry-After header
- Integration: hitting rate limit on a decorated endpoint returns 429
"""
import json
import sqlite3
from pathlib import Path
from unittest.mock import MagicMock, patch
import pytest
from fastapi.testclient import TestClient
from limits import parse as _limits_parse
from slowapi.errors import RateLimitExceeded
from slowapi.wrappers import Limit as _LimitWrapper
from starlette.requests import Request
from scripts.rate_limit import _rate_key, rate_limit_exceeded_handler
# ── Helpers ───────────────────────────────────────────────────────────────────
def _make_request(client_ip: str = "1.2.3.4") -> MagicMock:
"""Return a minimal mock Request with a client IP."""
req = MagicMock(spec=Request)
req.client = MagicMock()
req.client.host = client_ip
req.headers = {}
req.scope = {"type": "http"}
return req
def _make_rate_limit_exceeded(spec: str = "20/hour") -> RateLimitExceeded:
"""Construct a valid RateLimitExceeded (slowapi 0.1.9+ requires a Limit wrapper)."""
limit_item = _limits_parse(spec)
wrapper = _LimitWrapper(
limit=limit_item,
key_func=lambda r: "test",
scope=None,
per_method=False,
methods=None,
error_message=None,
exempt_when=None,
cost=1,
override_defaults=False,
)
return RateLimitExceeded(wrapper)
# ── Fixtures ──────────────────────────────────────────────────────────────────
@pytest.fixture()
def tmp_db(tmp_path):
"""Create a minimal staging.db in tmp_path and return its string path."""
db_path = tmp_path / "staging.db"
con = sqlite3.connect(str(db_path))
con.executescript("""
CREATE TABLE IF NOT EXISTS jobs (
id INTEGER PRIMARY KEY,
title TEXT, company TEXT, url TEXT, location TEXT,
is_remote INTEGER DEFAULT 0, salary TEXT,
match_score REAL, keyword_gaps TEXT, status TEXT,
interview_date TEXT, rejection_stage TEXT,
applied_at TEXT, phone_screen_at TEXT, interviewing_at TEXT,
offer_at TEXT, hired_at TEXT, survey_at TEXT
);
CREATE TABLE IF NOT EXISTS background_tasks (
id INTEGER PRIMARY KEY,
task_type TEXT,
job_id INTEGER,
status TEXT DEFAULT 'queued',
stage TEXT,
error TEXT,
params TEXT,
finished_at TEXT
);
""")
con.close()
return str(db_path)
@pytest.fixture()
def client(tmp_db, monkeypatch):
"""TestClient wired to a fresh isolated DB."""
monkeypatch.setenv("STAGING_DB", tmp_db)
import dev_api
monkeypatch.setattr(dev_api, "DB_PATH", tmp_db)
monkeypatch.setattr(
dev_api,
"_request_db",
type("CV", (), {"get": lambda self: tmp_db, "set": lambda *a: None})(),
)
return TestClient(dev_api.app)
# ── _rate_key(): demo mode ────────────────────────────────────────────────────
class TestRateKeyDemoMode:
def test_returns_unique_key_per_request(self):
"""In demo mode each request gets a unique key so no limiting occurs."""
req1 = _make_request()
req2 = _make_request()
with patch("dev_api.IS_DEMO", True), patch("dev_api._CLOUD_MODE", False):
key1 = _rate_key(req1)
key2 = _rate_key(req2)
assert key1.startswith("demo-")
assert key2.startswith("demo-")
assert key1 != key2 # unique per request object
def test_key_does_not_use_client_ip(self):
"""Demo key must not equal the client IP."""
req = _make_request(client_ip="9.9.9.9")
with patch("dev_api.IS_DEMO", True), patch("dev_api._CLOUD_MODE", False):
key = _rate_key(req)
assert "9.9.9.9" not in key
# ── _rate_key(): cloud mode ───────────────────────────────────────────────────
class TestRateKeyCloudMode:
def test_returns_user_id_from_db_path(self, tmp_path):
"""Cloud mode extracts user_id (3rd-from-end path segment)."""
cloud_db = str(tmp_path / "abc-user-123" / "peregrine" / "staging.db")
req = _make_request()
with (
patch("dev_api.IS_DEMO", False),
patch("dev_api._CLOUD_MODE", True),
patch("dev_api._request_db") as mock_cv,
):
mock_cv.get.return_value = cloud_db
key = _rate_key(req)
assert key == "abc-user-123"
def test_falls_back_to_ip_when_db_path_is_none(self):
"""Cloud mode without a DB path (unauthenticated) falls back to IP."""
req = _make_request(client_ip="10.0.0.1")
with (
patch("dev_api.IS_DEMO", False),
patch("dev_api._CLOUD_MODE", True),
patch("dev_api._request_db") as mock_cv,
):
mock_cv.get.return_value = None
key = _rate_key(req)
assert key == "10.0.0.1"
# ── _rate_key(): local mode ───────────────────────────────────────────────────
class TestRateKeyLocalMode:
def test_returns_client_ip(self):
"""Local (non-cloud, non-demo) mode uses the remote client IP."""
req = _make_request(client_ip="192.168.1.50")
with patch("dev_api.IS_DEMO", False), patch("dev_api._CLOUD_MODE", False):
key = _rate_key(req)
assert key == "192.168.1.50"
def test_different_ips_produce_different_keys(self):
"""Two distinct client IPs produce distinct rate limit keys."""
req_a = _make_request(client_ip="10.0.0.1")
req_b = _make_request(client_ip="10.0.0.2")
with patch("dev_api.IS_DEMO", False), patch("dev_api._CLOUD_MODE", False):
key_a = _rate_key(req_a)
key_b = _rate_key(req_b)
assert key_a != key_b
# ── rate_limit_exceeded_handler() ─────────────────────────────────────────────
class TestRateLimitExceededHandler:
def test_returns_429_status(self):
"""Handler always returns HTTP 429."""
req = _make_request()
exc = _make_rate_limit_exceeded("20/hour")
response = rate_limit_exceeded_handler(req, exc)
assert response.status_code == 429
def test_body_has_error_field(self):
"""Response body includes error: rate_limit_exceeded."""
req = _make_request()
exc = _make_rate_limit_exceeded("20/hour")
response = rate_limit_exceeded_handler(req, exc)
body = json.loads(response.body)
assert body["error"] == "rate_limit_exceeded"
def test_body_has_retry_after_field(self):
"""Response body includes retry_after value."""
req = _make_request()
exc = _make_rate_limit_exceeded("20/hour")
response = rate_limit_exceeded_handler(req, exc)
body = json.loads(response.body)
assert "retry_after" in body
def test_retry_after_header_present(self):
"""Retry-After HTTP header is set on the response."""
req = _make_request()
exc = _make_rate_limit_exceeded("20/hour")
response = rate_limit_exceeded_handler(req, exc)
assert "Retry-After" in response.headers
def test_retry_after_header_matches_body(self):
"""Retry-After header value matches the retry_after field in the body."""
req = _make_request()
exc = _make_rate_limit_exceeded("20/hour")
response = rate_limit_exceeded_handler(req, exc)
body = json.loads(response.body)
assert response.headers["Retry-After"] == str(body["retry_after"])
# ── Integration: 429 on rate-limited endpoints ────────────────────────────────
def _patch_limiter_to_raise(exc: RateLimitExceeded):
"""Context manager: make the slowapi limiter fire for any request."""
return patch(
"slowapi.extension.Limiter._check_request_limit",
side_effect=exc,
)
class TestRateLimitIntegration:
"""Verify that when the limiter fires, the app returns 429 via the exception handler."""
def test_cover_letter_generate_returns_429_on_limit(self, client):
"""When the rate limiter triggers, the cover letter endpoint returns 429."""
exc = _make_rate_limit_exceeded("20/hour")
with _patch_limiter_to_raise(exc):
resp = client.post("/api/jobs/1/cover_letter/generate")
assert resp.status_code == 429
def test_research_generate_returns_429_on_limit(self, client):
"""When the rate limiter triggers, the research endpoint returns 429."""
exc = _make_rate_limit_exceeded("10/hour")
with _patch_limiter_to_raise(exc):
resp = client.post("/api/jobs/1/research/generate")
assert resp.status_code == 429
def test_qa_suggest_returns_429_on_limit(self, client):
"""When the rate limiter triggers, the QA suggest endpoint returns 429."""
exc = _make_rate_limit_exceeded("60/hour")
with _patch_limiter_to_raise(exc):
resp = client.post(
"/api/jobs/1/qa/suggest",
json={"question": "Why do you want this job?", "items": []},
)
assert resp.status_code == 429
def test_survey_analyze_returns_429_on_limit(self, client):
"""When the rate limiter triggers, the survey analyze endpoint returns 429."""
exc = _make_rate_limit_exceeded("30/hour")
with _patch_limiter_to_raise(exc):
resp = client.post(
"/api/jobs/1/survey/analyze",
json={"text": "Q: ...", "mode": "quick"},
)
assert resp.status_code == 429
def test_cover_letter_generate_succeeds_when_not_limited(self, client):
"""Cover letter generate endpoint works normally when not rate-limited."""
with patch("scripts.task_runner.submit_task", return_value=(1, True)):
resp = client.post("/api/jobs/1/cover_letter/generate")
# 200 = task queued; 403 = demo/cloud guard; 404/422 = DB/payload issue
# Any non-5xx, non-429 response means the limiter did NOT block the request
assert resp.status_code in (200, 403, 404, 422)
def test_429_response_body_has_error_key(self, client):
"""429 responses from rate-limited endpoints include the error key."""
exc = _make_rate_limit_exceeded("20/hour")
with _patch_limiter_to_raise(exc):
resp = client.post("/api/jobs/1/cover_letter/generate")
assert resp.status_code == 429
body = resp.json()
assert body.get("error") == "rate_limit_exceeded"
def test_429_response_has_retry_after_header(self, client):
"""429 responses include a Retry-After header."""
exc = _make_rate_limit_exceeded("20/hour")
with _patch_limiter_to_raise(exc):
resp = client.post("/api/jobs/1/cover_letter/generate")
assert resp.status_code == 429
assert "retry-after" in resp.headers or "Retry-After" in resp.headers