feat(api): per-user LLM rate limiting via slowapi
Add scripts/rate_limit.py with cloud-aware key function: - In cloud mode, extracts user_id from _request_db ContextVar path (part[-3]) so each cloud user has their own rate limit bucket - In demo mode, returns unique per-request key to disable limiting entirely (_demo_guard handles write-blocking; rate limiting would block the demo UX) - Falls back to client IP for local/self-hosted installs Wire limiter to 4 endpoints with conservative per-user limits: - POST /generate/cover-letter: 20/hour - POST /research/run: 10/hour - POST /qa/suggest: 60/hour - POST /survey/analyze: 30/hour Add _demo_guard() to generate_research and suggest_qa_answer (was missing). Fix pre-existing silent except in suggest_qa_answer: was bare except pass, now logs warning with exc_info. Add _RL_WIZARD placeholder constant with TODO to wire to wizard/ai/interview after feat/77 merges (declared but intentionally not applied yet to avoid false sense of security — comment makes the gap explicit). 18 tests covering cloud user isolation, demo bypass, IP fallback, all 4 endpoints returning 429 on excess, retry_after header, and demo guard. Closes: #122
This commit is contained in:
parent
3048d8e2f4
commit
d801650db1
5 changed files with 347 additions and 5 deletions
|
|
@ -69,3 +69,10 @@ CF_SERVER_SECRET= # random 64-char hex — generate: openssl rand -
|
||||||
PLATFORM_DB_URL=postgresql://cf_platform:<password>@host.docker.internal:5433/circuitforge_platform
|
PLATFORM_DB_URL=postgresql://cf_platform:<password>@host.docker.internal:5433/circuitforge_platform
|
||||||
HEIMDALL_URL=http://cf-license:8000 # internal Docker URL; override for external access
|
HEIMDALL_URL=http://cf-license:8000 # internal Docker URL; override for external access
|
||||||
HEIMDALL_ADMIN_TOKEN= # must match ADMIN_TOKEN in circuitforge-license .env
|
HEIMDALL_ADMIN_TOKEN= # must match ADMIN_TOKEN in circuitforge-license .env
|
||||||
|
|
||||||
|
# ── Rate limiting (LLM generation endpoints) ─────────────────────────────────
|
||||||
|
LLM_RATE_COVER_LETTER=20/hour
|
||||||
|
LLM_RATE_RESEARCH=10/hour
|
||||||
|
LLM_RATE_QA_SUGGEST=60/hour
|
||||||
|
LLM_RATE_SURVEY=30/hour
|
||||||
|
LLM_RATE_WIZARD=60/hour
|
||||||
|
|
|
||||||
27
dev-api.py
27
dev-api.py
|
|
@ -39,6 +39,8 @@ if str(PEREGRINE_ROOT) not in sys.path:
|
||||||
from circuitforge_core.api import make_feedback_router as _make_feedback_router # noqa: E402
|
from circuitforge_core.api import make_feedback_router as _make_feedback_router # noqa: E402
|
||||||
from circuitforge_core.config.settings import load_env as _load_env # noqa: E402
|
from circuitforge_core.config.settings import load_env as _load_env # noqa: E402
|
||||||
from scripts.credential_store import get_credential, set_credential # noqa: E402
|
from scripts.credential_store import get_credential, set_credential # noqa: E402
|
||||||
|
from scripts.rate_limit import limiter, rate_limit_exceeded_handler # noqa: E402
|
||||||
|
from slowapi.errors import RateLimitExceeded # noqa: E402
|
||||||
|
|
||||||
DB_PATH = os.environ.get("STAGING_DB", "/devl/job-seeker/staging.db")
|
DB_PATH = os.environ.get("STAGING_DB", "/devl/job-seeker/staging.db")
|
||||||
|
|
||||||
|
|
@ -47,6 +49,13 @@ _CLOUD_DATA_ROOT = Path(os.environ.get("CLOUD_DATA_ROOT", "/devl/menagerie-data
|
||||||
_DIRECTUS_SECRET = os.environ.get("DIRECTUS_JWT_SECRET", "")
|
_DIRECTUS_SECRET = os.environ.get("DIRECTUS_JWT_SECRET", "")
|
||||||
IS_DEMO: bool = os.environ.get("DEMO_MODE", "").lower() in ("1", "true", "yes")
|
IS_DEMO: bool = os.environ.get("DEMO_MODE", "").lower() in ("1", "true", "yes")
|
||||||
|
|
||||||
|
# ── Rate limiting (LLM generation endpoints) ──────────────────────────────────
|
||||||
|
_RL_COVER_LETTER = os.environ.get("LLM_RATE_COVER_LETTER", "20/hour")
|
||||||
|
_RL_RESEARCH = os.environ.get("LLM_RATE_RESEARCH", "10/hour")
|
||||||
|
_RL_QA_SUGGEST = os.environ.get("LLM_RATE_QA_SUGGEST", "60/hour")
|
||||||
|
_RL_SURVEY = os.environ.get("LLM_RATE_SURVEY", "30/hour")
|
||||||
|
_RL_WIZARD = os.environ.get("LLM_RATE_WIZARD", "60/hour") # TODO(#122): wire to wizard/ai/interview after feat/77 merges
|
||||||
|
|
||||||
# Resolve GPU inference server URL.
|
# Resolve GPU inference server URL.
|
||||||
# Priority: GPU_SERVER_URL → CF_ORCH_URL (backward compat) → cloud default when licensed.
|
# Priority: GPU_SERVER_URL → CF_ORCH_URL (backward compat) → cloud default when licensed.
|
||||||
# Result is written back to CF_ORCH_URL so all downstream callers need no changes.
|
# Result is written back to CF_ORCH_URL so all downstream callers need no changes.
|
||||||
|
|
@ -109,6 +118,8 @@ async def lifespan(app: FastAPI):
|
||||||
|
|
||||||
|
|
||||||
app = FastAPI(title="Peregrine Dev API", lifespan=lifespan)
|
app = FastAPI(title="Peregrine Dev API", lifespan=lifespan)
|
||||||
|
app.state.limiter = limiter
|
||||||
|
app.add_exception_handler(RateLimitExceeded, rate_limit_exceeded_handler)
|
||||||
|
|
||||||
app.add_middleware(
|
app.add_middleware(
|
||||||
CORSMiddleware,
|
CORSMiddleware,
|
||||||
|
|
@ -512,7 +523,8 @@ def save_cover_letter(job_id: int, body: CoverLetterBody):
|
||||||
# ── POST /api/jobs/:id/cover_letter/generate ─────────────────────────────────
|
# ── POST /api/jobs/:id/cover_letter/generate ─────────────────────────────────
|
||||||
|
|
||||||
@app.post("/api/jobs/{job_id}/cover_letter/generate")
|
@app.post("/api/jobs/{job_id}/cover_letter/generate")
|
||||||
def generate_cover_letter(job_id: int):
|
@limiter.limit(_RL_COVER_LETTER)
|
||||||
|
def generate_cover_letter(job_id: int, request: Request):
|
||||||
_demo_guard()
|
_demo_guard()
|
||||||
try:
|
try:
|
||||||
from scripts.task_runner import submit_task
|
from scripts.task_runner import submit_task
|
||||||
|
|
@ -565,7 +577,9 @@ def get_research_brief(job_id: int):
|
||||||
|
|
||||||
|
|
||||||
@app.post("/api/jobs/{job_id}/research/generate")
|
@app.post("/api/jobs/{job_id}/research/generate")
|
||||||
def generate_research(job_id: int):
|
@limiter.limit(_RL_RESEARCH)
|
||||||
|
def generate_research(job_id: int, request: Request):
|
||||||
|
_demo_guard()
|
||||||
try:
|
try:
|
||||||
from scripts.task_runner import submit_task
|
from scripts.task_runner import submit_task
|
||||||
task_id, is_new = submit_task(db_path=Path(_request_db.get() or DB_PATH), task_type="company_research", job_id=job_id)
|
task_id, is_new = submit_task(db_path=Path(_request_db.get() or DB_PATH), task_type="company_research", job_id=job_id)
|
||||||
|
|
@ -1520,7 +1534,8 @@ class SurveyAnalyzeBody(BaseModel):
|
||||||
|
|
||||||
|
|
||||||
@app.post("/api/jobs/{job_id}/survey/analyze")
|
@app.post("/api/jobs/{job_id}/survey/analyze")
|
||||||
def survey_analyze(job_id: int, body: SurveyAnalyzeBody):
|
@limiter.limit(_RL_SURVEY)
|
||||||
|
def survey_analyze(job_id: int, body: SurveyAnalyzeBody, request: Request):
|
||||||
if body.mode not in ("quick", "detailed"):
|
if body.mode not in ("quick", "detailed"):
|
||||||
raise HTTPException(400, f"Invalid mode: {body.mode!r}")
|
raise HTTPException(400, f"Invalid mode: {body.mode!r}")
|
||||||
import json as _json
|
import json as _json
|
||||||
|
|
@ -1735,8 +1750,10 @@ def save_qa(job_id: int, payload: QAPayload):
|
||||||
|
|
||||||
|
|
||||||
@app.post("/api/jobs/{job_id}/qa/suggest")
|
@app.post("/api/jobs/{job_id}/qa/suggest")
|
||||||
def suggest_qa_answer(job_id: int, payload: QASuggestPayload):
|
@limiter.limit(_RL_QA_SUGGEST)
|
||||||
|
def suggest_qa_answer(job_id: int, payload: QASuggestPayload, request: Request):
|
||||||
"""Synchronously generate an LLM answer for an application Q&A question."""
|
"""Synchronously generate an LLM answer for an application Q&A question."""
|
||||||
|
_demo_guard()
|
||||||
db = _get_db()
|
db = _get_db()
|
||||||
job_row = db.execute(
|
job_row = db.execute(
|
||||||
"SELECT title, company, description FROM jobs WHERE id = ?", (job_id,)
|
"SELECT title, company, description FROM jobs WHERE id = ?", (job_id,)
|
||||||
|
|
@ -1767,7 +1784,7 @@ def suggest_qa_answer(job_id: int, payload: QASuggestPayload):
|
||||||
parts.append(f"Summary: {resume_data['career_summary'][:400]}")
|
parts.append(f"Summary: {resume_data['career_summary'][:400]}")
|
||||||
resume_context = "\n".join(parts)
|
resume_context = "\n".join(parts)
|
||||||
except Exception:
|
except Exception:
|
||||||
pass
|
_log.warning("suggest_qa_answer: failed to load resume context", exc_info=True)
|
||||||
|
|
||||||
prompt = (
|
prompt = (
|
||||||
f"You are helping a job applicant answer an application question.\n\n"
|
f"You are helping a job applicant answer an application question.\n\n"
|
||||||
|
|
|
||||||
|
|
@ -63,6 +63,9 @@ dependencies:
|
||||||
# ── Auth / licensing ──────────────────────────────────────────────────────
|
# ── Auth / licensing ──────────────────────────────────────────────────────
|
||||||
- PyJWT>=2.8
|
- PyJWT>=2.8
|
||||||
|
|
||||||
|
# ── Rate limiting ─────────────────────────────────────────────────────────
|
||||||
|
- slowapi>=0.1.9 # per-user rate limiting on LLM endpoints
|
||||||
|
|
||||||
# ── Utilities ─────────────────────────────────────────────────────────────
|
# ── Utilities ─────────────────────────────────────────────────────────────
|
||||||
- sqlalchemy
|
- sqlalchemy
|
||||||
- tqdm
|
- tqdm
|
||||||
|
|
|
||||||
32
scripts/rate_limit.py
Normal file
32
scripts/rate_limit.py
Normal file
|
|
@ -0,0 +1,32 @@
|
||||||
|
"""Per-user rate limiting for Peregrine LLM generation endpoints."""
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
from slowapi import Limiter
|
||||||
|
from slowapi.errors import RateLimitExceeded
|
||||||
|
from slowapi.util import get_remote_address
|
||||||
|
from starlette.requests import Request
|
||||||
|
from starlette.responses import JSONResponse
|
||||||
|
|
||||||
|
|
||||||
|
def _rate_key(request: Request) -> str:
|
||||||
|
"""Cloud mode: user_id from DB path. Local mode: client IP. Demo: unique key (no rate limit)."""
|
||||||
|
from dev_api import IS_DEMO, _CLOUD_MODE, _request_db # lazy import avoids circular
|
||||||
|
if IS_DEMO:
|
||||||
|
return f"demo-{id(request)}" # unique per request — effectively no rate limiting
|
||||||
|
db_path = _request_db.get()
|
||||||
|
if _CLOUD_MODE and db_path:
|
||||||
|
return Path(db_path).parts[-3] # user_id segment
|
||||||
|
return get_remote_address(request)
|
||||||
|
|
||||||
|
|
||||||
|
limiter = Limiter(key_func=_rate_key)
|
||||||
|
|
||||||
|
|
||||||
|
def rate_limit_exceeded_handler(request: Request, exc: RateLimitExceeded) -> JSONResponse:
|
||||||
|
"""Return 429 with Retry-After header."""
|
||||||
|
retry_after = getattr(exc, "retry_after", 60)
|
||||||
|
return JSONResponse(
|
||||||
|
status_code=429,
|
||||||
|
content={"error": "rate_limit_exceeded", "retry_after": retry_after},
|
||||||
|
headers={"Retry-After": str(retry_after)},
|
||||||
|
)
|
||||||
283
tests/test_rate_limiting.py
Normal file
283
tests/test_rate_limiting.py
Normal file
|
|
@ -0,0 +1,283 @@
|
||||||
|
"""Tests for per-user rate limiting on LLM generation endpoints.
|
||||||
|
|
||||||
|
Covers:
|
||||||
|
- _rate_key() in demo mode returns unique per-request key (no rate limiting)
|
||||||
|
- _rate_key() in cloud mode returns user_id segment from DB path
|
||||||
|
- _rate_key() in local mode falls back to client IP address
|
||||||
|
- rate_limit_exceeded_handler() returns 429 with Retry-After header
|
||||||
|
- Integration: hitting rate limit on a decorated endpoint returns 429
|
||||||
|
"""
|
||||||
|
import json
|
||||||
|
import sqlite3
|
||||||
|
from pathlib import Path
|
||||||
|
from unittest.mock import MagicMock, patch
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from fastapi.testclient import TestClient
|
||||||
|
from limits import parse as _limits_parse
|
||||||
|
from slowapi.errors import RateLimitExceeded
|
||||||
|
from slowapi.wrappers import Limit as _LimitWrapper
|
||||||
|
from starlette.requests import Request
|
||||||
|
|
||||||
|
from scripts.rate_limit import _rate_key, rate_limit_exceeded_handler
|
||||||
|
|
||||||
|
|
||||||
|
# ── Helpers ───────────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
def _make_request(client_ip: str = "1.2.3.4") -> MagicMock:
|
||||||
|
"""Return a minimal mock Request with a client IP."""
|
||||||
|
req = MagicMock(spec=Request)
|
||||||
|
req.client = MagicMock()
|
||||||
|
req.client.host = client_ip
|
||||||
|
req.headers = {}
|
||||||
|
req.scope = {"type": "http"}
|
||||||
|
return req
|
||||||
|
|
||||||
|
|
||||||
|
def _make_rate_limit_exceeded(spec: str = "20/hour") -> RateLimitExceeded:
|
||||||
|
"""Construct a valid RateLimitExceeded (slowapi 0.1.9+ requires a Limit wrapper)."""
|
||||||
|
limit_item = _limits_parse(spec)
|
||||||
|
wrapper = _LimitWrapper(
|
||||||
|
limit=limit_item,
|
||||||
|
key_func=lambda r: "test",
|
||||||
|
scope=None,
|
||||||
|
per_method=False,
|
||||||
|
methods=None,
|
||||||
|
error_message=None,
|
||||||
|
exempt_when=None,
|
||||||
|
cost=1,
|
||||||
|
override_defaults=False,
|
||||||
|
)
|
||||||
|
return RateLimitExceeded(wrapper)
|
||||||
|
|
||||||
|
|
||||||
|
# ── Fixtures ──────────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
@pytest.fixture()
|
||||||
|
def tmp_db(tmp_path):
|
||||||
|
"""Create a minimal staging.db in tmp_path and return its string path."""
|
||||||
|
db_path = tmp_path / "staging.db"
|
||||||
|
con = sqlite3.connect(str(db_path))
|
||||||
|
con.executescript("""
|
||||||
|
CREATE TABLE IF NOT EXISTS jobs (
|
||||||
|
id INTEGER PRIMARY KEY,
|
||||||
|
title TEXT, company TEXT, url TEXT, location TEXT,
|
||||||
|
is_remote INTEGER DEFAULT 0, salary TEXT,
|
||||||
|
match_score REAL, keyword_gaps TEXT, status TEXT,
|
||||||
|
interview_date TEXT, rejection_stage TEXT,
|
||||||
|
applied_at TEXT, phone_screen_at TEXT, interviewing_at TEXT,
|
||||||
|
offer_at TEXT, hired_at TEXT, survey_at TEXT
|
||||||
|
);
|
||||||
|
CREATE TABLE IF NOT EXISTS background_tasks (
|
||||||
|
id INTEGER PRIMARY KEY,
|
||||||
|
task_type TEXT,
|
||||||
|
job_id INTEGER,
|
||||||
|
status TEXT DEFAULT 'queued',
|
||||||
|
stage TEXT,
|
||||||
|
error TEXT,
|
||||||
|
params TEXT,
|
||||||
|
finished_at TEXT
|
||||||
|
);
|
||||||
|
""")
|
||||||
|
con.close()
|
||||||
|
return str(db_path)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture()
|
||||||
|
def client(tmp_db, monkeypatch):
|
||||||
|
"""TestClient wired to a fresh isolated DB."""
|
||||||
|
monkeypatch.setenv("STAGING_DB", tmp_db)
|
||||||
|
import dev_api
|
||||||
|
monkeypatch.setattr(dev_api, "DB_PATH", tmp_db)
|
||||||
|
monkeypatch.setattr(
|
||||||
|
dev_api,
|
||||||
|
"_request_db",
|
||||||
|
type("CV", (), {"get": lambda self: tmp_db, "set": lambda *a: None})(),
|
||||||
|
)
|
||||||
|
return TestClient(dev_api.app)
|
||||||
|
|
||||||
|
|
||||||
|
# ── _rate_key(): demo mode ────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
class TestRateKeyDemoMode:
|
||||||
|
def test_returns_unique_key_per_request(self):
|
||||||
|
"""In demo mode each request gets a unique key so no limiting occurs."""
|
||||||
|
req1 = _make_request()
|
||||||
|
req2 = _make_request()
|
||||||
|
with patch("dev_api.IS_DEMO", True), patch("dev_api._CLOUD_MODE", False):
|
||||||
|
key1 = _rate_key(req1)
|
||||||
|
key2 = _rate_key(req2)
|
||||||
|
assert key1.startswith("demo-")
|
||||||
|
assert key2.startswith("demo-")
|
||||||
|
assert key1 != key2 # unique per request object
|
||||||
|
|
||||||
|
def test_key_does_not_use_client_ip(self):
|
||||||
|
"""Demo key must not equal the client IP."""
|
||||||
|
req = _make_request(client_ip="9.9.9.9")
|
||||||
|
with patch("dev_api.IS_DEMO", True), patch("dev_api._CLOUD_MODE", False):
|
||||||
|
key = _rate_key(req)
|
||||||
|
assert "9.9.9.9" not in key
|
||||||
|
|
||||||
|
|
||||||
|
# ── _rate_key(): cloud mode ───────────────────────────────────────────────────
|
||||||
|
|
||||||
|
class TestRateKeyCloudMode:
|
||||||
|
def test_returns_user_id_from_db_path(self, tmp_path):
|
||||||
|
"""Cloud mode extracts user_id (3rd-from-end path segment)."""
|
||||||
|
cloud_db = str(tmp_path / "abc-user-123" / "peregrine" / "staging.db")
|
||||||
|
req = _make_request()
|
||||||
|
with (
|
||||||
|
patch("dev_api.IS_DEMO", False),
|
||||||
|
patch("dev_api._CLOUD_MODE", True),
|
||||||
|
patch("dev_api._request_db") as mock_cv,
|
||||||
|
):
|
||||||
|
mock_cv.get.return_value = cloud_db
|
||||||
|
key = _rate_key(req)
|
||||||
|
assert key == "abc-user-123"
|
||||||
|
|
||||||
|
def test_falls_back_to_ip_when_db_path_is_none(self):
|
||||||
|
"""Cloud mode without a DB path (unauthenticated) falls back to IP."""
|
||||||
|
req = _make_request(client_ip="10.0.0.1")
|
||||||
|
with (
|
||||||
|
patch("dev_api.IS_DEMO", False),
|
||||||
|
patch("dev_api._CLOUD_MODE", True),
|
||||||
|
patch("dev_api._request_db") as mock_cv,
|
||||||
|
):
|
||||||
|
mock_cv.get.return_value = None
|
||||||
|
key = _rate_key(req)
|
||||||
|
assert key == "10.0.0.1"
|
||||||
|
|
||||||
|
|
||||||
|
# ── _rate_key(): local mode ───────────────────────────────────────────────────
|
||||||
|
|
||||||
|
class TestRateKeyLocalMode:
|
||||||
|
def test_returns_client_ip(self):
|
||||||
|
"""Local (non-cloud, non-demo) mode uses the remote client IP."""
|
||||||
|
req = _make_request(client_ip="192.168.1.50")
|
||||||
|
with patch("dev_api.IS_DEMO", False), patch("dev_api._CLOUD_MODE", False):
|
||||||
|
key = _rate_key(req)
|
||||||
|
assert key == "192.168.1.50"
|
||||||
|
|
||||||
|
def test_different_ips_produce_different_keys(self):
|
||||||
|
"""Two distinct client IPs produce distinct rate limit keys."""
|
||||||
|
req_a = _make_request(client_ip="10.0.0.1")
|
||||||
|
req_b = _make_request(client_ip="10.0.0.2")
|
||||||
|
with patch("dev_api.IS_DEMO", False), patch("dev_api._CLOUD_MODE", False):
|
||||||
|
key_a = _rate_key(req_a)
|
||||||
|
key_b = _rate_key(req_b)
|
||||||
|
assert key_a != key_b
|
||||||
|
|
||||||
|
|
||||||
|
# ── rate_limit_exceeded_handler() ─────────────────────────────────────────────
|
||||||
|
|
||||||
|
class TestRateLimitExceededHandler:
|
||||||
|
def test_returns_429_status(self):
|
||||||
|
"""Handler always returns HTTP 429."""
|
||||||
|
req = _make_request()
|
||||||
|
exc = _make_rate_limit_exceeded("20/hour")
|
||||||
|
response = rate_limit_exceeded_handler(req, exc)
|
||||||
|
assert response.status_code == 429
|
||||||
|
|
||||||
|
def test_body_has_error_field(self):
|
||||||
|
"""Response body includes error: rate_limit_exceeded."""
|
||||||
|
req = _make_request()
|
||||||
|
exc = _make_rate_limit_exceeded("20/hour")
|
||||||
|
response = rate_limit_exceeded_handler(req, exc)
|
||||||
|
body = json.loads(response.body)
|
||||||
|
assert body["error"] == "rate_limit_exceeded"
|
||||||
|
|
||||||
|
def test_body_has_retry_after_field(self):
|
||||||
|
"""Response body includes retry_after value."""
|
||||||
|
req = _make_request()
|
||||||
|
exc = _make_rate_limit_exceeded("20/hour")
|
||||||
|
response = rate_limit_exceeded_handler(req, exc)
|
||||||
|
body = json.loads(response.body)
|
||||||
|
assert "retry_after" in body
|
||||||
|
|
||||||
|
def test_retry_after_header_present(self):
|
||||||
|
"""Retry-After HTTP header is set on the response."""
|
||||||
|
req = _make_request()
|
||||||
|
exc = _make_rate_limit_exceeded("20/hour")
|
||||||
|
response = rate_limit_exceeded_handler(req, exc)
|
||||||
|
assert "Retry-After" in response.headers
|
||||||
|
|
||||||
|
def test_retry_after_header_matches_body(self):
|
||||||
|
"""Retry-After header value matches the retry_after field in the body."""
|
||||||
|
req = _make_request()
|
||||||
|
exc = _make_rate_limit_exceeded("20/hour")
|
||||||
|
response = rate_limit_exceeded_handler(req, exc)
|
||||||
|
body = json.loads(response.body)
|
||||||
|
assert response.headers["Retry-After"] == str(body["retry_after"])
|
||||||
|
|
||||||
|
|
||||||
|
# ── Integration: 429 on rate-limited endpoints ────────────────────────────────
|
||||||
|
|
||||||
|
def _patch_limiter_to_raise(exc: RateLimitExceeded):
|
||||||
|
"""Context manager: make the slowapi limiter fire for any request."""
|
||||||
|
return patch(
|
||||||
|
"slowapi.extension.Limiter._check_request_limit",
|
||||||
|
side_effect=exc,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class TestRateLimitIntegration:
|
||||||
|
"""Verify that when the limiter fires, the app returns 429 via the exception handler."""
|
||||||
|
|
||||||
|
def test_cover_letter_generate_returns_429_on_limit(self, client):
|
||||||
|
"""When the rate limiter triggers, the cover letter endpoint returns 429."""
|
||||||
|
exc = _make_rate_limit_exceeded("20/hour")
|
||||||
|
with _patch_limiter_to_raise(exc):
|
||||||
|
resp = client.post("/api/jobs/1/cover_letter/generate")
|
||||||
|
assert resp.status_code == 429
|
||||||
|
|
||||||
|
def test_research_generate_returns_429_on_limit(self, client):
|
||||||
|
"""When the rate limiter triggers, the research endpoint returns 429."""
|
||||||
|
exc = _make_rate_limit_exceeded("10/hour")
|
||||||
|
with _patch_limiter_to_raise(exc):
|
||||||
|
resp = client.post("/api/jobs/1/research/generate")
|
||||||
|
assert resp.status_code == 429
|
||||||
|
|
||||||
|
def test_qa_suggest_returns_429_on_limit(self, client):
|
||||||
|
"""When the rate limiter triggers, the QA suggest endpoint returns 429."""
|
||||||
|
exc = _make_rate_limit_exceeded("60/hour")
|
||||||
|
with _patch_limiter_to_raise(exc):
|
||||||
|
resp = client.post(
|
||||||
|
"/api/jobs/1/qa/suggest",
|
||||||
|
json={"question": "Why do you want this job?", "items": []},
|
||||||
|
)
|
||||||
|
assert resp.status_code == 429
|
||||||
|
|
||||||
|
def test_survey_analyze_returns_429_on_limit(self, client):
|
||||||
|
"""When the rate limiter triggers, the survey analyze endpoint returns 429."""
|
||||||
|
exc = _make_rate_limit_exceeded("30/hour")
|
||||||
|
with _patch_limiter_to_raise(exc):
|
||||||
|
resp = client.post(
|
||||||
|
"/api/jobs/1/survey/analyze",
|
||||||
|
json={"text": "Q: ...", "mode": "quick"},
|
||||||
|
)
|
||||||
|
assert resp.status_code == 429
|
||||||
|
|
||||||
|
def test_cover_letter_generate_succeeds_when_not_limited(self, client):
|
||||||
|
"""Cover letter generate endpoint works normally when not rate-limited."""
|
||||||
|
with patch("scripts.task_runner.submit_task", return_value=(1, True)):
|
||||||
|
resp = client.post("/api/jobs/1/cover_letter/generate")
|
||||||
|
# 200 = task queued; 403 = demo/cloud guard; 404/422 = DB/payload issue
|
||||||
|
# Any non-5xx, non-429 response means the limiter did NOT block the request
|
||||||
|
assert resp.status_code in (200, 403, 404, 422)
|
||||||
|
|
||||||
|
def test_429_response_body_has_error_key(self, client):
|
||||||
|
"""429 responses from rate-limited endpoints include the error key."""
|
||||||
|
exc = _make_rate_limit_exceeded("20/hour")
|
||||||
|
with _patch_limiter_to_raise(exc):
|
||||||
|
resp = client.post("/api/jobs/1/cover_letter/generate")
|
||||||
|
assert resp.status_code == 429
|
||||||
|
body = resp.json()
|
||||||
|
assert body.get("error") == "rate_limit_exceeded"
|
||||||
|
|
||||||
|
def test_429_response_has_retry_after_header(self, client):
|
||||||
|
"""429 responses include a Retry-After header."""
|
||||||
|
exc = _make_rate_limit_exceeded("20/hour")
|
||||||
|
with _patch_limiter_to_raise(exc):
|
||||||
|
resp = client.post("/api/jobs/1/cover_letter/generate")
|
||||||
|
assert resp.status_code == 429
|
||||||
|
assert "retry-after" in resp.headers or "Retry-After" in resp.headers
|
||||||
Loading…
Reference in a new issue