Merge pull request 'feat(tasks): shared VRAM-aware LLM task scheduler' (#2) from feature/shared-task-scheduler into main
This commit is contained in:
commit
563b73ce85
8 changed files with 678 additions and 13 deletions
|
|
@ -22,7 +22,9 @@ def get_connection(db_path: Path, key: str = "") -> sqlite3.Connection:
|
||||||
cloud_mode = os.environ.get("CLOUD_MODE", "").lower() in ("1", "true", "yes")
|
cloud_mode = os.environ.get("CLOUD_MODE", "").lower() in ("1", "true", "yes")
|
||||||
if cloud_mode and key:
|
if cloud_mode and key:
|
||||||
from pysqlcipher3 import dbapi2 as _sqlcipher # type: ignore
|
from pysqlcipher3 import dbapi2 as _sqlcipher # type: ignore
|
||||||
conn = _sqlcipher.connect(str(db_path))
|
conn = _sqlcipher.connect(str(db_path), timeout=30)
|
||||||
conn.execute(f"PRAGMA key='{key}'")
|
conn.execute(f"PRAGMA key='{key}'")
|
||||||
return conn
|
return conn
|
||||||
return sqlite3.connect(str(db_path))
|
# timeout=30: retry for up to 30s when another writer holds the lock (WAL mode
|
||||||
|
# allows concurrent readers but only one writer at a time).
|
||||||
|
return sqlite3.connect(str(db_path), timeout=30)
|
||||||
|
|
|
||||||
|
|
@ -23,5 +23,7 @@ def run_migrations(conn: sqlite3.Connection, migrations_dir: Path) -> None:
|
||||||
if sql_file.name in applied:
|
if sql_file.name in applied:
|
||||||
continue
|
continue
|
||||||
conn.executescript(sql_file.read_text())
|
conn.executescript(sql_file.read_text())
|
||||||
conn.execute("INSERT INTO _migrations (name) VALUES (?)", (sql_file.name,))
|
# OR IGNORE: safe if two Store() calls race on the same DB — second writer
|
||||||
|
# just skips the insert rather than raising UNIQUE constraint failed.
|
||||||
|
conn.execute("INSERT OR IGNORE INTO _migrations (name) VALUES (?)", (sql_file.name,))
|
||||||
conn.commit()
|
conn.commit()
|
||||||
|
|
|
||||||
14
circuitforge_core/tasks/__init__.py
Normal file
14
circuitforge_core/tasks/__init__.py
Normal file
|
|
@ -0,0 +1,14 @@
|
||||||
|
# circuitforge_core/tasks/__init__.py
|
||||||
|
from circuitforge_core.tasks.scheduler import (
|
||||||
|
TaskScheduler,
|
||||||
|
detect_available_vram_gb,
|
||||||
|
get_scheduler,
|
||||||
|
reset_scheduler,
|
||||||
|
)
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"TaskScheduler",
|
||||||
|
"detect_available_vram_gb",
|
||||||
|
"get_scheduler",
|
||||||
|
"reset_scheduler",
|
||||||
|
]
|
||||||
344
circuitforge_core/tasks/scheduler.py
Normal file
344
circuitforge_core/tasks/scheduler.py
Normal file
|
|
@ -0,0 +1,344 @@
|
||||||
|
# circuitforge_core/tasks/scheduler.py
|
||||||
|
"""Resource-aware batch scheduler for LLM background tasks.
|
||||||
|
|
||||||
|
Generic scheduler that any CircuitForge product can use. Products supply:
|
||||||
|
- task_types: frozenset[str] — task type strings routed through this scheduler
|
||||||
|
- vram_budgets: dict[str, float] — VRAM GB estimate per task type
|
||||||
|
- run_task_fn — product's task execution function
|
||||||
|
|
||||||
|
VRAM detection priority:
|
||||||
|
1. cf-orch coordinator /api/nodes — free VRAM (lease-aware, cooperative)
|
||||||
|
2. scripts.preflight.get_gpus() — total GPU VRAM (Peregrine-era fallback)
|
||||||
|
3. 999.0 — unlimited (CPU-only or no detection available)
|
||||||
|
|
||||||
|
Public API:
|
||||||
|
TaskScheduler — the scheduler class
|
||||||
|
detect_available_vram_gb() — standalone VRAM query helper
|
||||||
|
get_scheduler() — lazy process-level singleton
|
||||||
|
reset_scheduler() — test teardown only
|
||||||
|
"""
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import logging
|
||||||
|
import sqlite3
|
||||||
|
import threading
|
||||||
|
from collections import deque
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Callable, NamedTuple, Optional
|
||||||
|
|
||||||
|
try:
|
||||||
|
import httpx as httpx
|
||||||
|
except ImportError:
|
||||||
|
httpx = None # type: ignore[assignment]
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
RunTaskFn = Callable[["Path", int, str, int, Optional[str]], None]
|
||||||
|
|
||||||
|
|
||||||
|
class TaskSpec(NamedTuple):
|
||||||
|
id: int
|
||||||
|
job_id: int
|
||||||
|
params: Optional[str]
|
||||||
|
|
||||||
|
_DEFAULT_MAX_QUEUE_DEPTH = 500
|
||||||
|
|
||||||
|
|
||||||
|
def detect_available_vram_gb(
|
||||||
|
coordinator_url: str = "http://localhost:7700",
|
||||||
|
) -> float:
|
||||||
|
"""Detect available VRAM GB for task scheduling.
|
||||||
|
|
||||||
|
Returns free VRAM via cf-orch (sum across all nodes/GPUs) so the scheduler
|
||||||
|
cooperates with other cf-orch consumers. Falls back to preflight total VRAM,
|
||||||
|
then 999.0 (unlimited) if nothing is reachable.
|
||||||
|
"""
|
||||||
|
# 1. Try cf-orch: use free VRAM so the scheduler cooperates with other
|
||||||
|
# cf-orch consumers (vision service, inference services, etc.)
|
||||||
|
if httpx is not None:
|
||||||
|
try:
|
||||||
|
resp = httpx.get(f"{coordinator_url}/api/nodes", timeout=2.0)
|
||||||
|
if resp.status_code == 200:
|
||||||
|
nodes = resp.json().get("nodes", [])
|
||||||
|
total_free_mb = sum(
|
||||||
|
gpu.get("vram_free_mb", 0)
|
||||||
|
for node in nodes
|
||||||
|
for gpu in node.get("gpus", [])
|
||||||
|
)
|
||||||
|
if total_free_mb > 0:
|
||||||
|
free_gb = total_free_mb / 1024.0
|
||||||
|
logger.debug(
|
||||||
|
"Scheduler VRAM from cf-orch: %.1f GB free", free_gb
|
||||||
|
)
|
||||||
|
return free_gb
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
# 2. Try preflight (systems with nvidia-smi; Peregrine-era fallback)
|
||||||
|
try:
|
||||||
|
from scripts.preflight import get_gpus # type: ignore[import]
|
||||||
|
|
||||||
|
gpus = get_gpus()
|
||||||
|
if gpus:
|
||||||
|
total_gb = sum(g.get("vram_total_gb", 0.0) for g in gpus)
|
||||||
|
logger.debug(
|
||||||
|
"Scheduler VRAM from preflight: %.1f GB total", total_gb
|
||||||
|
)
|
||||||
|
return total_gb
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
logger.debug(
|
||||||
|
"Scheduler VRAM detection unavailable — using unlimited (999 GB)"
|
||||||
|
)
|
||||||
|
return 999.0
|
||||||
|
|
||||||
|
|
||||||
|
class TaskScheduler:
|
||||||
|
"""Resource-aware LLM task batch scheduler.
|
||||||
|
|
||||||
|
Runs one batch-worker thread per task type while total reserved VRAM
|
||||||
|
stays within the detected available budget. Always allows at least one
|
||||||
|
batch to start even if its budget exceeds available VRAM (prevents
|
||||||
|
permanent starvation on low-VRAM systems).
|
||||||
|
|
||||||
|
Thread-safety: all queue/active state protected by self._lock.
|
||||||
|
|
||||||
|
Usage::
|
||||||
|
|
||||||
|
sched = get_scheduler(
|
||||||
|
db_path=Path("data/app.db"),
|
||||||
|
run_task_fn=my_run_task,
|
||||||
|
task_types=frozenset({"cover_letter", "research"}),
|
||||||
|
vram_budgets={"cover_letter": 2.5, "research": 5.0},
|
||||||
|
)
|
||||||
|
task_id, is_new = insert_task(db_path, "cover_letter", job_id)
|
||||||
|
if is_new:
|
||||||
|
enqueued = sched.enqueue(task_id, "cover_letter", job_id, params_json)
|
||||||
|
if not enqueued:
|
||||||
|
mark_task_failed(db_path, task_id, "Queue full")
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
db_path: Path,
|
||||||
|
run_task_fn: RunTaskFn,
|
||||||
|
task_types: frozenset[str],
|
||||||
|
vram_budgets: dict[str, float],
|
||||||
|
available_vram_gb: Optional[float] = None,
|
||||||
|
max_queue_depth: int = _DEFAULT_MAX_QUEUE_DEPTH,
|
||||||
|
) -> None:
|
||||||
|
self._db_path = db_path
|
||||||
|
self._run_task = run_task_fn
|
||||||
|
self._task_types = frozenset(task_types)
|
||||||
|
self._budgets: dict[str, float] = dict(vram_budgets)
|
||||||
|
self._max_queue_depth = max_queue_depth
|
||||||
|
|
||||||
|
self._lock = threading.Lock()
|
||||||
|
self._wake = threading.Event()
|
||||||
|
self._stop = threading.Event()
|
||||||
|
self._queues: dict[str, deque[TaskSpec]] = {}
|
||||||
|
self._active: dict[str, threading.Thread] = {}
|
||||||
|
self._reserved_vram: float = 0.0
|
||||||
|
self._thread: Optional[threading.Thread] = None
|
||||||
|
|
||||||
|
self._available_vram: float = (
|
||||||
|
available_vram_gb
|
||||||
|
if available_vram_gb is not None
|
||||||
|
else detect_available_vram_gb()
|
||||||
|
)
|
||||||
|
|
||||||
|
for t in self._task_types:
|
||||||
|
if t not in self._budgets:
|
||||||
|
logger.warning(
|
||||||
|
"No VRAM budget defined for task type %r — "
|
||||||
|
"defaulting to 0.0 GB (no VRAM gating for this type)",
|
||||||
|
t,
|
||||||
|
)
|
||||||
|
|
||||||
|
self._load_queued_tasks()
|
||||||
|
|
||||||
|
def enqueue(
|
||||||
|
self,
|
||||||
|
task_id: int,
|
||||||
|
task_type: str,
|
||||||
|
job_id: int,
|
||||||
|
params: Optional[str],
|
||||||
|
) -> bool:
|
||||||
|
"""Add a task to the scheduler queue.
|
||||||
|
|
||||||
|
Returns True if enqueued successfully.
|
||||||
|
Returns False if the queue is full — caller should mark the task failed.
|
||||||
|
"""
|
||||||
|
with self._lock:
|
||||||
|
q = self._queues.setdefault(task_type, deque())
|
||||||
|
if len(q) >= self._max_queue_depth:
|
||||||
|
logger.warning(
|
||||||
|
"Queue depth limit for %s (max=%d) — task %d dropped",
|
||||||
|
task_type,
|
||||||
|
self._max_queue_depth,
|
||||||
|
task_id,
|
||||||
|
)
|
||||||
|
return False
|
||||||
|
q.append(TaskSpec(task_id, job_id, params))
|
||||||
|
self._wake.set()
|
||||||
|
return True
|
||||||
|
|
||||||
|
def start(self) -> None:
|
||||||
|
"""Start the background scheduler loop thread. Call once after construction."""
|
||||||
|
self._thread = threading.Thread(
|
||||||
|
target=self._scheduler_loop, name="task-scheduler", daemon=True
|
||||||
|
)
|
||||||
|
self._thread.start()
|
||||||
|
# Wake the loop immediately so tasks loaded from DB at startup are dispatched
|
||||||
|
with self._lock:
|
||||||
|
if any(self._queues.values()):
|
||||||
|
self._wake.set()
|
||||||
|
|
||||||
|
def shutdown(self, timeout: float = 5.0) -> None:
|
||||||
|
"""Signal the scheduler to stop and wait for it to exit."""
|
||||||
|
self._stop.set()
|
||||||
|
self._wake.set()
|
||||||
|
if self._thread and self._thread.is_alive():
|
||||||
|
self._thread.join(timeout=timeout)
|
||||||
|
|
||||||
|
def _scheduler_loop(self) -> None:
|
||||||
|
while not self._stop.is_set():
|
||||||
|
self._wake.wait(timeout=30)
|
||||||
|
self._wake.clear()
|
||||||
|
with self._lock:
|
||||||
|
# Reap batch threads that finished without waking us.
|
||||||
|
# VRAM accounting is handled solely by _batch_worker's finally block;
|
||||||
|
# the reaper only removes dead entries from _active.
|
||||||
|
for t, thread in list(self._active.items()):
|
||||||
|
if not thread.is_alive():
|
||||||
|
del self._active[t]
|
||||||
|
# Start new type batches while VRAM budget allows
|
||||||
|
candidates = sorted(
|
||||||
|
[
|
||||||
|
t
|
||||||
|
for t in self._queues
|
||||||
|
if self._queues[t] and t not in self._active
|
||||||
|
],
|
||||||
|
key=lambda t: len(self._queues[t]),
|
||||||
|
reverse=True,
|
||||||
|
)
|
||||||
|
for task_type in candidates:
|
||||||
|
budget = self._budgets.get(task_type, 0.0)
|
||||||
|
# Always allow at least one batch to run
|
||||||
|
if (
|
||||||
|
self._reserved_vram == 0.0
|
||||||
|
or self._reserved_vram + budget <= self._available_vram
|
||||||
|
):
|
||||||
|
thread = threading.Thread(
|
||||||
|
target=self._batch_worker,
|
||||||
|
args=(task_type,),
|
||||||
|
name=f"batch-{task_type}",
|
||||||
|
daemon=True,
|
||||||
|
)
|
||||||
|
self._active[task_type] = thread
|
||||||
|
self._reserved_vram += budget
|
||||||
|
thread.start()
|
||||||
|
|
||||||
|
def _batch_worker(self, task_type: str) -> None:
|
||||||
|
"""Serial consumer for one task type. Runs until the type's deque is empty."""
|
||||||
|
try:
|
||||||
|
while True:
|
||||||
|
with self._lock:
|
||||||
|
q = self._queues.get(task_type)
|
||||||
|
if not q:
|
||||||
|
break
|
||||||
|
task = q.popleft()
|
||||||
|
self._run_task(
|
||||||
|
self._db_path, task.id, task_type, task.job_id, task.params
|
||||||
|
)
|
||||||
|
finally:
|
||||||
|
with self._lock:
|
||||||
|
self._active.pop(task_type, None)
|
||||||
|
self._reserved_vram -= self._budgets.get(task_type, 0.0)
|
||||||
|
self._wake.set()
|
||||||
|
|
||||||
|
def _load_queued_tasks(self) -> None:
|
||||||
|
"""Reload surviving 'queued' tasks from SQLite into deques at startup."""
|
||||||
|
if not self._task_types:
|
||||||
|
return
|
||||||
|
task_types_list = sorted(self._task_types)
|
||||||
|
placeholders = ",".join("?" * len(task_types_list))
|
||||||
|
try:
|
||||||
|
with sqlite3.connect(self._db_path) as conn:
|
||||||
|
rows = conn.execute(
|
||||||
|
f"SELECT id, task_type, job_id, params FROM background_tasks"
|
||||||
|
f" WHERE status='queued' AND task_type IN ({placeholders})"
|
||||||
|
f" ORDER BY created_at ASC",
|
||||||
|
task_types_list,
|
||||||
|
).fetchall()
|
||||||
|
except sqlite3.OperationalError:
|
||||||
|
# Table not yet created (first run before migrations)
|
||||||
|
rows = []
|
||||||
|
|
||||||
|
for row_id, task_type, job_id, params in rows:
|
||||||
|
q = self._queues.setdefault(task_type, deque())
|
||||||
|
q.append(TaskSpec(row_id, job_id, params))
|
||||||
|
|
||||||
|
if rows:
|
||||||
|
logger.info(
|
||||||
|
"Scheduler: resumed %d queued task(s) from prior run", len(rows)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# ── Process-level singleton ────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
_scheduler: Optional[TaskScheduler] = None
|
||||||
|
_scheduler_lock = threading.Lock()
|
||||||
|
|
||||||
|
|
||||||
|
def get_scheduler(
|
||||||
|
db_path: Path,
|
||||||
|
run_task_fn: Optional[RunTaskFn] = None,
|
||||||
|
task_types: Optional[frozenset[str]] = None,
|
||||||
|
vram_budgets: Optional[dict[str, float]] = None,
|
||||||
|
max_queue_depth: int = _DEFAULT_MAX_QUEUE_DEPTH,
|
||||||
|
) -> TaskScheduler:
|
||||||
|
"""Return the process-level TaskScheduler singleton.
|
||||||
|
|
||||||
|
``run_task_fn``, ``task_types``, and ``vram_budgets`` are required on the
|
||||||
|
first call; ignored on subsequent calls (singleton already constructed).
|
||||||
|
|
||||||
|
VRAM detection (which may involve a network call) is performed outside the
|
||||||
|
lock so the lock is never held across blocking I/O.
|
||||||
|
"""
|
||||||
|
global _scheduler
|
||||||
|
if _scheduler is not None:
|
||||||
|
return _scheduler
|
||||||
|
# Build outside the lock — TaskScheduler.__init__ may call detect_available_vram_gb()
|
||||||
|
# which makes an httpx network call (up to 2 s). Holding the lock during that
|
||||||
|
# would block any concurrent caller for the full duration.
|
||||||
|
if run_task_fn is None or task_types is None or vram_budgets is None:
|
||||||
|
raise ValueError(
|
||||||
|
"run_task_fn, task_types, and vram_budgets are required "
|
||||||
|
"on the first call to get_scheduler()"
|
||||||
|
)
|
||||||
|
candidate = TaskScheduler(
|
||||||
|
db_path=db_path,
|
||||||
|
run_task_fn=run_task_fn,
|
||||||
|
task_types=task_types,
|
||||||
|
vram_budgets=vram_budgets,
|
||||||
|
max_queue_depth=max_queue_depth,
|
||||||
|
)
|
||||||
|
candidate.start()
|
||||||
|
with _scheduler_lock:
|
||||||
|
if _scheduler is None:
|
||||||
|
_scheduler = candidate
|
||||||
|
else:
|
||||||
|
# Another thread beat us — shut down our candidate and use the winner.
|
||||||
|
candidate.shutdown()
|
||||||
|
return _scheduler
|
||||||
|
|
||||||
|
|
||||||
|
def reset_scheduler() -> None:
|
||||||
|
"""Shut down and clear the singleton. TEST TEARDOWN ONLY."""
|
||||||
|
global _scheduler
|
||||||
|
with _scheduler_lock:
|
||||||
|
if _scheduler is not None:
|
||||||
|
_scheduler.shutdown()
|
||||||
|
_scheduler = None
|
||||||
|
|
@ -30,6 +30,8 @@ def can_use(
|
||||||
has_byok: bool = False,
|
has_byok: bool = False,
|
||||||
has_local_vision: bool = False,
|
has_local_vision: bool = False,
|
||||||
_features: dict[str, str] | None = None,
|
_features: dict[str, str] | None = None,
|
||||||
|
_byok_unlockable: frozenset[str] | None = None,
|
||||||
|
_local_vision_unlockable: frozenset[str] | None = None,
|
||||||
) -> bool:
|
) -> bool:
|
||||||
"""
|
"""
|
||||||
Return True if the given tier (and optional unlocks) can access feature.
|
Return True if the given tier (and optional unlocks) can access feature.
|
||||||
|
|
@ -41,15 +43,22 @@ def can_use(
|
||||||
has_local_vision: True if user has a local vision model configured.
|
has_local_vision: True if user has a local vision model configured.
|
||||||
_features: Feature→min_tier map. Products pass their own dict here.
|
_features: Feature→min_tier map. Products pass their own dict here.
|
||||||
If None, all features are free.
|
If None, all features are free.
|
||||||
|
_byok_unlockable: Product-specific BYOK-unlockable features.
|
||||||
|
If None, uses module-level BYOK_UNLOCKABLE.
|
||||||
|
_local_vision_unlockable: Product-specific local vision unlockable features.
|
||||||
|
If None, uses module-level LOCAL_VISION_UNLOCKABLE.
|
||||||
"""
|
"""
|
||||||
features = _features or {}
|
features = _features or {}
|
||||||
|
byok_unlockable = _byok_unlockable if _byok_unlockable is not None else BYOK_UNLOCKABLE
|
||||||
|
local_vision_unlockable = _local_vision_unlockable if _local_vision_unlockable is not None else LOCAL_VISION_UNLOCKABLE
|
||||||
|
|
||||||
if feature not in features:
|
if feature not in features:
|
||||||
return True
|
return True
|
||||||
|
|
||||||
if has_byok and feature in BYOK_UNLOCKABLE:
|
if has_byok and feature in byok_unlockable:
|
||||||
return True
|
return True
|
||||||
|
|
||||||
if has_local_vision and feature in LOCAL_VISION_UNLOCKABLE:
|
if has_local_vision and feature in local_vision_unlockable:
|
||||||
return True
|
return True
|
||||||
|
|
||||||
min_tier = features[feature]
|
min_tier = features[feature]
|
||||||
|
|
@ -64,13 +73,18 @@ def tier_label(
|
||||||
has_byok: bool = False,
|
has_byok: bool = False,
|
||||||
has_local_vision: bool = False,
|
has_local_vision: bool = False,
|
||||||
_features: dict[str, str] | None = None,
|
_features: dict[str, str] | None = None,
|
||||||
|
_byok_unlockable: frozenset[str] | None = None,
|
||||||
|
_local_vision_unlockable: frozenset[str] | None = None,
|
||||||
) -> str:
|
) -> str:
|
||||||
"""Return a human-readable label for the minimum tier needed for feature."""
|
"""Return a human-readable label for the minimum tier needed for feature."""
|
||||||
features = _features or {}
|
features = _features or {}
|
||||||
|
byok_unlockable = _byok_unlockable if _byok_unlockable is not None else BYOK_UNLOCKABLE
|
||||||
|
local_vision_unlockable = _local_vision_unlockable if _local_vision_unlockable is not None else LOCAL_VISION_UNLOCKABLE
|
||||||
|
|
||||||
if feature not in features:
|
if feature not in features:
|
||||||
return "free"
|
return "free"
|
||||||
if has_byok and feature in BYOK_UNLOCKABLE:
|
if has_byok and feature in byok_unlockable:
|
||||||
return "free (BYOK)"
|
return "free (BYOK)"
|
||||||
if has_local_vision and feature in LOCAL_VISION_UNLOCKABLE:
|
if has_local_vision and feature in local_vision_unlockable:
|
||||||
return "free (local vision)"
|
return "free (local vision)"
|
||||||
return features[feature]
|
return features[feature]
|
||||||
|
|
|
||||||
|
|
@ -22,8 +22,12 @@ orch = [
|
||||||
"typer[all]>=0.12",
|
"typer[all]>=0.12",
|
||||||
"psutil>=5.9",
|
"psutil>=5.9",
|
||||||
]
|
]
|
||||||
|
tasks = [
|
||||||
|
"httpx>=0.27",
|
||||||
|
]
|
||||||
dev = [
|
dev = [
|
||||||
"circuitforge-core[orch]",
|
"circuitforge-core[orch]",
|
||||||
|
"circuitforge-core[tasks]",
|
||||||
"pytest>=8.0",
|
"pytest>=8.0",
|
||||||
"pytest-asyncio>=0.23",
|
"pytest-asyncio>=0.23",
|
||||||
"httpx>=0.27",
|
"httpx>=0.27",
|
||||||
|
|
|
||||||
0
tests/test_tasks/__init__.py
Normal file
0
tests/test_tasks/__init__.py
Normal file
285
tests/test_tasks/test_scheduler.py
Normal file
285
tests/test_tasks/test_scheduler.py
Normal file
|
|
@ -0,0 +1,285 @@
|
||||||
|
"""Tests for circuitforge_core.tasks.scheduler."""
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import sqlite3
|
||||||
|
import threading
|
||||||
|
import time
|
||||||
|
from pathlib import Path
|
||||||
|
from types import ModuleType
|
||||||
|
from typing import List
|
||||||
|
from unittest.mock import MagicMock, patch
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from circuitforge_core.tasks.scheduler import (
|
||||||
|
TaskScheduler,
|
||||||
|
detect_available_vram_gb,
|
||||||
|
get_scheduler,
|
||||||
|
reset_scheduler,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# ── Fixtures ──────────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def tmp_db(tmp_path: Path) -> Path:
|
||||||
|
"""SQLite DB with background_tasks table."""
|
||||||
|
db = tmp_path / "test.db"
|
||||||
|
conn = sqlite3.connect(db)
|
||||||
|
conn.execute("""
|
||||||
|
CREATE TABLE background_tasks (
|
||||||
|
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||||
|
task_type TEXT NOT NULL,
|
||||||
|
job_id INTEGER NOT NULL DEFAULT 0,
|
||||||
|
status TEXT NOT NULL DEFAULT 'queued',
|
||||||
|
params TEXT,
|
||||||
|
error TEXT,
|
||||||
|
created_at TEXT NOT NULL DEFAULT CURRENT_TIMESTAMP
|
||||||
|
)
|
||||||
|
""")
|
||||||
|
conn.commit()
|
||||||
|
conn.close()
|
||||||
|
return db
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(autouse=True)
|
||||||
|
def _reset_singleton():
|
||||||
|
"""Always tear down the scheduler singleton between tests."""
|
||||||
|
yield
|
||||||
|
reset_scheduler()
|
||||||
|
|
||||||
|
|
||||||
|
TASK_TYPES = frozenset({"fast_task"})
|
||||||
|
BUDGETS = {"fast_task": 1.0}
|
||||||
|
|
||||||
|
|
||||||
|
# ── detect_available_vram_gb ──────────────────────────────────────────────────
|
||||||
|
|
||||||
|
def test_detect_vram_from_cfortch():
|
||||||
|
"""Uses cf-orch free VRAM when coordinator is reachable."""
|
||||||
|
mock_resp = MagicMock()
|
||||||
|
mock_resp.status_code = 200
|
||||||
|
mock_resp.json.return_value = {
|
||||||
|
"nodes": [
|
||||||
|
{"node_id": "local", "gpus": [{"vram_free_mb": 4096}, {"vram_free_mb": 4096}]}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
with patch("circuitforge_core.tasks.scheduler.httpx") as mock_httpx:
|
||||||
|
mock_httpx.get.return_value = mock_resp
|
||||||
|
result = detect_available_vram_gb(coordinator_url="http://localhost:7700")
|
||||||
|
assert result == pytest.approx(8.0) # 4096 + 4096 MB → 8 GB
|
||||||
|
|
||||||
|
|
||||||
|
def test_detect_vram_cforch_unavailable_falls_back_to_unlimited():
|
||||||
|
"""Falls back to 999.0 when cf-orch is unreachable and preflight unavailable."""
|
||||||
|
with patch("circuitforge_core.tasks.scheduler.httpx") as mock_httpx:
|
||||||
|
mock_httpx.get.side_effect = ConnectionRefusedError()
|
||||||
|
result = detect_available_vram_gb()
|
||||||
|
assert result == 999.0
|
||||||
|
|
||||||
|
|
||||||
|
def test_detect_vram_cforch_empty_nodes_falls_back():
|
||||||
|
"""If cf-orch returns no nodes with GPUs, falls back to unlimited."""
|
||||||
|
mock_resp = MagicMock()
|
||||||
|
mock_resp.status_code = 200
|
||||||
|
mock_resp.json.return_value = {"nodes": []}
|
||||||
|
with patch("circuitforge_core.tasks.scheduler.httpx") as mock_httpx:
|
||||||
|
mock_httpx.get.return_value = mock_resp
|
||||||
|
result = detect_available_vram_gb()
|
||||||
|
assert result == 999.0
|
||||||
|
|
||||||
|
|
||||||
|
def test_detect_vram_preflight_fallback():
|
||||||
|
"""Falls back to preflight total VRAM when cf-orch is unreachable."""
|
||||||
|
# Build a fake scripts.preflight module with get_gpus returning two GPUs.
|
||||||
|
fake_scripts = ModuleType("scripts")
|
||||||
|
fake_preflight = ModuleType("scripts.preflight")
|
||||||
|
fake_preflight.get_gpus = lambda: [ # type: ignore[attr-defined]
|
||||||
|
{"vram_total_gb": 8.0},
|
||||||
|
{"vram_total_gb": 4.0},
|
||||||
|
]
|
||||||
|
fake_scripts.preflight = fake_preflight # type: ignore[attr-defined]
|
||||||
|
|
||||||
|
with patch("circuitforge_core.tasks.scheduler.httpx") as mock_httpx, \
|
||||||
|
patch.dict(
|
||||||
|
__import__("sys").modules,
|
||||||
|
{"scripts": fake_scripts, "scripts.preflight": fake_preflight},
|
||||||
|
):
|
||||||
|
mock_httpx.get.side_effect = ConnectionRefusedError()
|
||||||
|
result = detect_available_vram_gb()
|
||||||
|
|
||||||
|
assert result == pytest.approx(12.0) # 8.0 + 4.0 GB
|
||||||
|
|
||||||
|
|
||||||
|
# ── TaskScheduler basic behaviour ─────────────────────────────────────────────
|
||||||
|
|
||||||
|
def test_enqueue_returns_true_on_success(tmp_db: Path):
|
||||||
|
ran: List[int] = []
|
||||||
|
|
||||||
|
def run_fn(db_path, task_id, task_type, job_id, params):
|
||||||
|
ran.append(task_id)
|
||||||
|
|
||||||
|
sched = TaskScheduler(tmp_db, run_fn, TASK_TYPES, BUDGETS, available_vram_gb=8.0)
|
||||||
|
sched.start()
|
||||||
|
result = sched.enqueue(1, "fast_task", 0, None)
|
||||||
|
sched.shutdown()
|
||||||
|
assert result is True
|
||||||
|
|
||||||
|
|
||||||
|
def test_scheduler_runs_task(tmp_db: Path):
|
||||||
|
"""Enqueued task is executed by the batch worker."""
|
||||||
|
ran: List[int] = []
|
||||||
|
event = threading.Event()
|
||||||
|
|
||||||
|
def run_fn(db_path, task_id, task_type, job_id, params):
|
||||||
|
ran.append(task_id)
|
||||||
|
event.set()
|
||||||
|
|
||||||
|
sched = TaskScheduler(tmp_db, run_fn, TASK_TYPES, BUDGETS, available_vram_gb=8.0)
|
||||||
|
sched.start()
|
||||||
|
sched.enqueue(42, "fast_task", 0, None)
|
||||||
|
assert event.wait(timeout=3.0), "Task was not executed within 3 seconds"
|
||||||
|
sched.shutdown()
|
||||||
|
assert ran == [42]
|
||||||
|
|
||||||
|
|
||||||
|
def test_enqueue_returns_false_when_queue_full(tmp_db: Path):
|
||||||
|
"""Returns False and does not enqueue when max_queue_depth is reached."""
|
||||||
|
gate = threading.Event()
|
||||||
|
|
||||||
|
def blocking_run_fn(db_path, task_id, task_type, job_id, params):
|
||||||
|
gate.wait()
|
||||||
|
|
||||||
|
sched = TaskScheduler(
|
||||||
|
tmp_db, blocking_run_fn, TASK_TYPES, BUDGETS,
|
||||||
|
available_vram_gb=8.0, max_queue_depth=2
|
||||||
|
)
|
||||||
|
sched.start()
|
||||||
|
results = [sched.enqueue(i, "fast_task", 0, None) for i in range(1, 10)]
|
||||||
|
gate.set()
|
||||||
|
sched.shutdown()
|
||||||
|
assert not all(results), "Expected at least one enqueue to be rejected"
|
||||||
|
|
||||||
|
|
||||||
|
def test_scheduler_drains_multiple_tasks(tmp_db: Path):
|
||||||
|
"""All enqueued tasks of the same type are run serially."""
|
||||||
|
ran: List[int] = []
|
||||||
|
done = threading.Event()
|
||||||
|
TOTAL = 5
|
||||||
|
|
||||||
|
def run_fn(db_path, task_id, task_type, job_id, params):
|
||||||
|
ran.append(task_id)
|
||||||
|
if len(ran) >= TOTAL:
|
||||||
|
done.set()
|
||||||
|
|
||||||
|
sched = TaskScheduler(tmp_db, run_fn, TASK_TYPES, BUDGETS, available_vram_gb=8.0)
|
||||||
|
sched.start()
|
||||||
|
for i in range(1, TOTAL + 1):
|
||||||
|
sched.enqueue(i, "fast_task", 0, None)
|
||||||
|
assert done.wait(timeout=5.0), f"Only ran {len(ran)} of {TOTAL} tasks"
|
||||||
|
sched.shutdown()
|
||||||
|
assert sorted(ran) == list(range(1, TOTAL + 1))
|
||||||
|
|
||||||
|
|
||||||
|
def test_vram_budget_blocks_second_type(tmp_db: Path):
|
||||||
|
"""Second task type is not started when VRAM would be exceeded."""
|
||||||
|
type_a_started = threading.Event()
|
||||||
|
type_b_started = threading.Event()
|
||||||
|
gate_a = threading.Event()
|
||||||
|
gate_b = threading.Event()
|
||||||
|
started = []
|
||||||
|
|
||||||
|
def run_fn(db_path, task_id, task_type, job_id, params):
|
||||||
|
started.append(task_type)
|
||||||
|
if task_type == "type_a":
|
||||||
|
type_a_started.set()
|
||||||
|
gate_a.wait()
|
||||||
|
else:
|
||||||
|
type_b_started.set()
|
||||||
|
gate_b.wait()
|
||||||
|
|
||||||
|
two_types = frozenset({"type_a", "type_b"})
|
||||||
|
tight_budgets = {"type_a": 4.0, "type_b": 4.0} # 4+4 > 6 GB available
|
||||||
|
|
||||||
|
sched = TaskScheduler(
|
||||||
|
tmp_db, run_fn, two_types, tight_budgets, available_vram_gb=6.0
|
||||||
|
)
|
||||||
|
sched.start()
|
||||||
|
sched.enqueue(1, "type_a", 0, None)
|
||||||
|
sched.enqueue(2, "type_b", 0, None)
|
||||||
|
|
||||||
|
assert type_a_started.wait(timeout=3.0), "type_a never started"
|
||||||
|
assert not type_b_started.is_set(), "type_b should be blocked by VRAM"
|
||||||
|
|
||||||
|
gate_a.set()
|
||||||
|
assert type_b_started.wait(timeout=3.0), "type_b never started after type_a finished"
|
||||||
|
gate_b.set()
|
||||||
|
sched.shutdown()
|
||||||
|
assert sorted(started) == ["type_a", "type_b"]
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_scheduler_singleton(tmp_db: Path):
|
||||||
|
"""get_scheduler() returns the same instance on repeated calls."""
|
||||||
|
run_fn = MagicMock()
|
||||||
|
s1 = get_scheduler(tmp_db, run_fn, TASK_TYPES, BUDGETS)
|
||||||
|
s2 = get_scheduler(tmp_db) # no run_fn — should reuse existing
|
||||||
|
assert s1 is s2
|
||||||
|
|
||||||
|
|
||||||
|
def test_reset_scheduler_clears_singleton(tmp_db: Path):
|
||||||
|
"""reset_scheduler() allows a new singleton to be constructed."""
|
||||||
|
run_fn = MagicMock()
|
||||||
|
s1 = get_scheduler(tmp_db, run_fn, TASK_TYPES, BUDGETS)
|
||||||
|
reset_scheduler()
|
||||||
|
s2 = get_scheduler(tmp_db, run_fn, TASK_TYPES, BUDGETS)
|
||||||
|
assert s1 is not s2
|
||||||
|
|
||||||
|
|
||||||
|
def test_load_queued_tasks_on_startup(tmp_db: Path):
|
||||||
|
"""Tasks with status='queued' in the DB at startup are loaded into the deque."""
|
||||||
|
conn = sqlite3.connect(tmp_db)
|
||||||
|
conn.execute(
|
||||||
|
"INSERT INTO background_tasks (task_type, job_id, status) VALUES ('fast_task', 0, 'queued')"
|
||||||
|
)
|
||||||
|
conn.commit()
|
||||||
|
conn.close()
|
||||||
|
|
||||||
|
ran: List[int] = []
|
||||||
|
done = threading.Event()
|
||||||
|
|
||||||
|
def run_fn(db_path, task_id, task_type, job_id, params):
|
||||||
|
ran.append(task_id)
|
||||||
|
done.set()
|
||||||
|
|
||||||
|
sched = TaskScheduler(tmp_db, run_fn, TASK_TYPES, BUDGETS, available_vram_gb=8.0)
|
||||||
|
sched.start()
|
||||||
|
assert done.wait(timeout=3.0), "Pre-loaded task was not run"
|
||||||
|
sched.shutdown()
|
||||||
|
assert len(ran) == 1
|
||||||
|
|
||||||
|
|
||||||
|
def test_load_queued_tasks_missing_table_does_not_crash(tmp_path: Path):
|
||||||
|
"""Scheduler does not crash if background_tasks table doesn't exist yet."""
|
||||||
|
db = tmp_path / "empty.db"
|
||||||
|
sqlite3.connect(db).close()
|
||||||
|
|
||||||
|
run_fn = MagicMock()
|
||||||
|
sched = TaskScheduler(db, run_fn, TASK_TYPES, BUDGETS, available_vram_gb=8.0)
|
||||||
|
sched.start()
|
||||||
|
sched.shutdown()
|
||||||
|
# No exception = pass
|
||||||
|
|
||||||
|
|
||||||
|
def test_reserved_vram_zero_after_task_completes(tmp_db: Path):
|
||||||
|
"""_reserved_vram returns to 0.0 after a task finishes — no double-decrement."""
|
||||||
|
done = threading.Event()
|
||||||
|
|
||||||
|
def run_fn(db_path, task_id, task_type, job_id, params):
|
||||||
|
done.set()
|
||||||
|
|
||||||
|
sched = TaskScheduler(tmp_db, run_fn, TASK_TYPES, BUDGETS, available_vram_gb=8.0)
|
||||||
|
sched.start()
|
||||||
|
sched.enqueue(1, "fast_task", 0, None)
|
||||||
|
assert done.wait(timeout=3.0), "Task never completed"
|
||||||
|
sched.shutdown()
|
||||||
|
assert sched._reserved_vram == 0.0, f"Expected 0.0, got {sched._reserved_vram}"
|
||||||
Loading…
Reference in a new issue