# scripts/task_scheduler.py """Resource-aware batch scheduler for LLM background tasks. Routes LLM task types through per-type deques with VRAM-aware scheduling. Non-LLM tasks bypass this module — routing lives in scripts/task_runner.py. Public API: LLM_TASK_TYPES — set of task type strings routed through the scheduler get_scheduler() — lazy singleton accessor reset_scheduler() — test teardown only """ import logging import sqlite3 import threading from collections import deque, namedtuple from pathlib import Path from typing import Callable, Optional # Module-level import so tests can monkeypatch scripts.task_scheduler._get_gpus try: from scripts.preflight import get_gpus as _get_gpus except Exception: # graceful degradation if preflight unavailable _get_gpus = lambda: [] logger = logging.getLogger(__name__) # Task types that go through the scheduler (all others spawn free threads) LLM_TASK_TYPES: frozenset[str] = frozenset({ "cover_letter", "company_research", "wizard_generate", }) # Conservative peak VRAM estimates (GB) per task type. # Overridable per-install via scheduler.vram_budgets in config/llm.yaml. DEFAULT_VRAM_BUDGETS: dict[str, float] = { "cover_letter": 2.5, # alex-cover-writer:latest (~2GB GGUF + headroom) "company_research": 5.0, # llama3.1:8b or vllm model "wizard_generate": 2.5, # same model family as cover_letter } # Lightweight task descriptor stored in per-type deques TaskSpec = namedtuple("TaskSpec", ["id", "job_id", "params"]) class TaskScheduler: """Resource-aware LLM task batch scheduler. Use get_scheduler() — not direct construction.""" def __init__(self, db_path: Path, run_task_fn: Callable) -> None: self._db_path = db_path self._run_task = run_task_fn self._lock = threading.Lock() self._wake = threading.Event() self._stop = threading.Event() self._queues: dict[str, deque] = {} self._active: dict[str, threading.Thread] = {} self._reserved_vram: float = 0.0 self._thread: Optional[threading.Thread] = None # Load VRAM budgets: defaults + optional config overrides self._budgets: dict[str, float] = dict(DEFAULT_VRAM_BUDGETS) config_path = db_path.parent.parent / "config" / "llm.yaml" self._max_queue_depth: int = 500 if config_path.exists(): try: import yaml with open(config_path) as f: cfg = yaml.safe_load(f) or {} sched_cfg = cfg.get("scheduler", {}) self._budgets.update(sched_cfg.get("vram_budgets", {})) self._max_queue_depth = sched_cfg.get("max_queue_depth", 500) except Exception as exc: logger.warning("Failed to load scheduler config from %s: %s", config_path, exc) # Warn on LLM types with no budget entry after merge for t in LLM_TASK_TYPES: if t not in self._budgets: logger.warning( "No VRAM budget defined for LLM task type %r — " "defaulting to 0.0 GB (unlimited concurrency for this type)", t ) # Detect total GPU VRAM; fall back to unlimited (999) on CPU-only systems. # Uses module-level _get_gpus so tests can monkeypatch scripts.task_scheduler._get_gpus. try: gpus = _get_gpus() self._available_vram: float = ( sum(g["vram_total_gb"] for g in gpus) if gpus else 999.0 ) except Exception: self._available_vram = 999.0 # Durability: reload surviving 'queued' LLM tasks from prior run self._load_queued_tasks() def enqueue(self, task_id: int, task_type: str, job_id: int, params: Optional[str]) -> None: """Add an LLM task to the scheduler queue. If the queue for this type is at max_queue_depth, the task is marked failed in SQLite immediately (no ghost queued rows) and a warning is logged. """ from scripts.db import update_task_status with self._lock: q = self._queues.setdefault(task_type, deque()) if len(q) >= self._max_queue_depth: logger.warning( "Queue depth limit reached for %s (max=%d) — task %d dropped", task_type, self._max_queue_depth, task_id, ) update_task_status(self._db_path, task_id, "failed", error="Queue depth limit reached") return q.append(TaskSpec(task_id, job_id, params)) self._wake.set() def start(self) -> None: """Start the background scheduler loop thread. Call once after construction.""" self._thread = threading.Thread( target=self._scheduler_loop, name="task-scheduler", daemon=True ) self._thread.start() def shutdown(self, timeout: float = 5.0) -> None: """Signal the scheduler to stop and wait for it to exit.""" self._stop.set() self._wake.set() # unblock any wait() if self._thread and self._thread.is_alive(): self._thread.join(timeout=timeout) def _scheduler_loop(self) -> None: """Main scheduler daemon — wakes on enqueue or batch completion.""" while not self._stop.is_set(): self._wake.wait(timeout=30) self._wake.clear() with self._lock: # Defense in depth: reap externally-killed batch threads. # In normal operation _active.pop() runs in finally before _wake fires, # so this reap finds nothing — no double-decrement risk. for t, thread in list(self._active.items()): if not thread.is_alive(): self._reserved_vram -= self._budgets.get(t, 0.0) del self._active[t] # Start new type batches while VRAM allows candidates = sorted( [t for t in self._queues if self._queues[t] and t not in self._active], key=lambda t: len(self._queues[t]), reverse=True, ) for task_type in candidates: budget = self._budgets.get(task_type, 0.0) # Always allow at least one batch to run even if its budget # exceeds _available_vram (prevents permanent starvation when # a single type's budget is larger than the VRAM ceiling). if self._reserved_vram == 0.0 or self._reserved_vram + budget <= self._available_vram: thread = threading.Thread( target=self._batch_worker, args=(task_type,), name=f"batch-{task_type}", daemon=True, ) self._active[task_type] = thread self._reserved_vram += budget thread.start() def _batch_worker(self, task_type: str) -> None: """Serial consumer for one task type. Runs until the type's deque is empty.""" try: while True: with self._lock: q = self._queues.get(task_type) if not q: break task = q.popleft() # _run_task is scripts.task_runner._run_task (passed at construction) self._run_task( self._db_path, task.id, task_type, task.job_id, task.params ) finally: # Always release — even if _run_task raises. # _active.pop here prevents the scheduler loop reap from double-decrementing. with self._lock: self._active.pop(task_type, None) self._reserved_vram -= self._budgets.get(task_type, 0.0) self._wake.set() def _load_queued_tasks(self) -> None: """Load pre-existing queued LLM tasks from SQLite into deques (called once in __init__).""" llm_types = sorted(LLM_TASK_TYPES) # sorted for deterministic SQL params in logs placeholders = ",".join("?" * len(llm_types)) conn = sqlite3.connect(self._db_path) rows = conn.execute( f"SELECT id, task_type, job_id, params FROM background_tasks" f" WHERE status='queued' AND task_type IN ({placeholders})" f" ORDER BY created_at ASC", llm_types, ).fetchall() conn.close() for row_id, task_type, job_id, params in rows: q = self._queues.setdefault(task_type, deque()) q.append(TaskSpec(row_id, job_id, params)) if rows: logger.info("Scheduler: resumed %d queued task(s) from prior run", len(rows)) # ── Singleton ───────────────────────────────────────────────────────────────── _scheduler: Optional[TaskScheduler] = None _scheduler_lock = threading.Lock() def get_scheduler(db_path: Path, run_task_fn: Callable = None) -> TaskScheduler: """Return the process-level TaskScheduler singleton, constructing it if needed. run_task_fn is required on the first call; ignored on subsequent calls. Safety: inner lock + double-check prevents double-construction under races. The outer None check is a fast-path performance optimisation only. """ global _scheduler if _scheduler is None: # fast path — avoids lock on steady state with _scheduler_lock: if _scheduler is None: # re-check under lock (double-checked locking) if run_task_fn is None: raise ValueError("run_task_fn required on first get_scheduler() call") _scheduler = TaskScheduler(db_path, run_task_fn) _scheduler.start() return _scheduler def reset_scheduler() -> None: """Shut down and clear the singleton. TEST TEARDOWN ONLY.""" global _scheduler with _scheduler_lock: if _scheduler is not None: _scheduler.shutdown() _scheduler = None