refactor: replace coordinator-aware TaskScheduler with Protocol + LocalScheduler (MIT); update LLMRouter import path

This commit is contained in:
pyr0ball 2026-04-04 22:26:06 -07:00
parent 090a86ce1b
commit 2259382d0b
3 changed files with 158 additions and 477 deletions

View file

@ -1,6 +1,6 @@
# circuitforge_core/tasks/__init__.py
from circuitforge_core.tasks.scheduler import (
TaskScheduler,
LocalScheduler,
detect_available_vram_gb,
get_scheduler,
reset_scheduler,
@ -8,6 +8,7 @@ from circuitforge_core.tasks.scheduler import (
__all__ = [
"TaskScheduler",
"LocalScheduler",
"detect_available_vram_gb",
"get_scheduler",
"reset_scheduler",

View file

@ -1,21 +1,17 @@
# circuitforge_core/tasks/scheduler.py
"""Resource-aware batch scheduler for LLM background tasks.
"""Task scheduler for CircuitForge products — MIT layer.
Generic scheduler that any CircuitForge product can use. Products supply:
- task_types: frozenset[str] task type strings routed through this scheduler
- vram_budgets: dict[str, float] VRAM GB estimate per task type
- run_task_fn product's task execution function
Provides a simple FIFO task queue with no coordinator dependency.
VRAM detection priority:
1. cf-orch coordinator /api/nodes free VRAM (lease-aware, cooperative)
2. scripts.preflight.get_gpus() total GPU VRAM (Peregrine-era fallback)
3. 999.0 unlimited (CPU-only or no detection available)
For coordinator-aware VRAM-budgeted scheduling on paid/premium tiers, install
circuitforge-orch and use OrchestratedScheduler instead.
Public API:
TaskScheduler the scheduler class
detect_available_vram_gb() standalone VRAM query helper
get_scheduler() lazy process-level singleton
reset_scheduler() test teardown only
TaskScheduler Protocol defining the scheduler interface
LocalScheduler Simple FIFO queue implementation (MIT, no coordinator)
detect_available_vram_gb() Returns 999.0 (unlimited; no coordinator on free tier)
get_scheduler() Lazy process-level singleton returning a LocalScheduler
reset_scheduler() Test teardown only
"""
from __future__ import annotations
@ -24,12 +20,7 @@ import sqlite3
import threading
from collections import deque
from pathlib import Path
from typing import Callable, NamedTuple, Optional
try:
import httpx as httpx
except ImportError:
httpx = None # type: ignore[assignment]
from typing import Callable, NamedTuple, Optional, Protocol, runtime_checkable
logger = logging.getLogger(__name__)
@ -41,68 +32,45 @@ class TaskSpec(NamedTuple):
job_id: int
params: Optional[str]
_DEFAULT_MAX_QUEUE_DEPTH = 500
def detect_available_vram_gb(
coordinator_url: str = "http://localhost:7700",
) -> float:
"""Detect available VRAM GB for task scheduling.
def detect_available_vram_gb() -> float:
"""Return available VRAM for task scheduling.
Returns free VRAM via cf-orch (sum across all nodes/GPUs) so the scheduler
cooperates with other cf-orch consumers. Falls back to preflight total VRAM,
then 999.0 (unlimited) if nothing is reachable.
Free tier (no coordinator): always returns 999.0 no VRAM gating.
For coordinator-aware VRAM detection use circuitforge_orch.scheduler.
"""
# 1. Try cf-orch: use free VRAM so the scheduler cooperates with other
# cf-orch consumers (vision service, inference services, etc.)
if httpx is not None:
try:
resp = httpx.get(f"{coordinator_url}/api/nodes", timeout=2.0)
if resp.status_code == 200:
nodes = resp.json().get("nodes", [])
total_free_mb = sum(
gpu.get("vram_free_mb", 0)
for node in nodes
for gpu in node.get("gpus", [])
)
if total_free_mb > 0:
free_gb = total_free_mb / 1024.0
logger.debug(
"Scheduler VRAM from cf-orch: %.1f GB free", free_gb
)
return free_gb
except Exception:
pass
# 2. Try preflight (systems with nvidia-smi; Peregrine-era fallback)
try:
from scripts.preflight import get_gpus # type: ignore[import]
gpus = get_gpus()
if gpus:
total_gb = sum(g.get("vram_total_gb", 0.0) for g in gpus)
logger.debug(
"Scheduler VRAM from preflight: %.1f GB total", total_gb
)
return total_gb
except Exception:
pass
logger.debug(
"Scheduler VRAM detection unavailable — using unlimited (999 GB)"
)
return 999.0
class TaskScheduler:
"""Resource-aware LLM task batch scheduler.
@runtime_checkable
class TaskScheduler(Protocol):
"""Protocol for task schedulers across free and paid tiers.
Runs one batch-worker thread per task type while total reserved VRAM
stays within the detected available budget. Always allows at least one
batch to start even if its budget exceeds available VRAM (prevents
permanent starvation on low-VRAM systems).
Both LocalScheduler (MIT) and OrchestratedScheduler (BSL, circuitforge-orch)
implement this interface so products can inject either without API changes.
"""
Thread-safety: all queue/active state protected by self._lock.
def enqueue(self, task_id: int, task_type: str, job_id: int, params: Optional[str]) -> bool:
"""Add a task to the queue. Returns True if enqueued, False if queue full."""
...
def start(self) -> None:
"""Start the background worker thread."""
...
def shutdown(self, timeout: float = 5.0) -> None:
"""Stop the scheduler and wait for it to exit."""
...
class LocalScheduler:
"""Simple FIFO task scheduler with no coordinator dependency.
Processes tasks serially per task type. No VRAM gating all tasks run.
Suitable for free tier (single-host, up to 2 GPUs, static config).
Usage::
@ -112,11 +80,7 @@ class TaskScheduler:
task_types=frozenset({"cover_letter", "research"}),
vram_budgets={"cover_letter": 2.5, "research": 5.0},
)
task_id, is_new = insert_task(db_path, "cover_letter", job_id)
if is_new:
enqueued = sched.enqueue(task_id, "cover_letter", job_id, params_json)
if not enqueued:
mark_task_failed(db_path, task_id, "Queue full")
enqueued = sched.enqueue(task_id, "cover_letter", job_id, params_json)
"""
def __init__(
@ -125,11 +89,7 @@ class TaskScheduler:
run_task_fn: RunTaskFn,
task_types: frozenset[str],
vram_budgets: dict[str, float],
available_vram_gb: Optional[float] = None,
max_queue_depth: int = _DEFAULT_MAX_QUEUE_DEPTH,
coordinator_url: str = "http://localhost:7700",
service_name: str = "peregrine",
lease_priority: int = 2,
) -> None:
self._db_path = db_path
self._run_task = run_task_fn
@ -137,54 +97,22 @@ class TaskScheduler:
self._budgets: dict[str, float] = dict(vram_budgets)
self._max_queue_depth = max_queue_depth
self._coordinator_url = coordinator_url.rstrip("/")
self._service_name = service_name
self._lease_priority = lease_priority
self._lock = threading.Lock()
self._wake = threading.Event()
self._stop = threading.Event()
self._queues: dict[str, deque[TaskSpec]] = {}
self._active: dict[str, threading.Thread] = {}
self._reserved_vram: float = 0.0
self._thread: Optional[threading.Thread] = None
self._available_vram: float = (
available_vram_gb
if available_vram_gb is not None
else detect_available_vram_gb()
)
for t in self._task_types:
if t not in self._budgets:
logger.warning(
"No VRAM budget defined for task type %r"
"defaulting to 0.0 GB (no VRAM gating for this type)",
t,
)
self._load_queued_tasks()
def enqueue(
self,
task_id: int,
task_type: str,
job_id: int,
params: Optional[str],
) -> bool:
"""Add a task to the scheduler queue.
Returns True if enqueued successfully.
Returns False if the queue is full caller should mark the task failed.
"""
def enqueue(self, task_id: int, task_type: str, job_id: int, params: Optional[str]) -> bool:
with self._lock:
q = self._queues.setdefault(task_type, deque())
if len(q) >= self._max_queue_depth:
logger.warning(
"Queue depth limit for %s (max=%d) — task %d dropped",
task_type,
self._max_queue_depth,
task_id,
task_type, self._max_queue_depth, task_id,
)
return False
q.append(TaskSpec(task_id, job_id, params))
@ -192,28 +120,19 @@ class TaskScheduler:
return True
def start(self) -> None:
"""Start the background scheduler loop thread. Call once after construction."""
self._thread = threading.Thread(
target=self._scheduler_loop, name="task-scheduler", daemon=True
)
self._thread.start()
# Wake the loop immediately so tasks loaded from DB at startup are dispatched
with self._lock:
if any(self._queues.values()):
self._wake.set()
def shutdown(self, timeout: float = 5.0) -> None:
"""Signal the scheduler to stop and wait for it to exit.
Joins both the scheduler loop thread and any active batch worker
threads so callers can rely on clean state (e.g. _reserved_vram == 0)
immediately after this returns.
"""
self._stop.set()
self._wake.set()
if self._thread and self._thread.is_alive():
self._thread.join(timeout=timeout)
# Join active batch workers so _reserved_vram is settled on return
with self._lock:
workers = list(self._active.values())
for worker in workers:
@ -224,103 +143,25 @@ class TaskScheduler:
self._wake.wait(timeout=30)
self._wake.clear()
with self._lock:
# Reap batch threads that finished without waking us.
# VRAM accounting is handled solely by _batch_worker's finally block;
# the reaper only removes dead entries from _active.
for t, thread in list(self._active.items()):
if not thread.is_alive():
del self._active[t]
# Start new type batches while VRAM budget allows
candidates = sorted(
[
t
for t in self._queues
if self._queues[t] and t not in self._active
],
[t for t in self._queues if self._queues[t] and t not in self._active],
key=lambda t: len(self._queues[t]),
reverse=True,
)
for task_type in candidates:
budget = self._budgets.get(task_type, 0.0)
# Always allow at least one batch to run
if (
self._reserved_vram == 0.0
or self._reserved_vram + budget <= self._available_vram
):
thread = threading.Thread(
target=self._batch_worker,
args=(task_type,),
name=f"batch-{task_type}",
daemon=True,
)
self._active[task_type] = thread
self._reserved_vram += budget
thread.start()
def _acquire_lease(self, task_type: str) -> Optional[str]:
"""Request a VRAM lease from the coordinator. Returns lease_id or None."""
if httpx is None:
return None
budget_gb = self._budgets.get(task_type, 0.0)
if budget_gb <= 0:
return None
mb = int(budget_gb * 1024)
try:
# Pick the GPU with the most free VRAM on the first registered node
resp = httpx.get(f"{self._coordinator_url}/api/nodes", timeout=2.0)
if resp.status_code != 200:
return None
nodes = resp.json().get("nodes", [])
if not nodes:
return None
best_node = best_gpu = best_free = None
for node in nodes:
for gpu in node.get("gpus", []):
free = gpu.get("vram_free_mb", 0)
if best_free is None or free > best_free:
best_node = node["node_id"]
best_gpu = gpu["gpu_id"]
best_free = free
if best_node is None:
return None
lease_resp = httpx.post(
f"{self._coordinator_url}/api/leases",
json={
"node_id": best_node,
"gpu_id": best_gpu,
"mb": mb,
"service": self._service_name,
"priority": self._lease_priority,
},
timeout=3.0,
)
if lease_resp.status_code == 200:
lease_id = lease_resp.json()["lease"]["lease_id"]
logger.debug(
"Acquired VRAM lease %s for task_type=%s (%d MB)",
lease_id, task_type, mb,
)
return lease_id
except Exception as exc:
logger.debug("Lease acquire failed (non-fatal): %s", exc)
return None
def _release_lease(self, lease_id: str) -> None:
"""Release a coordinator VRAM lease. Best-effort; failures are logged only."""
if httpx is None or not lease_id:
return
try:
httpx.delete(
f"{self._coordinator_url}/api/leases/{lease_id}",
timeout=3.0,
)
logger.debug("Released VRAM lease %s", lease_id)
except Exception as exc:
logger.debug("Lease release failed (non-fatal): %s", exc)
thread = threading.Thread(
target=self._batch_worker,
args=(task_type,),
name=f"batch-{task_type}",
daemon=True,
)
self._active[task_type] = thread
thread.start()
def _batch_worker(self, task_type: str) -> None:
"""Serial consumer for one task type. Runs until the type's deque is empty."""
lease_id: Optional[str] = self._acquire_lease(task_type)
try:
while True:
with self._lock:
@ -328,19 +169,13 @@ class TaskScheduler:
if not q:
break
task = q.popleft()
self._run_task(
self._db_path, task.id, task_type, task.job_id, task.params
)
self._run_task(self._db_path, task.id, task_type, task.job_id, task.params)
finally:
if lease_id:
self._release_lease(lease_id)
with self._lock:
self._active.pop(task_type, None)
self._reserved_vram -= self._budgets.get(task_type, 0.0)
self._wake.set()
def _load_queued_tasks(self) -> None:
"""Reload surviving 'queued' tasks from SQLite into deques at startup."""
if not self._task_types:
return
task_types_list = sorted(self._task_types)
@ -354,68 +189,58 @@ class TaskScheduler:
task_types_list,
).fetchall()
except sqlite3.OperationalError:
# Table not yet created (first run before migrations)
rows = []
for row_id, task_type, job_id, params in rows:
q = self._queues.setdefault(task_type, deque())
q.append(TaskSpec(row_id, job_id, params))
if rows:
logger.info(
"Scheduler: resumed %d queued task(s) from prior run", len(rows)
)
logger.info("Scheduler: resumed %d queued task(s) from prior run", len(rows))
# ── Process-level singleton ────────────────────────────────────────────────────
_scheduler: Optional[TaskScheduler] = None
_scheduler: Optional[LocalScheduler] = None
_scheduler_lock = threading.Lock()
def get_scheduler(
db_path: Path,
db_path: Optional[Path] = None,
run_task_fn: Optional[RunTaskFn] = None,
task_types: Optional[frozenset[str]] = None,
vram_budgets: Optional[dict[str, float]] = None,
max_queue_depth: int = _DEFAULT_MAX_QUEUE_DEPTH,
coordinator_url: str = "http://localhost:7700",
service_name: str = "peregrine",
) -> TaskScheduler:
"""Return the process-level TaskScheduler singleton.
) -> LocalScheduler:
"""Return the process-level LocalScheduler singleton.
``run_task_fn``, ``task_types``, and ``vram_budgets`` are required on the
first call; ignored on subsequent calls (singleton already constructed).
``run_task_fn``, ``task_types``, ``vram_budgets``, and ``db_path`` are
required on the first call; ignored on subsequent calls.
VRAM detection (which may involve a network call) is performed outside the
lock so the lock is never held across blocking I/O.
``coordinator_url`` and ``service_name`` are accepted but ignored
LocalScheduler has no coordinator. They exist for API compatibility with
OrchestratedScheduler call sites.
"""
global _scheduler
if _scheduler is not None:
return _scheduler
# Build outside the lock — TaskScheduler.__init__ may call detect_available_vram_gb()
# which makes an httpx network call (up to 2 s). Holding the lock during that
# would block any concurrent caller for the full duration.
if run_task_fn is None or task_types is None or vram_budgets is None:
if run_task_fn is None or task_types is None or vram_budgets is None or db_path is None:
raise ValueError(
"run_task_fn, task_types, and vram_budgets are required "
"db_path, run_task_fn, task_types, and vram_budgets are required "
"on the first call to get_scheduler()"
)
candidate = TaskScheduler(
candidate = LocalScheduler(
db_path=db_path,
run_task_fn=run_task_fn,
task_types=task_types,
vram_budgets=vram_budgets,
max_queue_depth=max_queue_depth,
coordinator_url=coordinator_url,
service_name=service_name,
)
candidate.start()
with _scheduler_lock:
if _scheduler is None:
_scheduler = candidate
else:
# Another thread beat us — shut down our candidate and use the winner.
candidate.shutdown()
return _scheduler

View file

@ -1,17 +1,14 @@
"""Tests for circuitforge_core.tasks.scheduler."""
"""Tests for TaskScheduler Protocol + LocalScheduler (MIT, no coordinator)."""
from __future__ import annotations
import sqlite3
import threading
import time
from pathlib import Path
from types import ModuleType
from typing import List
from unittest.mock import MagicMock, patch
import pytest
from circuitforge_core.tasks.scheduler import (
LocalScheduler,
TaskScheduler,
detect_available_vram_gb,
get_scheduler,
@ -19,267 +16,125 @@ from circuitforge_core.tasks.scheduler import (
)
# ── Fixtures ──────────────────────────────────────────────────────────────────
@pytest.fixture
def tmp_db(tmp_path: Path) -> Path:
"""SQLite DB with background_tasks table."""
db = tmp_path / "test.db"
conn = sqlite3.connect(db)
conn.execute("""
CREATE TABLE background_tasks (
id INTEGER PRIMARY KEY AUTOINCREMENT,
task_type TEXT NOT NULL,
job_id INTEGER NOT NULL DEFAULT 0,
status TEXT NOT NULL DEFAULT 'queued',
params TEXT,
error TEXT,
created_at TEXT NOT NULL DEFAULT CURRENT_TIMESTAMP
def db_path(tmp_path: Path) -> Path:
p = tmp_path / "test.db"
with sqlite3.connect(p) as conn:
conn.execute(
"CREATE TABLE background_tasks "
"(id INTEGER PRIMARY KEY, task_type TEXT, job_id INTEGER, "
"params TEXT, status TEXT DEFAULT 'queued', created_at TEXT DEFAULT '')"
)
""")
conn.commit()
conn.close()
return db
return p
@pytest.fixture(autouse=True)
def _reset_singleton():
"""Always tear down the scheduler singleton between tests."""
def clean_singleton():
yield
reset_scheduler()
TASK_TYPES = frozenset({"fast_task"})
BUDGETS = {"fast_task": 1.0}
def make_run_fn(results: list):
def run(db_path, task_id, task_type, job_id, params):
results.append((task_type, task_id))
time.sleep(0.01)
return run
# ── detect_available_vram_gb ──────────────────────────────────────────────────
def test_detect_vram_from_cfortch():
"""Uses cf-orch free VRAM when coordinator is reachable."""
mock_resp = MagicMock()
mock_resp.status_code = 200
mock_resp.json.return_value = {
"nodes": [
{"node_id": "local", "gpus": [{"vram_free_mb": 4096}, {"vram_free_mb": 4096}]}
]
}
with patch("circuitforge_core.tasks.scheduler.httpx") as mock_httpx:
mock_httpx.get.return_value = mock_resp
result = detect_available_vram_gb(coordinator_url="http://localhost:7700")
assert result == pytest.approx(8.0) # 4096 + 4096 MB → 8 GB
def test_local_scheduler_implements_protocol():
assert isinstance(LocalScheduler.__new__(LocalScheduler), TaskScheduler)
def test_detect_vram_cforch_unavailable_falls_back_to_unlimited():
"""Falls back to 999.0 when cf-orch is unreachable and preflight unavailable."""
with patch("circuitforge_core.tasks.scheduler.httpx") as mock_httpx:
mock_httpx.get.side_effect = ConnectionRefusedError()
result = detect_available_vram_gb()
assert result == 999.0
def test_detect_available_vram_returns_unlimited():
assert detect_available_vram_gb() == 999.0
def test_detect_vram_cforch_empty_nodes_falls_back():
"""If cf-orch returns no nodes with GPUs, falls back to unlimited."""
mock_resp = MagicMock()
mock_resp.status_code = 200
mock_resp.json.return_value = {"nodes": []}
with patch("circuitforge_core.tasks.scheduler.httpx") as mock_httpx:
mock_httpx.get.return_value = mock_resp
result = detect_available_vram_gb()
assert result == 999.0
def test_detect_vram_preflight_fallback():
"""Falls back to preflight total VRAM when cf-orch is unreachable."""
# Build a fake scripts.preflight module with get_gpus returning two GPUs.
fake_scripts = ModuleType("scripts")
fake_preflight = ModuleType("scripts.preflight")
fake_preflight.get_gpus = lambda: [ # type: ignore[attr-defined]
{"vram_total_gb": 8.0},
{"vram_total_gb": 4.0},
]
fake_scripts.preflight = fake_preflight # type: ignore[attr-defined]
with patch("circuitforge_core.tasks.scheduler.httpx") as mock_httpx, \
patch.dict(
__import__("sys").modules,
{"scripts": fake_scripts, "scripts.preflight": fake_preflight},
):
mock_httpx.get.side_effect = ConnectionRefusedError()
result = detect_available_vram_gb()
assert result == pytest.approx(12.0) # 8.0 + 4.0 GB
# ── TaskScheduler basic behaviour ─────────────────────────────────────────────
def test_enqueue_returns_true_on_success(tmp_db: Path):
ran: List[int] = []
def run_fn(db_path, task_id, task_type, job_id, params):
ran.append(task_id)
sched = TaskScheduler(tmp_db, run_fn, TASK_TYPES, BUDGETS, available_vram_gb=8.0)
sched.start()
result = sched.enqueue(1, "fast_task", 0, None)
sched.shutdown()
assert result is True
def test_scheduler_runs_task(tmp_db: Path):
"""Enqueued task is executed by the batch worker."""
ran: List[int] = []
event = threading.Event()
def run_fn(db_path, task_id, task_type, job_id, params):
ran.append(task_id)
event.set()
sched = TaskScheduler(tmp_db, run_fn, TASK_TYPES, BUDGETS, available_vram_gb=8.0)
sched.start()
sched.enqueue(42, "fast_task", 0, None)
assert event.wait(timeout=3.0), "Task was not executed within 3 seconds"
sched.shutdown()
assert ran == [42]
def test_enqueue_returns_false_when_queue_full(tmp_db: Path):
"""Returns False and does not enqueue when max_queue_depth is reached."""
gate = threading.Event()
def blocking_run_fn(db_path, task_id, task_type, job_id, params):
gate.wait()
sched = TaskScheduler(
tmp_db, blocking_run_fn, TASK_TYPES, BUDGETS,
available_vram_gb=8.0, max_queue_depth=2
def test_enqueue_and_execute(db_path):
results = []
sched = LocalScheduler(
db_path=db_path,
run_task_fn=make_run_fn(results),
task_types=frozenset({"cover_letter"}),
vram_budgets={"cover_letter": 0.0},
)
sched.start()
results = [sched.enqueue(i, "fast_task", 0, None) for i in range(1, 10)]
gate.set()
sched.enqueue(1, "cover_letter", 1, None)
time.sleep(0.3)
sched.shutdown()
assert not all(results), "Expected at least one enqueue to be rejected"
assert ("cover_letter", 1) in results
def test_scheduler_drains_multiple_tasks(tmp_db: Path):
"""All enqueued tasks of the same type are run serially."""
ran: List[int] = []
done = threading.Event()
TOTAL = 5
def run_fn(db_path, task_id, task_type, job_id, params):
ran.append(task_id)
if len(ran) >= TOTAL:
done.set()
sched = TaskScheduler(tmp_db, run_fn, TASK_TYPES, BUDGETS, available_vram_gb=8.0)
sched.start()
for i in range(1, TOTAL + 1):
sched.enqueue(i, "fast_task", 0, None)
assert done.wait(timeout=5.0), f"Only ran {len(ran)} of {TOTAL} tasks"
sched.shutdown()
assert sorted(ran) == list(range(1, TOTAL + 1))
def test_vram_budget_blocks_second_type(tmp_db: Path):
"""Second task type is not started when VRAM would be exceeded."""
type_a_started = threading.Event()
type_b_started = threading.Event()
gate_a = threading.Event()
gate_b = threading.Event()
started = []
def run_fn(db_path, task_id, task_type, job_id, params):
started.append(task_type)
if task_type == "type_a":
type_a_started.set()
gate_a.wait()
else:
type_b_started.set()
gate_b.wait()
two_types = frozenset({"type_a", "type_b"})
tight_budgets = {"type_a": 4.0, "type_b": 4.0} # 4+4 > 6 GB available
sched = TaskScheduler(
tmp_db, run_fn, two_types, tight_budgets, available_vram_gb=6.0
def test_fifo_ordering(db_path):
results = []
sched = LocalScheduler(
db_path=db_path,
run_task_fn=make_run_fn(results),
task_types=frozenset({"t"}),
vram_budgets={"t": 0.0},
)
sched.start()
sched.enqueue(1, "type_a", 0, None)
sched.enqueue(2, "type_b", 0, None)
assert type_a_started.wait(timeout=3.0), "type_a never started"
assert not type_b_started.is_set(), "type_b should be blocked by VRAM"
gate_a.set()
assert type_b_started.wait(timeout=3.0), "type_b never started after type_a finished"
gate_b.set()
sched.enqueue(1, "t", 1, None)
sched.enqueue(2, "t", 1, None)
sched.enqueue(3, "t", 1, None)
time.sleep(0.5)
sched.shutdown()
assert sorted(started) == ["type_a", "type_b"]
assert [r[1] for r in results] == [1, 2, 3]
def test_get_scheduler_singleton(tmp_db: Path):
"""get_scheduler() returns the same instance on repeated calls."""
run_fn = MagicMock()
s1 = get_scheduler(tmp_db, run_fn, TASK_TYPES, BUDGETS)
s2 = get_scheduler(tmp_db) # no run_fn — should reuse existing
def test_queue_depth_limit(db_path):
sched = LocalScheduler(
db_path=db_path,
run_task_fn=make_run_fn([]),
task_types=frozenset({"t"}),
vram_budgets={"t": 0.0},
max_queue_depth=2,
)
assert sched.enqueue(1, "t", 1, None) is True
assert sched.enqueue(2, "t", 1, None) is True
assert sched.enqueue(3, "t", 1, None) is False
def test_get_scheduler_singleton(db_path):
results = []
s1 = get_scheduler(
db_path=db_path,
run_task_fn=make_run_fn(results),
task_types=frozenset({"t"}),
vram_budgets={"t": 0.0},
)
s2 = get_scheduler(db_path=db_path)
assert s1 is s2
s1.shutdown()
def test_reset_scheduler_clears_singleton(tmp_db: Path):
"""reset_scheduler() allows a new singleton to be constructed."""
run_fn = MagicMock()
s1 = get_scheduler(tmp_db, run_fn, TASK_TYPES, BUDGETS)
reset_scheduler()
s2 = get_scheduler(tmp_db, run_fn, TASK_TYPES, BUDGETS)
assert s1 is not s2
def test_local_scheduler_no_httpx_dependency():
"""LocalScheduler must not import httpx — not in MIT core's hard deps."""
import ast, inspect
from circuitforge_core.tasks import scheduler as sched_mod
src = inspect.getsource(sched_mod)
tree = ast.parse(src)
for node in ast.walk(tree):
if isinstance(node, (ast.Import, ast.ImportFrom)):
names = [a.name for a in getattr(node, 'names', [])]
module = getattr(node, 'module', '') or ''
assert 'httpx' not in names and 'httpx' not in module, \
"LocalScheduler must not import httpx"
def test_load_queued_tasks_on_startup(tmp_db: Path):
"""Tasks with status='queued' in the DB at startup are loaded into the deque."""
conn = sqlite3.connect(tmp_db)
conn.execute(
"INSERT INTO background_tasks (task_type, job_id, status) VALUES ('fast_task', 0, 'queued')"
def test_load_queued_tasks_on_startup(db_path):
"""Tasks with status='queued' in the DB at startup are loaded and run."""
with sqlite3.connect(db_path) as conn:
conn.execute(
"INSERT INTO background_tasks (id, task_type, job_id, status) VALUES (99, 't', 1, 'queued')"
)
results = []
sched = LocalScheduler(
db_path=db_path,
run_task_fn=make_run_fn(results),
task_types=frozenset({"t"}),
vram_budgets={"t": 0.0},
)
conn.commit()
conn.close()
ran: List[int] = []
done = threading.Event()
def run_fn(db_path, task_id, task_type, job_id, params):
ran.append(task_id)
done.set()
sched = TaskScheduler(tmp_db, run_fn, TASK_TYPES, BUDGETS, available_vram_gb=8.0)
sched.start()
assert done.wait(timeout=3.0), "Pre-loaded task was not run"
time.sleep(0.3)
sched.shutdown()
assert len(ran) == 1
def test_load_queued_tasks_missing_table_does_not_crash(tmp_path: Path):
"""Scheduler does not crash if background_tasks table doesn't exist yet."""
db = tmp_path / "empty.db"
sqlite3.connect(db).close()
run_fn = MagicMock()
sched = TaskScheduler(db, run_fn, TASK_TYPES, BUDGETS, available_vram_gb=8.0)
sched.start()
sched.shutdown()
# No exception = pass
def test_reserved_vram_zero_after_task_completes(tmp_db: Path):
"""_reserved_vram returns to 0.0 after a task finishes — no double-decrement."""
done = threading.Event()
def run_fn(db_path, task_id, task_type, job_id, params):
done.set()
sched = TaskScheduler(tmp_db, run_fn, TASK_TYPES, BUDGETS, available_vram_gb=8.0)
sched.start()
sched.enqueue(1, "fast_task", 0, None)
assert done.wait(timeout=3.0), "Task never completed"
sched.shutdown()
assert sched._reserved_vram == 0.0, f"Expected 0.0, got {sched._reserved_vram}"
assert ("t", 99) in results