refactor: replace coordinator-aware TaskScheduler with Protocol + LocalScheduler (MIT); update LLMRouter import path

This commit is contained in:
pyr0ball 2026-04-04 22:26:06 -07:00
parent 090a86ce1b
commit 2259382d0b
3 changed files with 158 additions and 477 deletions

View file

@ -1,6 +1,6 @@
# circuitforge_core/tasks/__init__.py
from circuitforge_core.tasks.scheduler import ( from circuitforge_core.tasks.scheduler import (
TaskScheduler, TaskScheduler,
LocalScheduler,
detect_available_vram_gb, detect_available_vram_gb,
get_scheduler, get_scheduler,
reset_scheduler, reset_scheduler,
@ -8,6 +8,7 @@ from circuitforge_core.tasks.scheduler import (
__all__ = [ __all__ = [
"TaskScheduler", "TaskScheduler",
"LocalScheduler",
"detect_available_vram_gb", "detect_available_vram_gb",
"get_scheduler", "get_scheduler",
"reset_scheduler", "reset_scheduler",

View file

@ -1,21 +1,17 @@
# circuitforge_core/tasks/scheduler.py # circuitforge_core/tasks/scheduler.py
"""Resource-aware batch scheduler for LLM background tasks. """Task scheduler for CircuitForge products — MIT layer.
Generic scheduler that any CircuitForge product can use. Products supply: Provides a simple FIFO task queue with no coordinator dependency.
- task_types: frozenset[str] task type strings routed through this scheduler
- vram_budgets: dict[str, float] VRAM GB estimate per task type
- run_task_fn product's task execution function
VRAM detection priority: For coordinator-aware VRAM-budgeted scheduling on paid/premium tiers, install
1. cf-orch coordinator /api/nodes free VRAM (lease-aware, cooperative) circuitforge-orch and use OrchestratedScheduler instead.
2. scripts.preflight.get_gpus() total GPU VRAM (Peregrine-era fallback)
3. 999.0 unlimited (CPU-only or no detection available)
Public API: Public API:
TaskScheduler the scheduler class TaskScheduler Protocol defining the scheduler interface
detect_available_vram_gb() standalone VRAM query helper LocalScheduler Simple FIFO queue implementation (MIT, no coordinator)
get_scheduler() lazy process-level singleton detect_available_vram_gb() Returns 999.0 (unlimited; no coordinator on free tier)
reset_scheduler() test teardown only get_scheduler() Lazy process-level singleton returning a LocalScheduler
reset_scheduler() Test teardown only
""" """
from __future__ import annotations from __future__ import annotations
@ -24,12 +20,7 @@ import sqlite3
import threading import threading
from collections import deque from collections import deque
from pathlib import Path from pathlib import Path
from typing import Callable, NamedTuple, Optional from typing import Callable, NamedTuple, Optional, Protocol, runtime_checkable
try:
import httpx as httpx
except ImportError:
httpx = None # type: ignore[assignment]
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -41,68 +32,45 @@ class TaskSpec(NamedTuple):
job_id: int job_id: int
params: Optional[str] params: Optional[str]
_DEFAULT_MAX_QUEUE_DEPTH = 500 _DEFAULT_MAX_QUEUE_DEPTH = 500
def detect_available_vram_gb( def detect_available_vram_gb() -> float:
coordinator_url: str = "http://localhost:7700", """Return available VRAM for task scheduling.
) -> float:
"""Detect available VRAM GB for task scheduling.
Returns free VRAM via cf-orch (sum across all nodes/GPUs) so the scheduler Free tier (no coordinator): always returns 999.0 no VRAM gating.
cooperates with other cf-orch consumers. Falls back to preflight total VRAM, For coordinator-aware VRAM detection use circuitforge_orch.scheduler.
then 999.0 (unlimited) if nothing is reachable.
""" """
# 1. Try cf-orch: use free VRAM so the scheduler cooperates with other
# cf-orch consumers (vision service, inference services, etc.)
if httpx is not None:
try:
resp = httpx.get(f"{coordinator_url}/api/nodes", timeout=2.0)
if resp.status_code == 200:
nodes = resp.json().get("nodes", [])
total_free_mb = sum(
gpu.get("vram_free_mb", 0)
for node in nodes
for gpu in node.get("gpus", [])
)
if total_free_mb > 0:
free_gb = total_free_mb / 1024.0
logger.debug(
"Scheduler VRAM from cf-orch: %.1f GB free", free_gb
)
return free_gb
except Exception:
pass
# 2. Try preflight (systems with nvidia-smi; Peregrine-era fallback)
try:
from scripts.preflight import get_gpus # type: ignore[import]
gpus = get_gpus()
if gpus:
total_gb = sum(g.get("vram_total_gb", 0.0) for g in gpus)
logger.debug(
"Scheduler VRAM from preflight: %.1f GB total", total_gb
)
return total_gb
except Exception:
pass
logger.debug(
"Scheduler VRAM detection unavailable — using unlimited (999 GB)"
)
return 999.0 return 999.0
class TaskScheduler: @runtime_checkable
"""Resource-aware LLM task batch scheduler. class TaskScheduler(Protocol):
"""Protocol for task schedulers across free and paid tiers.
Runs one batch-worker thread per task type while total reserved VRAM Both LocalScheduler (MIT) and OrchestratedScheduler (BSL, circuitforge-orch)
stays within the detected available budget. Always allows at least one implement this interface so products can inject either without API changes.
batch to start even if its budget exceeds available VRAM (prevents """
permanent starvation on low-VRAM systems).
Thread-safety: all queue/active state protected by self._lock. def enqueue(self, task_id: int, task_type: str, job_id: int, params: Optional[str]) -> bool:
"""Add a task to the queue. Returns True if enqueued, False if queue full."""
...
def start(self) -> None:
"""Start the background worker thread."""
...
def shutdown(self, timeout: float = 5.0) -> None:
"""Stop the scheduler and wait for it to exit."""
...
class LocalScheduler:
"""Simple FIFO task scheduler with no coordinator dependency.
Processes tasks serially per task type. No VRAM gating all tasks run.
Suitable for free tier (single-host, up to 2 GPUs, static config).
Usage:: Usage::
@ -112,11 +80,7 @@ class TaskScheduler:
task_types=frozenset({"cover_letter", "research"}), task_types=frozenset({"cover_letter", "research"}),
vram_budgets={"cover_letter": 2.5, "research": 5.0}, vram_budgets={"cover_letter": 2.5, "research": 5.0},
) )
task_id, is_new = insert_task(db_path, "cover_letter", job_id)
if is_new:
enqueued = sched.enqueue(task_id, "cover_letter", job_id, params_json) enqueued = sched.enqueue(task_id, "cover_letter", job_id, params_json)
if not enqueued:
mark_task_failed(db_path, task_id, "Queue full")
""" """
def __init__( def __init__(
@ -125,11 +89,7 @@ class TaskScheduler:
run_task_fn: RunTaskFn, run_task_fn: RunTaskFn,
task_types: frozenset[str], task_types: frozenset[str],
vram_budgets: dict[str, float], vram_budgets: dict[str, float],
available_vram_gb: Optional[float] = None,
max_queue_depth: int = _DEFAULT_MAX_QUEUE_DEPTH, max_queue_depth: int = _DEFAULT_MAX_QUEUE_DEPTH,
coordinator_url: str = "http://localhost:7700",
service_name: str = "peregrine",
lease_priority: int = 2,
) -> None: ) -> None:
self._db_path = db_path self._db_path = db_path
self._run_task = run_task_fn self._run_task = run_task_fn
@ -137,54 +97,22 @@ class TaskScheduler:
self._budgets: dict[str, float] = dict(vram_budgets) self._budgets: dict[str, float] = dict(vram_budgets)
self._max_queue_depth = max_queue_depth self._max_queue_depth = max_queue_depth
self._coordinator_url = coordinator_url.rstrip("/")
self._service_name = service_name
self._lease_priority = lease_priority
self._lock = threading.Lock() self._lock = threading.Lock()
self._wake = threading.Event() self._wake = threading.Event()
self._stop = threading.Event() self._stop = threading.Event()
self._queues: dict[str, deque[TaskSpec]] = {} self._queues: dict[str, deque[TaskSpec]] = {}
self._active: dict[str, threading.Thread] = {} self._active: dict[str, threading.Thread] = {}
self._reserved_vram: float = 0.0
self._thread: Optional[threading.Thread] = None self._thread: Optional[threading.Thread] = None
self._available_vram: float = (
available_vram_gb
if available_vram_gb is not None
else detect_available_vram_gb()
)
for t in self._task_types:
if t not in self._budgets:
logger.warning(
"No VRAM budget defined for task type %r"
"defaulting to 0.0 GB (no VRAM gating for this type)",
t,
)
self._load_queued_tasks() self._load_queued_tasks()
def enqueue( def enqueue(self, task_id: int, task_type: str, job_id: int, params: Optional[str]) -> bool:
self,
task_id: int,
task_type: str,
job_id: int,
params: Optional[str],
) -> bool:
"""Add a task to the scheduler queue.
Returns True if enqueued successfully.
Returns False if the queue is full caller should mark the task failed.
"""
with self._lock: with self._lock:
q = self._queues.setdefault(task_type, deque()) q = self._queues.setdefault(task_type, deque())
if len(q) >= self._max_queue_depth: if len(q) >= self._max_queue_depth:
logger.warning( logger.warning(
"Queue depth limit for %s (max=%d) — task %d dropped", "Queue depth limit for %s (max=%d) — task %d dropped",
task_type, task_type, self._max_queue_depth, task_id,
self._max_queue_depth,
task_id,
) )
return False return False
q.append(TaskSpec(task_id, job_id, params)) q.append(TaskSpec(task_id, job_id, params))
@ -192,28 +120,19 @@ class TaskScheduler:
return True return True
def start(self) -> None: def start(self) -> None:
"""Start the background scheduler loop thread. Call once after construction."""
self._thread = threading.Thread( self._thread = threading.Thread(
target=self._scheduler_loop, name="task-scheduler", daemon=True target=self._scheduler_loop, name="task-scheduler", daemon=True
) )
self._thread.start() self._thread.start()
# Wake the loop immediately so tasks loaded from DB at startup are dispatched
with self._lock: with self._lock:
if any(self._queues.values()): if any(self._queues.values()):
self._wake.set() self._wake.set()
def shutdown(self, timeout: float = 5.0) -> None: def shutdown(self, timeout: float = 5.0) -> None:
"""Signal the scheduler to stop and wait for it to exit.
Joins both the scheduler loop thread and any active batch worker
threads so callers can rely on clean state (e.g. _reserved_vram == 0)
immediately after this returns.
"""
self._stop.set() self._stop.set()
self._wake.set() self._wake.set()
if self._thread and self._thread.is_alive(): if self._thread and self._thread.is_alive():
self._thread.join(timeout=timeout) self._thread.join(timeout=timeout)
# Join active batch workers so _reserved_vram is settled on return
with self._lock: with self._lock:
workers = list(self._active.values()) workers = list(self._active.values())
for worker in workers: for worker in workers:
@ -224,29 +143,15 @@ class TaskScheduler:
self._wake.wait(timeout=30) self._wake.wait(timeout=30)
self._wake.clear() self._wake.clear()
with self._lock: with self._lock:
# Reap batch threads that finished without waking us.
# VRAM accounting is handled solely by _batch_worker's finally block;
# the reaper only removes dead entries from _active.
for t, thread in list(self._active.items()): for t, thread in list(self._active.items()):
if not thread.is_alive(): if not thread.is_alive():
del self._active[t] del self._active[t]
# Start new type batches while VRAM budget allows
candidates = sorted( candidates = sorted(
[ [t for t in self._queues if self._queues[t] and t not in self._active],
t
for t in self._queues
if self._queues[t] and t not in self._active
],
key=lambda t: len(self._queues[t]), key=lambda t: len(self._queues[t]),
reverse=True, reverse=True,
) )
for task_type in candidates: for task_type in candidates:
budget = self._budgets.get(task_type, 0.0)
# Always allow at least one batch to run
if (
self._reserved_vram == 0.0
or self._reserved_vram + budget <= self._available_vram
):
thread = threading.Thread( thread = threading.Thread(
target=self._batch_worker, target=self._batch_worker,
args=(task_type,), args=(task_type,),
@ -254,73 +159,9 @@ class TaskScheduler:
daemon=True, daemon=True,
) )
self._active[task_type] = thread self._active[task_type] = thread
self._reserved_vram += budget
thread.start() thread.start()
def _acquire_lease(self, task_type: str) -> Optional[str]:
"""Request a VRAM lease from the coordinator. Returns lease_id or None."""
if httpx is None:
return None
budget_gb = self._budgets.get(task_type, 0.0)
if budget_gb <= 0:
return None
mb = int(budget_gb * 1024)
try:
# Pick the GPU with the most free VRAM on the first registered node
resp = httpx.get(f"{self._coordinator_url}/api/nodes", timeout=2.0)
if resp.status_code != 200:
return None
nodes = resp.json().get("nodes", [])
if not nodes:
return None
best_node = best_gpu = best_free = None
for node in nodes:
for gpu in node.get("gpus", []):
free = gpu.get("vram_free_mb", 0)
if best_free is None or free > best_free:
best_node = node["node_id"]
best_gpu = gpu["gpu_id"]
best_free = free
if best_node is None:
return None
lease_resp = httpx.post(
f"{self._coordinator_url}/api/leases",
json={
"node_id": best_node,
"gpu_id": best_gpu,
"mb": mb,
"service": self._service_name,
"priority": self._lease_priority,
},
timeout=3.0,
)
if lease_resp.status_code == 200:
lease_id = lease_resp.json()["lease"]["lease_id"]
logger.debug(
"Acquired VRAM lease %s for task_type=%s (%d MB)",
lease_id, task_type, mb,
)
return lease_id
except Exception as exc:
logger.debug("Lease acquire failed (non-fatal): %s", exc)
return None
def _release_lease(self, lease_id: str) -> None:
"""Release a coordinator VRAM lease. Best-effort; failures are logged only."""
if httpx is None or not lease_id:
return
try:
httpx.delete(
f"{self._coordinator_url}/api/leases/{lease_id}",
timeout=3.0,
)
logger.debug("Released VRAM lease %s", lease_id)
except Exception as exc:
logger.debug("Lease release failed (non-fatal): %s", exc)
def _batch_worker(self, task_type: str) -> None: def _batch_worker(self, task_type: str) -> None:
"""Serial consumer for one task type. Runs until the type's deque is empty."""
lease_id: Optional[str] = self._acquire_lease(task_type)
try: try:
while True: while True:
with self._lock: with self._lock:
@ -328,19 +169,13 @@ class TaskScheduler:
if not q: if not q:
break break
task = q.popleft() task = q.popleft()
self._run_task( self._run_task(self._db_path, task.id, task_type, task.job_id, task.params)
self._db_path, task.id, task_type, task.job_id, task.params
)
finally: finally:
if lease_id:
self._release_lease(lease_id)
with self._lock: with self._lock:
self._active.pop(task_type, None) self._active.pop(task_type, None)
self._reserved_vram -= self._budgets.get(task_type, 0.0)
self._wake.set() self._wake.set()
def _load_queued_tasks(self) -> None: def _load_queued_tasks(self) -> None:
"""Reload surviving 'queued' tasks from SQLite into deques at startup."""
if not self._task_types: if not self._task_types:
return return
task_types_list = sorted(self._task_types) task_types_list = sorted(self._task_types)
@ -354,68 +189,58 @@ class TaskScheduler:
task_types_list, task_types_list,
).fetchall() ).fetchall()
except sqlite3.OperationalError: except sqlite3.OperationalError:
# Table not yet created (first run before migrations)
rows = [] rows = []
for row_id, task_type, job_id, params in rows: for row_id, task_type, job_id, params in rows:
q = self._queues.setdefault(task_type, deque()) q = self._queues.setdefault(task_type, deque())
q.append(TaskSpec(row_id, job_id, params)) q.append(TaskSpec(row_id, job_id, params))
if rows: if rows:
logger.info( logger.info("Scheduler: resumed %d queued task(s) from prior run", len(rows))
"Scheduler: resumed %d queued task(s) from prior run", len(rows)
)
# ── Process-level singleton ──────────────────────────────────────────────────── # ── Process-level singleton ────────────────────────────────────────────────────
_scheduler: Optional[TaskScheduler] = None _scheduler: Optional[LocalScheduler] = None
_scheduler_lock = threading.Lock() _scheduler_lock = threading.Lock()
def get_scheduler( def get_scheduler(
db_path: Path, db_path: Optional[Path] = None,
run_task_fn: Optional[RunTaskFn] = None, run_task_fn: Optional[RunTaskFn] = None,
task_types: Optional[frozenset[str]] = None, task_types: Optional[frozenset[str]] = None,
vram_budgets: Optional[dict[str, float]] = None, vram_budgets: Optional[dict[str, float]] = None,
max_queue_depth: int = _DEFAULT_MAX_QUEUE_DEPTH, max_queue_depth: int = _DEFAULT_MAX_QUEUE_DEPTH,
coordinator_url: str = "http://localhost:7700", coordinator_url: str = "http://localhost:7700",
service_name: str = "peregrine", service_name: str = "peregrine",
) -> TaskScheduler: ) -> LocalScheduler:
"""Return the process-level TaskScheduler singleton. """Return the process-level LocalScheduler singleton.
``run_task_fn``, ``task_types``, and ``vram_budgets`` are required on the ``run_task_fn``, ``task_types``, ``vram_budgets``, and ``db_path`` are
first call; ignored on subsequent calls (singleton already constructed). required on the first call; ignored on subsequent calls.
VRAM detection (which may involve a network call) is performed outside the ``coordinator_url`` and ``service_name`` are accepted but ignored
lock so the lock is never held across blocking I/O. LocalScheduler has no coordinator. They exist for API compatibility with
OrchestratedScheduler call sites.
""" """
global _scheduler global _scheduler
if _scheduler is not None: if _scheduler is not None:
return _scheduler return _scheduler
# Build outside the lock — TaskScheduler.__init__ may call detect_available_vram_gb() if run_task_fn is None or task_types is None or vram_budgets is None or db_path is None:
# which makes an httpx network call (up to 2 s). Holding the lock during that
# would block any concurrent caller for the full duration.
if run_task_fn is None or task_types is None or vram_budgets is None:
raise ValueError( raise ValueError(
"run_task_fn, task_types, and vram_budgets are required " "db_path, run_task_fn, task_types, and vram_budgets are required "
"on the first call to get_scheduler()" "on the first call to get_scheduler()"
) )
candidate = TaskScheduler( candidate = LocalScheduler(
db_path=db_path, db_path=db_path,
run_task_fn=run_task_fn, run_task_fn=run_task_fn,
task_types=task_types, task_types=task_types,
vram_budgets=vram_budgets, vram_budgets=vram_budgets,
max_queue_depth=max_queue_depth, max_queue_depth=max_queue_depth,
coordinator_url=coordinator_url,
service_name=service_name,
) )
candidate.start() candidate.start()
with _scheduler_lock: with _scheduler_lock:
if _scheduler is None: if _scheduler is None:
_scheduler = candidate _scheduler = candidate
else: else:
# Another thread beat us — shut down our candidate and use the winner.
candidate.shutdown() candidate.shutdown()
return _scheduler return _scheduler

View file

@ -1,17 +1,14 @@
"""Tests for circuitforge_core.tasks.scheduler.""" """Tests for TaskScheduler Protocol + LocalScheduler (MIT, no coordinator)."""
from __future__ import annotations from __future__ import annotations
import sqlite3 import sqlite3
import threading
import time import time
from pathlib import Path from pathlib import Path
from types import ModuleType
from typing import List
from unittest.mock import MagicMock, patch
import pytest import pytest
from circuitforge_core.tasks.scheduler import ( from circuitforge_core.tasks.scheduler import (
LocalScheduler,
TaskScheduler, TaskScheduler,
detect_available_vram_gb, detect_available_vram_gb,
get_scheduler, get_scheduler,
@ -19,267 +16,125 @@ from circuitforge_core.tasks.scheduler import (
) )
# ── Fixtures ──────────────────────────────────────────────────────────────────
@pytest.fixture @pytest.fixture
def tmp_db(tmp_path: Path) -> Path: def db_path(tmp_path: Path) -> Path:
"""SQLite DB with background_tasks table.""" p = tmp_path / "test.db"
db = tmp_path / "test.db" with sqlite3.connect(p) as conn:
conn = sqlite3.connect(db) conn.execute(
conn.execute(""" "CREATE TABLE background_tasks "
CREATE TABLE background_tasks ( "(id INTEGER PRIMARY KEY, task_type TEXT, job_id INTEGER, "
id INTEGER PRIMARY KEY AUTOINCREMENT, "params TEXT, status TEXT DEFAULT 'queued', created_at TEXT DEFAULT '')"
task_type TEXT NOT NULL,
job_id INTEGER NOT NULL DEFAULT 0,
status TEXT NOT NULL DEFAULT 'queued',
params TEXT,
error TEXT,
created_at TEXT NOT NULL DEFAULT CURRENT_TIMESTAMP
) )
""") return p
conn.commit()
conn.close()
return db
@pytest.fixture(autouse=True) @pytest.fixture(autouse=True)
def _reset_singleton(): def clean_singleton():
"""Always tear down the scheduler singleton between tests."""
yield yield
reset_scheduler() reset_scheduler()
TASK_TYPES = frozenset({"fast_task"}) def make_run_fn(results: list):
BUDGETS = {"fast_task": 1.0} def run(db_path, task_id, task_type, job_id, params):
results.append((task_type, task_id))
time.sleep(0.01)
return run
# ── detect_available_vram_gb ────────────────────────────────────────────────── def test_local_scheduler_implements_protocol():
assert isinstance(LocalScheduler.__new__(LocalScheduler), TaskScheduler)
def test_detect_vram_from_cfortch():
"""Uses cf-orch free VRAM when coordinator is reachable."""
mock_resp = MagicMock()
mock_resp.status_code = 200
mock_resp.json.return_value = {
"nodes": [
{"node_id": "local", "gpus": [{"vram_free_mb": 4096}, {"vram_free_mb": 4096}]}
]
}
with patch("circuitforge_core.tasks.scheduler.httpx") as mock_httpx:
mock_httpx.get.return_value = mock_resp
result = detect_available_vram_gb(coordinator_url="http://localhost:7700")
assert result == pytest.approx(8.0) # 4096 + 4096 MB → 8 GB
def test_detect_vram_cforch_unavailable_falls_back_to_unlimited(): def test_detect_available_vram_returns_unlimited():
"""Falls back to 999.0 when cf-orch is unreachable and preflight unavailable.""" assert detect_available_vram_gb() == 999.0
with patch("circuitforge_core.tasks.scheduler.httpx") as mock_httpx:
mock_httpx.get.side_effect = ConnectionRefusedError()
result = detect_available_vram_gb()
assert result == 999.0
def test_detect_vram_cforch_empty_nodes_falls_back(): def test_enqueue_and_execute(db_path):
"""If cf-orch returns no nodes with GPUs, falls back to unlimited.""" results = []
mock_resp = MagicMock() sched = LocalScheduler(
mock_resp.status_code = 200 db_path=db_path,
mock_resp.json.return_value = {"nodes": []} run_task_fn=make_run_fn(results),
with patch("circuitforge_core.tasks.scheduler.httpx") as mock_httpx: task_types=frozenset({"cover_letter"}),
mock_httpx.get.return_value = mock_resp vram_budgets={"cover_letter": 0.0},
result = detect_available_vram_gb()
assert result == 999.0
def test_detect_vram_preflight_fallback():
"""Falls back to preflight total VRAM when cf-orch is unreachable."""
# Build a fake scripts.preflight module with get_gpus returning two GPUs.
fake_scripts = ModuleType("scripts")
fake_preflight = ModuleType("scripts.preflight")
fake_preflight.get_gpus = lambda: [ # type: ignore[attr-defined]
{"vram_total_gb": 8.0},
{"vram_total_gb": 4.0},
]
fake_scripts.preflight = fake_preflight # type: ignore[attr-defined]
with patch("circuitforge_core.tasks.scheduler.httpx") as mock_httpx, \
patch.dict(
__import__("sys").modules,
{"scripts": fake_scripts, "scripts.preflight": fake_preflight},
):
mock_httpx.get.side_effect = ConnectionRefusedError()
result = detect_available_vram_gb()
assert result == pytest.approx(12.0) # 8.0 + 4.0 GB
# ── TaskScheduler basic behaviour ─────────────────────────────────────────────
def test_enqueue_returns_true_on_success(tmp_db: Path):
ran: List[int] = []
def run_fn(db_path, task_id, task_type, job_id, params):
ran.append(task_id)
sched = TaskScheduler(tmp_db, run_fn, TASK_TYPES, BUDGETS, available_vram_gb=8.0)
sched.start()
result = sched.enqueue(1, "fast_task", 0, None)
sched.shutdown()
assert result is True
def test_scheduler_runs_task(tmp_db: Path):
"""Enqueued task is executed by the batch worker."""
ran: List[int] = []
event = threading.Event()
def run_fn(db_path, task_id, task_type, job_id, params):
ran.append(task_id)
event.set()
sched = TaskScheduler(tmp_db, run_fn, TASK_TYPES, BUDGETS, available_vram_gb=8.0)
sched.start()
sched.enqueue(42, "fast_task", 0, None)
assert event.wait(timeout=3.0), "Task was not executed within 3 seconds"
sched.shutdown()
assert ran == [42]
def test_enqueue_returns_false_when_queue_full(tmp_db: Path):
"""Returns False and does not enqueue when max_queue_depth is reached."""
gate = threading.Event()
def blocking_run_fn(db_path, task_id, task_type, job_id, params):
gate.wait()
sched = TaskScheduler(
tmp_db, blocking_run_fn, TASK_TYPES, BUDGETS,
available_vram_gb=8.0, max_queue_depth=2
) )
sched.start() sched.start()
results = [sched.enqueue(i, "fast_task", 0, None) for i in range(1, 10)] sched.enqueue(1, "cover_letter", 1, None)
gate.set() time.sleep(0.3)
sched.shutdown() sched.shutdown()
assert not all(results), "Expected at least one enqueue to be rejected" assert ("cover_letter", 1) in results
def test_scheduler_drains_multiple_tasks(tmp_db: Path): def test_fifo_ordering(db_path):
"""All enqueued tasks of the same type are run serially.""" results = []
ran: List[int] = [] sched = LocalScheduler(
done = threading.Event() db_path=db_path,
TOTAL = 5 run_task_fn=make_run_fn(results),
task_types=frozenset({"t"}),
def run_fn(db_path, task_id, task_type, job_id, params): vram_budgets={"t": 0.0},
ran.append(task_id)
if len(ran) >= TOTAL:
done.set()
sched = TaskScheduler(tmp_db, run_fn, TASK_TYPES, BUDGETS, available_vram_gb=8.0)
sched.start()
for i in range(1, TOTAL + 1):
sched.enqueue(i, "fast_task", 0, None)
assert done.wait(timeout=5.0), f"Only ran {len(ran)} of {TOTAL} tasks"
sched.shutdown()
assert sorted(ran) == list(range(1, TOTAL + 1))
def test_vram_budget_blocks_second_type(tmp_db: Path):
"""Second task type is not started when VRAM would be exceeded."""
type_a_started = threading.Event()
type_b_started = threading.Event()
gate_a = threading.Event()
gate_b = threading.Event()
started = []
def run_fn(db_path, task_id, task_type, job_id, params):
started.append(task_type)
if task_type == "type_a":
type_a_started.set()
gate_a.wait()
else:
type_b_started.set()
gate_b.wait()
two_types = frozenset({"type_a", "type_b"})
tight_budgets = {"type_a": 4.0, "type_b": 4.0} # 4+4 > 6 GB available
sched = TaskScheduler(
tmp_db, run_fn, two_types, tight_budgets, available_vram_gb=6.0
) )
sched.start() sched.start()
sched.enqueue(1, "type_a", 0, None) sched.enqueue(1, "t", 1, None)
sched.enqueue(2, "type_b", 0, None) sched.enqueue(2, "t", 1, None)
sched.enqueue(3, "t", 1, None)
assert type_a_started.wait(timeout=3.0), "type_a never started" time.sleep(0.5)
assert not type_b_started.is_set(), "type_b should be blocked by VRAM"
gate_a.set()
assert type_b_started.wait(timeout=3.0), "type_b never started after type_a finished"
gate_b.set()
sched.shutdown() sched.shutdown()
assert sorted(started) == ["type_a", "type_b"] assert [r[1] for r in results] == [1, 2, 3]
def test_get_scheduler_singleton(tmp_db: Path): def test_queue_depth_limit(db_path):
"""get_scheduler() returns the same instance on repeated calls.""" sched = LocalScheduler(
run_fn = MagicMock() db_path=db_path,
s1 = get_scheduler(tmp_db, run_fn, TASK_TYPES, BUDGETS) run_task_fn=make_run_fn([]),
s2 = get_scheduler(tmp_db) # no run_fn — should reuse existing task_types=frozenset({"t"}),
vram_budgets={"t": 0.0},
max_queue_depth=2,
)
assert sched.enqueue(1, "t", 1, None) is True
assert sched.enqueue(2, "t", 1, None) is True
assert sched.enqueue(3, "t", 1, None) is False
def test_get_scheduler_singleton(db_path):
results = []
s1 = get_scheduler(
db_path=db_path,
run_task_fn=make_run_fn(results),
task_types=frozenset({"t"}),
vram_budgets={"t": 0.0},
)
s2 = get_scheduler(db_path=db_path)
assert s1 is s2 assert s1 is s2
s1.shutdown()
def test_reset_scheduler_clears_singleton(tmp_db: Path): def test_local_scheduler_no_httpx_dependency():
"""reset_scheduler() allows a new singleton to be constructed.""" """LocalScheduler must not import httpx — not in MIT core's hard deps."""
run_fn = MagicMock() import ast, inspect
s1 = get_scheduler(tmp_db, run_fn, TASK_TYPES, BUDGETS) from circuitforge_core.tasks import scheduler as sched_mod
reset_scheduler() src = inspect.getsource(sched_mod)
s2 = get_scheduler(tmp_db, run_fn, TASK_TYPES, BUDGETS) tree = ast.parse(src)
assert s1 is not s2 for node in ast.walk(tree):
if isinstance(node, (ast.Import, ast.ImportFrom)):
names = [a.name for a in getattr(node, 'names', [])]
module = getattr(node, 'module', '') or ''
assert 'httpx' not in names and 'httpx' not in module, \
"LocalScheduler must not import httpx"
def test_load_queued_tasks_on_startup(tmp_db: Path): def test_load_queued_tasks_on_startup(db_path):
"""Tasks with status='queued' in the DB at startup are loaded into the deque.""" """Tasks with status='queued' in the DB at startup are loaded and run."""
conn = sqlite3.connect(tmp_db) with sqlite3.connect(db_path) as conn:
conn.execute( conn.execute(
"INSERT INTO background_tasks (task_type, job_id, status) VALUES ('fast_task', 0, 'queued')" "INSERT INTO background_tasks (id, task_type, job_id, status) VALUES (99, 't', 1, 'queued')"
)
results = []
sched = LocalScheduler(
db_path=db_path,
run_task_fn=make_run_fn(results),
task_types=frozenset({"t"}),
vram_budgets={"t": 0.0},
) )
conn.commit()
conn.close()
ran: List[int] = []
done = threading.Event()
def run_fn(db_path, task_id, task_type, job_id, params):
ran.append(task_id)
done.set()
sched = TaskScheduler(tmp_db, run_fn, TASK_TYPES, BUDGETS, available_vram_gb=8.0)
sched.start() sched.start()
assert done.wait(timeout=3.0), "Pre-loaded task was not run" time.sleep(0.3)
sched.shutdown() sched.shutdown()
assert len(ran) == 1 assert ("t", 99) in results
def test_load_queued_tasks_missing_table_does_not_crash(tmp_path: Path):
"""Scheduler does not crash if background_tasks table doesn't exist yet."""
db = tmp_path / "empty.db"
sqlite3.connect(db).close()
run_fn = MagicMock()
sched = TaskScheduler(db, run_fn, TASK_TYPES, BUDGETS, available_vram_gb=8.0)
sched.start()
sched.shutdown()
# No exception = pass
def test_reserved_vram_zero_after_task_completes(tmp_db: Path):
"""_reserved_vram returns to 0.0 after a task finishes — no double-decrement."""
done = threading.Event()
def run_fn(db_path, task_id, task_type, job_id, params):
done.set()
sched = TaskScheduler(tmp_db, run_fn, TASK_TYPES, BUDGETS, available_vram_gb=8.0)
sched.start()
sched.enqueue(1, "fast_task", 0, None)
assert done.wait(timeout=3.0), "Task never completed"
sched.shutdown()
assert sched._reserved_vram == 0.0, f"Expected 0.0, got {sched._reserved_vram}"