Merge pull request 'feat(tasks): shared VRAM-aware LLM task scheduler' (#2) from feature/shared-task-scheduler into main

This commit is contained in:
pyr0ball 2026-03-31 10:45:21 -07:00
commit 563b73ce85
8 changed files with 678 additions and 13 deletions

View file

@ -22,7 +22,9 @@ def get_connection(db_path: Path, key: str = "") -> sqlite3.Connection:
cloud_mode = os.environ.get("CLOUD_MODE", "").lower() in ("1", "true", "yes")
if cloud_mode and key:
from pysqlcipher3 import dbapi2 as _sqlcipher # type: ignore
conn = _sqlcipher.connect(str(db_path))
conn = _sqlcipher.connect(str(db_path), timeout=30)
conn.execute(f"PRAGMA key='{key}'")
return conn
return sqlite3.connect(str(db_path))
# timeout=30: retry for up to 30s when another writer holds the lock (WAL mode
# allows concurrent readers but only one writer at a time).
return sqlite3.connect(str(db_path), timeout=30)

View file

@ -23,5 +23,7 @@ def run_migrations(conn: sqlite3.Connection, migrations_dir: Path) -> None:
if sql_file.name in applied:
continue
conn.executescript(sql_file.read_text())
conn.execute("INSERT INTO _migrations (name) VALUES (?)", (sql_file.name,))
# OR IGNORE: safe if two Store() calls race on the same DB — second writer
# just skips the insert rather than raising UNIQUE constraint failed.
conn.execute("INSERT OR IGNORE INTO _migrations (name) VALUES (?)", (sql_file.name,))
conn.commit()

View file

@ -0,0 +1,14 @@
# circuitforge_core/tasks/__init__.py
from circuitforge_core.tasks.scheduler import (
TaskScheduler,
detect_available_vram_gb,
get_scheduler,
reset_scheduler,
)
__all__ = [
"TaskScheduler",
"detect_available_vram_gb",
"get_scheduler",
"reset_scheduler",
]

View file

@ -0,0 +1,344 @@
# circuitforge_core/tasks/scheduler.py
"""Resource-aware batch scheduler for LLM background tasks.
Generic scheduler that any CircuitForge product can use. Products supply:
- task_types: frozenset[str] task type strings routed through this scheduler
- vram_budgets: dict[str, float] VRAM GB estimate per task type
- run_task_fn product's task execution function
VRAM detection priority:
1. cf-orch coordinator /api/nodes free VRAM (lease-aware, cooperative)
2. scripts.preflight.get_gpus() total GPU VRAM (Peregrine-era fallback)
3. 999.0 unlimited (CPU-only or no detection available)
Public API:
TaskScheduler the scheduler class
detect_available_vram_gb() standalone VRAM query helper
get_scheduler() lazy process-level singleton
reset_scheduler() test teardown only
"""
from __future__ import annotations
import logging
import sqlite3
import threading
from collections import deque
from pathlib import Path
from typing import Callable, NamedTuple, Optional
try:
import httpx as httpx
except ImportError:
httpx = None # type: ignore[assignment]
logger = logging.getLogger(__name__)
RunTaskFn = Callable[["Path", int, str, int, Optional[str]], None]
class TaskSpec(NamedTuple):
id: int
job_id: int
params: Optional[str]
_DEFAULT_MAX_QUEUE_DEPTH = 500
def detect_available_vram_gb(
coordinator_url: str = "http://localhost:7700",
) -> float:
"""Detect available VRAM GB for task scheduling.
Returns free VRAM via cf-orch (sum across all nodes/GPUs) so the scheduler
cooperates with other cf-orch consumers. Falls back to preflight total VRAM,
then 999.0 (unlimited) if nothing is reachable.
"""
# 1. Try cf-orch: use free VRAM so the scheduler cooperates with other
# cf-orch consumers (vision service, inference services, etc.)
if httpx is not None:
try:
resp = httpx.get(f"{coordinator_url}/api/nodes", timeout=2.0)
if resp.status_code == 200:
nodes = resp.json().get("nodes", [])
total_free_mb = sum(
gpu.get("vram_free_mb", 0)
for node in nodes
for gpu in node.get("gpus", [])
)
if total_free_mb > 0:
free_gb = total_free_mb / 1024.0
logger.debug(
"Scheduler VRAM from cf-orch: %.1f GB free", free_gb
)
return free_gb
except Exception:
pass
# 2. Try preflight (systems with nvidia-smi; Peregrine-era fallback)
try:
from scripts.preflight import get_gpus # type: ignore[import]
gpus = get_gpus()
if gpus:
total_gb = sum(g.get("vram_total_gb", 0.0) for g in gpus)
logger.debug(
"Scheduler VRAM from preflight: %.1f GB total", total_gb
)
return total_gb
except Exception:
pass
logger.debug(
"Scheduler VRAM detection unavailable — using unlimited (999 GB)"
)
return 999.0
class TaskScheduler:
"""Resource-aware LLM task batch scheduler.
Runs one batch-worker thread per task type while total reserved VRAM
stays within the detected available budget. Always allows at least one
batch to start even if its budget exceeds available VRAM (prevents
permanent starvation on low-VRAM systems).
Thread-safety: all queue/active state protected by self._lock.
Usage::
sched = get_scheduler(
db_path=Path("data/app.db"),
run_task_fn=my_run_task,
task_types=frozenset({"cover_letter", "research"}),
vram_budgets={"cover_letter": 2.5, "research": 5.0},
)
task_id, is_new = insert_task(db_path, "cover_letter", job_id)
if is_new:
enqueued = sched.enqueue(task_id, "cover_letter", job_id, params_json)
if not enqueued:
mark_task_failed(db_path, task_id, "Queue full")
"""
def __init__(
self,
db_path: Path,
run_task_fn: RunTaskFn,
task_types: frozenset[str],
vram_budgets: dict[str, float],
available_vram_gb: Optional[float] = None,
max_queue_depth: int = _DEFAULT_MAX_QUEUE_DEPTH,
) -> None:
self._db_path = db_path
self._run_task = run_task_fn
self._task_types = frozenset(task_types)
self._budgets: dict[str, float] = dict(vram_budgets)
self._max_queue_depth = max_queue_depth
self._lock = threading.Lock()
self._wake = threading.Event()
self._stop = threading.Event()
self._queues: dict[str, deque[TaskSpec]] = {}
self._active: dict[str, threading.Thread] = {}
self._reserved_vram: float = 0.0
self._thread: Optional[threading.Thread] = None
self._available_vram: float = (
available_vram_gb
if available_vram_gb is not None
else detect_available_vram_gb()
)
for t in self._task_types:
if t not in self._budgets:
logger.warning(
"No VRAM budget defined for task type %r"
"defaulting to 0.0 GB (no VRAM gating for this type)",
t,
)
self._load_queued_tasks()
def enqueue(
self,
task_id: int,
task_type: str,
job_id: int,
params: Optional[str],
) -> bool:
"""Add a task to the scheduler queue.
Returns True if enqueued successfully.
Returns False if the queue is full caller should mark the task failed.
"""
with self._lock:
q = self._queues.setdefault(task_type, deque())
if len(q) >= self._max_queue_depth:
logger.warning(
"Queue depth limit for %s (max=%d) — task %d dropped",
task_type,
self._max_queue_depth,
task_id,
)
return False
q.append(TaskSpec(task_id, job_id, params))
self._wake.set()
return True
def start(self) -> None:
"""Start the background scheduler loop thread. Call once after construction."""
self._thread = threading.Thread(
target=self._scheduler_loop, name="task-scheduler", daemon=True
)
self._thread.start()
# Wake the loop immediately so tasks loaded from DB at startup are dispatched
with self._lock:
if any(self._queues.values()):
self._wake.set()
def shutdown(self, timeout: float = 5.0) -> None:
"""Signal the scheduler to stop and wait for it to exit."""
self._stop.set()
self._wake.set()
if self._thread and self._thread.is_alive():
self._thread.join(timeout=timeout)
def _scheduler_loop(self) -> None:
while not self._stop.is_set():
self._wake.wait(timeout=30)
self._wake.clear()
with self._lock:
# Reap batch threads that finished without waking us.
# VRAM accounting is handled solely by _batch_worker's finally block;
# the reaper only removes dead entries from _active.
for t, thread in list(self._active.items()):
if not thread.is_alive():
del self._active[t]
# Start new type batches while VRAM budget allows
candidates = sorted(
[
t
for t in self._queues
if self._queues[t] and t not in self._active
],
key=lambda t: len(self._queues[t]),
reverse=True,
)
for task_type in candidates:
budget = self._budgets.get(task_type, 0.0)
# Always allow at least one batch to run
if (
self._reserved_vram == 0.0
or self._reserved_vram + budget <= self._available_vram
):
thread = threading.Thread(
target=self._batch_worker,
args=(task_type,),
name=f"batch-{task_type}",
daemon=True,
)
self._active[task_type] = thread
self._reserved_vram += budget
thread.start()
def _batch_worker(self, task_type: str) -> None:
"""Serial consumer for one task type. Runs until the type's deque is empty."""
try:
while True:
with self._lock:
q = self._queues.get(task_type)
if not q:
break
task = q.popleft()
self._run_task(
self._db_path, task.id, task_type, task.job_id, task.params
)
finally:
with self._lock:
self._active.pop(task_type, None)
self._reserved_vram -= self._budgets.get(task_type, 0.0)
self._wake.set()
def _load_queued_tasks(self) -> None:
"""Reload surviving 'queued' tasks from SQLite into deques at startup."""
if not self._task_types:
return
task_types_list = sorted(self._task_types)
placeholders = ",".join("?" * len(task_types_list))
try:
with sqlite3.connect(self._db_path) as conn:
rows = conn.execute(
f"SELECT id, task_type, job_id, params FROM background_tasks"
f" WHERE status='queued' AND task_type IN ({placeholders})"
f" ORDER BY created_at ASC",
task_types_list,
).fetchall()
except sqlite3.OperationalError:
# Table not yet created (first run before migrations)
rows = []
for row_id, task_type, job_id, params in rows:
q = self._queues.setdefault(task_type, deque())
q.append(TaskSpec(row_id, job_id, params))
if rows:
logger.info(
"Scheduler: resumed %d queued task(s) from prior run", len(rows)
)
# ── Process-level singleton ────────────────────────────────────────────────────
_scheduler: Optional[TaskScheduler] = None
_scheduler_lock = threading.Lock()
def get_scheduler(
db_path: Path,
run_task_fn: Optional[RunTaskFn] = None,
task_types: Optional[frozenset[str]] = None,
vram_budgets: Optional[dict[str, float]] = None,
max_queue_depth: int = _DEFAULT_MAX_QUEUE_DEPTH,
) -> TaskScheduler:
"""Return the process-level TaskScheduler singleton.
``run_task_fn``, ``task_types``, and ``vram_budgets`` are required on the
first call; ignored on subsequent calls (singleton already constructed).
VRAM detection (which may involve a network call) is performed outside the
lock so the lock is never held across blocking I/O.
"""
global _scheduler
if _scheduler is not None:
return _scheduler
# Build outside the lock — TaskScheduler.__init__ may call detect_available_vram_gb()
# which makes an httpx network call (up to 2 s). Holding the lock during that
# would block any concurrent caller for the full duration.
if run_task_fn is None or task_types is None or vram_budgets is None:
raise ValueError(
"run_task_fn, task_types, and vram_budgets are required "
"on the first call to get_scheduler()"
)
candidate = TaskScheduler(
db_path=db_path,
run_task_fn=run_task_fn,
task_types=task_types,
vram_budgets=vram_budgets,
max_queue_depth=max_queue_depth,
)
candidate.start()
with _scheduler_lock:
if _scheduler is None:
_scheduler = candidate
else:
# Another thread beat us — shut down our candidate and use the winner.
candidate.shutdown()
return _scheduler
def reset_scheduler() -> None:
"""Shut down and clear the singleton. TEST TEARDOWN ONLY."""
global _scheduler
with _scheduler_lock:
if _scheduler is not None:
_scheduler.shutdown()
_scheduler = None

View file

@ -30,26 +30,35 @@ def can_use(
has_byok: bool = False,
has_local_vision: bool = False,
_features: dict[str, str] | None = None,
_byok_unlockable: frozenset[str] | None = None,
_local_vision_unlockable: frozenset[str] | None = None,
) -> bool:
"""
Return True if the given tier (and optional unlocks) can access feature.
Args:
feature: Feature key string.
tier: User's current tier ("free", "paid", "premium", "ultra").
has_byok: True if user has a configured LLM backend.
has_local_vision: True if user has a local vision model configured.
_features: Featuremin_tier map. Products pass their own dict here.
If None, all features are free.
feature: Feature key string.
tier: User's current tier ("free", "paid", "premium", "ultra").
has_byok: True if user has a configured LLM backend.
has_local_vision: True if user has a local vision model configured.
_features: Featuremin_tier map. Products pass their own dict here.
If None, all features are free.
_byok_unlockable: Product-specific BYOK-unlockable features.
If None, uses module-level BYOK_UNLOCKABLE.
_local_vision_unlockable: Product-specific local vision unlockable features.
If None, uses module-level LOCAL_VISION_UNLOCKABLE.
"""
features = _features or {}
byok_unlockable = _byok_unlockable if _byok_unlockable is not None else BYOK_UNLOCKABLE
local_vision_unlockable = _local_vision_unlockable if _local_vision_unlockable is not None else LOCAL_VISION_UNLOCKABLE
if feature not in features:
return True
if has_byok and feature in BYOK_UNLOCKABLE:
if has_byok and feature in byok_unlockable:
return True
if has_local_vision and feature in LOCAL_VISION_UNLOCKABLE:
if has_local_vision and feature in local_vision_unlockable:
return True
min_tier = features[feature]
@ -64,13 +73,18 @@ def tier_label(
has_byok: bool = False,
has_local_vision: bool = False,
_features: dict[str, str] | None = None,
_byok_unlockable: frozenset[str] | None = None,
_local_vision_unlockable: frozenset[str] | None = None,
) -> str:
"""Return a human-readable label for the minimum tier needed for feature."""
features = _features or {}
byok_unlockable = _byok_unlockable if _byok_unlockable is not None else BYOK_UNLOCKABLE
local_vision_unlockable = _local_vision_unlockable if _local_vision_unlockable is not None else LOCAL_VISION_UNLOCKABLE
if feature not in features:
return "free"
if has_byok and feature in BYOK_UNLOCKABLE:
if has_byok and feature in byok_unlockable:
return "free (BYOK)"
if has_local_vision and feature in LOCAL_VISION_UNLOCKABLE:
if has_local_vision and feature in local_vision_unlockable:
return "free (local vision)"
return features[feature]

View file

@ -22,8 +22,12 @@ orch = [
"typer[all]>=0.12",
"psutil>=5.9",
]
tasks = [
"httpx>=0.27",
]
dev = [
"circuitforge-core[orch]",
"circuitforge-core[tasks]",
"pytest>=8.0",
"pytest-asyncio>=0.23",
"httpx>=0.27",

View file

View file

@ -0,0 +1,285 @@
"""Tests for circuitforge_core.tasks.scheduler."""
from __future__ import annotations
import sqlite3
import threading
import time
from pathlib import Path
from types import ModuleType
from typing import List
from unittest.mock import MagicMock, patch
import pytest
from circuitforge_core.tasks.scheduler import (
TaskScheduler,
detect_available_vram_gb,
get_scheduler,
reset_scheduler,
)
# ── Fixtures ──────────────────────────────────────────────────────────────────
@pytest.fixture
def tmp_db(tmp_path: Path) -> Path:
"""SQLite DB with background_tasks table."""
db = tmp_path / "test.db"
conn = sqlite3.connect(db)
conn.execute("""
CREATE TABLE background_tasks (
id INTEGER PRIMARY KEY AUTOINCREMENT,
task_type TEXT NOT NULL,
job_id INTEGER NOT NULL DEFAULT 0,
status TEXT NOT NULL DEFAULT 'queued',
params TEXT,
error TEXT,
created_at TEXT NOT NULL DEFAULT CURRENT_TIMESTAMP
)
""")
conn.commit()
conn.close()
return db
@pytest.fixture(autouse=True)
def _reset_singleton():
"""Always tear down the scheduler singleton between tests."""
yield
reset_scheduler()
TASK_TYPES = frozenset({"fast_task"})
BUDGETS = {"fast_task": 1.0}
# ── detect_available_vram_gb ──────────────────────────────────────────────────
def test_detect_vram_from_cfortch():
"""Uses cf-orch free VRAM when coordinator is reachable."""
mock_resp = MagicMock()
mock_resp.status_code = 200
mock_resp.json.return_value = {
"nodes": [
{"node_id": "local", "gpus": [{"vram_free_mb": 4096}, {"vram_free_mb": 4096}]}
]
}
with patch("circuitforge_core.tasks.scheduler.httpx") as mock_httpx:
mock_httpx.get.return_value = mock_resp
result = detect_available_vram_gb(coordinator_url="http://localhost:7700")
assert result == pytest.approx(8.0) # 4096 + 4096 MB → 8 GB
def test_detect_vram_cforch_unavailable_falls_back_to_unlimited():
"""Falls back to 999.0 when cf-orch is unreachable and preflight unavailable."""
with patch("circuitforge_core.tasks.scheduler.httpx") as mock_httpx:
mock_httpx.get.side_effect = ConnectionRefusedError()
result = detect_available_vram_gb()
assert result == 999.0
def test_detect_vram_cforch_empty_nodes_falls_back():
"""If cf-orch returns no nodes with GPUs, falls back to unlimited."""
mock_resp = MagicMock()
mock_resp.status_code = 200
mock_resp.json.return_value = {"nodes": []}
with patch("circuitforge_core.tasks.scheduler.httpx") as mock_httpx:
mock_httpx.get.return_value = mock_resp
result = detect_available_vram_gb()
assert result == 999.0
def test_detect_vram_preflight_fallback():
"""Falls back to preflight total VRAM when cf-orch is unreachable."""
# Build a fake scripts.preflight module with get_gpus returning two GPUs.
fake_scripts = ModuleType("scripts")
fake_preflight = ModuleType("scripts.preflight")
fake_preflight.get_gpus = lambda: [ # type: ignore[attr-defined]
{"vram_total_gb": 8.0},
{"vram_total_gb": 4.0},
]
fake_scripts.preflight = fake_preflight # type: ignore[attr-defined]
with patch("circuitforge_core.tasks.scheduler.httpx") as mock_httpx, \
patch.dict(
__import__("sys").modules,
{"scripts": fake_scripts, "scripts.preflight": fake_preflight},
):
mock_httpx.get.side_effect = ConnectionRefusedError()
result = detect_available_vram_gb()
assert result == pytest.approx(12.0) # 8.0 + 4.0 GB
# ── TaskScheduler basic behaviour ─────────────────────────────────────────────
def test_enqueue_returns_true_on_success(tmp_db: Path):
ran: List[int] = []
def run_fn(db_path, task_id, task_type, job_id, params):
ran.append(task_id)
sched = TaskScheduler(tmp_db, run_fn, TASK_TYPES, BUDGETS, available_vram_gb=8.0)
sched.start()
result = sched.enqueue(1, "fast_task", 0, None)
sched.shutdown()
assert result is True
def test_scheduler_runs_task(tmp_db: Path):
"""Enqueued task is executed by the batch worker."""
ran: List[int] = []
event = threading.Event()
def run_fn(db_path, task_id, task_type, job_id, params):
ran.append(task_id)
event.set()
sched = TaskScheduler(tmp_db, run_fn, TASK_TYPES, BUDGETS, available_vram_gb=8.0)
sched.start()
sched.enqueue(42, "fast_task", 0, None)
assert event.wait(timeout=3.0), "Task was not executed within 3 seconds"
sched.shutdown()
assert ran == [42]
def test_enqueue_returns_false_when_queue_full(tmp_db: Path):
"""Returns False and does not enqueue when max_queue_depth is reached."""
gate = threading.Event()
def blocking_run_fn(db_path, task_id, task_type, job_id, params):
gate.wait()
sched = TaskScheduler(
tmp_db, blocking_run_fn, TASK_TYPES, BUDGETS,
available_vram_gb=8.0, max_queue_depth=2
)
sched.start()
results = [sched.enqueue(i, "fast_task", 0, None) for i in range(1, 10)]
gate.set()
sched.shutdown()
assert not all(results), "Expected at least one enqueue to be rejected"
def test_scheduler_drains_multiple_tasks(tmp_db: Path):
"""All enqueued tasks of the same type are run serially."""
ran: List[int] = []
done = threading.Event()
TOTAL = 5
def run_fn(db_path, task_id, task_type, job_id, params):
ran.append(task_id)
if len(ran) >= TOTAL:
done.set()
sched = TaskScheduler(tmp_db, run_fn, TASK_TYPES, BUDGETS, available_vram_gb=8.0)
sched.start()
for i in range(1, TOTAL + 1):
sched.enqueue(i, "fast_task", 0, None)
assert done.wait(timeout=5.0), f"Only ran {len(ran)} of {TOTAL} tasks"
sched.shutdown()
assert sorted(ran) == list(range(1, TOTAL + 1))
def test_vram_budget_blocks_second_type(tmp_db: Path):
"""Second task type is not started when VRAM would be exceeded."""
type_a_started = threading.Event()
type_b_started = threading.Event()
gate_a = threading.Event()
gate_b = threading.Event()
started = []
def run_fn(db_path, task_id, task_type, job_id, params):
started.append(task_type)
if task_type == "type_a":
type_a_started.set()
gate_a.wait()
else:
type_b_started.set()
gate_b.wait()
two_types = frozenset({"type_a", "type_b"})
tight_budgets = {"type_a": 4.0, "type_b": 4.0} # 4+4 > 6 GB available
sched = TaskScheduler(
tmp_db, run_fn, two_types, tight_budgets, available_vram_gb=6.0
)
sched.start()
sched.enqueue(1, "type_a", 0, None)
sched.enqueue(2, "type_b", 0, None)
assert type_a_started.wait(timeout=3.0), "type_a never started"
assert not type_b_started.is_set(), "type_b should be blocked by VRAM"
gate_a.set()
assert type_b_started.wait(timeout=3.0), "type_b never started after type_a finished"
gate_b.set()
sched.shutdown()
assert sorted(started) == ["type_a", "type_b"]
def test_get_scheduler_singleton(tmp_db: Path):
"""get_scheduler() returns the same instance on repeated calls."""
run_fn = MagicMock()
s1 = get_scheduler(tmp_db, run_fn, TASK_TYPES, BUDGETS)
s2 = get_scheduler(tmp_db) # no run_fn — should reuse existing
assert s1 is s2
def test_reset_scheduler_clears_singleton(tmp_db: Path):
"""reset_scheduler() allows a new singleton to be constructed."""
run_fn = MagicMock()
s1 = get_scheduler(tmp_db, run_fn, TASK_TYPES, BUDGETS)
reset_scheduler()
s2 = get_scheduler(tmp_db, run_fn, TASK_TYPES, BUDGETS)
assert s1 is not s2
def test_load_queued_tasks_on_startup(tmp_db: Path):
"""Tasks with status='queued' in the DB at startup are loaded into the deque."""
conn = sqlite3.connect(tmp_db)
conn.execute(
"INSERT INTO background_tasks (task_type, job_id, status) VALUES ('fast_task', 0, 'queued')"
)
conn.commit()
conn.close()
ran: List[int] = []
done = threading.Event()
def run_fn(db_path, task_id, task_type, job_id, params):
ran.append(task_id)
done.set()
sched = TaskScheduler(tmp_db, run_fn, TASK_TYPES, BUDGETS, available_vram_gb=8.0)
sched.start()
assert done.wait(timeout=3.0), "Pre-loaded task was not run"
sched.shutdown()
assert len(ran) == 1
def test_load_queued_tasks_missing_table_does_not_crash(tmp_path: Path):
"""Scheduler does not crash if background_tasks table doesn't exist yet."""
db = tmp_path / "empty.db"
sqlite3.connect(db).close()
run_fn = MagicMock()
sched = TaskScheduler(db, run_fn, TASK_TYPES, BUDGETS, available_vram_gb=8.0)
sched.start()
sched.shutdown()
# No exception = pass
def test_reserved_vram_zero_after_task_completes(tmp_db: Path):
"""_reserved_vram returns to 0.0 after a task finishes — no double-decrement."""
done = threading.Event()
def run_fn(db_path, task_id, task_type, job_id, params):
done.set()
sched = TaskScheduler(tmp_db, run_fn, TASK_TYPES, BUDGETS, available_vram_gb=8.0)
sched.start()
sched.enqueue(1, "fast_task", 0, None)
assert done.wait(timeout=3.0), "Task never completed"
sched.shutdown()
assert sched._reserved_vram == 0.0, f"Expected 0.0, got {sched._reserved_vram}"