From 2259382d0b5ae7594ae7ffd47bf34f12febb2575 Mon Sep 17 00:00:00 2001 From: pyr0ball Date: Sat, 4 Apr 2026 22:26:06 -0700 Subject: [PATCH] refactor: replace coordinator-aware TaskScheduler with Protocol + LocalScheduler (MIT); update LLMRouter import path --- circuitforge_core/tasks/__init__.py | 3 +- circuitforge_core/tasks/scheduler.py | 305 ++++++------------------- tests/test_tasks/test_scheduler.py | 327 ++++++++------------------- 3 files changed, 158 insertions(+), 477 deletions(-) diff --git a/circuitforge_core/tasks/__init__.py b/circuitforge_core/tasks/__init__.py index dede9f5..d9ee25b 100644 --- a/circuitforge_core/tasks/__init__.py +++ b/circuitforge_core/tasks/__init__.py @@ -1,6 +1,6 @@ -# circuitforge_core/tasks/__init__.py from circuitforge_core.tasks.scheduler import ( TaskScheduler, + LocalScheduler, detect_available_vram_gb, get_scheduler, reset_scheduler, @@ -8,6 +8,7 @@ from circuitforge_core.tasks.scheduler import ( __all__ = [ "TaskScheduler", + "LocalScheduler", "detect_available_vram_gb", "get_scheduler", "reset_scheduler", diff --git a/circuitforge_core/tasks/scheduler.py b/circuitforge_core/tasks/scheduler.py index 4cb3347..49cf9ba 100644 --- a/circuitforge_core/tasks/scheduler.py +++ b/circuitforge_core/tasks/scheduler.py @@ -1,21 +1,17 @@ # circuitforge_core/tasks/scheduler.py -"""Resource-aware batch scheduler for LLM background tasks. +"""Task scheduler for CircuitForge products — MIT layer. -Generic scheduler that any CircuitForge product can use. Products supply: - - task_types: frozenset[str] — task type strings routed through this scheduler - - vram_budgets: dict[str, float] — VRAM GB estimate per task type - - run_task_fn — product's task execution function +Provides a simple FIFO task queue with no coordinator dependency. -VRAM detection priority: - 1. cf-orch coordinator /api/nodes — free VRAM (lease-aware, cooperative) - 2. scripts.preflight.get_gpus() — total GPU VRAM (Peregrine-era fallback) - 3. 999.0 — unlimited (CPU-only or no detection available) +For coordinator-aware VRAM-budgeted scheduling on paid/premium tiers, install +circuitforge-orch and use OrchestratedScheduler instead. Public API: - TaskScheduler — the scheduler class - detect_available_vram_gb() — standalone VRAM query helper - get_scheduler() — lazy process-level singleton - reset_scheduler() — test teardown only + TaskScheduler — Protocol defining the scheduler interface + LocalScheduler — Simple FIFO queue implementation (MIT, no coordinator) + detect_available_vram_gb() — Returns 999.0 (unlimited; no coordinator on free tier) + get_scheduler() — Lazy process-level singleton returning a LocalScheduler + reset_scheduler() — Test teardown only """ from __future__ import annotations @@ -24,12 +20,7 @@ import sqlite3 import threading from collections import deque from pathlib import Path -from typing import Callable, NamedTuple, Optional - -try: - import httpx as httpx -except ImportError: - httpx = None # type: ignore[assignment] +from typing import Callable, NamedTuple, Optional, Protocol, runtime_checkable logger = logging.getLogger(__name__) @@ -41,68 +32,45 @@ class TaskSpec(NamedTuple): job_id: int params: Optional[str] + _DEFAULT_MAX_QUEUE_DEPTH = 500 -def detect_available_vram_gb( - coordinator_url: str = "http://localhost:7700", -) -> float: - """Detect available VRAM GB for task scheduling. +def detect_available_vram_gb() -> float: + """Return available VRAM for task scheduling. - Returns free VRAM via cf-orch (sum across all nodes/GPUs) so the scheduler - cooperates with other cf-orch consumers. Falls back to preflight total VRAM, - then 999.0 (unlimited) if nothing is reachable. + Free tier (no coordinator): always returns 999.0 — no VRAM gating. + For coordinator-aware VRAM detection use circuitforge_orch.scheduler. """ - # 1. Try cf-orch: use free VRAM so the scheduler cooperates with other - # cf-orch consumers (vision service, inference services, etc.) - if httpx is not None: - try: - resp = httpx.get(f"{coordinator_url}/api/nodes", timeout=2.0) - if resp.status_code == 200: - nodes = resp.json().get("nodes", []) - total_free_mb = sum( - gpu.get("vram_free_mb", 0) - for node in nodes - for gpu in node.get("gpus", []) - ) - if total_free_mb > 0: - free_gb = total_free_mb / 1024.0 - logger.debug( - "Scheduler VRAM from cf-orch: %.1f GB free", free_gb - ) - return free_gb - except Exception: - pass - - # 2. Try preflight (systems with nvidia-smi; Peregrine-era fallback) - try: - from scripts.preflight import get_gpus # type: ignore[import] - - gpus = get_gpus() - if gpus: - total_gb = sum(g.get("vram_total_gb", 0.0) for g in gpus) - logger.debug( - "Scheduler VRAM from preflight: %.1f GB total", total_gb - ) - return total_gb - except Exception: - pass - - logger.debug( - "Scheduler VRAM detection unavailable — using unlimited (999 GB)" - ) return 999.0 -class TaskScheduler: - """Resource-aware LLM task batch scheduler. +@runtime_checkable +class TaskScheduler(Protocol): + """Protocol for task schedulers across free and paid tiers. - Runs one batch-worker thread per task type while total reserved VRAM - stays within the detected available budget. Always allows at least one - batch to start even if its budget exceeds available VRAM (prevents - permanent starvation on low-VRAM systems). + Both LocalScheduler (MIT) and OrchestratedScheduler (BSL, circuitforge-orch) + implement this interface so products can inject either without API changes. + """ - Thread-safety: all queue/active state protected by self._lock. + def enqueue(self, task_id: int, task_type: str, job_id: int, params: Optional[str]) -> bool: + """Add a task to the queue. Returns True if enqueued, False if queue full.""" + ... + + def start(self) -> None: + """Start the background worker thread.""" + ... + + def shutdown(self, timeout: float = 5.0) -> None: + """Stop the scheduler and wait for it to exit.""" + ... + + +class LocalScheduler: + """Simple FIFO task scheduler with no coordinator dependency. + + Processes tasks serially per task type. No VRAM gating — all tasks run. + Suitable for free tier (single-host, up to 2 GPUs, static config). Usage:: @@ -112,11 +80,7 @@ class TaskScheduler: task_types=frozenset({"cover_letter", "research"}), vram_budgets={"cover_letter": 2.5, "research": 5.0}, ) - task_id, is_new = insert_task(db_path, "cover_letter", job_id) - if is_new: - enqueued = sched.enqueue(task_id, "cover_letter", job_id, params_json) - if not enqueued: - mark_task_failed(db_path, task_id, "Queue full") + enqueued = sched.enqueue(task_id, "cover_letter", job_id, params_json) """ def __init__( @@ -125,11 +89,7 @@ class TaskScheduler: run_task_fn: RunTaskFn, task_types: frozenset[str], vram_budgets: dict[str, float], - available_vram_gb: Optional[float] = None, max_queue_depth: int = _DEFAULT_MAX_QUEUE_DEPTH, - coordinator_url: str = "http://localhost:7700", - service_name: str = "peregrine", - lease_priority: int = 2, ) -> None: self._db_path = db_path self._run_task = run_task_fn @@ -137,54 +97,22 @@ class TaskScheduler: self._budgets: dict[str, float] = dict(vram_budgets) self._max_queue_depth = max_queue_depth - self._coordinator_url = coordinator_url.rstrip("/") - self._service_name = service_name - self._lease_priority = lease_priority - self._lock = threading.Lock() self._wake = threading.Event() self._stop = threading.Event() self._queues: dict[str, deque[TaskSpec]] = {} self._active: dict[str, threading.Thread] = {} - self._reserved_vram: float = 0.0 self._thread: Optional[threading.Thread] = None - self._available_vram: float = ( - available_vram_gb - if available_vram_gb is not None - else detect_available_vram_gb() - ) - - for t in self._task_types: - if t not in self._budgets: - logger.warning( - "No VRAM budget defined for task type %r — " - "defaulting to 0.0 GB (no VRAM gating for this type)", - t, - ) - self._load_queued_tasks() - def enqueue( - self, - task_id: int, - task_type: str, - job_id: int, - params: Optional[str], - ) -> bool: - """Add a task to the scheduler queue. - - Returns True if enqueued successfully. - Returns False if the queue is full — caller should mark the task failed. - """ + def enqueue(self, task_id: int, task_type: str, job_id: int, params: Optional[str]) -> bool: with self._lock: q = self._queues.setdefault(task_type, deque()) if len(q) >= self._max_queue_depth: logger.warning( "Queue depth limit for %s (max=%d) — task %d dropped", - task_type, - self._max_queue_depth, - task_id, + task_type, self._max_queue_depth, task_id, ) return False q.append(TaskSpec(task_id, job_id, params)) @@ -192,28 +120,19 @@ class TaskScheduler: return True def start(self) -> None: - """Start the background scheduler loop thread. Call once after construction.""" self._thread = threading.Thread( target=self._scheduler_loop, name="task-scheduler", daemon=True ) self._thread.start() - # Wake the loop immediately so tasks loaded from DB at startup are dispatched with self._lock: if any(self._queues.values()): self._wake.set() def shutdown(self, timeout: float = 5.0) -> None: - """Signal the scheduler to stop and wait for it to exit. - - Joins both the scheduler loop thread and any active batch worker - threads so callers can rely on clean state (e.g. _reserved_vram == 0) - immediately after this returns. - """ self._stop.set() self._wake.set() if self._thread and self._thread.is_alive(): self._thread.join(timeout=timeout) - # Join active batch workers so _reserved_vram is settled on return with self._lock: workers = list(self._active.values()) for worker in workers: @@ -224,103 +143,25 @@ class TaskScheduler: self._wake.wait(timeout=30) self._wake.clear() with self._lock: - # Reap batch threads that finished without waking us. - # VRAM accounting is handled solely by _batch_worker's finally block; - # the reaper only removes dead entries from _active. for t, thread in list(self._active.items()): if not thread.is_alive(): del self._active[t] - # Start new type batches while VRAM budget allows candidates = sorted( - [ - t - for t in self._queues - if self._queues[t] and t not in self._active - ], + [t for t in self._queues if self._queues[t] and t not in self._active], key=lambda t: len(self._queues[t]), reverse=True, ) for task_type in candidates: - budget = self._budgets.get(task_type, 0.0) - # Always allow at least one batch to run - if ( - self._reserved_vram == 0.0 - or self._reserved_vram + budget <= self._available_vram - ): - thread = threading.Thread( - target=self._batch_worker, - args=(task_type,), - name=f"batch-{task_type}", - daemon=True, - ) - self._active[task_type] = thread - self._reserved_vram += budget - thread.start() - - def _acquire_lease(self, task_type: str) -> Optional[str]: - """Request a VRAM lease from the coordinator. Returns lease_id or None.""" - if httpx is None: - return None - budget_gb = self._budgets.get(task_type, 0.0) - if budget_gb <= 0: - return None - mb = int(budget_gb * 1024) - try: - # Pick the GPU with the most free VRAM on the first registered node - resp = httpx.get(f"{self._coordinator_url}/api/nodes", timeout=2.0) - if resp.status_code != 200: - return None - nodes = resp.json().get("nodes", []) - if not nodes: - return None - best_node = best_gpu = best_free = None - for node in nodes: - for gpu in node.get("gpus", []): - free = gpu.get("vram_free_mb", 0) - if best_free is None or free > best_free: - best_node = node["node_id"] - best_gpu = gpu["gpu_id"] - best_free = free - if best_node is None: - return None - lease_resp = httpx.post( - f"{self._coordinator_url}/api/leases", - json={ - "node_id": best_node, - "gpu_id": best_gpu, - "mb": mb, - "service": self._service_name, - "priority": self._lease_priority, - }, - timeout=3.0, - ) - if lease_resp.status_code == 200: - lease_id = lease_resp.json()["lease"]["lease_id"] - logger.debug( - "Acquired VRAM lease %s for task_type=%s (%d MB)", - lease_id, task_type, mb, - ) - return lease_id - except Exception as exc: - logger.debug("Lease acquire failed (non-fatal): %s", exc) - return None - - def _release_lease(self, lease_id: str) -> None: - """Release a coordinator VRAM lease. Best-effort; failures are logged only.""" - if httpx is None or not lease_id: - return - try: - httpx.delete( - f"{self._coordinator_url}/api/leases/{lease_id}", - timeout=3.0, - ) - logger.debug("Released VRAM lease %s", lease_id) - except Exception as exc: - logger.debug("Lease release failed (non-fatal): %s", exc) + thread = threading.Thread( + target=self._batch_worker, + args=(task_type,), + name=f"batch-{task_type}", + daemon=True, + ) + self._active[task_type] = thread + thread.start() def _batch_worker(self, task_type: str) -> None: - """Serial consumer for one task type. Runs until the type's deque is empty.""" - lease_id: Optional[str] = self._acquire_lease(task_type) try: while True: with self._lock: @@ -328,19 +169,13 @@ class TaskScheduler: if not q: break task = q.popleft() - self._run_task( - self._db_path, task.id, task_type, task.job_id, task.params - ) + self._run_task(self._db_path, task.id, task_type, task.job_id, task.params) finally: - if lease_id: - self._release_lease(lease_id) with self._lock: self._active.pop(task_type, None) - self._reserved_vram -= self._budgets.get(task_type, 0.0) self._wake.set() def _load_queued_tasks(self) -> None: - """Reload surviving 'queued' tasks from SQLite into deques at startup.""" if not self._task_types: return task_types_list = sorted(self._task_types) @@ -354,68 +189,58 @@ class TaskScheduler: task_types_list, ).fetchall() except sqlite3.OperationalError: - # Table not yet created (first run before migrations) rows = [] - for row_id, task_type, job_id, params in rows: q = self._queues.setdefault(task_type, deque()) q.append(TaskSpec(row_id, job_id, params)) - if rows: - logger.info( - "Scheduler: resumed %d queued task(s) from prior run", len(rows) - ) + logger.info("Scheduler: resumed %d queued task(s) from prior run", len(rows)) # ── Process-level singleton ──────────────────────────────────────────────────── -_scheduler: Optional[TaskScheduler] = None +_scheduler: Optional[LocalScheduler] = None _scheduler_lock = threading.Lock() def get_scheduler( - db_path: Path, + db_path: Optional[Path] = None, run_task_fn: Optional[RunTaskFn] = None, task_types: Optional[frozenset[str]] = None, vram_budgets: Optional[dict[str, float]] = None, max_queue_depth: int = _DEFAULT_MAX_QUEUE_DEPTH, coordinator_url: str = "http://localhost:7700", service_name: str = "peregrine", -) -> TaskScheduler: - """Return the process-level TaskScheduler singleton. +) -> LocalScheduler: + """Return the process-level LocalScheduler singleton. - ``run_task_fn``, ``task_types``, and ``vram_budgets`` are required on the - first call; ignored on subsequent calls (singleton already constructed). + ``run_task_fn``, ``task_types``, ``vram_budgets``, and ``db_path`` are + required on the first call; ignored on subsequent calls. - VRAM detection (which may involve a network call) is performed outside the - lock so the lock is never held across blocking I/O. + ``coordinator_url`` and ``service_name`` are accepted but ignored — + LocalScheduler has no coordinator. They exist for API compatibility with + OrchestratedScheduler call sites. """ global _scheduler if _scheduler is not None: return _scheduler - # Build outside the lock — TaskScheduler.__init__ may call detect_available_vram_gb() - # which makes an httpx network call (up to 2 s). Holding the lock during that - # would block any concurrent caller for the full duration. - if run_task_fn is None or task_types is None or vram_budgets is None: + if run_task_fn is None or task_types is None or vram_budgets is None or db_path is None: raise ValueError( - "run_task_fn, task_types, and vram_budgets are required " + "db_path, run_task_fn, task_types, and vram_budgets are required " "on the first call to get_scheduler()" ) - candidate = TaskScheduler( + candidate = LocalScheduler( db_path=db_path, run_task_fn=run_task_fn, task_types=task_types, vram_budgets=vram_budgets, max_queue_depth=max_queue_depth, - coordinator_url=coordinator_url, - service_name=service_name, ) candidate.start() with _scheduler_lock: if _scheduler is None: _scheduler = candidate else: - # Another thread beat us — shut down our candidate and use the winner. candidate.shutdown() return _scheduler diff --git a/tests/test_tasks/test_scheduler.py b/tests/test_tasks/test_scheduler.py index 4428e42..59b08af 100644 --- a/tests/test_tasks/test_scheduler.py +++ b/tests/test_tasks/test_scheduler.py @@ -1,17 +1,14 @@ -"""Tests for circuitforge_core.tasks.scheduler.""" +"""Tests for TaskScheduler Protocol + LocalScheduler (MIT, no coordinator).""" from __future__ import annotations import sqlite3 -import threading import time from pathlib import Path -from types import ModuleType -from typing import List -from unittest.mock import MagicMock, patch import pytest from circuitforge_core.tasks.scheduler import ( + LocalScheduler, TaskScheduler, detect_available_vram_gb, get_scheduler, @@ -19,267 +16,125 @@ from circuitforge_core.tasks.scheduler import ( ) -# ── Fixtures ────────────────────────────────────────────────────────────────── - @pytest.fixture -def tmp_db(tmp_path: Path) -> Path: - """SQLite DB with background_tasks table.""" - db = tmp_path / "test.db" - conn = sqlite3.connect(db) - conn.execute(""" - CREATE TABLE background_tasks ( - id INTEGER PRIMARY KEY AUTOINCREMENT, - task_type TEXT NOT NULL, - job_id INTEGER NOT NULL DEFAULT 0, - status TEXT NOT NULL DEFAULT 'queued', - params TEXT, - error TEXT, - created_at TEXT NOT NULL DEFAULT CURRENT_TIMESTAMP +def db_path(tmp_path: Path) -> Path: + p = tmp_path / "test.db" + with sqlite3.connect(p) as conn: + conn.execute( + "CREATE TABLE background_tasks " + "(id INTEGER PRIMARY KEY, task_type TEXT, job_id INTEGER, " + "params TEXT, status TEXT DEFAULT 'queued', created_at TEXT DEFAULT '')" ) - """) - conn.commit() - conn.close() - return db + return p @pytest.fixture(autouse=True) -def _reset_singleton(): - """Always tear down the scheduler singleton between tests.""" +def clean_singleton(): yield reset_scheduler() -TASK_TYPES = frozenset({"fast_task"}) -BUDGETS = {"fast_task": 1.0} +def make_run_fn(results: list): + def run(db_path, task_id, task_type, job_id, params): + results.append((task_type, task_id)) + time.sleep(0.01) + return run -# ── detect_available_vram_gb ────────────────────────────────────────────────── - -def test_detect_vram_from_cfortch(): - """Uses cf-orch free VRAM when coordinator is reachable.""" - mock_resp = MagicMock() - mock_resp.status_code = 200 - mock_resp.json.return_value = { - "nodes": [ - {"node_id": "local", "gpus": [{"vram_free_mb": 4096}, {"vram_free_mb": 4096}]} - ] - } - with patch("circuitforge_core.tasks.scheduler.httpx") as mock_httpx: - mock_httpx.get.return_value = mock_resp - result = detect_available_vram_gb(coordinator_url="http://localhost:7700") - assert result == pytest.approx(8.0) # 4096 + 4096 MB → 8 GB +def test_local_scheduler_implements_protocol(): + assert isinstance(LocalScheduler.__new__(LocalScheduler), TaskScheduler) -def test_detect_vram_cforch_unavailable_falls_back_to_unlimited(): - """Falls back to 999.0 when cf-orch is unreachable and preflight unavailable.""" - with patch("circuitforge_core.tasks.scheduler.httpx") as mock_httpx: - mock_httpx.get.side_effect = ConnectionRefusedError() - result = detect_available_vram_gb() - assert result == 999.0 +def test_detect_available_vram_returns_unlimited(): + assert detect_available_vram_gb() == 999.0 -def test_detect_vram_cforch_empty_nodes_falls_back(): - """If cf-orch returns no nodes with GPUs, falls back to unlimited.""" - mock_resp = MagicMock() - mock_resp.status_code = 200 - mock_resp.json.return_value = {"nodes": []} - with patch("circuitforge_core.tasks.scheduler.httpx") as mock_httpx: - mock_httpx.get.return_value = mock_resp - result = detect_available_vram_gb() - assert result == 999.0 - - -def test_detect_vram_preflight_fallback(): - """Falls back to preflight total VRAM when cf-orch is unreachable.""" - # Build a fake scripts.preflight module with get_gpus returning two GPUs. - fake_scripts = ModuleType("scripts") - fake_preflight = ModuleType("scripts.preflight") - fake_preflight.get_gpus = lambda: [ # type: ignore[attr-defined] - {"vram_total_gb": 8.0}, - {"vram_total_gb": 4.0}, - ] - fake_scripts.preflight = fake_preflight # type: ignore[attr-defined] - - with patch("circuitforge_core.tasks.scheduler.httpx") as mock_httpx, \ - patch.dict( - __import__("sys").modules, - {"scripts": fake_scripts, "scripts.preflight": fake_preflight}, - ): - mock_httpx.get.side_effect = ConnectionRefusedError() - result = detect_available_vram_gb() - - assert result == pytest.approx(12.0) # 8.0 + 4.0 GB - - -# ── TaskScheduler basic behaviour ───────────────────────────────────────────── - -def test_enqueue_returns_true_on_success(tmp_db: Path): - ran: List[int] = [] - - def run_fn(db_path, task_id, task_type, job_id, params): - ran.append(task_id) - - sched = TaskScheduler(tmp_db, run_fn, TASK_TYPES, BUDGETS, available_vram_gb=8.0) - sched.start() - result = sched.enqueue(1, "fast_task", 0, None) - sched.shutdown() - assert result is True - - -def test_scheduler_runs_task(tmp_db: Path): - """Enqueued task is executed by the batch worker.""" - ran: List[int] = [] - event = threading.Event() - - def run_fn(db_path, task_id, task_type, job_id, params): - ran.append(task_id) - event.set() - - sched = TaskScheduler(tmp_db, run_fn, TASK_TYPES, BUDGETS, available_vram_gb=8.0) - sched.start() - sched.enqueue(42, "fast_task", 0, None) - assert event.wait(timeout=3.0), "Task was not executed within 3 seconds" - sched.shutdown() - assert ran == [42] - - -def test_enqueue_returns_false_when_queue_full(tmp_db: Path): - """Returns False and does not enqueue when max_queue_depth is reached.""" - gate = threading.Event() - - def blocking_run_fn(db_path, task_id, task_type, job_id, params): - gate.wait() - - sched = TaskScheduler( - tmp_db, blocking_run_fn, TASK_TYPES, BUDGETS, - available_vram_gb=8.0, max_queue_depth=2 +def test_enqueue_and_execute(db_path): + results = [] + sched = LocalScheduler( + db_path=db_path, + run_task_fn=make_run_fn(results), + task_types=frozenset({"cover_letter"}), + vram_budgets={"cover_letter": 0.0}, ) sched.start() - results = [sched.enqueue(i, "fast_task", 0, None) for i in range(1, 10)] - gate.set() + sched.enqueue(1, "cover_letter", 1, None) + time.sleep(0.3) sched.shutdown() - assert not all(results), "Expected at least one enqueue to be rejected" + assert ("cover_letter", 1) in results -def test_scheduler_drains_multiple_tasks(tmp_db: Path): - """All enqueued tasks of the same type are run serially.""" - ran: List[int] = [] - done = threading.Event() - TOTAL = 5 - - def run_fn(db_path, task_id, task_type, job_id, params): - ran.append(task_id) - if len(ran) >= TOTAL: - done.set() - - sched = TaskScheduler(tmp_db, run_fn, TASK_TYPES, BUDGETS, available_vram_gb=8.0) - sched.start() - for i in range(1, TOTAL + 1): - sched.enqueue(i, "fast_task", 0, None) - assert done.wait(timeout=5.0), f"Only ran {len(ran)} of {TOTAL} tasks" - sched.shutdown() - assert sorted(ran) == list(range(1, TOTAL + 1)) - - -def test_vram_budget_blocks_second_type(tmp_db: Path): - """Second task type is not started when VRAM would be exceeded.""" - type_a_started = threading.Event() - type_b_started = threading.Event() - gate_a = threading.Event() - gate_b = threading.Event() - started = [] - - def run_fn(db_path, task_id, task_type, job_id, params): - started.append(task_type) - if task_type == "type_a": - type_a_started.set() - gate_a.wait() - else: - type_b_started.set() - gate_b.wait() - - two_types = frozenset({"type_a", "type_b"}) - tight_budgets = {"type_a": 4.0, "type_b": 4.0} # 4+4 > 6 GB available - - sched = TaskScheduler( - tmp_db, run_fn, two_types, tight_budgets, available_vram_gb=6.0 +def test_fifo_ordering(db_path): + results = [] + sched = LocalScheduler( + db_path=db_path, + run_task_fn=make_run_fn(results), + task_types=frozenset({"t"}), + vram_budgets={"t": 0.0}, ) sched.start() - sched.enqueue(1, "type_a", 0, None) - sched.enqueue(2, "type_b", 0, None) - - assert type_a_started.wait(timeout=3.0), "type_a never started" - assert not type_b_started.is_set(), "type_b should be blocked by VRAM" - - gate_a.set() - assert type_b_started.wait(timeout=3.0), "type_b never started after type_a finished" - gate_b.set() + sched.enqueue(1, "t", 1, None) + sched.enqueue(2, "t", 1, None) + sched.enqueue(3, "t", 1, None) + time.sleep(0.5) sched.shutdown() - assert sorted(started) == ["type_a", "type_b"] + assert [r[1] for r in results] == [1, 2, 3] -def test_get_scheduler_singleton(tmp_db: Path): - """get_scheduler() returns the same instance on repeated calls.""" - run_fn = MagicMock() - s1 = get_scheduler(tmp_db, run_fn, TASK_TYPES, BUDGETS) - s2 = get_scheduler(tmp_db) # no run_fn — should reuse existing +def test_queue_depth_limit(db_path): + sched = LocalScheduler( + db_path=db_path, + run_task_fn=make_run_fn([]), + task_types=frozenset({"t"}), + vram_budgets={"t": 0.0}, + max_queue_depth=2, + ) + assert sched.enqueue(1, "t", 1, None) is True + assert sched.enqueue(2, "t", 1, None) is True + assert sched.enqueue(3, "t", 1, None) is False + + +def test_get_scheduler_singleton(db_path): + results = [] + s1 = get_scheduler( + db_path=db_path, + run_task_fn=make_run_fn(results), + task_types=frozenset({"t"}), + vram_budgets={"t": 0.0}, + ) + s2 = get_scheduler(db_path=db_path) assert s1 is s2 + s1.shutdown() -def test_reset_scheduler_clears_singleton(tmp_db: Path): - """reset_scheduler() allows a new singleton to be constructed.""" - run_fn = MagicMock() - s1 = get_scheduler(tmp_db, run_fn, TASK_TYPES, BUDGETS) - reset_scheduler() - s2 = get_scheduler(tmp_db, run_fn, TASK_TYPES, BUDGETS) - assert s1 is not s2 +def test_local_scheduler_no_httpx_dependency(): + """LocalScheduler must not import httpx — not in MIT core's hard deps.""" + import ast, inspect + from circuitforge_core.tasks import scheduler as sched_mod + src = inspect.getsource(sched_mod) + tree = ast.parse(src) + for node in ast.walk(tree): + if isinstance(node, (ast.Import, ast.ImportFrom)): + names = [a.name for a in getattr(node, 'names', [])] + module = getattr(node, 'module', '') or '' + assert 'httpx' not in names and 'httpx' not in module, \ + "LocalScheduler must not import httpx" -def test_load_queued_tasks_on_startup(tmp_db: Path): - """Tasks with status='queued' in the DB at startup are loaded into the deque.""" - conn = sqlite3.connect(tmp_db) - conn.execute( - "INSERT INTO background_tasks (task_type, job_id, status) VALUES ('fast_task', 0, 'queued')" +def test_load_queued_tasks_on_startup(db_path): + """Tasks with status='queued' in the DB at startup are loaded and run.""" + with sqlite3.connect(db_path) as conn: + conn.execute( + "INSERT INTO background_tasks (id, task_type, job_id, status) VALUES (99, 't', 1, 'queued')" + ) + results = [] + sched = LocalScheduler( + db_path=db_path, + run_task_fn=make_run_fn(results), + task_types=frozenset({"t"}), + vram_budgets={"t": 0.0}, ) - conn.commit() - conn.close() - - ran: List[int] = [] - done = threading.Event() - - def run_fn(db_path, task_id, task_type, job_id, params): - ran.append(task_id) - done.set() - - sched = TaskScheduler(tmp_db, run_fn, TASK_TYPES, BUDGETS, available_vram_gb=8.0) sched.start() - assert done.wait(timeout=3.0), "Pre-loaded task was not run" + time.sleep(0.3) sched.shutdown() - assert len(ran) == 1 - - -def test_load_queued_tasks_missing_table_does_not_crash(tmp_path: Path): - """Scheduler does not crash if background_tasks table doesn't exist yet.""" - db = tmp_path / "empty.db" - sqlite3.connect(db).close() - - run_fn = MagicMock() - sched = TaskScheduler(db, run_fn, TASK_TYPES, BUDGETS, available_vram_gb=8.0) - sched.start() - sched.shutdown() - # No exception = pass - - -def test_reserved_vram_zero_after_task_completes(tmp_db: Path): - """_reserved_vram returns to 0.0 after a task finishes — no double-decrement.""" - done = threading.Event() - - def run_fn(db_path, task_id, task_type, job_id, params): - done.set() - - sched = TaskScheduler(tmp_db, run_fn, TASK_TYPES, BUDGETS, available_vram_gb=8.0) - sched.start() - sched.enqueue(1, "fast_task", 0, None) - assert done.wait(timeout=3.0), "Task never completed" - sched.shutdown() - assert sched._reserved_vram == 0.0, f"Expected 0.0, got {sched._reserved_vram}" + assert ("t", 99) in results