diff --git a/circuitforge_core/tasks/__init__.py b/circuitforge_core/tasks/__init__.py new file mode 100644 index 0000000..dede9f5 --- /dev/null +++ b/circuitforge_core/tasks/__init__.py @@ -0,0 +1,14 @@ +# circuitforge_core/tasks/__init__.py +from circuitforge_core.tasks.scheduler import ( + TaskScheduler, + detect_available_vram_gb, + get_scheduler, + reset_scheduler, +) + +__all__ = [ + "TaskScheduler", + "detect_available_vram_gb", + "get_scheduler", + "reset_scheduler", +] diff --git a/circuitforge_core/tasks/scheduler.py b/circuitforge_core/tasks/scheduler.py new file mode 100644 index 0000000..7ccd189 --- /dev/null +++ b/circuitforge_core/tasks/scheduler.py @@ -0,0 +1,331 @@ +# circuitforge_core/tasks/scheduler.py +"""Resource-aware batch scheduler for LLM background tasks. + +Generic scheduler that any CircuitForge product can use. Products supply: + - task_types: frozenset[str] — task type strings routed through this scheduler + - vram_budgets: dict[str, float] — VRAM GB estimate per task type + - run_task_fn — product's task execution function + +VRAM detection priority: + 1. cf-orch coordinator /api/nodes — free VRAM (lease-aware, cooperative) + 2. scripts.preflight.get_gpus() — total GPU VRAM (Peregrine-era fallback) + 3. 999.0 — unlimited (CPU-only or no detection available) + +Public API: + TaskScheduler — the scheduler class + detect_available_vram_gb() — standalone VRAM query helper + get_scheduler() — lazy process-level singleton + reset_scheduler() — test teardown only +""" +from __future__ import annotations + +import logging +import sqlite3 +import threading +from collections import deque, namedtuple +from pathlib import Path +from typing import Callable, Optional + +try: + import httpx as httpx +except ImportError: + httpx = None # type: ignore[assignment] + +logger = logging.getLogger(__name__) + +TaskSpec = namedtuple("TaskSpec", ["id", "job_id", "params"]) + +_DEFAULT_MAX_QUEUE_DEPTH = 500 + + +def detect_available_vram_gb( + coordinator_url: str = "http://localhost:7700", +) -> float: + """Detect available VRAM GB for task scheduling. + + Returns free VRAM via cf-orch (sum across all nodes/GPUs) so the scheduler + cooperates with other cf-orch consumers. Falls back to preflight total VRAM, + then 999.0 (unlimited) if nothing is reachable. + """ + # 1. Try cf-orch: use free VRAM so the scheduler cooperates with other + # cf-orch consumers (vision service, inference services, etc.) + if httpx is not None: + try: + resp = httpx.get(f"{coordinator_url}/api/nodes", timeout=2.0) + if resp.status_code == 200: + nodes = resp.json().get("nodes", []) + total_free_mb = sum( + gpu.get("vram_free_mb", 0) + for node in nodes + for gpu in node.get("gpus", []) + ) + if total_free_mb > 0: + free_gb = total_free_mb / 1024.0 + logger.debug( + "Scheduler VRAM from cf-orch: %.1f GB free", free_gb + ) + return free_gb + except Exception: + pass + + # 2. Try preflight (systems with nvidia-smi; Peregrine-era fallback) + try: + from scripts.preflight import get_gpus # type: ignore[import] + + gpus = get_gpus() + if gpus: + total_gb = sum(g.get("vram_total_gb", 0.0) for g in gpus) + logger.debug( + "Scheduler VRAM from preflight: %.1f GB total", total_gb + ) + return total_gb + except Exception: + pass + + logger.debug( + "Scheduler VRAM detection unavailable — using unlimited (999 GB)" + ) + return 999.0 + + +class TaskScheduler: + """Resource-aware LLM task batch scheduler. + + Runs one batch-worker thread per task type while total reserved VRAM + stays within the detected available budget. Always allows at least one + batch to start even if its budget exceeds available VRAM (prevents + permanent starvation on low-VRAM systems). + + Thread-safety: all queue/active state protected by self._lock. + + Usage:: + + sched = get_scheduler( + db_path=Path("data/app.db"), + run_task_fn=my_run_task, + task_types=frozenset({"cover_letter", "research"}), + vram_budgets={"cover_letter": 2.5, "research": 5.0}, + ) + task_id, is_new = insert_task(db_path, "cover_letter", job_id) + if is_new: + enqueued = sched.enqueue(task_id, "cover_letter", job_id, params_json) + if not enqueued: + mark_task_failed(db_path, task_id, "Queue full") + """ + + def __init__( + self, + db_path: Path, + run_task_fn: Callable, + task_types: frozenset[str], + vram_budgets: dict[str, float], + available_vram_gb: Optional[float] = None, + max_queue_depth: int = _DEFAULT_MAX_QUEUE_DEPTH, + ) -> None: + self._db_path = db_path + self._run_task = run_task_fn + self._task_types = frozenset(task_types) + self._budgets: dict[str, float] = dict(vram_budgets) + self._max_queue_depth = max_queue_depth + + self._lock = threading.Lock() + self._wake = threading.Event() + self._stop = threading.Event() + self._queues: dict[str, deque] = {} + self._active: dict[str, threading.Thread] = {} + self._reserved_vram: float = 0.0 + self._thread: Optional[threading.Thread] = None + + self._available_vram: float = ( + available_vram_gb + if available_vram_gb is not None + else detect_available_vram_gb() + ) + + for t in self._task_types: + if t not in self._budgets: + logger.warning( + "No VRAM budget defined for task type %r — " + "defaulting to 0.0 GB (no VRAM gating for this type)", + t, + ) + + self._load_queued_tasks() + + def enqueue( + self, + task_id: int, + task_type: str, + job_id: int, + params: Optional[str], + ) -> bool: + """Add a task to the scheduler queue. + + Returns True if enqueued successfully. + Returns False if the queue is full — caller should mark the task failed. + """ + with self._lock: + q = self._queues.setdefault(task_type, deque()) + if len(q) >= self._max_queue_depth: + logger.warning( + "Queue depth limit for %s (max=%d) — task %d dropped", + task_type, + self._max_queue_depth, + task_id, + ) + return False + q.append(TaskSpec(task_id, job_id, params)) + self._wake.set() + return True + + def start(self) -> None: + """Start the background scheduler loop thread. Call once after construction.""" + self._thread = threading.Thread( + target=self._scheduler_loop, name="task-scheduler", daemon=True + ) + self._thread.start() + # Wake the loop immediately so tasks loaded from DB at startup are dispatched + if any(self._queues.values()): + self._wake.set() + + def shutdown(self, timeout: float = 5.0) -> None: + """Signal the scheduler to stop and wait for it to exit.""" + self._stop.set() + self._wake.set() + if self._thread and self._thread.is_alive(): + self._thread.join(timeout=timeout) + + def _scheduler_loop(self) -> None: + while not self._stop.is_set(): + self._wake.wait(timeout=30) + self._wake.clear() + with self._lock: + # Reap batch threads that finished without waking us + for t, thread in list(self._active.items()): + if not thread.is_alive(): + self._reserved_vram -= self._budgets.get(t, 0.0) + del self._active[t] + # Start new type batches while VRAM budget allows + candidates = sorted( + [ + t + for t in self._queues + if self._queues[t] and t not in self._active + ], + key=lambda t: len(self._queues[t]), + reverse=True, + ) + for task_type in candidates: + budget = self._budgets.get(task_type, 0.0) + # Always allow at least one batch to run + if ( + self._reserved_vram == 0.0 + or self._reserved_vram + budget <= self._available_vram + ): + thread = threading.Thread( + target=self._batch_worker, + args=(task_type,), + name=f"batch-{task_type}", + daemon=True, + ) + self._active[task_type] = thread + self._reserved_vram += budget + thread.start() + + def _batch_worker(self, task_type: str) -> None: + """Serial consumer for one task type. Runs until the type's deque is empty.""" + try: + while True: + with self._lock: + q = self._queues.get(task_type) + if not q: + break + task = q.popleft() + self._run_task( + self._db_path, task.id, task_type, task.job_id, task.params + ) + finally: + with self._lock: + self._active.pop(task_type, None) + self._reserved_vram -= self._budgets.get(task_type, 0.0) + self._wake.set() + + def _load_queued_tasks(self) -> None: + """Reload surviving 'queued' tasks from SQLite into deques at startup.""" + if not self._task_types: + return + task_types_list = sorted(self._task_types) + placeholders = ",".join("?" * len(task_types_list)) + conn = sqlite3.connect(self._db_path) + try: + rows = conn.execute( + f"SELECT id, task_type, job_id, params FROM background_tasks" + f" WHERE status='queued' AND task_type IN ({placeholders})" + f" ORDER BY created_at ASC", + task_types_list, + ).fetchall() + except sqlite3.OperationalError: + # Table not yet created (first run before migrations) + rows = [] + finally: + conn.close() + + for row_id, task_type, job_id, params in rows: + q = self._queues.setdefault(task_type, deque()) + q.append(TaskSpec(row_id, job_id, params)) + + if rows: + logger.info( + "Scheduler: resumed %d queued task(s) from prior run", len(rows) + ) + + +# ── Process-level singleton ──────────────────────────────────────────────────── + +_scheduler: Optional[TaskScheduler] = None +_scheduler_lock = threading.Lock() + + +def get_scheduler( + db_path: Path, + run_task_fn: Optional[Callable] = None, + task_types: Optional[frozenset[str]] = None, + vram_budgets: Optional[dict[str, float]] = None, + max_queue_depth: int = _DEFAULT_MAX_QUEUE_DEPTH, +) -> TaskScheduler: + """Return the process-level TaskScheduler singleton. + + ``run_task_fn``, ``task_types``, and ``vram_budgets`` are required on the + first call; ignored on subsequent calls (singleton already constructed). + """ + global _scheduler + if _scheduler is None: + with _scheduler_lock: + if _scheduler is None: + if ( + run_task_fn is None + or task_types is None + or vram_budgets is None + ): + raise ValueError( + "run_task_fn, task_types, and vram_budgets are required " + "on the first call to get_scheduler()" + ) + _scheduler = TaskScheduler( + db_path=db_path, + run_task_fn=run_task_fn, + task_types=task_types, + vram_budgets=vram_budgets, + max_queue_depth=max_queue_depth, + ) + _scheduler.start() + return _scheduler + + +def reset_scheduler() -> None: + """Shut down and clear the singleton. TEST TEARDOWN ONLY.""" + global _scheduler + with _scheduler_lock: + if _scheduler is not None: + _scheduler.shutdown() + _scheduler = None diff --git a/pyproject.toml b/pyproject.toml index bbc3026..ab8aa5d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -22,8 +22,12 @@ orch = [ "typer[all]>=0.12", "psutil>=5.9", ] +tasks = [ + "httpx>=0.27", +] dev = [ "circuitforge-core[orch]", + "circuitforge-core[tasks]", "pytest>=8.0", "pytest-asyncio>=0.23", "httpx>=0.27", diff --git a/tests/test_tasks/__init__.py b/tests/test_tasks/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/test_tasks/test_scheduler.py b/tests/test_tasks/test_scheduler.py new file mode 100644 index 0000000..b273545 --- /dev/null +++ b/tests/test_tasks/test_scheduler.py @@ -0,0 +1,243 @@ +"""Tests for circuitforge_core.tasks.scheduler.""" +from __future__ import annotations + +import sqlite3 +import threading +import time +from pathlib import Path +from typing import List +from unittest.mock import MagicMock, patch + +import pytest + +from circuitforge_core.tasks.scheduler import ( + TaskScheduler, + detect_available_vram_gb, + get_scheduler, + reset_scheduler, +) + + +# ── Fixtures ────────────────────────────────────────────────────────────────── + +@pytest.fixture +def tmp_db(tmp_path: Path) -> Path: + """SQLite DB with background_tasks table.""" + db = tmp_path / "test.db" + conn = sqlite3.connect(db) + conn.execute(""" + CREATE TABLE background_tasks ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + task_type TEXT NOT NULL, + job_id INTEGER NOT NULL DEFAULT 0, + status TEXT NOT NULL DEFAULT 'queued', + params TEXT, + error TEXT, + created_at TEXT NOT NULL DEFAULT CURRENT_TIMESTAMP + ) + """) + conn.commit() + conn.close() + return db + + +@pytest.fixture(autouse=True) +def _reset_singleton(): + """Always tear down the scheduler singleton between tests.""" + yield + reset_scheduler() + + +TASK_TYPES = frozenset({"fast_task"}) +BUDGETS = {"fast_task": 1.0} + + +# ── detect_available_vram_gb ────────────────────────────────────────────────── + +def test_detect_vram_from_cfortch(): + """Uses cf-orch free VRAM when coordinator is reachable.""" + mock_resp = MagicMock() + mock_resp.status_code = 200 + mock_resp.json.return_value = { + "nodes": [ + {"node_id": "local", "gpus": [{"vram_free_mb": 4096}, {"vram_free_mb": 4096}]} + ] + } + with patch("circuitforge_core.tasks.scheduler.httpx") as mock_httpx: + mock_httpx.get.return_value = mock_resp + result = detect_available_vram_gb(coordinator_url="http://localhost:7700") + assert result == pytest.approx(8.0) # 4096 + 4096 MB → 8 GB + + +def test_detect_vram_cforch_unavailable_falls_back_to_unlimited(): + """Falls back to 999.0 when cf-orch is unreachable and preflight unavailable.""" + with patch("circuitforge_core.tasks.scheduler.httpx") as mock_httpx: + mock_httpx.get.side_effect = ConnectionRefusedError() + result = detect_available_vram_gb() + assert result == 999.0 + + +def test_detect_vram_cforch_empty_nodes_falls_back(): + """If cf-orch returns no nodes with GPUs, falls back to unlimited.""" + mock_resp = MagicMock() + mock_resp.status_code = 200 + mock_resp.json.return_value = {"nodes": []} + with patch("circuitforge_core.tasks.scheduler.httpx") as mock_httpx: + mock_httpx.get.return_value = mock_resp + result = detect_available_vram_gb() + assert result == 999.0 + + +# ── TaskScheduler basic behaviour ───────────────────────────────────────────── + +def test_enqueue_returns_true_on_success(tmp_db: Path): + ran: List[int] = [] + + def run_fn(db_path, task_id, task_type, job_id, params): + ran.append(task_id) + + sched = TaskScheduler(tmp_db, run_fn, TASK_TYPES, BUDGETS, available_vram_gb=8.0) + sched.start() + result = sched.enqueue(1, "fast_task", 0, None) + sched.shutdown() + assert result is True + + +def test_scheduler_runs_task(tmp_db: Path): + """Enqueued task is executed by the batch worker.""" + ran: List[int] = [] + event = threading.Event() + + def run_fn(db_path, task_id, task_type, job_id, params): + ran.append(task_id) + event.set() + + sched = TaskScheduler(tmp_db, run_fn, TASK_TYPES, BUDGETS, available_vram_gb=8.0) + sched.start() + sched.enqueue(42, "fast_task", 0, None) + assert event.wait(timeout=3.0), "Task was not executed within 3 seconds" + sched.shutdown() + assert ran == [42] + + +def test_enqueue_returns_false_when_queue_full(tmp_db: Path): + """Returns False and does not enqueue when max_queue_depth is reached.""" + gate = threading.Event() + + def blocking_run_fn(db_path, task_id, task_type, job_id, params): + gate.wait() + + sched = TaskScheduler( + tmp_db, blocking_run_fn, TASK_TYPES, BUDGETS, + available_vram_gb=8.0, max_queue_depth=2 + ) + sched.start() + results = [sched.enqueue(i, "fast_task", 0, None) for i in range(1, 10)] + gate.set() + sched.shutdown() + assert False in results + + +def test_scheduler_drains_multiple_tasks(tmp_db: Path): + """All enqueued tasks of the same type are run serially.""" + ran: List[int] = [] + done = threading.Event() + TOTAL = 5 + + def run_fn(db_path, task_id, task_type, job_id, params): + ran.append(task_id) + if len(ran) >= TOTAL: + done.set() + + sched = TaskScheduler(tmp_db, run_fn, TASK_TYPES, BUDGETS, available_vram_gb=8.0) + sched.start() + for i in range(1, TOTAL + 1): + sched.enqueue(i, "fast_task", 0, None) + assert done.wait(timeout=5.0), f"Only ran {len(ran)} of {TOTAL} tasks" + sched.shutdown() + assert sorted(ran) == list(range(1, TOTAL + 1)) + + +def test_vram_budget_blocks_second_type(tmp_db: Path): + """Second task type is not started when VRAM would be exceeded.""" + gate_a = threading.Event() + gate_b = threading.Event() + started = [] + + def run_fn(db_path, task_id, task_type, job_id, params): + started.append(task_type) + if task_type == "type_a": + gate_a.wait() + else: + gate_b.wait() + + two_types = frozenset({"type_a", "type_b"}) + tight_budgets = {"type_a": 4.0, "type_b": 4.0} # 4+4 > 6 GB available + + sched = TaskScheduler( + tmp_db, run_fn, two_types, tight_budgets, available_vram_gb=6.0 + ) + sched.start() + sched.enqueue(1, "type_a", 0, None) + sched.enqueue(2, "type_b", 0, None) + + time.sleep(0.2) + assert started == ["type_a"] + + gate_a.set() + time.sleep(0.2) + gate_b.set() + sched.shutdown() + assert "type_b" in started + + +def test_get_scheduler_singleton(tmp_db: Path): + """get_scheduler() returns the same instance on repeated calls.""" + run_fn = MagicMock() + s1 = get_scheduler(tmp_db, run_fn, TASK_TYPES, BUDGETS) + s2 = get_scheduler(tmp_db) # no run_fn — should reuse existing + assert s1 is s2 + + +def test_reset_scheduler_clears_singleton(tmp_db: Path): + """reset_scheduler() allows a new singleton to be constructed.""" + run_fn = MagicMock() + s1 = get_scheduler(tmp_db, run_fn, TASK_TYPES, BUDGETS) + reset_scheduler() + s2 = get_scheduler(tmp_db, run_fn, TASK_TYPES, BUDGETS) + assert s1 is not s2 + + +def test_load_queued_tasks_on_startup(tmp_db: Path): + """Tasks with status='queued' in the DB at startup are loaded into the deque.""" + conn = sqlite3.connect(tmp_db) + conn.execute( + "INSERT INTO background_tasks (task_type, job_id, status) VALUES ('fast_task', 0, 'queued')" + ) + conn.commit() + conn.close() + + ran: List[int] = [] + done = threading.Event() + + def run_fn(db_path, task_id, task_type, job_id, params): + ran.append(task_id) + done.set() + + sched = TaskScheduler(tmp_db, run_fn, TASK_TYPES, BUDGETS, available_vram_gb=8.0) + sched.start() + assert done.wait(timeout=3.0), "Pre-loaded task was not run" + sched.shutdown() + assert len(ran) == 1 + + +def test_load_queued_tasks_missing_table_does_not_crash(tmp_path: Path): + """Scheduler does not crash if background_tasks table doesn't exist yet.""" + db = tmp_path / "empty.db" + sqlite3.connect(db).close() + + run_fn = MagicMock() + sched = TaskScheduler(db, run_fn, TASK_TYPES, BUDGETS, available_vram_gb=8.0) + sched.start() + sched.shutdown() + # No exception = pass