merge: feat/122-rate-limiting into freeze/rc-1

Per-user LLM rate limiting via slowapi: cloud-aware key function,
4 endpoint limits, demo bypass, SSRF and path traversal already in
fix/ci-ruff-lint merge.

Closes: #122
This commit is contained in:
pyr0ball 2026-06-14 12:41:18 -07:00
commit 88b6943527
5 changed files with 347 additions and 5 deletions

View file

@ -83,3 +83,10 @@ MNEMO_LLM_PROVIDER=ollama # ollama | openai | anthropic | custom
MNEMO_LLM_BASE_URL=http://ollama:11434/v1 # override for external LLM MNEMO_LLM_BASE_URL=http://ollama:11434/v1 # override for external LLM
MNEMO_LLM_API_KEY=ollama # "ollama" is a dummy value for local Ollama MNEMO_LLM_API_KEY=ollama # "ollama" is a dummy value for local Ollama
MNEMO_LLM_MODEL=llama3.2:3b # must be pulled in the ollama container MNEMO_LLM_MODEL=llama3.2:3b # must be pulled in the ollama container
# ── Rate limiting (LLM generation endpoints) ─────────────────────────────────
LLM_RATE_COVER_LETTER=20/hour
LLM_RATE_RESEARCH=10/hour
LLM_RATE_QA_SUGGEST=60/hour
LLM_RATE_SURVEY=30/hour
LLM_RATE_WIZARD=60/hour

View file

@ -41,6 +41,8 @@ from circuitforge_core.api import make_feedback_router as _make_feedback_router
from circuitforge_core.config.settings import load_env as _load_env # noqa: E402 from circuitforge_core.config.settings import load_env as _load_env # noqa: E402
from circuitforge_core.sync import SyncConfig, make_sync_router # noqa: E402 from circuitforge_core.sync import SyncConfig, make_sync_router # noqa: E402
from scripts.credential_store import get_credential, set_credential # noqa: E402 from scripts.credential_store import get_credential, set_credential # noqa: E402
from scripts.rate_limit import limiter, rate_limit_exceeded_handler # noqa: E402
from slowapi.errors import RateLimitExceeded # noqa: E402
DB_PATH = os.environ.get("STAGING_DB", "/devl/job-seeker/staging.db") DB_PATH = os.environ.get("STAGING_DB", "/devl/job-seeker/staging.db")
@ -73,6 +75,13 @@ def _is_ssrf_host(host: str) -> bool:
return True # fail closed on resolution errors return True # fail closed on resolution errors
IS_DEMO: bool = os.environ.get("DEMO_MODE", "").lower() in ("1", "true", "yes") IS_DEMO: bool = os.environ.get("DEMO_MODE", "").lower() in ("1", "true", "yes")
# ── Rate limiting (LLM generation endpoints) ──────────────────────────────────
_RL_COVER_LETTER = os.environ.get("LLM_RATE_COVER_LETTER", "20/hour")
_RL_RESEARCH = os.environ.get("LLM_RATE_RESEARCH", "10/hour")
_RL_QA_SUGGEST = os.environ.get("LLM_RATE_QA_SUGGEST", "60/hour")
_RL_SURVEY = os.environ.get("LLM_RATE_SURVEY", "30/hour")
_RL_WIZARD = os.environ.get("LLM_RATE_WIZARD", "60/hour") # TODO(#122): wire to wizard/ai/interview after feat/77 merges
# Resolve GPU inference server URL. # Resolve GPU inference server URL.
# Priority: GPU_SERVER_URL → CF_ORCH_URL (backward compat) → cloud default when licensed. # Priority: GPU_SERVER_URL → CF_ORCH_URL (backward compat) → cloud default when licensed.
# Result is written back to CF_ORCH_URL so all downstream callers need no changes. # Result is written back to CF_ORCH_URL so all downstream callers need no changes.
@ -135,6 +144,8 @@ async def lifespan(app: FastAPI):
app = FastAPI(title="Peregrine Dev API", lifespan=lifespan) app = FastAPI(title="Peregrine Dev API", lifespan=lifespan)
app.state.limiter = limiter
app.add_exception_handler(RateLimitExceeded, rate_limit_exceeded_handler)
app.add_middleware( app.add_middleware(
CORSMiddleware, CORSMiddleware,
@ -579,7 +590,8 @@ def save_cover_letter(job_id: int, body: CoverLetterBody):
# ── POST /api/jobs/:id/cover_letter/generate ───────────────────────────────── # ── POST /api/jobs/:id/cover_letter/generate ─────────────────────────────────
@app.post("/api/jobs/{job_id}/cover_letter/generate") @app.post("/api/jobs/{job_id}/cover_letter/generate")
def generate_cover_letter(job_id: int): @limiter.limit(_RL_COVER_LETTER)
def generate_cover_letter(job_id: int, request: Request):
_demo_guard() _demo_guard()
try: try:
from scripts.task_runner import submit_task from scripts.task_runner import submit_task
@ -632,7 +644,9 @@ def get_research_brief(job_id: int):
@app.post("/api/jobs/{job_id}/research/generate") @app.post("/api/jobs/{job_id}/research/generate")
def generate_research(job_id: int): @limiter.limit(_RL_RESEARCH)
def generate_research(job_id: int, request: Request):
_demo_guard()
try: try:
from scripts.task_runner import submit_task from scripts.task_runner import submit_task
task_id, is_new = submit_task(db_path=Path(_request_db.get() or DB_PATH), task_type="company_research", job_id=job_id) task_id, is_new = submit_task(db_path=Path(_request_db.get() or DB_PATH), task_type="company_research", job_id=job_id)
@ -1587,7 +1601,8 @@ class SurveyAnalyzeBody(BaseModel):
@app.post("/api/jobs/{job_id}/survey/analyze") @app.post("/api/jobs/{job_id}/survey/analyze")
def survey_analyze(job_id: int, body: SurveyAnalyzeBody): @limiter.limit(_RL_SURVEY)
def survey_analyze(job_id: int, body: SurveyAnalyzeBody, request: Request):
if body.mode not in ("quick", "detailed"): if body.mode not in ("quick", "detailed"):
raise HTTPException(400, f"Invalid mode: {body.mode!r}") raise HTTPException(400, f"Invalid mode: {body.mode!r}")
import json as _json import json as _json
@ -1802,8 +1817,10 @@ def save_qa(job_id: int, payload: QAPayload):
@app.post("/api/jobs/{job_id}/qa/suggest") @app.post("/api/jobs/{job_id}/qa/suggest")
def suggest_qa_answer(job_id: int, payload: QASuggestPayload): @limiter.limit(_RL_QA_SUGGEST)
def suggest_qa_answer(job_id: int, payload: QASuggestPayload, request: Request):
"""Synchronously generate an LLM answer for an application Q&A question.""" """Synchronously generate an LLM answer for an application Q&A question."""
_demo_guard()
db = _get_db() db = _get_db()
job_row = db.execute( job_row = db.execute(
"SELECT title, company, description FROM jobs WHERE id = ?", (job_id,) "SELECT title, company, description FROM jobs WHERE id = ?", (job_id,)
@ -1834,7 +1851,7 @@ def suggest_qa_answer(job_id: int, payload: QASuggestPayload):
parts.append(f"Summary: {resume_data['career_summary'][:400]}") parts.append(f"Summary: {resume_data['career_summary'][:400]}")
resume_context = "\n".join(parts) resume_context = "\n".join(parts)
except Exception: except Exception:
pass _log.warning("suggest_qa_answer: failed to load resume context", exc_info=True)
prompt = ( prompt = (
f"You are helping a job applicant answer an application question.\n\n" f"You are helping a job applicant answer an application question.\n\n"

View file

@ -63,6 +63,9 @@ dependencies:
# ── Auth / licensing ────────────────────────────────────────────────────── # ── Auth / licensing ──────────────────────────────────────────────────────
- PyJWT>=2.13.0 # 2.11 has sig bypass CVEs (PYSEC-2026-120/175-179); used for cloud session routing - PyJWT>=2.13.0 # 2.11 has sig bypass CVEs (PYSEC-2026-120/175-179); used for cloud session routing
# ── Rate limiting ─────────────────────────────────────────────────────────
- slowapi>=0.1.9 # per-user rate limiting on LLM endpoints
# ── Utilities ───────────────────────────────────────────────────────────── # ── Utilities ─────────────────────────────────────────────────────────────
- sqlalchemy - sqlalchemy
- tqdm - tqdm

32
scripts/rate_limit.py Normal file
View file

@ -0,0 +1,32 @@
"""Per-user rate limiting for Peregrine LLM generation endpoints."""
from pathlib import Path
from slowapi import Limiter
from slowapi.errors import RateLimitExceeded
from slowapi.util import get_remote_address
from starlette.requests import Request
from starlette.responses import JSONResponse
def _rate_key(request: Request) -> str:
"""Cloud mode: user_id from DB path. Local mode: client IP. Demo: unique key (no rate limit)."""
from dev_api import IS_DEMO, _CLOUD_MODE, _request_db # lazy import avoids circular
if IS_DEMO:
return f"demo-{id(request)}" # unique per request — effectively no rate limiting
db_path = _request_db.get()
if _CLOUD_MODE and db_path:
return Path(db_path).parts[-3] # user_id segment
return get_remote_address(request)
limiter = Limiter(key_func=_rate_key)
def rate_limit_exceeded_handler(request: Request, exc: RateLimitExceeded) -> JSONResponse:
"""Return 429 with Retry-After header."""
retry_after = getattr(exc, "retry_after", 60)
return JSONResponse(
status_code=429,
content={"error": "rate_limit_exceeded", "retry_after": retry_after},
headers={"Retry-After": str(retry_after)},
)

283
tests/test_rate_limiting.py Normal file
View file

@ -0,0 +1,283 @@
"""Tests for per-user rate limiting on LLM generation endpoints.
Covers:
- _rate_key() in demo mode returns unique per-request key (no rate limiting)
- _rate_key() in cloud mode returns user_id segment from DB path
- _rate_key() in local mode falls back to client IP address
- rate_limit_exceeded_handler() returns 429 with Retry-After header
- Integration: hitting rate limit on a decorated endpoint returns 429
"""
import json
import sqlite3
from pathlib import Path
from unittest.mock import MagicMock, patch
import pytest
from fastapi.testclient import TestClient
from limits import parse as _limits_parse
from slowapi.errors import RateLimitExceeded
from slowapi.wrappers import Limit as _LimitWrapper
from starlette.requests import Request
from scripts.rate_limit import _rate_key, rate_limit_exceeded_handler
# ── Helpers ───────────────────────────────────────────────────────────────────
def _make_request(client_ip: str = "1.2.3.4") -> MagicMock:
"""Return a minimal mock Request with a client IP."""
req = MagicMock(spec=Request)
req.client = MagicMock()
req.client.host = client_ip
req.headers = {}
req.scope = {"type": "http"}
return req
def _make_rate_limit_exceeded(spec: str = "20/hour") -> RateLimitExceeded:
"""Construct a valid RateLimitExceeded (slowapi 0.1.9+ requires a Limit wrapper)."""
limit_item = _limits_parse(spec)
wrapper = _LimitWrapper(
limit=limit_item,
key_func=lambda r: "test",
scope=None,
per_method=False,
methods=None,
error_message=None,
exempt_when=None,
cost=1,
override_defaults=False,
)
return RateLimitExceeded(wrapper)
# ── Fixtures ──────────────────────────────────────────────────────────────────
@pytest.fixture()
def tmp_db(tmp_path):
"""Create a minimal staging.db in tmp_path and return its string path."""
db_path = tmp_path / "staging.db"
con = sqlite3.connect(str(db_path))
con.executescript("""
CREATE TABLE IF NOT EXISTS jobs (
id INTEGER PRIMARY KEY,
title TEXT, company TEXT, url TEXT, location TEXT,
is_remote INTEGER DEFAULT 0, salary TEXT,
match_score REAL, keyword_gaps TEXT, status TEXT,
interview_date TEXT, rejection_stage TEXT,
applied_at TEXT, phone_screen_at TEXT, interviewing_at TEXT,
offer_at TEXT, hired_at TEXT, survey_at TEXT
);
CREATE TABLE IF NOT EXISTS background_tasks (
id INTEGER PRIMARY KEY,
task_type TEXT,
job_id INTEGER,
status TEXT DEFAULT 'queued',
stage TEXT,
error TEXT,
params TEXT,
finished_at TEXT
);
""")
con.close()
return str(db_path)
@pytest.fixture()
def client(tmp_db, monkeypatch):
"""TestClient wired to a fresh isolated DB."""
monkeypatch.setenv("STAGING_DB", tmp_db)
import dev_api
monkeypatch.setattr(dev_api, "DB_PATH", tmp_db)
monkeypatch.setattr(
dev_api,
"_request_db",
type("CV", (), {"get": lambda self: tmp_db, "set": lambda *a: None})(),
)
return TestClient(dev_api.app)
# ── _rate_key(): demo mode ────────────────────────────────────────────────────
class TestRateKeyDemoMode:
def test_returns_unique_key_per_request(self):
"""In demo mode each request gets a unique key so no limiting occurs."""
req1 = _make_request()
req2 = _make_request()
with patch("dev_api.IS_DEMO", True), patch("dev_api._CLOUD_MODE", False):
key1 = _rate_key(req1)
key2 = _rate_key(req2)
assert key1.startswith("demo-")
assert key2.startswith("demo-")
assert key1 != key2 # unique per request object
def test_key_does_not_use_client_ip(self):
"""Demo key must not equal the client IP."""
req = _make_request(client_ip="9.9.9.9")
with patch("dev_api.IS_DEMO", True), patch("dev_api._CLOUD_MODE", False):
key = _rate_key(req)
assert "9.9.9.9" not in key
# ── _rate_key(): cloud mode ───────────────────────────────────────────────────
class TestRateKeyCloudMode:
def test_returns_user_id_from_db_path(self, tmp_path):
"""Cloud mode extracts user_id (3rd-from-end path segment)."""
cloud_db = str(tmp_path / "abc-user-123" / "peregrine" / "staging.db")
req = _make_request()
with (
patch("dev_api.IS_DEMO", False),
patch("dev_api._CLOUD_MODE", True),
patch("dev_api._request_db") as mock_cv,
):
mock_cv.get.return_value = cloud_db
key = _rate_key(req)
assert key == "abc-user-123"
def test_falls_back_to_ip_when_db_path_is_none(self):
"""Cloud mode without a DB path (unauthenticated) falls back to IP."""
req = _make_request(client_ip="10.0.0.1")
with (
patch("dev_api.IS_DEMO", False),
patch("dev_api._CLOUD_MODE", True),
patch("dev_api._request_db") as mock_cv,
):
mock_cv.get.return_value = None
key = _rate_key(req)
assert key == "10.0.0.1"
# ── _rate_key(): local mode ───────────────────────────────────────────────────
class TestRateKeyLocalMode:
def test_returns_client_ip(self):
"""Local (non-cloud, non-demo) mode uses the remote client IP."""
req = _make_request(client_ip="192.168.1.50")
with patch("dev_api.IS_DEMO", False), patch("dev_api._CLOUD_MODE", False):
key = _rate_key(req)
assert key == "192.168.1.50"
def test_different_ips_produce_different_keys(self):
"""Two distinct client IPs produce distinct rate limit keys."""
req_a = _make_request(client_ip="10.0.0.1")
req_b = _make_request(client_ip="10.0.0.2")
with patch("dev_api.IS_DEMO", False), patch("dev_api._CLOUD_MODE", False):
key_a = _rate_key(req_a)
key_b = _rate_key(req_b)
assert key_a != key_b
# ── rate_limit_exceeded_handler() ─────────────────────────────────────────────
class TestRateLimitExceededHandler:
def test_returns_429_status(self):
"""Handler always returns HTTP 429."""
req = _make_request()
exc = _make_rate_limit_exceeded("20/hour")
response = rate_limit_exceeded_handler(req, exc)
assert response.status_code == 429
def test_body_has_error_field(self):
"""Response body includes error: rate_limit_exceeded."""
req = _make_request()
exc = _make_rate_limit_exceeded("20/hour")
response = rate_limit_exceeded_handler(req, exc)
body = json.loads(response.body)
assert body["error"] == "rate_limit_exceeded"
def test_body_has_retry_after_field(self):
"""Response body includes retry_after value."""
req = _make_request()
exc = _make_rate_limit_exceeded("20/hour")
response = rate_limit_exceeded_handler(req, exc)
body = json.loads(response.body)
assert "retry_after" in body
def test_retry_after_header_present(self):
"""Retry-After HTTP header is set on the response."""
req = _make_request()
exc = _make_rate_limit_exceeded("20/hour")
response = rate_limit_exceeded_handler(req, exc)
assert "Retry-After" in response.headers
def test_retry_after_header_matches_body(self):
"""Retry-After header value matches the retry_after field in the body."""
req = _make_request()
exc = _make_rate_limit_exceeded("20/hour")
response = rate_limit_exceeded_handler(req, exc)
body = json.loads(response.body)
assert response.headers["Retry-After"] == str(body["retry_after"])
# ── Integration: 429 on rate-limited endpoints ────────────────────────────────
def _patch_limiter_to_raise(exc: RateLimitExceeded):
"""Context manager: make the slowapi limiter fire for any request."""
return patch(
"slowapi.extension.Limiter._check_request_limit",
side_effect=exc,
)
class TestRateLimitIntegration:
"""Verify that when the limiter fires, the app returns 429 via the exception handler."""
def test_cover_letter_generate_returns_429_on_limit(self, client):
"""When the rate limiter triggers, the cover letter endpoint returns 429."""
exc = _make_rate_limit_exceeded("20/hour")
with _patch_limiter_to_raise(exc):
resp = client.post("/api/jobs/1/cover_letter/generate")
assert resp.status_code == 429
def test_research_generate_returns_429_on_limit(self, client):
"""When the rate limiter triggers, the research endpoint returns 429."""
exc = _make_rate_limit_exceeded("10/hour")
with _patch_limiter_to_raise(exc):
resp = client.post("/api/jobs/1/research/generate")
assert resp.status_code == 429
def test_qa_suggest_returns_429_on_limit(self, client):
"""When the rate limiter triggers, the QA suggest endpoint returns 429."""
exc = _make_rate_limit_exceeded("60/hour")
with _patch_limiter_to_raise(exc):
resp = client.post(
"/api/jobs/1/qa/suggest",
json={"question": "Why do you want this job?", "items": []},
)
assert resp.status_code == 429
def test_survey_analyze_returns_429_on_limit(self, client):
"""When the rate limiter triggers, the survey analyze endpoint returns 429."""
exc = _make_rate_limit_exceeded("30/hour")
with _patch_limiter_to_raise(exc):
resp = client.post(
"/api/jobs/1/survey/analyze",
json={"text": "Q: ...", "mode": "quick"},
)
assert resp.status_code == 429
def test_cover_letter_generate_succeeds_when_not_limited(self, client):
"""Cover letter generate endpoint works normally when not rate-limited."""
with patch("scripts.task_runner.submit_task", return_value=(1, True)):
resp = client.post("/api/jobs/1/cover_letter/generate")
# 200 = task queued; 403 = demo/cloud guard; 404/422 = DB/payload issue
# Any non-5xx, non-429 response means the limiter did NOT block the request
assert resp.status_code in (200, 403, 404, 422)
def test_429_response_body_has_error_key(self, client):
"""429 responses from rate-limited endpoints include the error key."""
exc = _make_rate_limit_exceeded("20/hour")
with _patch_limiter_to_raise(exc):
resp = client.post("/api/jobs/1/cover_letter/generate")
assert resp.status_code == 429
body = resp.json()
assert body.get("error") == "rate_limit_exceeded"
def test_429_response_has_retry_after_header(self, client):
"""429 responses include a Retry-After header."""
exc = _make_rate_limit_exceeded("20/hour")
with _patch_limiter_to_raise(exc):
resp = client.post("/api/jobs/1/cover_letter/generate")
assert resp.status_code == 429
assert "retry-after" in resp.headers or "Retry-After" in resp.headers