feat(wizard): backend AI interview endpoints + BYOK tier flag
This commit is contained in:
parent
3048d8e2f4
commit
6327a4cdd9
3 changed files with 441 additions and 0 deletions
|
|
@ -41,6 +41,7 @@ FEATURES: dict[str, str] = {
|
|||
"llm_voice_guidelines": "premium",
|
||||
"llm_job_titles": "paid",
|
||||
"llm_mission_notes": "paid",
|
||||
"llm_ai_wizard": "paid",
|
||||
|
||||
# Orchestration — stays gated (background data pipeline, not just an LLM call)
|
||||
"llm_keywords_blocklist": "paid",
|
||||
|
|
@ -79,6 +80,7 @@ BYOK_UNLOCKABLE: frozenset[str] = frozenset({
|
|||
"llm_voice_guidelines",
|
||||
"llm_job_titles",
|
||||
"llm_mission_notes",
|
||||
"llm_ai_wizard",
|
||||
"company_research",
|
||||
"interview_prep",
|
||||
"survey_assistant",
|
||||
|
|
|
|||
105
dev-api.py
105
dev-api.py
|
|
@ -2694,6 +2694,9 @@ def get_app_config():
|
|||
except Exception:
|
||||
wizard_complete = False
|
||||
|
||||
from app.wizard.tiers import has_configured_llm
|
||||
byok_unlocked = has_configured_llm()
|
||||
|
||||
return {
|
||||
"isCloud": os.environ.get("CLOUD_MODE", "").lower() in ("1", "true"),
|
||||
"isDemo": os.environ.get("DEMO_MODE", "").lower() in ("1", "true", "yes"),
|
||||
|
|
@ -2702,6 +2705,7 @@ def get_app_config():
|
|||
"contractedClient": os.environ.get("CONTRACTED_CLIENT", "").lower() in ("1", "true"),
|
||||
"inferenceProfile": profile if profile in valid_profiles else "cpu",
|
||||
"wizardComplete": wizard_complete,
|
||||
"byokUnlocked": byok_unlocked,
|
||||
}
|
||||
|
||||
|
||||
|
|
@ -4511,6 +4515,107 @@ def wizard_complete():
|
|||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
# ── AI Interview Wizard (BSL 1.1) ─────────────────────────────────────────────
|
||||
|
||||
_AI_WIZARD_SYSTEM_PROMPT = """You are a friendly, patient assistant helping someone set up their job search profile. Your goal is to gather the following information through natural conversation:
|
||||
|
||||
- name (string): their full name
|
||||
- email (string): their preferred contact email
|
||||
- career_summary (string): 1-2 sentence background summary
|
||||
- candidate_voice (string): their preferred writing voice/tone for cover letters
|
||||
- mission_preferences (list of strings): industries or causes they care about
|
||||
- candidate_accessibility_focus (bool): whether to include accessibility culture in company research
|
||||
- candidate_lgbtq_focus (bool): whether to include LGBTQIA+ inclusion signals in company research
|
||||
- linkedin (string, optional): their LinkedIn URL
|
||||
|
||||
Rules:
|
||||
1. Ask one or two questions at a time — never overwhelm
|
||||
2. Always remind them they can skip any question
|
||||
3. For candidate_voice, offer these options if they struggle: "professional and direct", "warm and conversational", "concise and clear", "enthusiastic and personable"
|
||||
4. For candidate_accessibility_focus and candidate_lgbtq_focus, use plain language: "Would you like me to look into whether companies actively support employees with disabilities or neurodivergent needs?" and "Would you like me to check whether companies have strong LGBTQIA+ inclusion policies?"
|
||||
5. When you have gathered enough information or the user says they are done, set complete to true
|
||||
|
||||
You must ALWAYS respond with valid JSON in this exact format:
|
||||
{"reply": "your conversational message here", "extracted_fields": {"name": "...", ...}, "complete": false}
|
||||
|
||||
Only include fields in extracted_fields that you are confident about from the conversation. Do not include fields the user hasn't mentioned. Infer complete=true when all required fields (name, email, career_summary) are gathered or when user explicitly says done."""
|
||||
|
||||
|
||||
class WizardInterviewRequest(BaseModel):
|
||||
history: list[dict] # [{"role": "user"|"assistant", "content": "..."}]
|
||||
profile_so_far: dict = {}
|
||||
|
||||
|
||||
class WizardFinalizeRequest(BaseModel):
|
||||
profile: dict
|
||||
|
||||
|
||||
_WIZARD_ALLOWED_FIELDS: frozenset[str] = frozenset({
|
||||
"name",
|
||||
"email",
|
||||
"career_summary",
|
||||
"candidate_voice",
|
||||
"mission_preferences",
|
||||
"candidate_accessibility_focus",
|
||||
"candidate_lgbtq_focus",
|
||||
"linkedin",
|
||||
})
|
||||
|
||||
|
||||
@app.post("/api/wizard/ai/interview")
|
||||
def wizard_ai_interview(request: WizardInterviewRequest):
|
||||
"""Conduct one turn of the AI-guided profile interview. Tier-gated (BYOK-unlockable)."""
|
||||
from app.wizard.tiers import can_use, has_configured_llm
|
||||
|
||||
tier = _get_effective_tier()
|
||||
if not can_use(tier, "llm_ai_wizard", has_byok=has_configured_llm()):
|
||||
raise HTTPException(402, detail={"error": "tier_required"})
|
||||
|
||||
# Build conversation prompt from history
|
||||
conversation_lines = []
|
||||
for msg in request.history:
|
||||
role = msg.get("role", "user")
|
||||
content = msg.get("content", "")
|
||||
if role == "user":
|
||||
conversation_lines.append(f"User: {content}")
|
||||
else:
|
||||
conversation_lines.append(f"Assistant: {content}")
|
||||
|
||||
prompt = "\n".join(conversation_lines) if conversation_lines else "User: (starting conversation)"
|
||||
|
||||
try:
|
||||
from scripts.llm_router import LLMRouter
|
||||
response_text = LLMRouter().complete(prompt, system=_AI_WIZARD_SYSTEM_PROMPT)
|
||||
except Exception as exc:
|
||||
raise HTTPException(500, detail={"error": "llm_error", "message": str(exc)})
|
||||
|
||||
try:
|
||||
parsed = json.loads(response_text)
|
||||
return {
|
||||
"reply": parsed.get("reply", ""),
|
||||
"extracted_fields": parsed.get("extracted_fields", {}),
|
||||
"complete": bool(parsed.get("complete", False)),
|
||||
}
|
||||
except (json.JSONDecodeError, AttributeError):
|
||||
return {"reply": response_text, "extracted_fields": {}, "complete": False}
|
||||
|
||||
|
||||
@app.post("/api/wizard/ai/finalize")
|
||||
def wizard_ai_finalize(request: WizardFinalizeRequest):
|
||||
"""Merge AI-collected wizard fields into user.yaml. Only allowed fields are written."""
|
||||
yaml_path = _user_yaml_path()
|
||||
current = load_user_profile(yaml_path)
|
||||
|
||||
merged_keys = []
|
||||
for key, value in request.profile.items():
|
||||
if key in _WIZARD_ALLOWED_FIELDS:
|
||||
current[key] = value
|
||||
merged_keys.append(key)
|
||||
|
||||
save_user_profile(yaml_path, current)
|
||||
return {"saved": True, "fields": merged_keys}
|
||||
|
||||
|
||||
# ── Messaging models ──────────────────────────────────────────────────────────
|
||||
|
||||
class MessageCreateBody(BaseModel):
|
||||
|
|
|
|||
334
tests/test_wizard_ai.py
Normal file
334
tests/test_wizard_ai.py
Normal file
|
|
@ -0,0 +1,334 @@
|
|||
"""Tests for AI interview wizard endpoints (POST /api/wizard/ai/*)."""
|
||||
import json
|
||||
import sys
|
||||
import yaml
|
||||
import pytest
|
||||
from pathlib import Path
|
||||
from unittest.mock import patch, MagicMock
|
||||
|
||||
# ── Path bootstrap ────────────────────────────────────────────────────────────
|
||||
_REPO = Path(__file__).parent.parent
|
||||
if str(_REPO) not in sys.path:
|
||||
sys.path.insert(0, str(_REPO))
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def client():
|
||||
from dev_api import app
|
||||
from fastapi.testclient import TestClient
|
||||
return TestClient(app)
|
||||
|
||||
|
||||
# ── Helpers ───────────────────────────────────────────────────────────────────
|
||||
|
||||
def _write_user_yaml(path: Path, data: dict | None = None) -> None:
|
||||
path.parent.mkdir(parents=True, exist_ok=True)
|
||||
payload = data if data is not None else {}
|
||||
path.write_text(yaml.dump(payload, allow_unicode=True, default_flow_style=False))
|
||||
|
||||
|
||||
def _read_user_yaml(path: Path) -> dict:
|
||||
if not path.exists():
|
||||
return {}
|
||||
return yaml.safe_load(path.read_text()) or {}
|
||||
|
||||
|
||||
# ── GET /api/config/app — byokUnlocked field ──────────────────────────────────
|
||||
|
||||
class TestAppConfigByokField:
|
||||
def test_byok_unlocked_false_when_no_llm_configured(self, client, tmp_path):
|
||||
yaml_path = tmp_path / "config" / "user.yaml"
|
||||
_write_user_yaml(yaml_path, {"wizard_complete": True})
|
||||
with patch("dev_api._user_yaml_path", return_value=str(yaml_path)):
|
||||
with patch("app.wizard.tiers.has_configured_llm", return_value=False):
|
||||
r = client.get("/api/config/app")
|
||||
assert r.status_code == 200
|
||||
assert r.json()["byokUnlocked"] is False
|
||||
|
||||
def test_byok_unlocked_true_when_llm_configured(self, client, tmp_path):
|
||||
yaml_path = tmp_path / "config" / "user.yaml"
|
||||
_write_user_yaml(yaml_path, {"wizard_complete": True})
|
||||
with patch("dev_api._user_yaml_path", return_value=str(yaml_path)):
|
||||
with patch("app.wizard.tiers.has_configured_llm", return_value=True):
|
||||
r = client.get("/api/config/app")
|
||||
assert r.status_code == 200
|
||||
assert r.json()["byokUnlocked"] is True
|
||||
|
||||
|
||||
# ── POST /api/wizard/ai/interview — tier gate ─────────────────────────────────
|
||||
|
||||
class TestWizardAIInterviewTierGate:
|
||||
def test_returns_402_when_tier_blocked(self, client):
|
||||
"""Free tier with no BYOK: expect 402."""
|
||||
with patch("dev_api._get_effective_tier", return_value="free"):
|
||||
with patch("app.wizard.tiers.has_configured_llm", return_value=False):
|
||||
r = client.post(
|
||||
"/api/wizard/ai/interview",
|
||||
json={"history": [{"role": "user", "content": "Hello"}]},
|
||||
)
|
||||
assert r.status_code == 402
|
||||
assert r.json()["detail"]["error"] == "tier_required"
|
||||
|
||||
def test_returns_402_for_free_tier_without_byok(self, client):
|
||||
"""Explicit check that free tier without LLM configured is gated."""
|
||||
with patch("dev_api._get_effective_tier", return_value="free"):
|
||||
with patch("app.wizard.tiers.has_configured_llm", return_value=False):
|
||||
r = client.post(
|
||||
"/api/wizard/ai/interview",
|
||||
json={"history": [], "profile_so_far": {}},
|
||||
)
|
||||
assert r.status_code == 402
|
||||
|
||||
def test_free_tier_with_byok_is_allowed(self, client):
|
||||
"""Free tier with BYOK configured: tier gate passes (mocked LLM response)."""
|
||||
llm_reply = json.dumps({
|
||||
"reply": "Hello! What's your name?",
|
||||
"extracted_fields": {},
|
||||
"complete": False,
|
||||
})
|
||||
with patch("dev_api._get_effective_tier", return_value="free"):
|
||||
with patch("app.wizard.tiers.has_configured_llm", return_value=True):
|
||||
with patch("scripts.llm_router.LLMRouter") as mock_cls:
|
||||
mock_cls.return_value.complete.return_value = llm_reply
|
||||
r = client.post(
|
||||
"/api/wizard/ai/interview",
|
||||
json={"history": [], "profile_so_far": {}},
|
||||
)
|
||||
assert r.status_code == 200
|
||||
|
||||
|
||||
# ── POST /api/wizard/ai/interview — LLM mocked responses ─────────────────────
|
||||
|
||||
class TestWizardAIInterviewLLM:
|
||||
def _paid_byok_patches(self):
|
||||
"""Context managers for paid tier + BYOK."""
|
||||
return (
|
||||
patch("dev_api._get_effective_tier", return_value="paid"),
|
||||
patch("app.wizard.tiers.has_configured_llm", return_value=True),
|
||||
)
|
||||
|
||||
def test_returns_valid_reply_structure(self, client):
|
||||
llm_reply = json.dumps({
|
||||
"reply": "Great to meet you! What's your preferred contact email?",
|
||||
"extracted_fields": {"name": "Alex Rivera"},
|
||||
"complete": False,
|
||||
})
|
||||
with patch("dev_api._get_effective_tier", return_value="paid"):
|
||||
with patch("app.wizard.tiers.has_configured_llm", return_value=True):
|
||||
with patch("scripts.llm_router.LLMRouter") as mock_cls:
|
||||
mock_cls.return_value.complete.return_value = llm_reply
|
||||
r = client.post(
|
||||
"/api/wizard/ai/interview",
|
||||
json={
|
||||
"history": [
|
||||
{"role": "user", "content": "My name is Alex Rivera"},
|
||||
],
|
||||
},
|
||||
)
|
||||
assert r.status_code == 200
|
||||
body = r.json()
|
||||
assert body["reply"] == "Great to meet you! What's your preferred contact email?"
|
||||
assert body["extracted_fields"] == {"name": "Alex Rivera"}
|
||||
assert body["complete"] is False
|
||||
|
||||
def test_returns_complete_true_when_llm_signals_done(self, client):
|
||||
llm_reply = json.dumps({
|
||||
"reply": "You're all set! Your profile is complete.",
|
||||
"extracted_fields": {
|
||||
"name": "Alex",
|
||||
"email": "alex@example.com",
|
||||
"career_summary": "Backend engineer with 5 years experience.",
|
||||
},
|
||||
"complete": True,
|
||||
})
|
||||
with patch("dev_api._get_effective_tier", return_value="paid"):
|
||||
with patch("app.wizard.tiers.has_configured_llm", return_value=True):
|
||||
with patch("scripts.llm_router.LLMRouter") as mock_cls:
|
||||
mock_cls.return_value.complete.return_value = llm_reply
|
||||
r = client.post(
|
||||
"/api/wizard/ai/interview",
|
||||
json={
|
||||
"history": [
|
||||
{"role": "user", "content": "I'm done"},
|
||||
],
|
||||
},
|
||||
)
|
||||
assert r.status_code == 200
|
||||
body = r.json()
|
||||
assert body["complete"] is True
|
||||
assert "name" in body["extracted_fields"]
|
||||
|
||||
def test_fallback_when_llm_returns_non_json(self, client):
|
||||
"""If LLM returns non-JSON, the endpoint still returns 200 with raw reply."""
|
||||
with patch("dev_api._get_effective_tier", return_value="paid"):
|
||||
with patch("app.wizard.tiers.has_configured_llm", return_value=True):
|
||||
with patch("scripts.llm_router.LLMRouter") as mock_cls:
|
||||
mock_cls.return_value.complete.return_value = "Hello, what is your name?"
|
||||
r = client.post(
|
||||
"/api/wizard/ai/interview",
|
||||
json={"history": []},
|
||||
)
|
||||
assert r.status_code == 200
|
||||
body = r.json()
|
||||
assert body["reply"] == "Hello, what is your name?"
|
||||
assert body["extracted_fields"] == {}
|
||||
assert body["complete"] is False
|
||||
|
||||
def test_history_passed_to_llm(self, client):
|
||||
"""Verify the history turns are included in the prompt sent to the LLM."""
|
||||
llm_reply = json.dumps({"reply": "OK", "extracted_fields": {}, "complete": False})
|
||||
captured_calls = []
|
||||
with patch("dev_api._get_effective_tier", return_value="paid"):
|
||||
with patch("app.wizard.tiers.has_configured_llm", return_value=True):
|
||||
with patch("scripts.llm_router.LLMRouter") as mock_cls:
|
||||
mock_cls.return_value.complete.side_effect = (
|
||||
lambda prompt, system=None: (captured_calls.append(prompt) or llm_reply)
|
||||
)
|
||||
client.post(
|
||||
"/api/wizard/ai/interview",
|
||||
json={
|
||||
"history": [
|
||||
{"role": "user", "content": "I am Alex"},
|
||||
{"role": "assistant", "content": "Nice to meet you Alex!"},
|
||||
{"role": "user", "content": "My email is alex@test.com"},
|
||||
],
|
||||
},
|
||||
)
|
||||
assert len(captured_calls) == 1
|
||||
prompt = captured_calls[0]
|
||||
assert "I am Alex" in prompt
|
||||
assert "alex@test.com" in prompt
|
||||
|
||||
def test_llm_error_returns_500(self, client):
|
||||
"""If LLM raises, the endpoint returns 500."""
|
||||
with patch("dev_api._get_effective_tier", return_value="paid"):
|
||||
with patch("app.wizard.tiers.has_configured_llm", return_value=True):
|
||||
with patch("scripts.llm_router.LLMRouter") as mock_cls:
|
||||
mock_cls.return_value.complete.side_effect = RuntimeError("no backends")
|
||||
r = client.post(
|
||||
"/api/wizard/ai/interview",
|
||||
json={"history": [{"role": "user", "content": "hi"}]},
|
||||
)
|
||||
assert r.status_code == 500
|
||||
|
||||
|
||||
# ── POST /api/wizard/ai/finalize ──────────────────────────────────────────────
|
||||
|
||||
class TestWizardAIFinalize:
|
||||
def test_merges_allowed_fields_into_user_yaml(self, client, tmp_path):
|
||||
yaml_path = tmp_path / "config" / "user.yaml"
|
||||
_write_user_yaml(yaml_path, {"tier": "paid", "wizard_complete": True})
|
||||
with patch("dev_api._user_yaml_path", return_value=str(yaml_path)):
|
||||
r = client.post(
|
||||
"/api/wizard/ai/finalize",
|
||||
json={
|
||||
"profile": {
|
||||
"name": "Jordan Lee",
|
||||
"email": "jordan@example.com",
|
||||
"career_summary": "Full-stack developer with 8 years experience.",
|
||||
"candidate_voice": "warm and conversational",
|
||||
}
|
||||
},
|
||||
)
|
||||
assert r.status_code == 200
|
||||
body = r.json()
|
||||
assert body["saved"] is True
|
||||
assert set(body["fields"]) == {"name", "email", "career_summary", "candidate_voice"}
|
||||
|
||||
saved = _read_user_yaml(yaml_path)
|
||||
assert saved["name"] == "Jordan Lee"
|
||||
assert saved["email"] == "jordan@example.com"
|
||||
assert saved["career_summary"] == "Full-stack developer with 8 years experience."
|
||||
assert saved["candidate_voice"] == "warm and conversational"
|
||||
|
||||
def test_does_not_clobber_existing_non_wizard_keys(self, client, tmp_path):
|
||||
"""Keys like tier, wizard_complete must not be overwritten by finalize."""
|
||||
yaml_path = tmp_path / "config" / "user.yaml"
|
||||
_write_user_yaml(yaml_path, {
|
||||
"tier": "premium",
|
||||
"wizard_complete": True,
|
||||
"inference_profile": "single-gpu",
|
||||
})
|
||||
with patch("dev_api._user_yaml_path", return_value=str(yaml_path)):
|
||||
r = client.post(
|
||||
"/api/wizard/ai/finalize",
|
||||
json={
|
||||
"profile": {
|
||||
"name": "Sam Park",
|
||||
"tier": "free", # attempt to downgrade — must be blocked
|
||||
"wizard_complete": False, # attempt to reset — must be blocked
|
||||
}
|
||||
},
|
||||
)
|
||||
assert r.status_code == 200
|
||||
saved = _read_user_yaml(yaml_path)
|
||||
# Non-wizard keys are preserved
|
||||
assert saved["tier"] == "premium"
|
||||
assert saved["wizard_complete"] is True
|
||||
assert saved["inference_profile"] == "single-gpu"
|
||||
# Allowed wizard key is written
|
||||
assert saved["name"] == "Sam Park"
|
||||
|
||||
def test_unknown_keys_are_silently_ignored(self, client, tmp_path):
|
||||
yaml_path = tmp_path / "config" / "user.yaml"
|
||||
_write_user_yaml(yaml_path, {})
|
||||
with patch("dev_api._user_yaml_path", return_value=str(yaml_path)):
|
||||
r = client.post(
|
||||
"/api/wizard/ai/finalize",
|
||||
json={
|
||||
"profile": {
|
||||
"email": "test@example.com",
|
||||
"injected_field": "should be ignored",
|
||||
"admin": True,
|
||||
}
|
||||
},
|
||||
)
|
||||
assert r.status_code == 200
|
||||
saved = _read_user_yaml(yaml_path)
|
||||
assert saved["email"] == "test@example.com"
|
||||
assert "injected_field" not in saved
|
||||
assert "admin" not in saved
|
||||
|
||||
def test_all_allowed_fields_are_written(self, client, tmp_path):
|
||||
"""All allowed wizard fields are accepted when provided."""
|
||||
yaml_path = tmp_path / "config" / "user.yaml"
|
||||
_write_user_yaml(yaml_path, {})
|
||||
full_profile = {
|
||||
"name": "Casey Morgan",
|
||||
"email": "casey@example.com",
|
||||
"career_summary": "Designer turned product manager.",
|
||||
"candidate_voice": "professional and direct",
|
||||
"mission_preferences": ["education", "social_impact"],
|
||||
"candidate_accessibility_focus": True,
|
||||
"candidate_lgbtq_focus": True,
|
||||
"linkedin": "https://linkedin.com/in/casey",
|
||||
}
|
||||
with patch("dev_api._user_yaml_path", return_value=str(yaml_path)):
|
||||
r = client.post("/api/wizard/ai/finalize", json={"profile": full_profile})
|
||||
assert r.status_code == 200
|
||||
saved = _read_user_yaml(yaml_path)
|
||||
for key, value in full_profile.items():
|
||||
assert saved[key] == value, f"Expected {key}={value!r}, got {saved.get(key)!r}"
|
||||
|
||||
def test_empty_profile_returns_saved_true(self, client, tmp_path):
|
||||
yaml_path = tmp_path / "config" / "user.yaml"
|
||||
_write_user_yaml(yaml_path, {"name": "Existing"})
|
||||
with patch("dev_api._user_yaml_path", return_value=str(yaml_path)):
|
||||
r = client.post("/api/wizard/ai/finalize", json={"profile": {}})
|
||||
assert r.status_code == 200
|
||||
assert r.json()["saved"] is True
|
||||
assert r.json()["fields"] == []
|
||||
# Existing data is preserved
|
||||
assert _read_user_yaml(yaml_path)["name"] == "Existing"
|
||||
|
||||
def test_mission_preferences_list_written_correctly(self, client, tmp_path):
|
||||
yaml_path = tmp_path / "config" / "user.yaml"
|
||||
_write_user_yaml(yaml_path, {})
|
||||
with patch("dev_api._user_yaml_path", return_value=str(yaml_path)):
|
||||
r = client.post(
|
||||
"/api/wizard/ai/finalize",
|
||||
json={"profile": {"mission_preferences": ["music", "animal_welfare"]}},
|
||||
)
|
||||
assert r.status_code == 200
|
||||
saved = _read_user_yaml(yaml_path)
|
||||
assert saved["mission_preferences"] == ["music", "animal_welfare"]
|
||||
Loading…
Reference in a new issue