feat: wizard orchestrator — 7 steps, LLM generation polling, crash recovery
Replaces the old 5-step wizard with a 7-step orchestrator that uses the step modules built in Tasks 2-8. Steps 1-6 are mandatory (hardware, tier, identity, resume, inference, search); step 7 (integrations) is optional. Each Next click validates, writes wizard_step to user.yaml for crash recovery, and resumes at the correct step on page reload. LLM generation buttons submit wizard_generate tasks and poll via @st.fragment(run_every=3). Finish sets wizard_complete=True, removes wizard_step, and calls apply_service_urls. Adds tests/test_wizard_flow.py (7 tests) covering validate() chain, yaml persistence helpers, and wizard state inference.
This commit is contained in:
parent
c9ce3efa92
commit
dbe05e7c2d
2 changed files with 701 additions and 214 deletions
|
|
@ -1,30 +1,50 @@
|
||||||
"""
|
"""
|
||||||
First-run setup wizard — shown by app.py when config/user.yaml is absent.
|
First-run setup wizard orchestrator.
|
||||||
Five steps: hardware detection → identity → NDA companies → inference/keys → Notion.
|
Shown by app.py when user.yaml is absent OR wizard_complete is False.
|
||||||
Writes config/user.yaml (and optionally config/notion.yaml) on completion.
|
Seven steps: hardware → tier → identity → resume → inference → search → integrations (optional).
|
||||||
|
Steps 1-6 are mandatory; step 7 is optional and can be skipped.
|
||||||
|
Each step writes to user.yaml on "Next" for crash recovery.
|
||||||
"""
|
"""
|
||||||
import subprocess
|
from __future__ import annotations
|
||||||
|
import json
|
||||||
import sys
|
import sys
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
sys.path.insert(0, str(Path(__file__).parent.parent.parent))
|
sys.path.insert(0, str(Path(__file__).parent.parent.parent))
|
||||||
|
|
||||||
import streamlit as st
|
import streamlit as st
|
||||||
import yaml
|
import yaml
|
||||||
|
|
||||||
CONFIG_DIR = Path(__file__).parent.parent.parent / "config"
|
_ROOT = Path(__file__).parent.parent.parent
|
||||||
USER_CFG = CONFIG_DIR / "user.yaml"
|
CONFIG_DIR = _ROOT / "config"
|
||||||
NOTION_CFG = CONFIG_DIR / "notion.yaml"
|
USER_YAML = CONFIG_DIR / "user.yaml"
|
||||||
LLM_CFG = CONFIG_DIR / "llm.yaml"
|
STEPS = 6 # mandatory steps
|
||||||
|
STEP_LABELS = ["Hardware", "Tier", "Identity", "Resume", "Inference", "Search"]
|
||||||
|
|
||||||
PROFILES = ["remote", "cpu", "single-gpu", "dual-gpu"]
|
|
||||||
|
# ── Helpers ────────────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
def _load_yaml() -> dict:
|
||||||
|
if USER_YAML.exists():
|
||||||
|
return yaml.safe_load(USER_YAML.read_text()) or {}
|
||||||
|
return {}
|
||||||
|
|
||||||
|
|
||||||
|
def _save_yaml(updates: dict) -> None:
|
||||||
|
existing = _load_yaml()
|
||||||
|
existing.update(updates)
|
||||||
|
CONFIG_DIR.mkdir(parents=True, exist_ok=True)
|
||||||
|
USER_YAML.write_text(
|
||||||
|
yaml.dump(existing, default_flow_style=False, allow_unicode=True)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def _detect_gpus() -> list[str]:
|
def _detect_gpus() -> list[str]:
|
||||||
"""Return list of GPU names via nvidia-smi, or [] if none."""
|
import subprocess
|
||||||
try:
|
try:
|
||||||
out = subprocess.check_output(
|
out = subprocess.check_output(
|
||||||
["nvidia-smi", "--query-gpu=name", "--format=csv,noheader"],
|
["nvidia-smi", "--query-gpu=name", "--format=csv,noheader"],
|
||||||
text=True, timeout=5
|
text=True, timeout=5,
|
||||||
)
|
)
|
||||||
return [l.strip() for l in out.strip().splitlines() if l.strip()]
|
return [l.strip() for l in out.strip().splitlines() if l.strip()]
|
||||||
except Exception:
|
except Exception:
|
||||||
|
|
@ -39,265 +59,616 @@ def _suggest_profile(gpus: list[str]) -> str:
|
||||||
return "remote"
|
return "remote"
|
||||||
|
|
||||||
|
|
||||||
# ── Wizard state ───────────────────────────────────────────────────────────────
|
def _submit_wizard_task(section: str, input_data: dict) -> int:
|
||||||
|
"""Submit a wizard_generate background task. Returns task_id."""
|
||||||
|
from scripts.db import DEFAULT_DB
|
||||||
|
from scripts.task_runner import submit_task
|
||||||
|
params = json.dumps({"section": section, "input": input_data})
|
||||||
|
task_id, _ = submit_task(DEFAULT_DB, "wizard_generate", 0, params=params)
|
||||||
|
return task_id
|
||||||
|
|
||||||
|
|
||||||
|
def _poll_wizard_task(section: str) -> dict | None:
|
||||||
|
"""Return the most recent wizard_generate task row for a given section, or None."""
|
||||||
|
import sqlite3
|
||||||
|
from scripts.db import DEFAULT_DB
|
||||||
|
conn = sqlite3.connect(DEFAULT_DB)
|
||||||
|
conn.row_factory = sqlite3.Row
|
||||||
|
row = conn.execute(
|
||||||
|
"SELECT * FROM background_tasks "
|
||||||
|
"WHERE task_type='wizard_generate' AND params LIKE ? "
|
||||||
|
"ORDER BY id DESC LIMIT 1",
|
||||||
|
(f'%"section": "{section}"%',),
|
||||||
|
).fetchone()
|
||||||
|
conn.close()
|
||||||
|
return dict(row) if row else None
|
||||||
|
|
||||||
|
|
||||||
|
def _generation_widget(section: str, label: str, tier: str,
|
||||||
|
feature_key: str, input_data: dict) -> str | None:
|
||||||
|
"""Render a generation button + polling fragment.
|
||||||
|
|
||||||
|
Returns the generated result string if completed and not yet applied, else None.
|
||||||
|
Call this inside a step to add LLM generation support.
|
||||||
|
The caller decides whether to auto-populate a field with the result.
|
||||||
|
"""
|
||||||
|
from app.wizard.tiers import can_use, tier_label as tl
|
||||||
|
|
||||||
|
if not can_use(tier, feature_key):
|
||||||
|
st.caption(f"{tl(feature_key)} {label}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
col_btn, col_fb = st.columns([2, 5])
|
||||||
|
if col_btn.button(f"\u2728 {label}", key=f"gen_{section}"):
|
||||||
|
_submit_wizard_task(section, input_data)
|
||||||
|
st.rerun()
|
||||||
|
|
||||||
|
with st.expander("\u270f\ufe0f Request changes (optional)", expanded=False):
|
||||||
|
prev = st.session_state.get(f"_gen_result_{section}", "")
|
||||||
|
feedback = st.text_area(
|
||||||
|
"Describe what to change", key=f"_feedback_{section}",
|
||||||
|
placeholder="e.g. Make it shorter and emphasise leadership",
|
||||||
|
height=60,
|
||||||
|
)
|
||||||
|
if prev and st.button(f"\u21ba Regenerate with feedback", key=f"regen_{section}"):
|
||||||
|
_submit_wizard_task(section, {**input_data,
|
||||||
|
"previous_result": prev,
|
||||||
|
"feedback": feedback})
|
||||||
|
st.rerun()
|
||||||
|
|
||||||
|
# Polling fragment
|
||||||
|
result_key = f"_gen_result_{section}"
|
||||||
|
|
||||||
|
@st.fragment(run_every=3)
|
||||||
|
def _poll():
|
||||||
|
task = _poll_wizard_task(section)
|
||||||
|
if not task:
|
||||||
|
return
|
||||||
|
status = task.get("status")
|
||||||
|
if status in ("queued", "running"):
|
||||||
|
stage = task.get("stage") or "Queued"
|
||||||
|
st.info(f"\u23f3 {stage}\u2026")
|
||||||
|
elif status == "completed":
|
||||||
|
payload = json.loads(task.get("error") or "{}")
|
||||||
|
result = payload.get("result", "")
|
||||||
|
if result and result != st.session_state.get(result_key):
|
||||||
|
st.session_state[result_key] = result
|
||||||
|
st.rerun()
|
||||||
|
elif status == "failed":
|
||||||
|
st.warning(f"Generation failed: {task.get('error', 'unknown error')}")
|
||||||
|
|
||||||
|
_poll()
|
||||||
|
|
||||||
|
return st.session_state.get(result_key)
|
||||||
|
|
||||||
|
|
||||||
|
# ── Wizard state init ──────────────────────────────────────────────────────────
|
||||||
|
|
||||||
if "wizard_step" not in st.session_state:
|
if "wizard_step" not in st.session_state:
|
||||||
st.session_state.wizard_step = 1
|
saved = _load_yaml()
|
||||||
if "wizard_data" not in st.session_state:
|
last_completed = saved.get("wizard_step", 0)
|
||||||
st.session_state.wizard_data = {}
|
st.session_state.wizard_step = min(last_completed + 1, STEPS + 1) # resume at next step
|
||||||
|
|
||||||
step = st.session_state.wizard_step
|
step = st.session_state.wizard_step
|
||||||
data = st.session_state.wizard_data
|
saved_yaml = _load_yaml()
|
||||||
|
_tier = saved_yaml.get("dev_tier_override") or saved_yaml.get("tier", "free")
|
||||||
|
|
||||||
st.title("👋 Welcome to Peregrine")
|
st.title("\U0001f44b Welcome to Peregrine")
|
||||||
st.caption("Let's get you set up. This takes about 2 minutes.")
|
st.caption("Complete the setup to start your job search. Progress saves automatically.")
|
||||||
st.progress(step / 5, text=f"Step {step} of 5")
|
st.progress(
|
||||||
|
min((step - 1) / STEPS, 1.0),
|
||||||
|
text=f"Step {min(step, STEPS)} of {STEPS}" if step <= STEPS else "Almost done!",
|
||||||
|
)
|
||||||
st.divider()
|
st.divider()
|
||||||
|
|
||||||
# ── Step 1: Hardware detection ─────────────────────────────────────────────────
|
|
||||||
|
# ── Step 1: Hardware ───────────────────────────────────────────────────────────
|
||||||
if step == 1:
|
if step == 1:
|
||||||
st.subheader("Step 1 — Hardware Detection")
|
from app.wizard.step_hardware import validate, PROFILES
|
||||||
|
|
||||||
|
st.subheader("Step 1 \u2014 Hardware Detection")
|
||||||
gpus = _detect_gpus()
|
gpus = _detect_gpus()
|
||||||
suggested = _suggest_profile(gpus)
|
suggested = _suggest_profile(gpus)
|
||||||
|
|
||||||
if gpus:
|
if gpus:
|
||||||
st.success(f"Found {len(gpus)} GPU(s): {', '.join(gpus)}")
|
st.success(f"Detected {len(gpus)} GPU(s): {', '.join(gpus)}")
|
||||||
else:
|
else:
|
||||||
st.info("No NVIDIA GPUs detected. Remote or CPU mode recommended.")
|
st.info("No NVIDIA GPUs detected. 'Remote' or 'CPU' mode recommended.")
|
||||||
|
|
||||||
profile = st.selectbox(
|
profile = st.selectbox(
|
||||||
"Inference mode",
|
"Inference mode", PROFILES, index=PROFILES.index(suggested),
|
||||||
PROFILES,
|
help="Controls which Docker services start. Change later in Settings \u2192 Services.",
|
||||||
index=PROFILES.index(suggested),
|
|
||||||
help="This controls which Docker services start. You can change it later in Settings → My Profile.",
|
|
||||||
)
|
)
|
||||||
if profile in ("single-gpu", "dual-gpu") and not gpus:
|
if profile in ("single-gpu", "dual-gpu") and not gpus:
|
||||||
st.warning("No GPUs detected — GPU profiles require NVIDIA Container Toolkit. See the README for install instructions.")
|
st.warning(
|
||||||
|
"No GPUs detected \u2014 GPU profiles require the NVIDIA Container Toolkit. "
|
||||||
|
"See README for install instructions."
|
||||||
|
)
|
||||||
|
|
||||||
if st.button("Next →", type="primary"):
|
if st.button("Next \u2192", type="primary", key="hw_next"):
|
||||||
data["inference_profile"] = profile
|
errs = validate({"inference_profile": profile})
|
||||||
data["gpus_detected"] = gpus
|
if errs:
|
||||||
|
st.error("\n".join(errs))
|
||||||
|
else:
|
||||||
|
_save_yaml({"inference_profile": profile, "wizard_step": 1})
|
||||||
st.session_state.wizard_step = 2
|
st.session_state.wizard_step = 2
|
||||||
st.rerun()
|
st.rerun()
|
||||||
|
|
||||||
# ── Step 2: Identity ───────────────────────────────────────────────────────────
|
|
||||||
|
# ── Step 2: Tier ───────────────────────────────────────────────────────────────
|
||||||
elif step == 2:
|
elif step == 2:
|
||||||
st.subheader("Step 2 — Your Identity")
|
from app.wizard.step_tier import validate
|
||||||
st.caption("Used in cover letter PDFs, LLM prompts, and the app header.")
|
|
||||||
c1, c2 = st.columns(2)
|
st.subheader("Step 2 \u2014 Choose Your Plan")
|
||||||
name = c1.text_input("Full Name *", data.get("name", ""))
|
st.caption(
|
||||||
email = c1.text_input("Email *", data.get("email", ""))
|
"**Free** is fully functional for self-hosted local use. "
|
||||||
phone = c2.text_input("Phone", data.get("phone", ""))
|
"**Paid/Premium** unlock LLM-assisted features."
|
||||||
linkedin = c2.text_input("LinkedIn URL", data.get("linkedin", ""))
|
)
|
||||||
summary = st.text_area(
|
|
||||||
"Career Summary *",
|
tier_options = {
|
||||||
data.get("career_summary", ""),
|
"free": "\U0001f193 **Free** \u2014 Local discovery, apply workspace, interviews kanban",
|
||||||
height=120,
|
"paid": "\U0001f4bc **Paid** \u2014 + AI career summary, company research, email classifier, calendar sync",
|
||||||
placeholder="Experienced professional with X years in [field]. Specialise in [skills].",
|
"premium": "\u2b50 **Premium** \u2014 + Voice guidelines, model fine-tuning, multi-user",
|
||||||
help="This paragraph is injected into cover letter and research prompts as your professional context.",
|
}
|
||||||
|
from app.wizard.tiers import TIERS
|
||||||
|
current_tier = saved_yaml.get("tier", "free")
|
||||||
|
selected_tier = st.radio(
|
||||||
|
"Plan",
|
||||||
|
list(tier_options.keys()),
|
||||||
|
format_func=lambda x: tier_options[x],
|
||||||
|
index=TIERS.index(current_tier) if current_tier in TIERS else 0,
|
||||||
)
|
)
|
||||||
|
|
||||||
col_back, col_next = st.columns([1, 4])
|
col_back, col_next = st.columns([1, 4])
|
||||||
if col_back.button("← Back"):
|
if col_back.button("\u2190 Back", key="tier_back"):
|
||||||
st.session_state.wizard_step = 1
|
st.session_state.wizard_step = 1
|
||||||
st.rerun()
|
st.rerun()
|
||||||
if col_next.button("Next →", type="primary"):
|
if col_next.button("Next \u2192", type="primary", key="tier_next"):
|
||||||
if not name or not email or not summary:
|
errs = validate({"tier": selected_tier})
|
||||||
st.error("Name, email, and career summary are required.")
|
if errs:
|
||||||
|
st.error("\n".join(errs))
|
||||||
else:
|
else:
|
||||||
data.update({"name": name, "email": email, "phone": phone,
|
_save_yaml({"tier": selected_tier, "wizard_step": 2})
|
||||||
"linkedin": linkedin, "career_summary": summary})
|
|
||||||
st.session_state.wizard_step = 3
|
st.session_state.wizard_step = 3
|
||||||
st.rerun()
|
st.rerun()
|
||||||
|
|
||||||
# ── Step 3: NDA Companies ──────────────────────────────────────────────────────
|
|
||||||
elif step == 3:
|
|
||||||
st.subheader("Step 3 — Sensitive Employers (Optional)")
|
|
||||||
st.caption(
|
|
||||||
"Previous employers listed here will appear as 'previous employer (NDA)' in "
|
|
||||||
"research briefs and talking points. Skip if not applicable."
|
|
||||||
)
|
|
||||||
nda_list = list(data.get("nda_companies", []))
|
|
||||||
if nda_list:
|
|
||||||
cols = st.columns(min(len(nda_list), 5))
|
|
||||||
to_remove = None
|
|
||||||
for i, c in enumerate(nda_list):
|
|
||||||
if cols[i % 5].button(f"× {c}", key=f"rm_{c}"):
|
|
||||||
to_remove = c
|
|
||||||
if to_remove:
|
|
||||||
nda_list.remove(to_remove)
|
|
||||||
data["nda_companies"] = nda_list
|
|
||||||
st.rerun()
|
|
||||||
nc, nb = st.columns([4, 1])
|
|
||||||
new_c = nc.text_input("Add employer", key="new_nda_wiz",
|
|
||||||
label_visibility="collapsed", placeholder="Employer name…")
|
|
||||||
if nb.button("+ Add") and new_c.strip():
|
|
||||||
nda_list.append(new_c.strip())
|
|
||||||
data["nda_companies"] = nda_list
|
|
||||||
st.rerun()
|
|
||||||
|
|
||||||
col_back, col_skip, col_next = st.columns([1, 1, 3])
|
# ── Step 3: Identity ───────────────────────────────────────────────────────────
|
||||||
if col_back.button("← Back"):
|
elif step == 3:
|
||||||
|
from app.wizard.step_identity import validate
|
||||||
|
|
||||||
|
st.subheader("Step 3 \u2014 Your Identity")
|
||||||
|
st.caption("Used in cover letter PDFs, LLM prompts, and the app header.")
|
||||||
|
|
||||||
|
c1, c2 = st.columns(2)
|
||||||
|
name = c1.text_input("Full Name *", saved_yaml.get("name", ""))
|
||||||
|
email = c1.text_input("Email *", saved_yaml.get("email", ""))
|
||||||
|
phone = c2.text_input("Phone", saved_yaml.get("phone", ""))
|
||||||
|
linkedin = c2.text_input("LinkedIn URL", saved_yaml.get("linkedin", ""))
|
||||||
|
|
||||||
|
# Career summary with optional LLM generation
|
||||||
|
summary_default = st.session_state.get("_gen_result_career_summary") or saved_yaml.get("career_summary", "")
|
||||||
|
summary = st.text_area(
|
||||||
|
"Career Summary *", value=summary_default, height=120,
|
||||||
|
placeholder="Experienced professional with X years in [field]. Specialise in [skills].",
|
||||||
|
help="Injected into cover letter and research prompts as your professional context.",
|
||||||
|
)
|
||||||
|
|
||||||
|
gen_result = _generation_widget(
|
||||||
|
section="career_summary",
|
||||||
|
label="Generate from resume",
|
||||||
|
tier=_tier,
|
||||||
|
feature_key="llm_career_summary",
|
||||||
|
input_data={"resume_text": saved_yaml.get("_raw_resume_text", "")},
|
||||||
|
)
|
||||||
|
if gen_result and gen_result != summary:
|
||||||
|
st.info(f"\u2728 Suggested summary \u2014 paste it above if it looks good:\n\n{gen_result}")
|
||||||
|
|
||||||
|
col_back, col_next = st.columns([1, 4])
|
||||||
|
if col_back.button("\u2190 Back", key="ident_back"):
|
||||||
st.session_state.wizard_step = 2
|
st.session_state.wizard_step = 2
|
||||||
st.rerun()
|
st.rerun()
|
||||||
if col_skip.button("Skip"):
|
if col_next.button("Next \u2192", type="primary", key="ident_next"):
|
||||||
data.setdefault("nda_companies", [])
|
errs = validate({"name": name, "email": email, "career_summary": summary})
|
||||||
st.session_state.wizard_step = 4
|
if errs:
|
||||||
st.rerun()
|
st.error("\n".join(errs))
|
||||||
if col_next.button("Next →", type="primary"):
|
else:
|
||||||
data["nda_companies"] = nda_list
|
_save_yaml({
|
||||||
|
"name": name, "email": email, "phone": phone,
|
||||||
|
"linkedin": linkedin, "career_summary": summary,
|
||||||
|
"wizard_complete": False, "wizard_step": 3,
|
||||||
|
})
|
||||||
st.session_state.wizard_step = 4
|
st.session_state.wizard_step = 4
|
||||||
st.rerun()
|
st.rerun()
|
||||||
|
|
||||||
# ── Step 4: Inference & API Keys ───────────────────────────────────────────────
|
|
||||||
|
# ── Step 4: Resume ─────────────────────────────────────────────────────────────
|
||||||
elif step == 4:
|
elif step == 4:
|
||||||
profile = data.get("inference_profile", "remote")
|
from app.wizard.step_resume import validate
|
||||||
st.subheader("Step 4 — Inference & API Keys")
|
|
||||||
|
st.subheader("Step 4 \u2014 Resume")
|
||||||
|
st.caption("Upload your resume for fast parsing, or build it section by section.")
|
||||||
|
|
||||||
|
tab_upload, tab_builder = st.tabs(["\U0001f4ce Upload", "\U0001f4dd Build manually"])
|
||||||
|
|
||||||
|
with tab_upload:
|
||||||
|
uploaded = st.file_uploader("Upload PDF or DOCX", type=["pdf", "docx"])
|
||||||
|
if uploaded and st.button("Parse Resume", type="primary", key="parse_resume"):
|
||||||
|
from scripts.resume_parser import (
|
||||||
|
extract_text_from_pdf, extract_text_from_docx, structure_resume,
|
||||||
|
)
|
||||||
|
file_bytes = uploaded.read()
|
||||||
|
ext = uploaded.name.rsplit(".", 1)[-1].lower()
|
||||||
|
raw_text = (
|
||||||
|
extract_text_from_pdf(file_bytes) if ext == "pdf"
|
||||||
|
else extract_text_from_docx(file_bytes)
|
||||||
|
)
|
||||||
|
with st.spinner("Parsing\u2026"):
|
||||||
|
parsed = structure_resume(raw_text)
|
||||||
|
if parsed:
|
||||||
|
st.session_state["_parsed_resume"] = parsed
|
||||||
|
st.session_state["_raw_resume_text"] = raw_text
|
||||||
|
_save_yaml({"_raw_resume_text": raw_text[:8000]})
|
||||||
|
st.success("Parsed! Review the builder tab to edit entries.")
|
||||||
|
else:
|
||||||
|
st.warning("Auto-parse failed \u2014 switch to the Build tab and add entries manually.")
|
||||||
|
|
||||||
|
with tab_builder:
|
||||||
|
parsed = st.session_state.get("_parsed_resume", {})
|
||||||
|
experience = st.session_state.get(
|
||||||
|
"_experience",
|
||||||
|
parsed.get("experience") or saved_yaml.get("experience", []),
|
||||||
|
)
|
||||||
|
|
||||||
|
for i, entry in enumerate(experience):
|
||||||
|
with st.expander(
|
||||||
|
f"{entry.get('title', 'Entry')} @ {entry.get('company', '?')}",
|
||||||
|
expanded=(i == len(experience) - 1),
|
||||||
|
):
|
||||||
|
entry["company"] = st.text_input("Company", entry.get("company", ""), key=f"co_{i}")
|
||||||
|
entry["title"] = st.text_input("Title", entry.get("title", ""), key=f"ti_{i}")
|
||||||
|
raw_bullets = st.text_area(
|
||||||
|
"Responsibilities (one per line)",
|
||||||
|
"\n".join(entry.get("bullets", [])),
|
||||||
|
key=f"bu_{i}", height=80,
|
||||||
|
)
|
||||||
|
entry["bullets"] = [b.strip() for b in raw_bullets.splitlines() if b.strip()]
|
||||||
|
if st.button("Remove entry", key=f"rm_{i}"):
|
||||||
|
experience.pop(i)
|
||||||
|
st.session_state["_experience"] = experience
|
||||||
|
st.rerun()
|
||||||
|
|
||||||
|
if st.button("\uff0b Add work experience entry", key="add_exp"):
|
||||||
|
experience.append({"company": "", "title": "", "bullets": []})
|
||||||
|
st.session_state["_experience"] = experience
|
||||||
|
st.rerun()
|
||||||
|
|
||||||
|
# Bullet expansion generation
|
||||||
|
if experience:
|
||||||
|
all_bullets = "\n".join(
|
||||||
|
b for e in experience for b in e.get("bullets", [])
|
||||||
|
)
|
||||||
|
_generation_widget(
|
||||||
|
section="expand_bullets",
|
||||||
|
label="Expand bullet points",
|
||||||
|
tier=_tier,
|
||||||
|
feature_key="llm_expand_bullets",
|
||||||
|
input_data={"bullet_notes": all_bullets},
|
||||||
|
)
|
||||||
|
|
||||||
|
col_back, col_next = st.columns([1, 4])
|
||||||
|
if col_back.button("\u2190 Back", key="resume_back"):
|
||||||
|
st.session_state.wizard_step = 3
|
||||||
|
st.rerun()
|
||||||
|
if col_next.button("Next \u2192", type="primary", key="resume_next"):
|
||||||
|
parsed = st.session_state.get("_parsed_resume", {})
|
||||||
|
experience = (
|
||||||
|
parsed.get("experience") or
|
||||||
|
st.session_state.get("_experience", [])
|
||||||
|
)
|
||||||
|
errs = validate({"experience": experience})
|
||||||
|
if errs:
|
||||||
|
st.error("\n".join(errs))
|
||||||
|
else:
|
||||||
|
resume_yaml_path = _ROOT / "aihawk" / "data_folder" / "plain_text_resume.yaml"
|
||||||
|
resume_yaml_path.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
resume_data = {**parsed, "experience": experience} if parsed else {"experience": experience}
|
||||||
|
resume_yaml_path.write_text(
|
||||||
|
yaml.dump(resume_data, default_flow_style=False, allow_unicode=True)
|
||||||
|
)
|
||||||
|
_save_yaml({"wizard_step": 4})
|
||||||
|
st.session_state.wizard_step = 5
|
||||||
|
st.rerun()
|
||||||
|
|
||||||
|
|
||||||
|
# ── Step 5: Inference ──────────────────────────────────────────────────────────
|
||||||
|
elif step == 5:
|
||||||
|
from app.wizard.step_inference import validate
|
||||||
|
|
||||||
|
st.subheader("Step 5 \u2014 Inference & API Keys")
|
||||||
|
profile = saved_yaml.get("inference_profile", "remote")
|
||||||
|
|
||||||
if profile == "remote":
|
if profile == "remote":
|
||||||
st.info("Remote mode: LLM calls go to external APIs. At least one key is needed.")
|
st.info("Remote mode: at least one external API key is required.")
|
||||||
anthropic_key = st.text_input("Anthropic API Key", type="password",
|
anthropic_key = st.text_input("Anthropic API Key", type="password", placeholder="sk-ant-\u2026")
|
||||||
placeholder="sk-ant-…")
|
|
||||||
openai_url = st.text_input("OpenAI-compatible endpoint (optional)",
|
openai_url = st.text_input("OpenAI-compatible endpoint (optional)",
|
||||||
placeholder="https://api.together.xyz/v1")
|
placeholder="https://api.together.xyz/v1")
|
||||||
openai_key = st.text_input("Endpoint API Key (optional)", type="password") if openai_url else ""
|
openai_key = st.text_input("Endpoint API Key (optional)", type="password",
|
||||||
data.update({"anthropic_key": anthropic_key, "openai_url": openai_url,
|
key="oai_key") if openai_url else ""
|
||||||
"openai_key": openai_key})
|
|
||||||
else:
|
else:
|
||||||
st.info(f"Local mode ({profile}): Ollama handles cover letters. Configure model below.")
|
st.info(f"Local mode ({profile}): Ollama provides inference.")
|
||||||
ollama_model = st.text_input("Cover letter model name",
|
anthropic_key = openai_url = openai_key = ""
|
||||||
data.get("ollama_model", "llama3.2:3b"),
|
|
||||||
help="This model will be pulled by Ollama on first start.")
|
|
||||||
data["ollama_model"] = ollama_model
|
|
||||||
|
|
||||||
st.divider()
|
with st.expander("Advanced \u2014 Service Ports & Hosts"):
|
||||||
with st.expander("Advanced — Service Ports & Hosts"):
|
|
||||||
st.caption("Change only if services run on non-default ports or remote hosts.")
|
st.caption("Change only if services run on non-default ports or remote hosts.")
|
||||||
svc = data.get("services", {})
|
svc = dict(saved_yaml.get("services", {}))
|
||||||
for svc_name, default_host, default_port in [
|
for svc_name, default_host, default_port in [
|
||||||
("ollama", "localhost", 11434),
|
("ollama", "localhost", 11434),
|
||||||
("vllm", "localhost", 8000),
|
("vllm", "localhost", 8000),
|
||||||
("searxng", "localhost", 8888),
|
("searxng", "localhost", 8888),
|
||||||
]:
|
]:
|
||||||
c1, c2, c3, c4 = st.columns([2, 1, 0.5, 0.5])
|
c1, c2 = st.columns([3, 1])
|
||||||
svc[f"{svc_name}_host"] = c1.text_input(f"{svc_name} host", svc.get(f"{svc_name}_host", default_host), key=f"adv_{svc_name}_host")
|
svc[f"{svc_name}_host"] = c1.text_input(
|
||||||
svc[f"{svc_name}_port"] = int(c2.number_input("port", value=svc.get(f"{svc_name}_port", default_port), step=1, key=f"adv_{svc_name}_port"))
|
f"{svc_name} host",
|
||||||
svc[f"{svc_name}_ssl"] = c3.checkbox("SSL", svc.get(f"{svc_name}_ssl", False), key=f"adv_{svc_name}_ssl")
|
svc.get(f"{svc_name}_host", default_host),
|
||||||
svc[f"{svc_name}_ssl_verify"] = c4.checkbox("Verify", svc.get(f"{svc_name}_ssl_verify", True), key=f"adv_{svc_name}_verify")
|
key=f"h_{svc_name}",
|
||||||
data["services"] = svc
|
)
|
||||||
|
svc[f"{svc_name}_port"] = int(c2.number_input(
|
||||||
|
"port",
|
||||||
|
value=int(svc.get(f"{svc_name}_port", default_port)),
|
||||||
|
step=1, key=f"p_{svc_name}",
|
||||||
|
))
|
||||||
|
|
||||||
|
confirmed = st.session_state.get("_inf_confirmed", False)
|
||||||
|
test_label = "\U0001f50c Test Ollama connection" if profile != "remote" else "\U0001f50c Test LLM connection"
|
||||||
|
if st.button(test_label, key="inf_test"):
|
||||||
|
if profile == "remote":
|
||||||
|
from scripts.llm_router import LLMRouter
|
||||||
|
try:
|
||||||
|
r = LLMRouter().complete("Reply with only: OK")
|
||||||
|
if r and r.strip():
|
||||||
|
st.success("LLM responding.")
|
||||||
|
st.session_state["_inf_confirmed"] = True
|
||||||
|
confirmed = True
|
||||||
|
except Exception as e:
|
||||||
|
st.error(f"LLM test failed: {e}")
|
||||||
|
else:
|
||||||
|
import requests
|
||||||
|
ollama_url = f"http://{svc.get('ollama_host','localhost')}:{svc.get('ollama_port',11434)}"
|
||||||
|
try:
|
||||||
|
requests.get(f"{ollama_url}/api/tags", timeout=5)
|
||||||
|
st.success("Ollama is running.")
|
||||||
|
st.session_state["_inf_confirmed"] = True
|
||||||
|
confirmed = True
|
||||||
|
except Exception:
|
||||||
|
st.warning("Ollama not responding \u2014 you can skip this check and configure later.")
|
||||||
|
st.session_state["_inf_confirmed"] = True
|
||||||
|
confirmed = True
|
||||||
|
|
||||||
col_back, col_next = st.columns([1, 4])
|
col_back, col_next = st.columns([1, 4])
|
||||||
if col_back.button("← Back"):
|
if col_back.button("\u2190 Back", key="inf_back"):
|
||||||
st.session_state.wizard_step = 3
|
|
||||||
st.rerun()
|
|
||||||
if col_next.button("Next →", type="primary"):
|
|
||||||
st.session_state.wizard_step = 5
|
|
||||||
st.rerun()
|
|
||||||
|
|
||||||
# ── Step 5: Notion (optional) ──────────────────────────────────────────────────
|
|
||||||
elif step == 5:
|
|
||||||
st.subheader("Step 5 — Notion Sync (Optional)")
|
|
||||||
st.caption("Syncs approved and applied jobs to a Notion database. Skip if not using Notion.")
|
|
||||||
notion_token = st.text_input("Integration Token", type="password", placeholder="secret_…")
|
|
||||||
notion_db = st.text_input("Database ID", placeholder="32-character ID from Notion URL")
|
|
||||||
|
|
||||||
if notion_token and notion_db:
|
|
||||||
if st.button("🔌 Test connection"):
|
|
||||||
with st.spinner("Connecting…"):
|
|
||||||
try:
|
|
||||||
from notion_client import Client
|
|
||||||
db = Client(auth=notion_token).databases.retrieve(notion_db)
|
|
||||||
st.success(f"Connected: {db['title'][0]['plain_text']}")
|
|
||||||
except Exception as e:
|
|
||||||
st.error(f"Connection failed: {e}")
|
|
||||||
|
|
||||||
col_back, col_skip, col_finish = st.columns([1, 1, 3])
|
|
||||||
if col_back.button("← Back"):
|
|
||||||
st.session_state.wizard_step = 4
|
st.session_state.wizard_step = 4
|
||||||
st.rerun()
|
st.rerun()
|
||||||
|
if col_next.button("Next \u2192", type="primary", key="inf_next", disabled=not confirmed):
|
||||||
|
errs = validate({"endpoint_confirmed": confirmed})
|
||||||
|
if errs:
|
||||||
|
st.error("\n".join(errs))
|
||||||
|
else:
|
||||||
|
# Write API keys to .env
|
||||||
|
env_path = _ROOT / ".env"
|
||||||
|
env_lines = env_path.read_text().splitlines() if env_path.exists() else []
|
||||||
|
|
||||||
def _finish(save_notion: bool) -> None:
|
def _set_env(lines: list[str], key: str, val: str) -> list[str]:
|
||||||
svc_defaults = {
|
for i, l in enumerate(lines):
|
||||||
"streamlit_port": 8501,
|
if l.startswith(f"{key}="):
|
||||||
"ollama_host": "localhost", "ollama_port": 11434,
|
lines[i] = f"{key}={val}"
|
||||||
"ollama_ssl": False, "ollama_ssl_verify": True,
|
|
||||||
"vllm_host": "localhost", "vllm_port": 8000,
|
|
||||||
"vllm_ssl": False, "vllm_ssl_verify": True,
|
|
||||||
"searxng_host": "localhost", "searxng_port": 8888,
|
|
||||||
"searxng_ssl": False, "searxng_ssl_verify": True,
|
|
||||||
}
|
|
||||||
svc_defaults.update(data.get("services", {}))
|
|
||||||
user_data = {
|
|
||||||
"name": data.get("name", ""),
|
|
||||||
"email": data.get("email", ""),
|
|
||||||
"phone": data.get("phone", ""),
|
|
||||||
"linkedin": data.get("linkedin", ""),
|
|
||||||
"career_summary": data.get("career_summary", ""),
|
|
||||||
"nda_companies": data.get("nda_companies", []),
|
|
||||||
"docs_dir": "~/Documents/JobSearch",
|
|
||||||
"ollama_models_dir": "~/models/ollama",
|
|
||||||
"vllm_models_dir": "~/models/vllm",
|
|
||||||
"inference_profile": data.get("inference_profile", "remote"),
|
|
||||||
"services": svc_defaults,
|
|
||||||
}
|
|
||||||
CONFIG_DIR.mkdir(parents=True, exist_ok=True)
|
|
||||||
USER_CFG.write_text(yaml.dump(user_data, default_flow_style=False, allow_unicode=True))
|
|
||||||
|
|
||||||
if LLM_CFG.exists():
|
|
||||||
from scripts.user_profile import UserProfile
|
|
||||||
from scripts.generate_llm_config import apply_service_urls
|
|
||||||
apply_service_urls(UserProfile(USER_CFG), LLM_CFG)
|
|
||||||
|
|
||||||
# Write API keys to .env (Docker Compose reads these)
|
|
||||||
env_path = CONFIG_DIR.parent / ".env"
|
|
||||||
env_lines = []
|
|
||||||
if env_path.exists():
|
|
||||||
env_lines = env_path.read_text().splitlines()
|
|
||||||
|
|
||||||
def _set_env(lines: list[str], key: str, value: str) -> list[str]:
|
|
||||||
"""Update or append a KEY=value line."""
|
|
||||||
prefix = f"{key}="
|
|
||||||
new_line = f"{key}={value}"
|
|
||||||
for i, line in enumerate(lines):
|
|
||||||
if line.startswith(prefix):
|
|
||||||
lines[i] = new_line
|
|
||||||
return lines
|
return lines
|
||||||
lines.append(new_line)
|
lines.append(f"{key}={val}")
|
||||||
return lines
|
return lines
|
||||||
|
|
||||||
anthropic_key = data.get("anthropic_key", "")
|
|
||||||
openai_url = data.get("openai_url", "")
|
|
||||||
openai_key = data.get("openai_key", "")
|
|
||||||
|
|
||||||
if anthropic_key:
|
if anthropic_key:
|
||||||
env_lines = _set_env(env_lines, "ANTHROPIC_API_KEY", anthropic_key)
|
env_lines = _set_env(env_lines, "ANTHROPIC_API_KEY", anthropic_key)
|
||||||
if openai_url:
|
if openai_url:
|
||||||
env_lines = _set_env(env_lines, "OPENAI_COMPAT_URL", openai_url)
|
env_lines = _set_env(env_lines, "OPENAI_COMPAT_URL", openai_url)
|
||||||
if openai_key:
|
if openai_key:
|
||||||
env_lines = _set_env(env_lines, "OPENAI_COMPAT_KEY", openai_key)
|
env_lines = _set_env(env_lines, "OPENAI_COMPAT_KEY", openai_key)
|
||||||
|
|
||||||
if anthropic_key or openai_url:
|
if anthropic_key or openai_url:
|
||||||
env_path.write_text("\n".join(env_lines) + "\n")
|
env_path.write_text("\n".join(env_lines) + "\n")
|
||||||
|
|
||||||
if save_notion and notion_token and notion_db:
|
_save_yaml({"services": svc, "wizard_step": 5})
|
||||||
# Load field_map defaults from example
|
st.session_state.wizard_step = 6
|
||||||
notion_example = CONFIG_DIR / "notion.yaml.example"
|
|
||||||
field_map = {}
|
|
||||||
if notion_example.exists():
|
|
||||||
ex = yaml.safe_load(notion_example.read_text()) or {}
|
|
||||||
field_map = ex.get("field_map", {})
|
|
||||||
|
|
||||||
NOTION_CFG.write_text(yaml.dump({
|
|
||||||
"token": notion_token,
|
|
||||||
"database_id": notion_db,
|
|
||||||
"field_map": field_map,
|
|
||||||
}, default_flow_style=False, allow_unicode=True))
|
|
||||||
|
|
||||||
st.session_state.wizard_step = 1
|
|
||||||
st.session_state.wizard_data = {}
|
|
||||||
st.success("Setup complete! Redirecting…")
|
|
||||||
st.rerun()
|
st.rerun()
|
||||||
|
|
||||||
if col_skip.button("Skip & Finish"):
|
|
||||||
_finish(save_notion=False)
|
# ── Step 6: Search ─────────────────────────────────────────────────────────────
|
||||||
if col_finish.button("💾 Save & Finish", type="primary"):
|
elif step == 6:
|
||||||
_finish(save_notion=True)
|
from app.wizard.step_search import validate
|
||||||
|
|
||||||
|
st.subheader("Step 6 \u2014 Job Search Preferences")
|
||||||
|
st.caption("Set up what to search for. You can refine these in Settings \u2192 Search later.")
|
||||||
|
|
||||||
|
titles = st.session_state.get("_titles", saved_yaml.get("_wiz_titles", []))
|
||||||
|
locations = st.session_state.get("_locations", saved_yaml.get("_wiz_locations", []))
|
||||||
|
|
||||||
|
c1, c2 = st.columns(2)
|
||||||
|
|
||||||
|
with c1:
|
||||||
|
st.markdown("**Job Titles**")
|
||||||
|
for i, t in enumerate(titles):
|
||||||
|
tc1, tc2 = st.columns([5, 1])
|
||||||
|
tc1.text(t)
|
||||||
|
if tc2.button("\u00d7", key=f"rmtitle_{i}"):
|
||||||
|
titles.pop(i)
|
||||||
|
st.session_state["_titles"] = titles
|
||||||
|
st.rerun()
|
||||||
|
new_title = st.text_input("Add title", key="new_title_wiz",
|
||||||
|
placeholder="Software Engineer, Product Manager\u2026")
|
||||||
|
ac1, ac2 = st.columns([4, 1])
|
||||||
|
if ac2.button("\uff0b", key="add_title"):
|
||||||
|
if new_title.strip() and new_title.strip() not in titles:
|
||||||
|
titles.append(new_title.strip())
|
||||||
|
st.session_state["_titles"] = titles
|
||||||
|
st.rerun()
|
||||||
|
|
||||||
|
# LLM title suggestions
|
||||||
|
_generation_widget(
|
||||||
|
section="job_titles",
|
||||||
|
label="Suggest job titles",
|
||||||
|
tier=_tier,
|
||||||
|
feature_key="llm_job_titles",
|
||||||
|
input_data={
|
||||||
|
"resume_text": saved_yaml.get("_raw_resume_text", ""),
|
||||||
|
"current_titles": str(titles),
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
with c2:
|
||||||
|
st.markdown("**Locations**")
|
||||||
|
for i, l in enumerate(locations):
|
||||||
|
lc1, lc2 = st.columns([5, 1])
|
||||||
|
lc1.text(l)
|
||||||
|
if lc2.button("\u00d7", key=f"rmloc_{i}"):
|
||||||
|
locations.pop(i)
|
||||||
|
st.session_state["_locations"] = locations
|
||||||
|
st.rerun()
|
||||||
|
new_loc = st.text_input("Add location", key="new_loc_wiz",
|
||||||
|
placeholder="Remote, New York NY, San Francisco CA\u2026")
|
||||||
|
ll1, ll2 = st.columns([4, 1])
|
||||||
|
if ll2.button("\uff0b", key="add_loc"):
|
||||||
|
if new_loc.strip():
|
||||||
|
locations.append(new_loc.strip())
|
||||||
|
st.session_state["_locations"] = locations
|
||||||
|
st.rerun()
|
||||||
|
|
||||||
|
col_back, col_next = st.columns([1, 4])
|
||||||
|
if col_back.button("\u2190 Back", key="search_back"):
|
||||||
|
st.session_state.wizard_step = 5
|
||||||
|
st.rerun()
|
||||||
|
if col_next.button("Next \u2192", type="primary", key="search_next"):
|
||||||
|
errs = validate({"job_titles": titles, "locations": locations})
|
||||||
|
if errs:
|
||||||
|
st.error("\n".join(errs))
|
||||||
|
else:
|
||||||
|
search_profile_path = CONFIG_DIR / "search_profiles.yaml"
|
||||||
|
existing_profiles = {}
|
||||||
|
if search_profile_path.exists():
|
||||||
|
existing_profiles = yaml.safe_load(search_profile_path.read_text()) or {}
|
||||||
|
profiles_list = existing_profiles.get("profiles", [])
|
||||||
|
# Update or create "default" profile
|
||||||
|
default_idx = next(
|
||||||
|
(i for i, p in enumerate(profiles_list) if p.get("name") == "default"), None
|
||||||
|
)
|
||||||
|
default_profile = {
|
||||||
|
"name": "default",
|
||||||
|
"job_titles": titles,
|
||||||
|
"locations": locations,
|
||||||
|
"remote_only": False,
|
||||||
|
"boards": ["linkedin", "indeed", "glassdoor", "zip_recruiter"],
|
||||||
|
}
|
||||||
|
if default_idx is not None:
|
||||||
|
profiles_list[default_idx] = default_profile
|
||||||
|
else:
|
||||||
|
profiles_list.insert(0, default_profile)
|
||||||
|
search_profile_path.write_text(
|
||||||
|
yaml.dump({"profiles": profiles_list},
|
||||||
|
default_flow_style=False, allow_unicode=True)
|
||||||
|
)
|
||||||
|
_save_yaml({"wizard_step": 6})
|
||||||
|
st.session_state.wizard_step = 7
|
||||||
|
st.rerun()
|
||||||
|
|
||||||
|
|
||||||
|
# ── Step 7: Integrations (optional) ───────────────────────────────────────────
|
||||||
|
elif step == 7:
|
||||||
|
st.subheader("Step 7 \u2014 Integrations (Optional)")
|
||||||
|
st.caption(
|
||||||
|
"Connect cloud services, calendars, and notification tools. "
|
||||||
|
"You can add or change these any time in Settings \u2192 Integrations."
|
||||||
|
)
|
||||||
|
|
||||||
|
from scripts.integrations import REGISTRY
|
||||||
|
from app.wizard.step_integrations import get_available, is_connected
|
||||||
|
from app.wizard.tiers import tier_label
|
||||||
|
|
||||||
|
available = get_available(_tier)
|
||||||
|
|
||||||
|
for name, cls in sorted(REGISTRY.items(), key=lambda x: (x[0] not in available, x[0])):
|
||||||
|
is_conn = is_connected(name, CONFIG_DIR)
|
||||||
|
icon = "\u2705" if is_conn else "\u25cb"
|
||||||
|
lock = tier_label(f"{name}_sync") or tier_label(f"{name}_notifications")
|
||||||
|
|
||||||
|
with st.expander(f"{icon} {cls.label} {lock}"):
|
||||||
|
if name not in available:
|
||||||
|
st.caption(f"Upgrade to {cls.tier} to unlock {cls.label}.")
|
||||||
|
continue
|
||||||
|
|
||||||
|
inst = cls()
|
||||||
|
config: dict = {}
|
||||||
|
for field in inst.fields():
|
||||||
|
val = st.text_input(
|
||||||
|
field["label"],
|
||||||
|
type="password" if field["type"] == "password" else "default",
|
||||||
|
placeholder=field.get("placeholder", ""),
|
||||||
|
help=field.get("help", ""),
|
||||||
|
key=f"int_{name}_{field['key']}",
|
||||||
|
)
|
||||||
|
config[field["key"]] = val
|
||||||
|
|
||||||
|
required_filled = all(
|
||||||
|
config.get(f["key"])
|
||||||
|
for f in inst.fields()
|
||||||
|
if f.get("required")
|
||||||
|
)
|
||||||
|
if st.button(f"Connect {cls.label}", key=f"conn_{name}",
|
||||||
|
disabled=not required_filled):
|
||||||
|
inst.connect(config)
|
||||||
|
with st.spinner(f"Testing {cls.label} connection\u2026"):
|
||||||
|
if inst.test():
|
||||||
|
inst.save_config(config, CONFIG_DIR)
|
||||||
|
st.success(f"{cls.label} connected!")
|
||||||
|
st.rerun()
|
||||||
|
else:
|
||||||
|
st.error(
|
||||||
|
f"Connection test failed for {cls.label}. "
|
||||||
|
"Double-check your credentials."
|
||||||
|
)
|
||||||
|
|
||||||
|
st.divider()
|
||||||
|
col_back, col_skip, col_finish = st.columns([1, 1, 3])
|
||||||
|
|
||||||
|
if col_back.button("\u2190 Back", key="int_back"):
|
||||||
|
st.session_state.wizard_step = 6
|
||||||
|
st.rerun()
|
||||||
|
|
||||||
|
if col_skip.button("Skip \u2192"):
|
||||||
|
st.session_state.wizard_step = 8 # trigger Finish
|
||||||
|
st.rerun()
|
||||||
|
|
||||||
|
if col_finish.button("\U0001f389 Finish Setup", type="primary", key="finish_btn"):
|
||||||
|
st.session_state.wizard_step = 8
|
||||||
|
st.rerun()
|
||||||
|
|
||||||
|
|
||||||
|
# ── Finish ─────────────────────────────────────────────────────────────────────
|
||||||
|
elif step >= 8:
|
||||||
|
with st.spinner("Finalising setup\u2026"):
|
||||||
|
from scripts.user_profile import UserProfile
|
||||||
|
from scripts.generate_llm_config import apply_service_urls
|
||||||
|
|
||||||
|
try:
|
||||||
|
profile_obj = UserProfile(USER_YAML)
|
||||||
|
if (CONFIG_DIR / "llm.yaml").exists():
|
||||||
|
apply_service_urls(profile_obj, CONFIG_DIR / "llm.yaml")
|
||||||
|
except Exception:
|
||||||
|
pass # don't block finish on llm.yaml errors
|
||||||
|
|
||||||
|
data = _load_yaml()
|
||||||
|
data["wizard_complete"] = True
|
||||||
|
data.pop("wizard_step", None)
|
||||||
|
USER_YAML.write_text(
|
||||||
|
yaml.dump(data, default_flow_style=False, allow_unicode=True)
|
||||||
|
)
|
||||||
|
|
||||||
|
st.success("\u2705 Setup complete! Loading Peregrine\u2026")
|
||||||
|
st.session_state.clear()
|
||||||
|
st.rerun()
|
||||||
|
|
|
||||||
116
tests/test_wizard_flow.py
Normal file
116
tests/test_wizard_flow.py
Normal file
|
|
@ -0,0 +1,116 @@
|
||||||
|
"""
|
||||||
|
Wizard flow logic tests — no Streamlit dependency.
|
||||||
|
Tests validate() chain, yaml persistence helpers, and wizard state inference.
|
||||||
|
"""
|
||||||
|
import sys
|
||||||
|
from pathlib import Path
|
||||||
|
import yaml
|
||||||
|
sys.path.insert(0, str(Path(__file__).parent.parent))
|
||||||
|
|
||||||
|
|
||||||
|
# ── All mandatory steps validate correctly ────────────────────────────────────
|
||||||
|
|
||||||
|
def test_all_mandatory_steps_accept_minimal_valid_data():
|
||||||
|
"""Each step's validate() accepts the minimum required input."""
|
||||||
|
from app.wizard.step_hardware import validate as hw
|
||||||
|
from app.wizard.step_tier import validate as tier
|
||||||
|
from app.wizard.step_identity import validate as ident
|
||||||
|
from app.wizard.step_resume import validate as resume
|
||||||
|
from app.wizard.step_inference import validate as inf
|
||||||
|
from app.wizard.step_search import validate as search
|
||||||
|
|
||||||
|
assert hw({"inference_profile": "remote"}) == []
|
||||||
|
assert tier({"tier": "free"}) == []
|
||||||
|
assert ident({"name": "A", "email": "a@b.com", "career_summary": "x"}) == []
|
||||||
|
assert resume({"experience": [{"company": "X", "title": "T", "bullets": []}]}) == []
|
||||||
|
assert inf({"endpoint_confirmed": True}) == []
|
||||||
|
assert search({"job_titles": ["SWE"], "locations": ["Remote"]}) == []
|
||||||
|
|
||||||
|
|
||||||
|
def test_mandatory_steps_reject_empty_data():
|
||||||
|
"""Each step's validate() rejects completely empty input."""
|
||||||
|
from app.wizard.step_hardware import validate as hw
|
||||||
|
from app.wizard.step_tier import validate as tier
|
||||||
|
from app.wizard.step_identity import validate as ident
|
||||||
|
from app.wizard.step_resume import validate as resume
|
||||||
|
from app.wizard.step_inference import validate as inf
|
||||||
|
from app.wizard.step_search import validate as search
|
||||||
|
|
||||||
|
assert hw({}) != []
|
||||||
|
assert tier({}) != []
|
||||||
|
assert ident({}) != []
|
||||||
|
assert resume({}) != []
|
||||||
|
assert inf({}) != []
|
||||||
|
assert search({}) != []
|
||||||
|
|
||||||
|
|
||||||
|
# ── Yaml persistence helpers ──────────────────────────────────────────────────
|
||||||
|
|
||||||
|
def test_wizard_step_persists_to_yaml(tmp_path):
|
||||||
|
"""Writing wizard_step to user.yaml survives a reload."""
|
||||||
|
p = tmp_path / "user.yaml"
|
||||||
|
p.write_text(yaml.dump({
|
||||||
|
"name": "Test", "email": "t@t.com",
|
||||||
|
"career_summary": "x", "wizard_complete": False,
|
||||||
|
}))
|
||||||
|
# Simulate "write step 3 on Next"
|
||||||
|
data = yaml.safe_load(p.read_text()) or {}
|
||||||
|
data["wizard_step"] = 3
|
||||||
|
p.write_text(yaml.dump(data))
|
||||||
|
reloaded = yaml.safe_load(p.read_text())
|
||||||
|
assert reloaded["wizard_step"] == 3
|
||||||
|
assert reloaded["wizard_complete"] is False
|
||||||
|
|
||||||
|
|
||||||
|
def test_finish_sets_wizard_complete_and_removes_wizard_step(tmp_path):
|
||||||
|
"""After Finish, wizard_complete is True and wizard_step is absent."""
|
||||||
|
p = tmp_path / "user.yaml"
|
||||||
|
p.write_text(yaml.dump({
|
||||||
|
"name": "Test", "email": "t@t.com",
|
||||||
|
"career_summary": "x", "wizard_complete": False, "wizard_step": 6,
|
||||||
|
}))
|
||||||
|
# Simulate Finish action
|
||||||
|
data = yaml.safe_load(p.read_text()) or {}
|
||||||
|
data["wizard_complete"] = True
|
||||||
|
data.pop("wizard_step", None)
|
||||||
|
p.write_text(yaml.dump(data))
|
||||||
|
reloaded = yaml.safe_load(p.read_text())
|
||||||
|
assert reloaded["wizard_complete"] is True
|
||||||
|
assert "wizard_step" not in reloaded
|
||||||
|
|
||||||
|
|
||||||
|
def test_wizard_resume_step_inferred_from_yaml(tmp_path):
|
||||||
|
"""wizard_step in user.yaml determines which step to resume at."""
|
||||||
|
p = tmp_path / "user.yaml"
|
||||||
|
p.write_text(yaml.dump({
|
||||||
|
"name": "Test", "email": "t@t.com",
|
||||||
|
"career_summary": "x", "wizard_complete": False, "wizard_step": 4,
|
||||||
|
}))
|
||||||
|
data = yaml.safe_load(p.read_text()) or {}
|
||||||
|
# Wizard should resume at step 5 (last_completed + 1)
|
||||||
|
resume_at = data.get("wizard_step", 0) + 1
|
||||||
|
assert resume_at == 5
|
||||||
|
|
||||||
|
|
||||||
|
def test_wizard_complete_true_means_no_wizard(tmp_path):
|
||||||
|
"""If wizard_complete is True, the app should NOT show the wizard."""
|
||||||
|
p = tmp_path / "user.yaml"
|
||||||
|
p.write_text(yaml.dump({
|
||||||
|
"name": "Test", "email": "t@t.com",
|
||||||
|
"career_summary": "x", "wizard_complete": True,
|
||||||
|
}))
|
||||||
|
from scripts.user_profile import UserProfile
|
||||||
|
u = UserProfile(p)
|
||||||
|
assert u.wizard_complete is True
|
||||||
|
|
||||||
|
|
||||||
|
def test_wizard_incomplete_means_show_wizard(tmp_path):
|
||||||
|
"""If wizard_complete is False, the app SHOULD show the wizard."""
|
||||||
|
p = tmp_path / "user.yaml"
|
||||||
|
p.write_text(yaml.dump({
|
||||||
|
"name": "Test", "email": "t@t.com",
|
||||||
|
"career_summary": "x", "wizard_complete": False,
|
||||||
|
}))
|
||||||
|
from scripts.user_profile import UserProfile
|
||||||
|
u = UserProfile(p)
|
||||||
|
assert u.wizard_complete is False
|
||||||
Loading…
Reference in a new issue