From 415e98d401f3bf3e455130bfc7a3227aa5c93fed Mon Sep 17 00:00:00 2001 From: pyr0ball Date: Sun, 15 Mar 2026 03:32:11 -0700 Subject: [PATCH] feat(scheduler): implement TaskScheduler.__init__ with budget loading and VRAM detection --- scripts/task_scheduler.py | 47 +++++++++++++++++++++- tests/test_task_scheduler.py | 75 ++++++++++++++++++++++++++++++++++++ 2 files changed, 121 insertions(+), 1 deletion(-) diff --git a/scripts/task_scheduler.py b/scripts/task_scheduler.py index b5aa0d4..1d2a29f 100644 --- a/scripts/task_scheduler.py +++ b/scripts/task_scheduler.py @@ -45,7 +45,52 @@ TaskSpec = namedtuple("TaskSpec", ["id", "job_id", "params"]) class TaskScheduler: """Resource-aware LLM task batch scheduler. Use get_scheduler() — not direct construction.""" - pass + + def __init__(self, db_path: Path, run_task_fn: Callable) -> None: + self._db_path = db_path + self._run_task = run_task_fn + + self._lock = threading.Lock() + self._wake = threading.Event() + self._stop = threading.Event() + self._queues: dict[str, deque] = {} + self._active: dict[str, threading.Thread] = {} + self._reserved_vram: float = 0.0 + self._thread: Optional[threading.Thread] = None + + # Load VRAM budgets: defaults + optional config overrides + self._budgets: dict[str, float] = dict(DEFAULT_VRAM_BUDGETS) + config_path = db_path.parent.parent / "config" / "llm.yaml" + self._max_queue_depth: int = 500 + if config_path.exists(): + try: + import yaml + with open(config_path) as f: + cfg = yaml.safe_load(f) or {} + sched_cfg = cfg.get("scheduler", {}) + self._budgets.update(sched_cfg.get("vram_budgets", {})) + self._max_queue_depth = sched_cfg.get("max_queue_depth", 500) + except Exception as exc: + logger.warning("Failed to load scheduler config from %s: %s", config_path, exc) + + # Warn on LLM types with no budget entry after merge + for t in LLM_TASK_TYPES: + if t not in self._budgets: + logger.warning( + "No VRAM budget defined for LLM task type %r — " + "defaulting to 0.0 GB (unlimited concurrency for this type)", t + ) + + # Detect total GPU VRAM; fall back to unlimited (999) on CPU-only systems. + # Uses module-level _get_gpus so tests can monkeypatch scripts.task_scheduler._get_gpus. + try: + from scripts import task_scheduler as _ts_mod + gpus = _ts_mod._get_gpus() + self._available_vram: float = ( + sum(g["vram_total_gb"] for g in gpus) if gpus else 999.0 + ) + except Exception: + self._available_vram = 999.0 # ── Singleton ───────────────────────────────────────────────────────────────── diff --git a/tests/test_task_scheduler.py b/tests/test_task_scheduler.py index f165990..de0dc6e 100644 --- a/tests/test_task_scheduler.py +++ b/tests/test_task_scheduler.py @@ -53,3 +53,78 @@ def test_reset_running_tasks_returns_zero_when_nothing_running(tmp_db): conn.close() assert reset_running_tasks(tmp_db) == 0 + + +from scripts.task_scheduler import ( + TaskScheduler, LLM_TASK_TYPES, DEFAULT_VRAM_BUDGETS, + get_scheduler, reset_scheduler, +) + + +def _noop_run_task(*args, **kwargs): + """Stand-in for _run_task that does nothing.""" + pass + + +@pytest.fixture(autouse=True) +def clean_scheduler(): + """Reset singleton between every test.""" + yield + try: + reset_scheduler() + except NotImplementedError: + pass + + +def test_default_budgets_used_when_no_config(tmp_db): + """Scheduler falls back to DEFAULT_VRAM_BUDGETS when config key absent.""" + s = TaskScheduler(tmp_db, _noop_run_task) + assert s._budgets == DEFAULT_VRAM_BUDGETS + + +def test_config_budgets_override_defaults(tmp_db, tmp_path): + """Values in llm.yaml scheduler.vram_budgets override defaults.""" + config_dir = tmp_db.parent.parent / "config" + config_dir.mkdir(parents=True, exist_ok=True) + (config_dir / "llm.yaml").write_text( + "scheduler:\n vram_budgets:\n cover_letter: 9.9\n" + ) + s = TaskScheduler(tmp_db, _noop_run_task) + assert s._budgets["cover_letter"] == 9.9 + # Non-overridden keys still use defaults + assert s._budgets["company_research"] == DEFAULT_VRAM_BUDGETS["company_research"] + + +def test_missing_budget_logs_warning(tmp_db, caplog): + """A type in LLM_TASK_TYPES with no budget entry logs a warning.""" + import logging + # Temporarily add a type with no budget + original = LLM_TASK_TYPES.copy() if hasattr(LLM_TASK_TYPES, 'copy') else set(LLM_TASK_TYPES) + from scripts import task_scheduler as ts + ts.LLM_TASK_TYPES = frozenset(LLM_TASK_TYPES | {"orphan_type"}) + try: + with caplog.at_level(logging.WARNING, logger="scripts.task_scheduler"): + s = TaskScheduler(tmp_db, _noop_run_task) + assert any("orphan_type" in r.message for r in caplog.records) + finally: + ts.LLM_TASK_TYPES = frozenset(original) + + +def test_cpu_only_system_gets_unlimited_vram(tmp_db, monkeypatch): + """_available_vram is 999.0 when _get_gpus() returns empty list.""" + # Patch the module-level _get_gpus in task_scheduler (not preflight) + # so __init__'s _ts_mod._get_gpus() call picks up the mock. + monkeypatch.setattr("scripts.task_scheduler._get_gpus", lambda: []) + s = TaskScheduler(tmp_db, _noop_run_task) + assert s._available_vram == 999.0 + + +def test_gpu_vram_summed_across_all_gpus(tmp_db, monkeypatch): + """_available_vram sums vram_total_gb across all detected GPUs.""" + fake_gpus = [ + {"name": "RTX 3090", "vram_total_gb": 24.0, "vram_free_gb": 20.0}, + {"name": "RTX 3090", "vram_total_gb": 24.0, "vram_free_gb": 18.0}, + ] + monkeypatch.setattr("scripts.task_scheduler._get_gpus", lambda: fake_gpus) + s = TaskScheduler(tmp_db, _noop_run_task) + assert s._available_vram == 48.0