feat(tasks): add shared VRAM-aware LLM task scheduler

Extract generic batch scheduler into circuitforge_core.tasks.scheduler
so any CircuitForge product can use it. Includes VRAM detection via
cf-orch coordinator (cooperative free-VRAM), preflight fallback, and
unlimited fallback; singleton API; full test coverage (12 tests).
This commit is contained in:
pyr0ball 2026-03-30 23:12:23 -07:00
parent db4e3047fd
commit 5801928f8e
5 changed files with 592 additions and 0 deletions

View file

@ -0,0 +1,14 @@
# circuitforge_core/tasks/__init__.py
from circuitforge_core.tasks.scheduler import (
TaskScheduler,
detect_available_vram_gb,
get_scheduler,
reset_scheduler,
)
__all__ = [
"TaskScheduler",
"detect_available_vram_gb",
"get_scheduler",
"reset_scheduler",
]

View file

@ -0,0 +1,331 @@
# circuitforge_core/tasks/scheduler.py
"""Resource-aware batch scheduler for LLM background tasks.
Generic scheduler that any CircuitForge product can use. Products supply:
- task_types: frozenset[str] task type strings routed through this scheduler
- vram_budgets: dict[str, float] VRAM GB estimate per task type
- run_task_fn product's task execution function
VRAM detection priority:
1. cf-orch coordinator /api/nodes free VRAM (lease-aware, cooperative)
2. scripts.preflight.get_gpus() total GPU VRAM (Peregrine-era fallback)
3. 999.0 unlimited (CPU-only or no detection available)
Public API:
TaskScheduler the scheduler class
detect_available_vram_gb() standalone VRAM query helper
get_scheduler() lazy process-level singleton
reset_scheduler() test teardown only
"""
from __future__ import annotations
import logging
import sqlite3
import threading
from collections import deque, namedtuple
from pathlib import Path
from typing import Callable, Optional
try:
import httpx as httpx
except ImportError:
httpx = None # type: ignore[assignment]
logger = logging.getLogger(__name__)
TaskSpec = namedtuple("TaskSpec", ["id", "job_id", "params"])
_DEFAULT_MAX_QUEUE_DEPTH = 500
def detect_available_vram_gb(
coordinator_url: str = "http://localhost:7700",
) -> float:
"""Detect available VRAM GB for task scheduling.
Returns free VRAM via cf-orch (sum across all nodes/GPUs) so the scheduler
cooperates with other cf-orch consumers. Falls back to preflight total VRAM,
then 999.0 (unlimited) if nothing is reachable.
"""
# 1. Try cf-orch: use free VRAM so the scheduler cooperates with other
# cf-orch consumers (vision service, inference services, etc.)
if httpx is not None:
try:
resp = httpx.get(f"{coordinator_url}/api/nodes", timeout=2.0)
if resp.status_code == 200:
nodes = resp.json().get("nodes", [])
total_free_mb = sum(
gpu.get("vram_free_mb", 0)
for node in nodes
for gpu in node.get("gpus", [])
)
if total_free_mb > 0:
free_gb = total_free_mb / 1024.0
logger.debug(
"Scheduler VRAM from cf-orch: %.1f GB free", free_gb
)
return free_gb
except Exception:
pass
# 2. Try preflight (systems with nvidia-smi; Peregrine-era fallback)
try:
from scripts.preflight import get_gpus # type: ignore[import]
gpus = get_gpus()
if gpus:
total_gb = sum(g.get("vram_total_gb", 0.0) for g in gpus)
logger.debug(
"Scheduler VRAM from preflight: %.1f GB total", total_gb
)
return total_gb
except Exception:
pass
logger.debug(
"Scheduler VRAM detection unavailable — using unlimited (999 GB)"
)
return 999.0
class TaskScheduler:
"""Resource-aware LLM task batch scheduler.
Runs one batch-worker thread per task type while total reserved VRAM
stays within the detected available budget. Always allows at least one
batch to start even if its budget exceeds available VRAM (prevents
permanent starvation on low-VRAM systems).
Thread-safety: all queue/active state protected by self._lock.
Usage::
sched = get_scheduler(
db_path=Path("data/app.db"),
run_task_fn=my_run_task,
task_types=frozenset({"cover_letter", "research"}),
vram_budgets={"cover_letter": 2.5, "research": 5.0},
)
task_id, is_new = insert_task(db_path, "cover_letter", job_id)
if is_new:
enqueued = sched.enqueue(task_id, "cover_letter", job_id, params_json)
if not enqueued:
mark_task_failed(db_path, task_id, "Queue full")
"""
def __init__(
self,
db_path: Path,
run_task_fn: Callable,
task_types: frozenset[str],
vram_budgets: dict[str, float],
available_vram_gb: Optional[float] = None,
max_queue_depth: int = _DEFAULT_MAX_QUEUE_DEPTH,
) -> None:
self._db_path = db_path
self._run_task = run_task_fn
self._task_types = frozenset(task_types)
self._budgets: dict[str, float] = dict(vram_budgets)
self._max_queue_depth = max_queue_depth
self._lock = threading.Lock()
self._wake = threading.Event()
self._stop = threading.Event()
self._queues: dict[str, deque] = {}
self._active: dict[str, threading.Thread] = {}
self._reserved_vram: float = 0.0
self._thread: Optional[threading.Thread] = None
self._available_vram: float = (
available_vram_gb
if available_vram_gb is not None
else detect_available_vram_gb()
)
for t in self._task_types:
if t not in self._budgets:
logger.warning(
"No VRAM budget defined for task type %r"
"defaulting to 0.0 GB (no VRAM gating for this type)",
t,
)
self._load_queued_tasks()
def enqueue(
self,
task_id: int,
task_type: str,
job_id: int,
params: Optional[str],
) -> bool:
"""Add a task to the scheduler queue.
Returns True if enqueued successfully.
Returns False if the queue is full caller should mark the task failed.
"""
with self._lock:
q = self._queues.setdefault(task_type, deque())
if len(q) >= self._max_queue_depth:
logger.warning(
"Queue depth limit for %s (max=%d) — task %d dropped",
task_type,
self._max_queue_depth,
task_id,
)
return False
q.append(TaskSpec(task_id, job_id, params))
self._wake.set()
return True
def start(self) -> None:
"""Start the background scheduler loop thread. Call once after construction."""
self._thread = threading.Thread(
target=self._scheduler_loop, name="task-scheduler", daemon=True
)
self._thread.start()
# Wake the loop immediately so tasks loaded from DB at startup are dispatched
if any(self._queues.values()):
self._wake.set()
def shutdown(self, timeout: float = 5.0) -> None:
"""Signal the scheduler to stop and wait for it to exit."""
self._stop.set()
self._wake.set()
if self._thread and self._thread.is_alive():
self._thread.join(timeout=timeout)
def _scheduler_loop(self) -> None:
while not self._stop.is_set():
self._wake.wait(timeout=30)
self._wake.clear()
with self._lock:
# Reap batch threads that finished without waking us
for t, thread in list(self._active.items()):
if not thread.is_alive():
self._reserved_vram -= self._budgets.get(t, 0.0)
del self._active[t]
# Start new type batches while VRAM budget allows
candidates = sorted(
[
t
for t in self._queues
if self._queues[t] and t not in self._active
],
key=lambda t: len(self._queues[t]),
reverse=True,
)
for task_type in candidates:
budget = self._budgets.get(task_type, 0.0)
# Always allow at least one batch to run
if (
self._reserved_vram == 0.0
or self._reserved_vram + budget <= self._available_vram
):
thread = threading.Thread(
target=self._batch_worker,
args=(task_type,),
name=f"batch-{task_type}",
daemon=True,
)
self._active[task_type] = thread
self._reserved_vram += budget
thread.start()
def _batch_worker(self, task_type: str) -> None:
"""Serial consumer for one task type. Runs until the type's deque is empty."""
try:
while True:
with self._lock:
q = self._queues.get(task_type)
if not q:
break
task = q.popleft()
self._run_task(
self._db_path, task.id, task_type, task.job_id, task.params
)
finally:
with self._lock:
self._active.pop(task_type, None)
self._reserved_vram -= self._budgets.get(task_type, 0.0)
self._wake.set()
def _load_queued_tasks(self) -> None:
"""Reload surviving 'queued' tasks from SQLite into deques at startup."""
if not self._task_types:
return
task_types_list = sorted(self._task_types)
placeholders = ",".join("?" * len(task_types_list))
conn = sqlite3.connect(self._db_path)
try:
rows = conn.execute(
f"SELECT id, task_type, job_id, params FROM background_tasks"
f" WHERE status='queued' AND task_type IN ({placeholders})"
f" ORDER BY created_at ASC",
task_types_list,
).fetchall()
except sqlite3.OperationalError:
# Table not yet created (first run before migrations)
rows = []
finally:
conn.close()
for row_id, task_type, job_id, params in rows:
q = self._queues.setdefault(task_type, deque())
q.append(TaskSpec(row_id, job_id, params))
if rows:
logger.info(
"Scheduler: resumed %d queued task(s) from prior run", len(rows)
)
# ── Process-level singleton ────────────────────────────────────────────────────
_scheduler: Optional[TaskScheduler] = None
_scheduler_lock = threading.Lock()
def get_scheduler(
db_path: Path,
run_task_fn: Optional[Callable] = None,
task_types: Optional[frozenset[str]] = None,
vram_budgets: Optional[dict[str, float]] = None,
max_queue_depth: int = _DEFAULT_MAX_QUEUE_DEPTH,
) -> TaskScheduler:
"""Return the process-level TaskScheduler singleton.
``run_task_fn``, ``task_types``, and ``vram_budgets`` are required on the
first call; ignored on subsequent calls (singleton already constructed).
"""
global _scheduler
if _scheduler is None:
with _scheduler_lock:
if _scheduler is None:
if (
run_task_fn is None
or task_types is None
or vram_budgets is None
):
raise ValueError(
"run_task_fn, task_types, and vram_budgets are required "
"on the first call to get_scheduler()"
)
_scheduler = TaskScheduler(
db_path=db_path,
run_task_fn=run_task_fn,
task_types=task_types,
vram_budgets=vram_budgets,
max_queue_depth=max_queue_depth,
)
_scheduler.start()
return _scheduler
def reset_scheduler() -> None:
"""Shut down and clear the singleton. TEST TEARDOWN ONLY."""
global _scheduler
with _scheduler_lock:
if _scheduler is not None:
_scheduler.shutdown()
_scheduler = None

