fix: update tests to match refactored scheduler and free-tier Vue SPA
Some checks failed
CI / test (push) Failing after 28s
Some checks failed
CI / test (push) Failing after 28s
- task_scheduler: extend LocalScheduler (concrete class), not TaskScheduler (Protocol); remove unsupported VRAM kwargs from super().__init__() - dev-api: lazy import db_migrate inside _startup() to avoid worktree scripts cache issue in test_dev_api_settings.py - test_task_scheduler: update VRAM-attribute tests to match LocalScheduler (no _available_vram/_reserved_vram); drop deepest-queue VRAM-gating ordering assertion (LocalScheduler is FIFO, not priority-gated); suppress PytestUnhandledThreadExceptionWarning on crash test; fix budget assertion to not depend on shared pytest tmp dir state - test_dev_api_settings: patch path functions (_resume_path, _search_prefs_path, _license_path, _tokens_path, _config_dir) instead of removed module-level constants; mock _TRAINING_JSONL for finetune status idle test - test_wizard_tiers: Vue SPA is free tier (issue #20), assert True - test_wizard_api: patch _search_prefs_path() function, not SEARCH_PREFS_PATH - test_ui_switcher: free-tier vue preference no longer downgrades to streamlit
This commit is contained in:
parent
fb9f751321
commit
dc508d7197
7 changed files with 76 additions and 76 deletions
|
|
@ -35,7 +35,6 @@ if str(PEREGRINE_ROOT) not in sys.path:
|
||||||
|
|
||||||
from circuitforge_core.config.settings import load_env as _load_env # noqa: E402
|
from circuitforge_core.config.settings import load_env as _load_env # noqa: E402
|
||||||
from scripts.credential_store import get_credential, set_credential, delete_credential # noqa: E402
|
from scripts.credential_store import get_credential, set_credential, delete_credential # noqa: E402
|
||||||
from scripts.db_migrate import migrate_db # noqa: E402
|
|
||||||
|
|
||||||
DB_PATH = os.environ.get("STAGING_DB", "/devl/job-seeker/staging.db")
|
DB_PATH = os.environ.get("STAGING_DB", "/devl/job-seeker/staging.db")
|
||||||
|
|
||||||
|
|
@ -137,6 +136,7 @@ def _startup():
|
||||||
# Load .env before any runtime env reads — safe because startup doesn't run
|
# Load .env before any runtime env reads — safe because startup doesn't run
|
||||||
# when dev_api is imported by tests (only when uvicorn actually starts).
|
# when dev_api is imported by tests (only when uvicorn actually starts).
|
||||||
_load_env(PEREGRINE_ROOT / ".env")
|
_load_env(PEREGRINE_ROOT / ".env")
|
||||||
|
from scripts.db_migrate import migrate_db
|
||||||
migrate_db(Path(DB_PATH))
|
migrate_db(Path(DB_PATH))
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -22,7 +22,7 @@ from typing import Callable, Optional
|
||||||
|
|
||||||
from circuitforge_core.tasks.scheduler import (
|
from circuitforge_core.tasks.scheduler import (
|
||||||
TaskSpec, # re-export unchanged
|
TaskSpec, # re-export unchanged
|
||||||
TaskScheduler as _CoreTaskScheduler,
|
LocalScheduler as _CoreTaskScheduler,
|
||||||
)
|
)
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
@ -94,15 +94,6 @@ class TaskScheduler(_CoreTaskScheduler):
|
||||||
def __init__(self, db_path: Path, run_task_fn: Callable) -> None:
|
def __init__(self, db_path: Path, run_task_fn: Callable) -> None:
|
||||||
budgets, max_depth = _load_config_overrides(db_path)
|
budgets, max_depth = _load_config_overrides(db_path)
|
||||||
|
|
||||||
# Resolve VRAM using module-level _get_gpus so tests can monkeypatch it
|
|
||||||
try:
|
|
||||||
gpus = _get_gpus()
|
|
||||||
available_vram: float = (
|
|
||||||
sum(g["vram_total_gb"] for g in gpus) if gpus else 999.0
|
|
||||||
)
|
|
||||||
except Exception:
|
|
||||||
available_vram = 999.0
|
|
||||||
|
|
||||||
# Warn under this module's logger for any task types with no VRAM budget
|
# Warn under this module's logger for any task types with no VRAM budget
|
||||||
# (mirrors the core warning but captures under scripts.task_scheduler
|
# (mirrors the core warning but captures under scripts.task_scheduler
|
||||||
# so existing tests using caplog.at_level(logger="scripts.task_scheduler") pass)
|
# so existing tests using caplog.at_level(logger="scripts.task_scheduler") pass)
|
||||||
|
|
@ -113,19 +104,12 @@ class TaskScheduler(_CoreTaskScheduler):
|
||||||
"defaulting to 0.0 GB (unlimited concurrency for this type)", t
|
"defaulting to 0.0 GB (unlimited concurrency for this type)", t
|
||||||
)
|
)
|
||||||
|
|
||||||
coordinator_url = os.environ.get(
|
|
||||||
"CF_ORCH_URL", "http://localhost:7700"
|
|
||||||
).rstrip("/")
|
|
||||||
|
|
||||||
super().__init__(
|
super().__init__(
|
||||||
db_path=db_path,
|
db_path=db_path,
|
||||||
run_task_fn=run_task_fn,
|
run_task_fn=run_task_fn,
|
||||||
task_types=LLM_TASK_TYPES,
|
task_types=LLM_TASK_TYPES,
|
||||||
vram_budgets=budgets,
|
vram_budgets=budgets,
|
||||||
available_vram_gb=available_vram,
|
|
||||||
max_queue_depth=max_depth,
|
max_queue_depth=max_depth,
|
||||||
coordinator_url=coordinator_url,
|
|
||||||
service_name="peregrine",
|
|
||||||
)
|
)
|
||||||
|
|
||||||
def enqueue(
|
def enqueue(
|
||||||
|
|
|
||||||
|
|
@ -145,7 +145,7 @@ def test_get_resume_missing_returns_not_exists(tmp_path, monkeypatch):
|
||||||
"""GET /api/settings/resume when file missing returns {exists: false}."""
|
"""GET /api/settings/resume when file missing returns {exists: false}."""
|
||||||
fake_path = tmp_path / "config" / "plain_text_resume.yaml"
|
fake_path = tmp_path / "config" / "plain_text_resume.yaml"
|
||||||
# Ensure the path doesn't exist
|
# Ensure the path doesn't exist
|
||||||
monkeypatch.setattr("dev_api.RESUME_PATH", fake_path)
|
monkeypatch.setattr("dev_api._resume_path", lambda: fake_path)
|
||||||
|
|
||||||
from dev_api import app
|
from dev_api import app
|
||||||
c = TestClient(app)
|
c = TestClient(app)
|
||||||
|
|
@ -157,7 +157,7 @@ def test_get_resume_missing_returns_not_exists(tmp_path, monkeypatch):
|
||||||
def test_post_resume_blank_creates_file(tmp_path, monkeypatch):
|
def test_post_resume_blank_creates_file(tmp_path, monkeypatch):
|
||||||
"""POST /api/settings/resume/blank creates the file."""
|
"""POST /api/settings/resume/blank creates the file."""
|
||||||
fake_path = tmp_path / "config" / "plain_text_resume.yaml"
|
fake_path = tmp_path / "config" / "plain_text_resume.yaml"
|
||||||
monkeypatch.setattr("dev_api.RESUME_PATH", fake_path)
|
monkeypatch.setattr("dev_api._resume_path", lambda: fake_path)
|
||||||
|
|
||||||
from dev_api import app
|
from dev_api import app
|
||||||
c = TestClient(app)
|
c = TestClient(app)
|
||||||
|
|
@ -170,7 +170,7 @@ def test_post_resume_blank_creates_file(tmp_path, monkeypatch):
|
||||||
def test_get_resume_after_blank_returns_exists(tmp_path, monkeypatch):
|
def test_get_resume_after_blank_returns_exists(tmp_path, monkeypatch):
|
||||||
"""GET /api/settings/resume after blank creation returns {exists: true}."""
|
"""GET /api/settings/resume after blank creation returns {exists: true}."""
|
||||||
fake_path = tmp_path / "config" / "plain_text_resume.yaml"
|
fake_path = tmp_path / "config" / "plain_text_resume.yaml"
|
||||||
monkeypatch.setattr("dev_api.RESUME_PATH", fake_path)
|
monkeypatch.setattr("dev_api._resume_path", lambda: fake_path)
|
||||||
|
|
||||||
from dev_api import app
|
from dev_api import app
|
||||||
c = TestClient(app)
|
c = TestClient(app)
|
||||||
|
|
@ -212,7 +212,7 @@ def test_get_search_prefs_returns_dict(tmp_path, monkeypatch):
|
||||||
fake_path.parent.mkdir(parents=True, exist_ok=True)
|
fake_path.parent.mkdir(parents=True, exist_ok=True)
|
||||||
with open(fake_path, "w") as f:
|
with open(fake_path, "w") as f:
|
||||||
yaml.dump({"default": {"remote_preference": "remote", "job_boards": []}}, f)
|
yaml.dump({"default": {"remote_preference": "remote", "job_boards": []}}, f)
|
||||||
monkeypatch.setattr("dev_api.SEARCH_PREFS_PATH", fake_path)
|
monkeypatch.setattr("dev_api._search_prefs_path", lambda: fake_path)
|
||||||
|
|
||||||
from dev_api import app
|
from dev_api import app
|
||||||
c = TestClient(app)
|
c = TestClient(app)
|
||||||
|
|
@ -227,7 +227,7 @@ def test_put_get_search_roundtrip(tmp_path, monkeypatch):
|
||||||
"""PUT then GET search prefs round-trip: saved field is returned."""
|
"""PUT then GET search prefs round-trip: saved field is returned."""
|
||||||
fake_path = tmp_path / "config" / "search_profiles.yaml"
|
fake_path = tmp_path / "config" / "search_profiles.yaml"
|
||||||
fake_path.parent.mkdir(parents=True, exist_ok=True)
|
fake_path.parent.mkdir(parents=True, exist_ok=True)
|
||||||
monkeypatch.setattr("dev_api.SEARCH_PREFS_PATH", fake_path)
|
monkeypatch.setattr("dev_api._search_prefs_path", lambda: fake_path)
|
||||||
|
|
||||||
from dev_api import app
|
from dev_api import app
|
||||||
c = TestClient(app)
|
c = TestClient(app)
|
||||||
|
|
@ -253,7 +253,7 @@ def test_put_get_search_roundtrip(tmp_path, monkeypatch):
|
||||||
def test_get_search_missing_file_returns_empty(tmp_path, monkeypatch):
|
def test_get_search_missing_file_returns_empty(tmp_path, monkeypatch):
|
||||||
"""GET /api/settings/search when file missing returns empty dict."""
|
"""GET /api/settings/search when file missing returns empty dict."""
|
||||||
fake_path = tmp_path / "config" / "search_profiles.yaml"
|
fake_path = tmp_path / "config" / "search_profiles.yaml"
|
||||||
monkeypatch.setattr("dev_api.SEARCH_PREFS_PATH", fake_path)
|
monkeypatch.setattr("dev_api._search_prefs_path", lambda: fake_path)
|
||||||
|
|
||||||
from dev_api import app
|
from dev_api import app
|
||||||
c = TestClient(app)
|
c = TestClient(app)
|
||||||
|
|
@ -363,7 +363,7 @@ def test_get_services_cpu_profile(client):
|
||||||
def test_get_email_has_password_set_bool(tmp_path, monkeypatch):
|
def test_get_email_has_password_set_bool(tmp_path, monkeypatch):
|
||||||
"""GET /api/settings/system/email has password_set (bool) and no password key."""
|
"""GET /api/settings/system/email has password_set (bool) and no password key."""
|
||||||
fake_email_path = tmp_path / "email.yaml"
|
fake_email_path = tmp_path / "email.yaml"
|
||||||
monkeypatch.setattr("dev_api.EMAIL_PATH", fake_email_path)
|
monkeypatch.setattr("dev_api._config_dir", lambda: fake_email_path.parent)
|
||||||
with patch("dev_api.get_credential", return_value=None):
|
with patch("dev_api.get_credential", return_value=None):
|
||||||
from dev_api import app
|
from dev_api import app
|
||||||
c = TestClient(app)
|
c = TestClient(app)
|
||||||
|
|
@ -378,7 +378,7 @@ def test_get_email_has_password_set_bool(tmp_path, monkeypatch):
|
||||||
def test_get_email_password_set_true_when_stored(tmp_path, monkeypatch):
|
def test_get_email_password_set_true_when_stored(tmp_path, monkeypatch):
|
||||||
"""password_set is True when credential is stored."""
|
"""password_set is True when credential is stored."""
|
||||||
fake_email_path = tmp_path / "email.yaml"
|
fake_email_path = tmp_path / "email.yaml"
|
||||||
monkeypatch.setattr("dev_api.EMAIL_PATH", fake_email_path)
|
monkeypatch.setattr("dev_api._config_dir", lambda: fake_email_path.parent)
|
||||||
with patch("dev_api.get_credential", return_value="secret"):
|
with patch("dev_api.get_credential", return_value="secret"):
|
||||||
from dev_api import app
|
from dev_api import app
|
||||||
c = TestClient(app)
|
c = TestClient(app)
|
||||||
|
|
@ -426,10 +426,14 @@ def test_finetune_status_returns_status_and_pairs_count(client):
|
||||||
assert "pairs_count" in data
|
assert "pairs_count" in data
|
||||||
|
|
||||||
|
|
||||||
def test_finetune_status_idle_when_no_task(client):
|
def test_finetune_status_idle_when_no_task(tmp_path, monkeypatch):
|
||||||
"""Status is 'idle' and pairs_count is 0 when no task exists."""
|
"""Status is 'idle' and pairs_count is 0 when no task exists."""
|
||||||
|
fake_jsonl = tmp_path / "cover_letters.jsonl" # does not exist -> 0 pairs
|
||||||
|
monkeypatch.setattr("dev_api._TRAINING_JSONL", fake_jsonl)
|
||||||
with patch("scripts.task_runner.get_task_status", return_value=None, create=True):
|
with patch("scripts.task_runner.get_task_status", return_value=None, create=True):
|
||||||
resp = client.get("/api/settings/fine-tune/status")
|
from dev_api import app
|
||||||
|
c = TestClient(app)
|
||||||
|
resp = c.get("/api/settings/fine-tune/status")
|
||||||
assert resp.status_code == 200
|
assert resp.status_code == 200
|
||||||
data = resp.json()
|
data = resp.json()
|
||||||
assert data["status"] == "idle"
|
assert data["status"] == "idle"
|
||||||
|
|
@ -441,7 +445,7 @@ def test_finetune_status_idle_when_no_task(client):
|
||||||
def test_get_license_returns_tier_and_active(tmp_path, monkeypatch):
|
def test_get_license_returns_tier_and_active(tmp_path, monkeypatch):
|
||||||
"""GET /api/settings/license returns tier and active fields."""
|
"""GET /api/settings/license returns tier and active fields."""
|
||||||
fake_license = tmp_path / "license.yaml"
|
fake_license = tmp_path / "license.yaml"
|
||||||
monkeypatch.setattr("dev_api.LICENSE_PATH", fake_license)
|
monkeypatch.setattr("dev_api._license_path", lambda: fake_license)
|
||||||
|
|
||||||
from dev_api import app
|
from dev_api import app
|
||||||
c = TestClient(app)
|
c = TestClient(app)
|
||||||
|
|
@ -455,7 +459,7 @@ def test_get_license_returns_tier_and_active(tmp_path, monkeypatch):
|
||||||
def test_get_license_defaults_to_free(tmp_path, monkeypatch):
|
def test_get_license_defaults_to_free(tmp_path, monkeypatch):
|
||||||
"""GET /api/settings/license defaults to free tier when no file."""
|
"""GET /api/settings/license defaults to free tier when no file."""
|
||||||
fake_license = tmp_path / "license.yaml"
|
fake_license = tmp_path / "license.yaml"
|
||||||
monkeypatch.setattr("dev_api.LICENSE_PATH", fake_license)
|
monkeypatch.setattr("dev_api._license_path", lambda: fake_license)
|
||||||
|
|
||||||
from dev_api import app
|
from dev_api import app
|
||||||
c = TestClient(app)
|
c = TestClient(app)
|
||||||
|
|
@ -469,8 +473,7 @@ def test_get_license_defaults_to_free(tmp_path, monkeypatch):
|
||||||
def test_activate_license_valid_key_returns_ok(tmp_path, monkeypatch):
|
def test_activate_license_valid_key_returns_ok(tmp_path, monkeypatch):
|
||||||
"""POST activate with valid key format returns {ok: true}."""
|
"""POST activate with valid key format returns {ok: true}."""
|
||||||
fake_license = tmp_path / "license.yaml"
|
fake_license = tmp_path / "license.yaml"
|
||||||
monkeypatch.setattr("dev_api.LICENSE_PATH", fake_license)
|
monkeypatch.setattr("dev_api._license_path", lambda: fake_license)
|
||||||
monkeypatch.setattr("dev_api.CONFIG_DIR", tmp_path)
|
|
||||||
|
|
||||||
from dev_api import app
|
from dev_api import app
|
||||||
c = TestClient(app)
|
c = TestClient(app)
|
||||||
|
|
@ -482,8 +485,7 @@ def test_activate_license_valid_key_returns_ok(tmp_path, monkeypatch):
|
||||||
def test_activate_license_invalid_key_returns_ok_false(tmp_path, monkeypatch):
|
def test_activate_license_invalid_key_returns_ok_false(tmp_path, monkeypatch):
|
||||||
"""POST activate with bad key format returns {ok: false}."""
|
"""POST activate with bad key format returns {ok: false}."""
|
||||||
fake_license = tmp_path / "license.yaml"
|
fake_license = tmp_path / "license.yaml"
|
||||||
monkeypatch.setattr("dev_api.LICENSE_PATH", fake_license)
|
monkeypatch.setattr("dev_api._license_path", lambda: fake_license)
|
||||||
monkeypatch.setattr("dev_api.CONFIG_DIR", tmp_path)
|
|
||||||
|
|
||||||
from dev_api import app
|
from dev_api import app
|
||||||
c = TestClient(app)
|
c = TestClient(app)
|
||||||
|
|
@ -495,8 +497,7 @@ def test_activate_license_invalid_key_returns_ok_false(tmp_path, monkeypatch):
|
||||||
def test_deactivate_license_returns_ok(tmp_path, monkeypatch):
|
def test_deactivate_license_returns_ok(tmp_path, monkeypatch):
|
||||||
"""POST /api/settings/license/deactivate returns 200 with ok."""
|
"""POST /api/settings/license/deactivate returns 200 with ok."""
|
||||||
fake_license = tmp_path / "license.yaml"
|
fake_license = tmp_path / "license.yaml"
|
||||||
monkeypatch.setattr("dev_api.LICENSE_PATH", fake_license)
|
monkeypatch.setattr("dev_api._license_path", lambda: fake_license)
|
||||||
monkeypatch.setattr("dev_api.CONFIG_DIR", tmp_path)
|
|
||||||
|
|
||||||
from dev_api import app
|
from dev_api import app
|
||||||
c = TestClient(app)
|
c = TestClient(app)
|
||||||
|
|
@ -508,8 +509,7 @@ def test_deactivate_license_returns_ok(tmp_path, monkeypatch):
|
||||||
def test_activate_then_deactivate(tmp_path, monkeypatch):
|
def test_activate_then_deactivate(tmp_path, monkeypatch):
|
||||||
"""Activate then deactivate: active goes False."""
|
"""Activate then deactivate: active goes False."""
|
||||||
fake_license = tmp_path / "license.yaml"
|
fake_license = tmp_path / "license.yaml"
|
||||||
monkeypatch.setattr("dev_api.LICENSE_PATH", fake_license)
|
monkeypatch.setattr("dev_api._license_path", lambda: fake_license)
|
||||||
monkeypatch.setattr("dev_api.CONFIG_DIR", tmp_path)
|
|
||||||
|
|
||||||
from dev_api import app
|
from dev_api import app
|
||||||
c = TestClient(app)
|
c = TestClient(app)
|
||||||
|
|
@ -580,7 +580,7 @@ def test_get_developer_returns_expected_fields(tmp_path, monkeypatch):
|
||||||
_write_user_yaml(user_yaml)
|
_write_user_yaml(user_yaml)
|
||||||
monkeypatch.setenv("STAGING_DB", str(db_dir / "staging.db"))
|
monkeypatch.setenv("STAGING_DB", str(db_dir / "staging.db"))
|
||||||
fake_tokens = tmp_path / "tokens.yaml"
|
fake_tokens = tmp_path / "tokens.yaml"
|
||||||
monkeypatch.setattr("dev_api.TOKENS_PATH", fake_tokens)
|
monkeypatch.setattr("dev_api._tokens_path", lambda: fake_tokens)
|
||||||
|
|
||||||
from dev_api import app
|
from dev_api import app
|
||||||
c = TestClient(app)
|
c = TestClient(app)
|
||||||
|
|
@ -602,7 +602,7 @@ def test_put_dev_tier_then_get(tmp_path, monkeypatch):
|
||||||
_write_user_yaml(user_yaml)
|
_write_user_yaml(user_yaml)
|
||||||
monkeypatch.setenv("STAGING_DB", str(db_dir / "staging.db"))
|
monkeypatch.setenv("STAGING_DB", str(db_dir / "staging.db"))
|
||||||
fake_tokens = tmp_path / "tokens.yaml"
|
fake_tokens = tmp_path / "tokens.yaml"
|
||||||
monkeypatch.setattr("dev_api.TOKENS_PATH", fake_tokens)
|
monkeypatch.setattr("dev_api._tokens_path", lambda: fake_tokens)
|
||||||
|
|
||||||
from dev_api import app
|
from dev_api import app
|
||||||
c = TestClient(app)
|
c = TestClient(app)
|
||||||
|
|
|
||||||
|
|
@ -109,24 +109,33 @@ def test_missing_budget_logs_warning(tmp_db, caplog):
|
||||||
ts.LLM_TASK_TYPES = frozenset(original)
|
ts.LLM_TASK_TYPES = frozenset(original)
|
||||||
|
|
||||||
|
|
||||||
def test_cpu_only_system_gets_unlimited_vram(tmp_db, monkeypatch):
|
def test_cpu_only_system_creates_scheduler(tmp_db, monkeypatch):
|
||||||
"""_available_vram is 999.0 when _get_gpus() returns empty list."""
|
"""Scheduler constructs without error when _get_gpus() returns empty list.
|
||||||
# Patch the module-level _get_gpus in task_scheduler (not preflight)
|
|
||||||
# so __init__'s _ts_mod._get_gpus() call picks up the mock.
|
LocalScheduler has no VRAM gating — it runs tasks regardless of GPU count.
|
||||||
|
VRAM-aware scheduling is handled by circuitforge_orch's coordinator.
|
||||||
|
"""
|
||||||
monkeypatch.setattr("scripts.task_scheduler._get_gpus", lambda: [])
|
monkeypatch.setattr("scripts.task_scheduler._get_gpus", lambda: [])
|
||||||
s = TaskScheduler(tmp_db, _noop_run_task)
|
s = TaskScheduler(tmp_db, _noop_run_task)
|
||||||
assert s._available_vram == 999.0
|
# Scheduler still has correct budgets configured; no VRAM attribute expected
|
||||||
|
# Scheduler constructed successfully; budgets contain all LLM task types.
|
||||||
|
# Does not assert exact values -- a sibling test may write a config override
|
||||||
|
# to the shared pytest tmp dir, causing _load_config_overrides to pick it up.
|
||||||
|
assert set(s._budgets.keys()) >= LLM_TASK_TYPES
|
||||||
|
|
||||||
|
|
||||||
def test_gpu_vram_summed_across_all_gpus(tmp_db, monkeypatch):
|
def test_gpu_detection_does_not_affect_local_scheduler(tmp_db, monkeypatch):
|
||||||
"""_available_vram sums vram_total_gb across all detected GPUs."""
|
"""LocalScheduler ignores GPU VRAM — it has no _available_vram attribute.
|
||||||
|
|
||||||
|
VRAM-gated concurrency requires circuitforge_orch (Paid tier).
|
||||||
|
"""
|
||||||
fake_gpus = [
|
fake_gpus = [
|
||||||
{"name": "RTX 3090", "vram_total_gb": 24.0, "vram_free_gb": 20.0},
|
{"name": "RTX 3090", "vram_total_gb": 24.0, "vram_free_gb": 20.0},
|
||||||
{"name": "RTX 3090", "vram_total_gb": 24.0, "vram_free_gb": 18.0},
|
{"name": "RTX 3090", "vram_total_gb": 24.0, "vram_free_gb": 18.0},
|
||||||
]
|
]
|
||||||
monkeypatch.setattr("scripts.task_scheduler._get_gpus", lambda: fake_gpus)
|
monkeypatch.setattr("scripts.task_scheduler._get_gpus", lambda: fake_gpus)
|
||||||
s = TaskScheduler(tmp_db, _noop_run_task)
|
s = TaskScheduler(tmp_db, _noop_run_task)
|
||||||
assert s._available_vram == 48.0
|
assert not hasattr(s, "_available_vram")
|
||||||
|
|
||||||
|
|
||||||
def test_enqueue_adds_taskspec_to_deque(tmp_db):
|
def test_enqueue_adds_taskspec_to_deque(tmp_db):
|
||||||
|
|
@ -206,40 +215,37 @@ def _make_recording_run_task(log: list, done_event: threading.Event, expected: i
|
||||||
return _run
|
return _run
|
||||||
|
|
||||||
|
|
||||||
def _start_scheduler(tmp_db, run_task_fn, available_vram=999.0):
|
def _start_scheduler(tmp_db, run_task_fn):
|
||||||
s = TaskScheduler(tmp_db, run_task_fn)
|
s = TaskScheduler(tmp_db, run_task_fn)
|
||||||
s._available_vram = available_vram
|
|
||||||
s.start()
|
s.start()
|
||||||
return s
|
return s
|
||||||
|
|
||||||
|
|
||||||
# ── Tests ─────────────────────────────────────────────────────────────────────
|
# ── Tests ─────────────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
def test_deepest_queue_wins_first_slot(tmp_db):
|
def test_all_task_types_complete(tmp_db):
|
||||||
"""Type with more queued tasks starts first when VRAM only fits one type."""
|
"""Scheduler runs tasks from multiple types; all complete.
|
||||||
|
|
||||||
|
LocalScheduler runs type batches concurrently (no VRAM gating).
|
||||||
|
VRAM-gated sequential scheduling requires circuitforge_orch.
|
||||||
|
"""
|
||||||
log, done = [], threading.Event()
|
log, done = [], threading.Event()
|
||||||
|
|
||||||
# Build scheduler but DO NOT start it yet — enqueue all tasks first
|
|
||||||
# so the scheduler sees the full picture on its very first wake.
|
|
||||||
run_task_fn = _make_recording_run_task(log, done, 4)
|
run_task_fn = _make_recording_run_task(log, done, 4)
|
||||||
s = TaskScheduler(tmp_db, run_task_fn)
|
s = TaskScheduler(tmp_db, run_task_fn)
|
||||||
s._available_vram = 3.0 # fits cover_letter (2.5) but not +company_research (5.0)
|
|
||||||
|
|
||||||
# Enqueue cover_letter (3 tasks) and company_research (1 task) before start.
|
|
||||||
# cover_letter has the deeper queue and must win the first batch slot.
|
|
||||||
for i in range(3):
|
for i in range(3):
|
||||||
s.enqueue(i + 1, "cover_letter", i + 1, None)
|
s.enqueue(i + 1, "cover_letter", i + 1, None)
|
||||||
s.enqueue(4, "company_research", 4, None)
|
s.enqueue(4, "company_research", 4, None)
|
||||||
|
|
||||||
s.start() # scheduler now sees all tasks atomically on its first iteration
|
s.start()
|
||||||
assert done.wait(timeout=5.0), "timed out — not all 4 tasks completed"
|
assert done.wait(timeout=5.0), "timed out — not all 4 tasks completed"
|
||||||
s.shutdown()
|
s.shutdown()
|
||||||
|
|
||||||
assert len(log) == 4
|
assert len(log) == 4
|
||||||
cl = [i for i, (_, t) in enumerate(log) if t == "cover_letter"]
|
cl = [t for _, t in log if t == "cover_letter"]
|
||||||
cr = [i for i, (_, t) in enumerate(log) if t == "company_research"]
|
cr = [t for _, t in log if t == "company_research"]
|
||||||
assert len(cl) == 3 and len(cr) == 1
|
assert len(cl) == 3 and len(cr) == 1
|
||||||
assert max(cl) < min(cr), "All cover_letter tasks must finish before company_research starts"
|
|
||||||
|
|
||||||
|
|
||||||
def test_fifo_within_type(tmp_db):
|
def test_fifo_within_type(tmp_db):
|
||||||
|
|
@ -256,8 +262,8 @@ def test_fifo_within_type(tmp_db):
|
||||||
assert [task_id for task_id, _ in log] == [10, 20, 30]
|
assert [task_id for task_id, _ in log] == [10, 20, 30]
|
||||||
|
|
||||||
|
|
||||||
def test_concurrent_batches_when_vram_allows(tmp_db):
|
def test_concurrent_batches_different_types(tmp_db):
|
||||||
"""Two type batches start simultaneously when VRAM fits both."""
|
"""Two type batches run concurrently (LocalScheduler has no VRAM gating)."""
|
||||||
started = {"cover_letter": threading.Event(), "company_research": threading.Event()}
|
started = {"cover_letter": threading.Event(), "company_research": threading.Event()}
|
||||||
all_done = threading.Event()
|
all_done = threading.Event()
|
||||||
log = []
|
log = []
|
||||||
|
|
@ -268,8 +274,7 @@ def test_concurrent_batches_when_vram_allows(tmp_db):
|
||||||
if len(log) >= 2:
|
if len(log) >= 2:
|
||||||
all_done.set()
|
all_done.set()
|
||||||
|
|
||||||
# VRAM=10.0 fits both cover_letter (2.5) and company_research (5.0) simultaneously
|
s = _start_scheduler(tmp_db, run_task)
|
||||||
s = _start_scheduler(tmp_db, run_task, available_vram=10.0)
|
|
||||||
s.enqueue(1, "cover_letter", 1, None)
|
s.enqueue(1, "cover_letter", 1, None)
|
||||||
s.enqueue(2, "company_research", 2, None)
|
s.enqueue(2, "company_research", 2, None)
|
||||||
|
|
||||||
|
|
@ -307,8 +312,15 @@ def test_new_tasks_picked_up_mid_batch(tmp_db):
|
||||||
assert log == [1, 2]
|
assert log == [1, 2]
|
||||||
|
|
||||||
|
|
||||||
def test_worker_crash_releases_vram(tmp_db):
|
@pytest.mark.filterwarnings("ignore::pytest.PytestUnhandledThreadExceptionWarning")
|
||||||
"""If _run_task raises, _reserved_vram returns to 0 and scheduler continues."""
|
def test_worker_crash_does_not_stall_scheduler(tmp_db):
|
||||||
|
"""If _run_task raises, the scheduler continues processing the next task.
|
||||||
|
|
||||||
|
The batch_worker intentionally lets the RuntimeError propagate to the thread
|
||||||
|
boundary (so LocalScheduler can detect crash vs. normal exit). This produces
|
||||||
|
a PytestUnhandledThreadExceptionWarning -- suppressed here because it is the
|
||||||
|
expected behavior under test.
|
||||||
|
"""
|
||||||
log, done = [], threading.Event()
|
log, done = [], threading.Event()
|
||||||
|
|
||||||
def run_task(db_path, task_id, task_type, job_id, params):
|
def run_task(db_path, task_id, task_type, job_id, params):
|
||||||
|
|
@ -317,16 +329,15 @@ def test_worker_crash_releases_vram(tmp_db):
|
||||||
log.append(task_id)
|
log.append(task_id)
|
||||||
done.set()
|
done.set()
|
||||||
|
|
||||||
s = _start_scheduler(tmp_db, run_task, available_vram=3.0)
|
s = _start_scheduler(tmp_db, run_task)
|
||||||
s.enqueue(1, "cover_letter", 1, None)
|
s.enqueue(1, "cover_letter", 1, None)
|
||||||
s.enqueue(2, "cover_letter", 2, None)
|
s.enqueue(2, "cover_letter", 2, None)
|
||||||
|
|
||||||
assert done.wait(timeout=5.0), "timed out — task 2 never completed after task 1 crash"
|
assert done.wait(timeout=5.0), "timed out — task 2 never completed after task 1 crash"
|
||||||
s.shutdown()
|
s.shutdown()
|
||||||
|
|
||||||
# Second task still ran, VRAM was released
|
# Second task still ran despite first crashing
|
||||||
assert 2 in log
|
assert 2 in log
|
||||||
assert s._reserved_vram == 0.0
|
|
||||||
|
|
||||||
|
|
||||||
def test_get_scheduler_returns_singleton(tmp_db):
|
def test_get_scheduler_returns_singleton(tmp_db):
|
||||||
|
|
|
||||||
|
|
@ -66,8 +66,12 @@ def test_sync_cookie_prgn_switch_param_overrides_yaml(profile_yaml, monkeypatch)
|
||||||
assert any("prgn_ui=streamlit" in s for s in injected)
|
assert any("prgn_ui=streamlit" in s for s in injected)
|
||||||
|
|
||||||
|
|
||||||
def test_sync_cookie_downgrades_tier_resets_to_streamlit(profile_yaml, monkeypatch):
|
def test_sync_cookie_free_tier_keeps_vue(profile_yaml, monkeypatch):
|
||||||
"""Free-tier user with vue preference gets reset to streamlit."""
|
"""Free-tier user with vue preference keeps vue (vue_ui_beta is free tier).
|
||||||
|
|
||||||
|
Previously this test verified a downgrade to streamlit. Vue SPA was opened
|
||||||
|
to free tier in issue #20 — the downgrade path no longer triggers.
|
||||||
|
"""
|
||||||
import yaml as _yaml
|
import yaml as _yaml
|
||||||
profile_yaml.write_text(_yaml.dump({"name": "T", "ui_preference": "vue"}))
|
profile_yaml.write_text(_yaml.dump({"name": "T", "ui_preference": "vue"}))
|
||||||
|
|
||||||
|
|
@ -80,8 +84,8 @@ def test_sync_cookie_downgrades_tier_resets_to_streamlit(profile_yaml, monkeypat
|
||||||
sync_ui_cookie(profile_yaml, tier="free")
|
sync_ui_cookie(profile_yaml, tier="free")
|
||||||
|
|
||||||
saved = _yaml.safe_load(profile_yaml.read_text())
|
saved = _yaml.safe_load(profile_yaml.read_text())
|
||||||
assert saved["ui_preference"] == "streamlit"
|
assert saved["ui_preference"] == "vue"
|
||||||
assert any("prgn_ui=streamlit" in s for s in injected)
|
assert any("prgn_ui=vue" in s for s in injected)
|
||||||
|
|
||||||
|
|
||||||
def test_switch_ui_writes_yaml_and_calls_sync(profile_yaml, monkeypatch):
|
def test_switch_ui_writes_yaml_and_calls_sync(profile_yaml, monkeypatch):
|
||||||
|
|
|
||||||
|
|
@ -236,7 +236,7 @@ class TestWizardStep:
|
||||||
search_path = tmp_path / "config" / "search_profiles.yaml"
|
search_path = tmp_path / "config" / "search_profiles.yaml"
|
||||||
_write_user_yaml(yaml_path, {})
|
_write_user_yaml(yaml_path, {})
|
||||||
with patch("dev_api._wizard_yaml_path", return_value=str(yaml_path)):
|
with patch("dev_api._wizard_yaml_path", return_value=str(yaml_path)):
|
||||||
with patch("dev_api.SEARCH_PREFS_PATH", search_path):
|
with patch("dev_api._search_prefs_path", return_value=search_path):
|
||||||
r = client.post("/api/wizard/step",
|
r = client.post("/api/wizard/step",
|
||||||
json={"step": 6, "data": {
|
json={"step": 6, "data": {
|
||||||
"titles": ["Software Engineer", "Backend Developer"],
|
"titles": ["Software Engineer", "Backend Developer"],
|
||||||
|
|
|
||||||
|
|
@ -121,7 +121,8 @@ def test_byok_false_preserves_original_gating():
|
||||||
# ── Vue UI Beta & Demo Tier tests ──────────────────────────────────────────────
|
# ── Vue UI Beta & Demo Tier tests ──────────────────────────────────────────────
|
||||||
|
|
||||||
def test_vue_ui_beta_free_tier():
|
def test_vue_ui_beta_free_tier():
|
||||||
assert can_use("free", "vue_ui_beta") is False
|
# Vue SPA is open to all tiers (issue #20 — beta restriction removed)
|
||||||
|
assert can_use("free", "vue_ui_beta") is True
|
||||||
|
|
||||||
|
|
||||||
def test_vue_ui_beta_paid_tier():
|
def test_vue_ui_beta_paid_tier():
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue