feat(scheduler): implement TaskScheduler.__init__ with budget loading and VRAM detection
This commit is contained in:
parent
fe8da36e00
commit
415e98d401
2 changed files with 121 additions and 1 deletions
|
|
@ -45,7 +45,52 @@ TaskSpec = namedtuple("TaskSpec", ["id", "job_id", "params"])
|
|||
|
||||
class TaskScheduler:
|
||||
"""Resource-aware LLM task batch scheduler. Use get_scheduler() — not direct construction."""
|
||||
pass
|
||||
|
||||
def __init__(self, db_path: Path, run_task_fn: Callable) -> None:
|
||||
self._db_path = db_path
|
||||
self._run_task = run_task_fn
|
||||
|
||||
self._lock = threading.Lock()
|
||||
self._wake = threading.Event()
|
||||
self._stop = threading.Event()
|
||||
self._queues: dict[str, deque] = {}
|
||||
self._active: dict[str, threading.Thread] = {}
|
||||
self._reserved_vram: float = 0.0
|
||||
self._thread: Optional[threading.Thread] = None
|
||||
|
||||
# Load VRAM budgets: defaults + optional config overrides
|
||||
self._budgets: dict[str, float] = dict(DEFAULT_VRAM_BUDGETS)
|
||||
config_path = db_path.parent.parent / "config" / "llm.yaml"
|
||||
self._max_queue_depth: int = 500
|
||||
if config_path.exists():
|
||||
try:
|
||||
import yaml
|
||||
with open(config_path) as f:
|
||||
cfg = yaml.safe_load(f) or {}
|
||||
sched_cfg = cfg.get("scheduler", {})
|
||||
self._budgets.update(sched_cfg.get("vram_budgets", {}))
|
||||
self._max_queue_depth = sched_cfg.get("max_queue_depth", 500)
|
||||
except Exception as exc:
|
||||
logger.warning("Failed to load scheduler config from %s: %s", config_path, exc)
|
||||
|
||||
# Warn on LLM types with no budget entry after merge
|
||||
for t in LLM_TASK_TYPES:
|
||||
if t not in self._budgets:
|
||||
logger.warning(
|
||||
"No VRAM budget defined for LLM task type %r — "
|
||||
"defaulting to 0.0 GB (unlimited concurrency for this type)", t
|
||||
)
|
||||
|
||||
# Detect total GPU VRAM; fall back to unlimited (999) on CPU-only systems.
|
||||
# Uses module-level _get_gpus so tests can monkeypatch scripts.task_scheduler._get_gpus.
|
||||
try:
|
||||
from scripts import task_scheduler as _ts_mod
|
||||
gpus = _ts_mod._get_gpus()
|
||||
self._available_vram: float = (
|
||||
sum(g["vram_total_gb"] for g in gpus) if gpus else 999.0
|
||||
)
|
||||
except Exception:
|
||||
self._available_vram = 999.0
|
||||
|
||||
|
||||
# ── Singleton ─────────────────────────────────────────────────────────────────
|
||||
|
|
|
|||
|
|
@ -53,3 +53,78 @@ def test_reset_running_tasks_returns_zero_when_nothing_running(tmp_db):
|
|||
conn.close()
|
||||
|
||||
assert reset_running_tasks(tmp_db) == 0
|
||||
|
||||
|
||||
from scripts.task_scheduler import (
|
||||
TaskScheduler, LLM_TASK_TYPES, DEFAULT_VRAM_BUDGETS,
|
||||
get_scheduler, reset_scheduler,
|
||||
)
|
||||
|
||||
|
||||
def _noop_run_task(*args, **kwargs):
|
||||
"""Stand-in for _run_task that does nothing."""
|
||||
pass
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def clean_scheduler():
|
||||
"""Reset singleton between every test."""
|
||||
yield
|
||||
try:
|
||||
reset_scheduler()
|
||||
except NotImplementedError:
|
||||
pass
|
||||
|
||||
|
||||
def test_default_budgets_used_when_no_config(tmp_db):
|
||||
"""Scheduler falls back to DEFAULT_VRAM_BUDGETS when config key absent."""
|
||||
s = TaskScheduler(tmp_db, _noop_run_task)
|
||||
assert s._budgets == DEFAULT_VRAM_BUDGETS
|
||||
|
||||
|
||||
def test_config_budgets_override_defaults(tmp_db, tmp_path):
|
||||
"""Values in llm.yaml scheduler.vram_budgets override defaults."""
|
||||
config_dir = tmp_db.parent.parent / "config"
|
||||
config_dir.mkdir(parents=True, exist_ok=True)
|
||||
(config_dir / "llm.yaml").write_text(
|
||||
"scheduler:\n vram_budgets:\n cover_letter: 9.9\n"
|
||||
)
|
||||
s = TaskScheduler(tmp_db, _noop_run_task)
|
||||
assert s._budgets["cover_letter"] == 9.9
|
||||
# Non-overridden keys still use defaults
|
||||
assert s._budgets["company_research"] == DEFAULT_VRAM_BUDGETS["company_research"]
|
||||
|
||||
|
||||
def test_missing_budget_logs_warning(tmp_db, caplog):
|
||||
"""A type in LLM_TASK_TYPES with no budget entry logs a warning."""
|
||||
import logging
|
||||
# Temporarily add a type with no budget
|
||||
original = LLM_TASK_TYPES.copy() if hasattr(LLM_TASK_TYPES, 'copy') else set(LLM_TASK_TYPES)
|
||||
from scripts import task_scheduler as ts
|
||||
ts.LLM_TASK_TYPES = frozenset(LLM_TASK_TYPES | {"orphan_type"})
|
||||
try:
|
||||
with caplog.at_level(logging.WARNING, logger="scripts.task_scheduler"):
|
||||
s = TaskScheduler(tmp_db, _noop_run_task)
|
||||
assert any("orphan_type" in r.message for r in caplog.records)
|
||||
finally:
|
||||
ts.LLM_TASK_TYPES = frozenset(original)
|
||||
|
||||
|
||||
def test_cpu_only_system_gets_unlimited_vram(tmp_db, monkeypatch):
|
||||
"""_available_vram is 999.0 when _get_gpus() returns empty list."""
|
||||
# Patch the module-level _get_gpus in task_scheduler (not preflight)
|
||||
# so __init__'s _ts_mod._get_gpus() call picks up the mock.
|
||||
monkeypatch.setattr("scripts.task_scheduler._get_gpus", lambda: [])
|
||||
s = TaskScheduler(tmp_db, _noop_run_task)
|
||||
assert s._available_vram == 999.0
|
||||
|
||||
|
||||
def test_gpu_vram_summed_across_all_gpus(tmp_db, monkeypatch):
|
||||
"""_available_vram sums vram_total_gb across all detected GPUs."""
|
||||
fake_gpus = [
|
||||
{"name": "RTX 3090", "vram_total_gb": 24.0, "vram_free_gb": 20.0},
|
||||
{"name": "RTX 3090", "vram_total_gb": 24.0, "vram_free_gb": 18.0},
|
||||
]
|
||||
monkeypatch.setattr("scripts.task_scheduler._get_gpus", lambda: fake_gpus)
|
||||
s = TaskScheduler(tmp_db, _noop_run_task)
|
||||
assert s._available_vram == 48.0
|
||||
|
|
|
|||
Loading…
Reference in a new issue