View file

@ -22,8 +22,12 @@ orch = [
"typer[all]>=0.12",
"psutil>=5.9",
]
tasks = [
"httpx>=0.27",
]
dev = [
"circuitforge-core[orch]",
"circuitforge-core[tasks]",
"pytest>=8.0",
"pytest-asyncio>=0.23",
"httpx>=0.27",

View file

View file

@ -0,0 +1,243 @@
"""Tests for circuitforge_core.tasks.scheduler."""
from __future__ import annotations
import sqlite3
import threading
import time
from pathlib import Path
from typing import List
from unittest.mock import MagicMock, patch
import pytest
from circuitforge_core.tasks.scheduler import (
TaskScheduler,
detect_available_vram_gb,
get_scheduler,
reset_scheduler,
)
# ── Fixtures ──────────────────────────────────────────────────────────────────
@pytest.fixture
def tmp_db(tmp_path: Path) -> Path:
"""SQLite DB with background_tasks table."""
db = tmp_path / "test.db"
conn = sqlite3.connect(db)
conn.execute("""
CREATE TABLE background_tasks (
id INTEGER PRIMARY KEY AUTOINCREMENT,
task_type TEXT NOT NULL,
job_id INTEGER NOT NULL DEFAULT 0,
status TEXT NOT NULL DEFAULT 'queued',
params TEXT,
error TEXT,
created_at TEXT NOT NULL DEFAULT CURRENT_TIMESTAMP
)
""")
conn.commit()
conn.close()
return db
@pytest.fixture(autouse=True)
def _reset_singleton():
"""Always tear down the scheduler singleton between tests."""
yield
reset_scheduler()
TASK_TYPES = frozenset({"fast_task"})
BUDGETS = {"fast_task": 1.0}
# ── detect_available_vram_gb ──────────────────────────────────────────────────
def test_detect_vram_from_cfortch():
"""Uses cf-orch free VRAM when coordinator is reachable."""
mock_resp = MagicMock()
mock_resp.status_code = 200
mock_resp.json.return_value = {
"nodes": [
{"node_id": "local", "gpus": [{"vram_free_mb": 4096}, {"vram_free_mb": 4096}]}
]
}
with patch("circuitforge_core.tasks.scheduler.httpx") as mock_httpx:
mock_httpx.get.return_value = mock_resp
result = detect_available_vram_gb(coordinator_url="http://localhost:7700")
assert result == pytest.approx(8.0) # 4096 + 4096 MB → 8 GB
def test_detect_vram_cforch_unavailable_falls_back_to_unlimited():
"""Falls back to 999.0 when cf-orch is unreachable and preflight unavailable."""
with patch("circuitforge_core.tasks.scheduler.httpx") as mock_httpx:
mock_httpx.get.side_effect = ConnectionRefusedError()
result = detect_available_vram_gb()
assert result == 999.0
def test_detect_vram_cforch_empty_nodes_falls_back():
"""If cf-orch returns no nodes with GPUs, falls back to unlimited."""
mock_resp = MagicMock()
mock_resp.status_code = 200
mock_resp.json.return_value = {"nodes": []}
with patch("circuitforge_core.tasks.scheduler.httpx") as mock_httpx:
mock_httpx.get.return_value = mock_resp
result = detect_available_vram_gb()
assert result == 999.0
# ── TaskScheduler basic behaviour ─────────────────────────────────────────────
def test_enqueue_returns_true_on_success(tmp_db: Path):
ran: List[int] = []
def run_fn(db_path, task_id, task_type, job_id, params):
ran.append(task_id)
sched = TaskScheduler(tmp_db, run_fn, TASK_TYPES, BUDGETS, available_vram_gb=8.0)
sched.start()
result = sched.enqueue(1, "fast_task", 0, None)
sched.shutdown()
assert result is True
def test_scheduler_runs_task(tmp_db: Path):
"""Enqueued task is executed by the batch worker."""
ran: List[int] = []
event = threading.Event()
def run_fn(db_path, task_id, task_type, job_id, params):
ran.append(task_id)
event.set()
sched = TaskScheduler(tmp_db, run_fn, TASK_TYPES, BUDGETS, available_vram_gb=8.0)
sched.start()
sched.enqueue(42, "fast_task", 0, None)
assert event.wait(timeout=3.0), "Task was not executed within 3 seconds"
sched.shutdown()
assert ran == [42]
def test_enqueue_returns_false_when_queue_full(tmp_db: Path):
"""Returns False and does not enqueue when max_queue_depth is reached."""
gate = threading.Event()
def blocking_run_fn(db_path, task_id, task_type, job_id, params):
gate.wait()
sched = TaskScheduler(
tmp_db, blocking_run_fn, TASK_TYPES, BUDGETS,
available_vram_gb=8.0, max_queue_depth=2
)
sched.start()
results = [sched.enqueue(i, "fast_task", 0, None) for i in range(1, 10)]
gate.set()
sched.shutdown()
assert False in results
def test_scheduler_drains_multiple_tasks(tmp_db: Path):
"""All enqueued tasks of the same type are run serially."""
ran: List[int] = []
done = threading.Event()
TOTAL = 5
def run_fn(db_path, task_id, task_type, job_id, params):
ran.append(task_id)
if len(ran) >= TOTAL:
done.set()
sched = TaskScheduler(tmp_db, run_fn, TASK_TYPES, BUDGETS, available_vram_gb=8.0)
sched.start()
for i in range(1, TOTAL + 1):
sched.enqueue(i, "fast_task", 0, None)
assert done.wait(timeout=5.0), f"Only ran {len(ran)} of {TOTAL} tasks"
sched.shutdown()
assert sorted(ran) == list(range(1, TOTAL + 1))
def test_vram_budget_blocks_second_type(tmp_db: Path):
"""Second task type is not started when VRAM would be exceeded."""
gate_a = threading.Event()
gate_b = threading.Event()
started = []
def run_fn(db_path, task_id, task_type, job_id, params):
started.append(task_type)
if task_type == "type_a":
gate_a.wait()
else:
gate_b.wait()
two_types = frozenset({"type_a", "type_b"})
tight_budgets = {"type_a": 4.0, "type_b": 4.0} # 4+4 > 6 GB available
sched = TaskScheduler(
tmp_db, run_fn, two_types, tight_budgets, available_vram_gb=6.0
)
sched.start()
sched.enqueue(1, "type_a", 0, None)
sched.enqueue(2, "type_b", 0, None)
time.sleep(0.2)
assert started == ["type_a"]
gate_a.set()
time.sleep(0.2)
gate_b.set()
sched.shutdown()
assert "type_b" in started
def test_get_scheduler_singleton(tmp_db: Path):
"""get_scheduler() returns the same instance on repeated calls."""
run_fn = MagicMock()
s1 = get_scheduler(tmp_db, run_fn, TASK_TYPES, BUDGETS)
s2 = get_scheduler(tmp_db) # no run_fn — should reuse existing
assert s1 is s2
def test_reset_scheduler_clears_singleton(tmp_db: Path):
"""reset_scheduler() allows a new singleton to be constructed."""
run_fn = MagicMock()
s1 = get_scheduler(tmp_db, run_fn, TASK_TYPES, BUDGETS)
reset_scheduler()
s2 = get_scheduler(tmp_db, run_fn, TASK_TYPES, BUDGETS)
assert s1 is not s2
def test_load_queued_tasks_on_startup(tmp_db: Path):
"""Tasks with status='queued' in the DB at startup are loaded into the deque."""
conn = sqlite3.connect(tmp_db)
conn.execute(
"INSERT INTO background_tasks (task_type, job_id, status) VALUES ('fast_task', 0, 'queued')"
)
conn.commit()
conn.close()
ran: List[int] = []
done = threading.Event()
def run_fn(db_path, task_id, task_type, job_id, params):
ran.append(task_id)
done.set()
sched = TaskScheduler(tmp_db, run_fn, TASK_TYPES, BUDGETS, available_vram_gb=8.0)
sched.start()
assert done.wait(timeout=3.0), "Pre-loaded task was not run"
sched.shutdown()
assert len(ran) == 1
def test_load_queued_tasks_missing_table_does_not_crash(tmp_path: Path):
"""Scheduler does not crash if background_tasks table doesn't exist yet."""
db = tmp_path / "empty.db"
sqlite3.connect(db).close()
run_fn = MagicMock()
sched = TaskScheduler(db, run_fn, TASK_TYPES, BUDGETS, available_vram_gb=8.0)
sched.start()
sched.shutdown()
# No exception = pass