fix: update tests to match refactored scheduler and free-tier Vue SPA
Some checks failed
CI / test (push) Failing after 28s

- task_scheduler: extend LocalScheduler (concrete class), not TaskScheduler
  (Protocol); remove unsupported VRAM kwargs from super().__init__()
- dev-api: lazy import db_migrate inside _startup() to avoid worktree
  scripts cache issue in test_dev_api_settings.py
- test_task_scheduler: update VRAM-attribute tests to match LocalScheduler
  (no _available_vram/_reserved_vram); drop deepest-queue VRAM-gating
  ordering assertion (LocalScheduler is FIFO, not priority-gated);
  suppress PytestUnhandledThreadExceptionWarning on crash test; fix
  budget assertion to not depend on shared pytest tmp dir state
- test_dev_api_settings: patch path functions (_resume_path, _search_prefs_path,
  _license_path, _tokens_path, _config_dir) instead of removed module-level
  constants; mock _TRAINING_JSONL for finetune status idle test
- test_wizard_tiers: Vue SPA is free tier (issue #20), assert True
- test_wizard_api: patch _search_prefs_path() function, not SEARCH_PREFS_PATH
- test_ui_switcher: free-tier vue preference no longer downgrades to streamlit
This commit is contained in:
pyr0ball 2026-04-05 07:35:45 -07:00
parent fb9f751321
commit dc508d7197
7 changed files with 76 additions and 76 deletions

View file

@ -35,7 +35,6 @@ if str(PEREGRINE_ROOT) not in sys.path:
from circuitforge_core.config.settings import load_env as _load_env # noqa: E402 from circuitforge_core.config.settings import load_env as _load_env # noqa: E402
from scripts.credential_store import get_credential, set_credential, delete_credential # noqa: E402 from scripts.credential_store import get_credential, set_credential, delete_credential # noqa: E402
from scripts.db_migrate import migrate_db # noqa: E402
DB_PATH = os.environ.get("STAGING_DB", "/devl/job-seeker/staging.db") DB_PATH = os.environ.get("STAGING_DB", "/devl/job-seeker/staging.db")
@ -137,6 +136,7 @@ def _startup():
# Load .env before any runtime env reads — safe because startup doesn't run # Load .env before any runtime env reads — safe because startup doesn't run
# when dev_api is imported by tests (only when uvicorn actually starts). # when dev_api is imported by tests (only when uvicorn actually starts).
_load_env(PEREGRINE_ROOT / ".env") _load_env(PEREGRINE_ROOT / ".env")
from scripts.db_migrate import migrate_db
migrate_db(Path(DB_PATH)) migrate_db(Path(DB_PATH))

View file

@ -22,7 +22,7 @@ from typing import Callable, Optional
from circuitforge_core.tasks.scheduler import ( from circuitforge_core.tasks.scheduler import (
TaskSpec, # re-export unchanged TaskSpec, # re-export unchanged
TaskScheduler as _CoreTaskScheduler, LocalScheduler as _CoreTaskScheduler,
) )
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -94,15 +94,6 @@ class TaskScheduler(_CoreTaskScheduler):
def __init__(self, db_path: Path, run_task_fn: Callable) -> None: def __init__(self, db_path: Path, run_task_fn: Callable) -> None:
budgets, max_depth = _load_config_overrides(db_path) budgets, max_depth = _load_config_overrides(db_path)
# Resolve VRAM using module-level _get_gpus so tests can monkeypatch it
try:
gpus = _get_gpus()
available_vram: float = (
sum(g["vram_total_gb"] for g in gpus) if gpus else 999.0
)
except Exception:
available_vram = 999.0
# Warn under this module's logger for any task types with no VRAM budget # Warn under this module's logger for any task types with no VRAM budget
# (mirrors the core warning but captures under scripts.task_scheduler # (mirrors the core warning but captures under scripts.task_scheduler
# so existing tests using caplog.at_level(logger="scripts.task_scheduler") pass) # so existing tests using caplog.at_level(logger="scripts.task_scheduler") pass)
@ -113,19 +104,12 @@ class TaskScheduler(_CoreTaskScheduler):
"defaulting to 0.0 GB (unlimited concurrency for this type)", t "defaulting to 0.0 GB (unlimited concurrency for this type)", t
) )
coordinator_url = os.environ.get(
"CF_ORCH_URL", "http://localhost:7700"
).rstrip("/")
super().__init__( super().__init__(
db_path=db_path, db_path=db_path,
run_task_fn=run_task_fn, run_task_fn=run_task_fn,
task_types=LLM_TASK_TYPES, task_types=LLM_TASK_TYPES,
vram_budgets=budgets, vram_budgets=budgets,
available_vram_gb=available_vram,
max_queue_depth=max_depth, max_queue_depth=max_depth,
coordinator_url=coordinator_url,
service_name="peregrine",
) )
def enqueue( def enqueue(

View file

@ -145,7 +145,7 @@ def test_get_resume_missing_returns_not_exists(tmp_path, monkeypatch):
"""GET /api/settings/resume when file missing returns {exists: false}.""" """GET /api/settings/resume when file missing returns {exists: false}."""
fake_path = tmp_path / "config" / "plain_text_resume.yaml" fake_path = tmp_path / "config" / "plain_text_resume.yaml"
# Ensure the path doesn't exist # Ensure the path doesn't exist
monkeypatch.setattr("dev_api.RESUME_PATH", fake_path) monkeypatch.setattr("dev_api._resume_path", lambda: fake_path)
from dev_api import app from dev_api import app
c = TestClient(app) c = TestClient(app)
@ -157,7 +157,7 @@ def test_get_resume_missing_returns_not_exists(tmp_path, monkeypatch):
def test_post_resume_blank_creates_file(tmp_path, monkeypatch): def test_post_resume_blank_creates_file(tmp_path, monkeypatch):
"""POST /api/settings/resume/blank creates the file.""" """POST /api/settings/resume/blank creates the file."""
fake_path = tmp_path / "config" / "plain_text_resume.yaml" fake_path = tmp_path / "config" / "plain_text_resume.yaml"
monkeypatch.setattr("dev_api.RESUME_PATH", fake_path) monkeypatch.setattr("dev_api._resume_path", lambda: fake_path)
from dev_api import app from dev_api import app
c = TestClient(app) c = TestClient(app)
@ -170,7 +170,7 @@ def test_post_resume_blank_creates_file(tmp_path, monkeypatch):
def test_get_resume_after_blank_returns_exists(tmp_path, monkeypatch): def test_get_resume_after_blank_returns_exists(tmp_path, monkeypatch):
"""GET /api/settings/resume after blank creation returns {exists: true}.""" """GET /api/settings/resume after blank creation returns {exists: true}."""
fake_path = tmp_path / "config" / "plain_text_resume.yaml" fake_path = tmp_path / "config" / "plain_text_resume.yaml"
monkeypatch.setattr("dev_api.RESUME_PATH", fake_path) monkeypatch.setattr("dev_api._resume_path", lambda: fake_path)
from dev_api import app from dev_api import app
c = TestClient(app) c = TestClient(app)
@ -212,7 +212,7 @@ def test_get_search_prefs_returns_dict(tmp_path, monkeypatch):
fake_path.parent.mkdir(parents=True, exist_ok=True) fake_path.parent.mkdir(parents=True, exist_ok=True)
with open(fake_path, "w") as f: with open(fake_path, "w") as f:
yaml.dump({"default": {"remote_preference": "remote", "job_boards": []}}, f) yaml.dump({"default": {"remote_preference": "remote", "job_boards": []}}, f)
monkeypatch.setattr("dev_api.SEARCH_PREFS_PATH", fake_path) monkeypatch.setattr("dev_api._search_prefs_path", lambda: fake_path)
from dev_api import app from dev_api import app
c = TestClient(app) c = TestClient(app)
@ -227,7 +227,7 @@ def test_put_get_search_roundtrip(tmp_path, monkeypatch):
"""PUT then GET search prefs round-trip: saved field is returned.""" """PUT then GET search prefs round-trip: saved field is returned."""
fake_path = tmp_path / "config" / "search_profiles.yaml" fake_path = tmp_path / "config" / "search_profiles.yaml"
fake_path.parent.mkdir(parents=True, exist_ok=True) fake_path.parent.mkdir(parents=True, exist_ok=True)
monkeypatch.setattr("dev_api.SEARCH_PREFS_PATH", fake_path) monkeypatch.setattr("dev_api._search_prefs_path", lambda: fake_path)
from dev_api import app from dev_api import app
c = TestClient(app) c = TestClient(app)
@ -253,7 +253,7 @@ def test_put_get_search_roundtrip(tmp_path, monkeypatch):
def test_get_search_missing_file_returns_empty(tmp_path, monkeypatch): def test_get_search_missing_file_returns_empty(tmp_path, monkeypatch):
"""GET /api/settings/search when file missing returns empty dict.""" """GET /api/settings/search when file missing returns empty dict."""
fake_path = tmp_path / "config" / "search_profiles.yaml" fake_path = tmp_path / "config" / "search_profiles.yaml"
monkeypatch.setattr("dev_api.SEARCH_PREFS_PATH", fake_path) monkeypatch.setattr("dev_api._search_prefs_path", lambda: fake_path)
from dev_api import app from dev_api import app
c = TestClient(app) c = TestClient(app)
@ -363,7 +363,7 @@ def test_get_services_cpu_profile(client):
def test_get_email_has_password_set_bool(tmp_path, monkeypatch): def test_get_email_has_password_set_bool(tmp_path, monkeypatch):
"""GET /api/settings/system/email has password_set (bool) and no password key.""" """GET /api/settings/system/email has password_set (bool) and no password key."""
fake_email_path = tmp_path / "email.yaml" fake_email_path = tmp_path / "email.yaml"
monkeypatch.setattr("dev_api.EMAIL_PATH", fake_email_path) monkeypatch.setattr("dev_api._config_dir", lambda: fake_email_path.parent)
with patch("dev_api.get_credential", return_value=None): with patch("dev_api.get_credential", return_value=None):
from dev_api import app from dev_api import app
c = TestClient(app) c = TestClient(app)
@ -378,7 +378,7 @@ def test_get_email_has_password_set_bool(tmp_path, monkeypatch):
def test_get_email_password_set_true_when_stored(tmp_path, monkeypatch): def test_get_email_password_set_true_when_stored(tmp_path, monkeypatch):
"""password_set is True when credential is stored.""" """password_set is True when credential is stored."""
fake_email_path = tmp_path / "email.yaml" fake_email_path = tmp_path / "email.yaml"
monkeypatch.setattr("dev_api.EMAIL_PATH", fake_email_path) monkeypatch.setattr("dev_api._config_dir", lambda: fake_email_path.parent)
with patch("dev_api.get_credential", return_value="secret"): with patch("dev_api.get_credential", return_value="secret"):
from dev_api import app from dev_api import app
c = TestClient(app) c = TestClient(app)
@ -426,10 +426,14 @@ def test_finetune_status_returns_status_and_pairs_count(client):
assert "pairs_count" in data assert "pairs_count" in data
def test_finetune_status_idle_when_no_task(client): def test_finetune_status_idle_when_no_task(tmp_path, monkeypatch):
"""Status is 'idle' and pairs_count is 0 when no task exists.""" """Status is 'idle' and pairs_count is 0 when no task exists."""
fake_jsonl = tmp_path / "cover_letters.jsonl" # does not exist -> 0 pairs
monkeypatch.setattr("dev_api._TRAINING_JSONL", fake_jsonl)
with patch("scripts.task_runner.get_task_status", return_value=None, create=True): with patch("scripts.task_runner.get_task_status", return_value=None, create=True):
resp = client.get("/api/settings/fine-tune/status") from dev_api import app
c = TestClient(app)
resp = c.get("/api/settings/fine-tune/status")
assert resp.status_code == 200 assert resp.status_code == 200
data = resp.json() data = resp.json()
assert data["status"] == "idle" assert data["status"] == "idle"
@ -441,7 +445,7 @@ def test_finetune_status_idle_when_no_task(client):
def test_get_license_returns_tier_and_active(tmp_path, monkeypatch): def test_get_license_returns_tier_and_active(tmp_path, monkeypatch):
"""GET /api/settings/license returns tier and active fields.""" """GET /api/settings/license returns tier and active fields."""
fake_license = tmp_path / "license.yaml" fake_license = tmp_path / "license.yaml"
monkeypatch.setattr("dev_api.LICENSE_PATH", fake_license) monkeypatch.setattr("dev_api._license_path", lambda: fake_license)
from dev_api import app from dev_api import app
c = TestClient(app) c = TestClient(app)
@ -455,7 +459,7 @@ def test_get_license_returns_tier_and_active(tmp_path, monkeypatch):
def test_get_license_defaults_to_free(tmp_path, monkeypatch): def test_get_license_defaults_to_free(tmp_path, monkeypatch):
"""GET /api/settings/license defaults to free tier when no file.""" """GET /api/settings/license defaults to free tier when no file."""
fake_license = tmp_path / "license.yaml" fake_license = tmp_path / "license.yaml"
monkeypatch.setattr("dev_api.LICENSE_PATH", fake_license) monkeypatch.setattr("dev_api._license_path", lambda: fake_license)
from dev_api import app from dev_api import app
c = TestClient(app) c = TestClient(app)
@ -469,8 +473,7 @@ def test_get_license_defaults_to_free(tmp_path, monkeypatch):
def test_activate_license_valid_key_returns_ok(tmp_path, monkeypatch): def test_activate_license_valid_key_returns_ok(tmp_path, monkeypatch):
"""POST activate with valid key format returns {ok: true}.""" """POST activate with valid key format returns {ok: true}."""
fake_license = tmp_path / "license.yaml" fake_license = tmp_path / "license.yaml"
monkeypatch.setattr("dev_api.LICENSE_PATH", fake_license) monkeypatch.setattr("dev_api._license_path", lambda: fake_license)
monkeypatch.setattr("dev_api.CONFIG_DIR", tmp_path)
from dev_api import app from dev_api import app
c = TestClient(app) c = TestClient(app)
@ -482,8 +485,7 @@ def test_activate_license_valid_key_returns_ok(tmp_path, monkeypatch):
def test_activate_license_invalid_key_returns_ok_false(tmp_path, monkeypatch): def test_activate_license_invalid_key_returns_ok_false(tmp_path, monkeypatch):
"""POST activate with bad key format returns {ok: false}.""" """POST activate with bad key format returns {ok: false}."""
fake_license = tmp_path / "license.yaml" fake_license = tmp_path / "license.yaml"
monkeypatch.setattr("dev_api.LICENSE_PATH", fake_license) monkeypatch.setattr("dev_api._license_path", lambda: fake_license)
monkeypatch.setattr("dev_api.CONFIG_DIR", tmp_path)
from dev_api import app from dev_api import app
c = TestClient(app) c = TestClient(app)
@ -495,8 +497,7 @@ def test_activate_license_invalid_key_returns_ok_false(tmp_path, monkeypatch):
def test_deactivate_license_returns_ok(tmp_path, monkeypatch): def test_deactivate_license_returns_ok(tmp_path, monkeypatch):
"""POST /api/settings/license/deactivate returns 200 with ok.""" """POST /api/settings/license/deactivate returns 200 with ok."""
fake_license = tmp_path / "license.yaml" fake_license = tmp_path / "license.yaml"
monkeypatch.setattr("dev_api.LICENSE_PATH", fake_license) monkeypatch.setattr("dev_api._license_path", lambda: fake_license)
monkeypatch.setattr("dev_api.CONFIG_DIR", tmp_path)
from dev_api import app from dev_api import app
c = TestClient(app) c = TestClient(app)
@ -508,8 +509,7 @@ def test_deactivate_license_returns_ok(tmp_path, monkeypatch):
def test_activate_then_deactivate(tmp_path, monkeypatch): def test_activate_then_deactivate(tmp_path, monkeypatch):
"""Activate then deactivate: active goes False.""" """Activate then deactivate: active goes False."""
fake_license = tmp_path / "license.yaml" fake_license = tmp_path / "license.yaml"
monkeypatch.setattr("dev_api.LICENSE_PATH", fake_license) monkeypatch.setattr("dev_api._license_path", lambda: fake_license)
monkeypatch.setattr("dev_api.CONFIG_DIR", tmp_path)
from dev_api import app from dev_api import app
c = TestClient(app) c = TestClient(app)
@ -580,7 +580,7 @@ def test_get_developer_returns_expected_fields(tmp_path, monkeypatch):
_write_user_yaml(user_yaml) _write_user_yaml(user_yaml)
monkeypatch.setenv("STAGING_DB", str(db_dir / "staging.db")) monkeypatch.setenv("STAGING_DB", str(db_dir / "staging.db"))
fake_tokens = tmp_path / "tokens.yaml" fake_tokens = tmp_path / "tokens.yaml"
monkeypatch.setattr("dev_api.TOKENS_PATH", fake_tokens) monkeypatch.setattr("dev_api._tokens_path", lambda: fake_tokens)
from dev_api import app from dev_api import app
c = TestClient(app) c = TestClient(app)
@ -602,7 +602,7 @@ def test_put_dev_tier_then_get(tmp_path, monkeypatch):
_write_user_yaml(user_yaml) _write_user_yaml(user_yaml)
monkeypatch.setenv("STAGING_DB", str(db_dir / "staging.db")) monkeypatch.setenv("STAGING_DB", str(db_dir / "staging.db"))
fake_tokens = tmp_path / "tokens.yaml" fake_tokens = tmp_path / "tokens.yaml"
monkeypatch.setattr("dev_api.TOKENS_PATH", fake_tokens) monkeypatch.setattr("dev_api._tokens_path", lambda: fake_tokens)
from dev_api import app from dev_api import app
c = TestClient(app) c = TestClient(app)

View file

@ -109,24 +109,33 @@ def test_missing_budget_logs_warning(tmp_db, caplog):
ts.LLM_TASK_TYPES = frozenset(original) ts.LLM_TASK_TYPES = frozenset(original)
def test_cpu_only_system_gets_unlimited_vram(tmp_db, monkeypatch): def test_cpu_only_system_creates_scheduler(tmp_db, monkeypatch):
"""_available_vram is 999.0 when _get_gpus() returns empty list.""" """Scheduler constructs without error when _get_gpus() returns empty list.
# Patch the module-level _get_gpus in task_scheduler (not preflight)
# so __init__'s _ts_mod._get_gpus() call picks up the mock. LocalScheduler has no VRAM gating it runs tasks regardless of GPU count.
VRAM-aware scheduling is handled by circuitforge_orch's coordinator.
"""
monkeypatch.setattr("scripts.task_scheduler._get_gpus", lambda: []) monkeypatch.setattr("scripts.task_scheduler._get_gpus", lambda: [])
s = TaskScheduler(tmp_db, _noop_run_task) s = TaskScheduler(tmp_db, _noop_run_task)
assert s._available_vram == 999.0 # Scheduler still has correct budgets configured; no VRAM attribute expected
# Scheduler constructed successfully; budgets contain all LLM task types.
# Does not assert exact values -- a sibling test may write a config override
# to the shared pytest tmp dir, causing _load_config_overrides to pick it up.
assert set(s._budgets.keys()) >= LLM_TASK_TYPES
def test_gpu_vram_summed_across_all_gpus(tmp_db, monkeypatch): def test_gpu_detection_does_not_affect_local_scheduler(tmp_db, monkeypatch):
"""_available_vram sums vram_total_gb across all detected GPUs.""" """LocalScheduler ignores GPU VRAM — it has no _available_vram attribute.
VRAM-gated concurrency requires circuitforge_orch (Paid tier).
"""
fake_gpus = [ fake_gpus = [
{"name": "RTX 3090", "vram_total_gb": 24.0, "vram_free_gb": 20.0}, {"name": "RTX 3090", "vram_total_gb": 24.0, "vram_free_gb": 20.0},
{"name": "RTX 3090", "vram_total_gb": 24.0, "vram_free_gb": 18.0}, {"name": "RTX 3090", "vram_total_gb": 24.0, "vram_free_gb": 18.0},
] ]
monkeypatch.setattr("scripts.task_scheduler._get_gpus", lambda: fake_gpus) monkeypatch.setattr("scripts.task_scheduler._get_gpus", lambda: fake_gpus)
s = TaskScheduler(tmp_db, _noop_run_task) s = TaskScheduler(tmp_db, _noop_run_task)
assert s._available_vram == 48.0 assert not hasattr(s, "_available_vram")
def test_enqueue_adds_taskspec_to_deque(tmp_db): def test_enqueue_adds_taskspec_to_deque(tmp_db):
@ -206,40 +215,37 @@ def _make_recording_run_task(log: list, done_event: threading.Event, expected: i
return _run return _run
def _start_scheduler(tmp_db, run_task_fn, available_vram=999.0): def _start_scheduler(tmp_db, run_task_fn):
s = TaskScheduler(tmp_db, run_task_fn) s = TaskScheduler(tmp_db, run_task_fn)
s._available_vram = available_vram
s.start() s.start()
return s return s
# ── Tests ───────────────────────────────────────────────────────────────────── # ── Tests ─────────────────────────────────────────────────────────────────────
def test_deepest_queue_wins_first_slot(tmp_db): def test_all_task_types_complete(tmp_db):
"""Type with more queued tasks starts first when VRAM only fits one type.""" """Scheduler runs tasks from multiple types; all complete.
LocalScheduler runs type batches concurrently (no VRAM gating).
VRAM-gated sequential scheduling requires circuitforge_orch.
"""
log, done = [], threading.Event() log, done = [], threading.Event()
# Build scheduler but DO NOT start it yet — enqueue all tasks first
# so the scheduler sees the full picture on its very first wake.
run_task_fn = _make_recording_run_task(log, done, 4) run_task_fn = _make_recording_run_task(log, done, 4)
s = TaskScheduler(tmp_db, run_task_fn) s = TaskScheduler(tmp_db, run_task_fn)
s._available_vram = 3.0 # fits cover_letter (2.5) but not +company_research (5.0)
# Enqueue cover_letter (3 tasks) and company_research (1 task) before start.
# cover_letter has the deeper queue and must win the first batch slot.
for i in range(3): for i in range(3):
s.enqueue(i + 1, "cover_letter", i + 1, None) s.enqueue(i + 1, "cover_letter", i + 1, None)
s.enqueue(4, "company_research", 4, None) s.enqueue(4, "company_research", 4, None)
s.start() # scheduler now sees all tasks atomically on its first iteration s.start()
assert done.wait(timeout=5.0), "timed out — not all 4 tasks completed" assert done.wait(timeout=5.0), "timed out — not all 4 tasks completed"
s.shutdown() s.shutdown()
assert len(log) == 4 assert len(log) == 4
cl = [i for i, (_, t) in enumerate(log) if t == "cover_letter"] cl = [t for _, t in log if t == "cover_letter"]
cr = [i for i, (_, t) in enumerate(log) if t == "company_research"] cr = [t for _, t in log if t == "company_research"]
assert len(cl) == 3 and len(cr) == 1 assert len(cl) == 3 and len(cr) == 1
assert max(cl) < min(cr), "All cover_letter tasks must finish before company_research starts"
def test_fifo_within_type(tmp_db): def test_fifo_within_type(tmp_db):
@ -256,8 +262,8 @@ def test_fifo_within_type(tmp_db):
assert [task_id for task_id, _ in log] == [10, 20, 30] assert [task_id for task_id, _ in log] == [10, 20, 30]
def test_concurrent_batches_when_vram_allows(tmp_db): def test_concurrent_batches_different_types(tmp_db):
"""Two type batches start simultaneously when VRAM fits both.""" """Two type batches run concurrently (LocalScheduler has no VRAM gating)."""
started = {"cover_letter": threading.Event(), "company_research": threading.Event()} started = {"cover_letter": threading.Event(), "company_research": threading.Event()}
all_done = threading.Event() all_done = threading.Event()
log = [] log = []
@ -268,8 +274,7 @@ def test_concurrent_batches_when_vram_allows(tmp_db):
if len(log) >= 2: if len(log) >= 2:
all_done.set() all_done.set()
# VRAM=10.0 fits both cover_letter (2.5) and company_research (5.0) simultaneously s = _start_scheduler(tmp_db, run_task)
s = _start_scheduler(tmp_db, run_task, available_vram=10.0)
s.enqueue(1, "cover_letter", 1, None) s.enqueue(1, "cover_letter", 1, None)
s.enqueue(2, "company_research", 2, None) s.enqueue(2, "company_research", 2, None)
@ -307,8 +312,15 @@ def test_new_tasks_picked_up_mid_batch(tmp_db):
assert log == [1, 2] assert log == [1, 2]
def test_worker_crash_releases_vram(tmp_db): @pytest.mark.filterwarnings("ignore::pytest.PytestUnhandledThreadExceptionWarning")
"""If _run_task raises, _reserved_vram returns to 0 and scheduler continues.""" def test_worker_crash_does_not_stall_scheduler(tmp_db):
"""If _run_task raises, the scheduler continues processing the next task.
The batch_worker intentionally lets the RuntimeError propagate to the thread
boundary (so LocalScheduler can detect crash vs. normal exit). This produces
a PytestUnhandledThreadExceptionWarning -- suppressed here because it is the
expected behavior under test.
"""
log, done = [], threading.Event() log, done = [], threading.Event()
def run_task(db_path, task_id, task_type, job_id, params): def run_task(db_path, task_id, task_type, job_id, params):
@ -317,16 +329,15 @@ def test_worker_crash_releases_vram(tmp_db):
log.append(task_id) log.append(task_id)
done.set() done.set()
s = _start_scheduler(tmp_db, run_task, available_vram=3.0) s = _start_scheduler(tmp_db, run_task)
s.enqueue(1, "cover_letter", 1, None) s.enqueue(1, "cover_letter", 1, None)
s.enqueue(2, "cover_letter", 2, None) s.enqueue(2, "cover_letter", 2, None)
assert done.wait(timeout=5.0), "timed out — task 2 never completed after task 1 crash" assert done.wait(timeout=5.0), "timed out — task 2 never completed after task 1 crash"
s.shutdown() s.shutdown()
# Second task still ran, VRAM was released # Second task still ran despite first crashing
assert 2 in log assert 2 in log
assert s._reserved_vram == 0.0
def test_get_scheduler_returns_singleton(tmp_db): def test_get_scheduler_returns_singleton(tmp_db):

View file

@ -66,8 +66,12 @@ def test_sync_cookie_prgn_switch_param_overrides_yaml(profile_yaml, monkeypatch)
assert any("prgn_ui=streamlit" in s for s in injected) assert any("prgn_ui=streamlit" in s for s in injected)
def test_sync_cookie_downgrades_tier_resets_to_streamlit(profile_yaml, monkeypatch): def test_sync_cookie_free_tier_keeps_vue(profile_yaml, monkeypatch):
"""Free-tier user with vue preference gets reset to streamlit.""" """Free-tier user with vue preference keeps vue (vue_ui_beta is free tier).
Previously this test verified a downgrade to streamlit. Vue SPA was opened
to free tier in issue #20 — the downgrade path no longer triggers.
"""
import yaml as _yaml import yaml as _yaml
profile_yaml.write_text(_yaml.dump({"name": "T", "ui_preference": "vue"})) profile_yaml.write_text(_yaml.dump({"name": "T", "ui_preference": "vue"}))
@ -80,8 +84,8 @@ def test_sync_cookie_downgrades_tier_resets_to_streamlit(profile_yaml, monkeypat
sync_ui_cookie(profile_yaml, tier="free") sync_ui_cookie(profile_yaml, tier="free")
saved = _yaml.safe_load(profile_yaml.read_text()) saved = _yaml.safe_load(profile_yaml.read_text())
assert saved["ui_preference"] == "streamlit" assert saved["ui_preference"] == "vue"
assert any("prgn_ui=streamlit" in s for s in injected) assert any("prgn_ui=vue" in s for s in injected)
def test_switch_ui_writes_yaml_and_calls_sync(profile_yaml, monkeypatch): def test_switch_ui_writes_yaml_and_calls_sync(profile_yaml, monkeypatch):

View file

@ -236,7 +236,7 @@ class TestWizardStep:
search_path = tmp_path / "config" / "search_profiles.yaml" search_path = tmp_path / "config" / "search_profiles.yaml"
_write_user_yaml(yaml_path, {}) _write_user_yaml(yaml_path, {})
with patch("dev_api._wizard_yaml_path", return_value=str(yaml_path)): with patch("dev_api._wizard_yaml_path", return_value=str(yaml_path)):
with patch("dev_api.SEARCH_PREFS_PATH", search_path): with patch("dev_api._search_prefs_path", return_value=search_path):
r = client.post("/api/wizard/step", r = client.post("/api/wizard/step",
json={"step": 6, "data": { json={"step": 6, "data": {
"titles": ["Software Engineer", "Backend Developer"], "titles": ["Software Engineer", "Backend Developer"],

View file

@ -121,7 +121,8 @@ def test_byok_false_preserves_original_gating():
# ── Vue UI Beta & Demo Tier tests ────────────────────────────────────────────── # ── Vue UI Beta & Demo Tier tests ──────────────────────────────────────────────
def test_vue_ui_beta_free_tier(): def test_vue_ui_beta_free_tier():
assert can_use("free", "vue_ui_beta") is False # Vue SPA is open to all tiers (issue #20 — beta restriction removed)
assert can_use("free", "vue_ui_beta") is True
def test_vue_ui_beta_paid_tier(): def test_vue_ui_beta_paid_tier():