refactor: replace coordinator-aware TaskScheduler with Protocol + LocalScheduler (MIT); update LLMRouter import path
This commit is contained in:
parent
090a86ce1b
commit
2259382d0b
3 changed files with 158 additions and 477 deletions
|
|
@ -1,6 +1,6 @@
|
|||
# circuitforge_core/tasks/__init__.py
|
||||
from circuitforge_core.tasks.scheduler import (
|
||||
TaskScheduler,
|
||||
LocalScheduler,
|
||||
detect_available_vram_gb,
|
||||
get_scheduler,
|
||||
reset_scheduler,
|
||||
|
|
@ -8,6 +8,7 @@ from circuitforge_core.tasks.scheduler import (
|
|||
|
||||
__all__ = [
|
||||
"TaskScheduler",
|
||||
"LocalScheduler",
|
||||
"detect_available_vram_gb",
|
||||
"get_scheduler",
|
||||
"reset_scheduler",
|
||||
|
|
|
|||
|
|
@ -1,21 +1,17 @@
|
|||
# circuitforge_core/tasks/scheduler.py
|
||||
"""Resource-aware batch scheduler for LLM background tasks.
|
||||
"""Task scheduler for CircuitForge products — MIT layer.
|
||||
|
||||
Generic scheduler that any CircuitForge product can use. Products supply:
|
||||
- task_types: frozenset[str] — task type strings routed through this scheduler
|
||||
- vram_budgets: dict[str, float] — VRAM GB estimate per task type
|
||||
- run_task_fn — product's task execution function
|
||||
Provides a simple FIFO task queue with no coordinator dependency.
|
||||
|
||||
VRAM detection priority:
|
||||
1. cf-orch coordinator /api/nodes — free VRAM (lease-aware, cooperative)
|
||||
2. scripts.preflight.get_gpus() — total GPU VRAM (Peregrine-era fallback)
|
||||
3. 999.0 — unlimited (CPU-only or no detection available)
|
||||
For coordinator-aware VRAM-budgeted scheduling on paid/premium tiers, install
|
||||
circuitforge-orch and use OrchestratedScheduler instead.
|
||||
|
||||
Public API:
|
||||
TaskScheduler — the scheduler class
|
||||
detect_available_vram_gb() — standalone VRAM query helper
|
||||
get_scheduler() — lazy process-level singleton
|
||||
reset_scheduler() — test teardown only
|
||||
TaskScheduler — Protocol defining the scheduler interface
|
||||
LocalScheduler — Simple FIFO queue implementation (MIT, no coordinator)
|
||||
detect_available_vram_gb() — Returns 999.0 (unlimited; no coordinator on free tier)
|
||||
get_scheduler() — Lazy process-level singleton returning a LocalScheduler
|
||||
reset_scheduler() — Test teardown only
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
|
|
@ -24,12 +20,7 @@ import sqlite3
|
|||
import threading
|
||||
from collections import deque
|
||||
from pathlib import Path
|
||||
from typing import Callable, NamedTuple, Optional
|
||||
|
||||
try:
|
||||
import httpx as httpx
|
||||
except ImportError:
|
||||
httpx = None # type: ignore[assignment]
|
||||
from typing import Callable, NamedTuple, Optional, Protocol, runtime_checkable
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
|
@ -41,68 +32,45 @@ class TaskSpec(NamedTuple):
|
|||
job_id: int
|
||||
params: Optional[str]
|
||||
|
||||
|
||||
_DEFAULT_MAX_QUEUE_DEPTH = 500
|
||||
|
||||
|
||||
def detect_available_vram_gb(
|
||||
coordinator_url: str = "http://localhost:7700",
|
||||
) -> float:
|
||||
"""Detect available VRAM GB for task scheduling.
|
||||
def detect_available_vram_gb() -> float:
|
||||
"""Return available VRAM for task scheduling.
|
||||
|
||||
Returns free VRAM via cf-orch (sum across all nodes/GPUs) so the scheduler
|
||||
cooperates with other cf-orch consumers. Falls back to preflight total VRAM,
|
||||
then 999.0 (unlimited) if nothing is reachable.
|
||||
Free tier (no coordinator): always returns 999.0 — no VRAM gating.
|
||||
For coordinator-aware VRAM detection use circuitforge_orch.scheduler.
|
||||
"""
|
||||
# 1. Try cf-orch: use free VRAM so the scheduler cooperates with other
|
||||
# cf-orch consumers (vision service, inference services, etc.)
|
||||
if httpx is not None:
|
||||
try:
|
||||
resp = httpx.get(f"{coordinator_url}/api/nodes", timeout=2.0)
|
||||
if resp.status_code == 200:
|
||||
nodes = resp.json().get("nodes", [])
|
||||
total_free_mb = sum(
|
||||
gpu.get("vram_free_mb", 0)
|
||||
for node in nodes
|
||||
for gpu in node.get("gpus", [])
|
||||
)
|
||||
if total_free_mb > 0:
|
||||
free_gb = total_free_mb / 1024.0
|
||||
logger.debug(
|
||||
"Scheduler VRAM from cf-orch: %.1f GB free", free_gb
|
||||
)
|
||||
return free_gb
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# 2. Try preflight (systems with nvidia-smi; Peregrine-era fallback)
|
||||
try:
|
||||
from scripts.preflight import get_gpus # type: ignore[import]
|
||||
|
||||
gpus = get_gpus()
|
||||
if gpus:
|
||||
total_gb = sum(g.get("vram_total_gb", 0.0) for g in gpus)
|
||||
logger.debug(
|
||||
"Scheduler VRAM from preflight: %.1f GB total", total_gb
|
||||
)
|
||||
return total_gb
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
logger.debug(
|
||||
"Scheduler VRAM detection unavailable — using unlimited (999 GB)"
|
||||
)
|
||||
return 999.0
|
||||
|
||||
|
||||
class TaskScheduler:
|
||||
"""Resource-aware LLM task batch scheduler.
|
||||
@runtime_checkable
|
||||
class TaskScheduler(Protocol):
|
||||
"""Protocol for task schedulers across free and paid tiers.
|
||||
|
||||
Runs one batch-worker thread per task type while total reserved VRAM
|
||||
stays within the detected available budget. Always allows at least one
|
||||
batch to start even if its budget exceeds available VRAM (prevents
|
||||
permanent starvation on low-VRAM systems).
|
||||
Both LocalScheduler (MIT) and OrchestratedScheduler (BSL, circuitforge-orch)
|
||||
implement this interface so products can inject either without API changes.
|
||||
"""
|
||||
|
||||
Thread-safety: all queue/active state protected by self._lock.
|
||||
def enqueue(self, task_id: int, task_type: str, job_id: int, params: Optional[str]) -> bool:
|
||||
"""Add a task to the queue. Returns True if enqueued, False if queue full."""
|
||||
...
|
||||
|
||||
def start(self) -> None:
|
||||
"""Start the background worker thread."""
|
||||
...
|
||||
|
||||
def shutdown(self, timeout: float = 5.0) -> None:
|
||||
"""Stop the scheduler and wait for it to exit."""
|
||||
...
|
||||
|
||||
|
||||
class LocalScheduler:
|
||||
"""Simple FIFO task scheduler with no coordinator dependency.
|
||||
|
||||
Processes tasks serially per task type. No VRAM gating — all tasks run.
|
||||
Suitable for free tier (single-host, up to 2 GPUs, static config).
|
||||
|
||||
Usage::
|
||||
|
||||
|
|
@ -112,11 +80,7 @@ class TaskScheduler:
|
|||
task_types=frozenset({"cover_letter", "research"}),
|
||||
vram_budgets={"cover_letter": 2.5, "research": 5.0},
|
||||
)
|
||||
task_id, is_new = insert_task(db_path, "cover_letter", job_id)
|
||||
if is_new:
|
||||
enqueued = sched.enqueue(task_id, "cover_letter", job_id, params_json)
|
||||
if not enqueued:
|
||||
mark_task_failed(db_path, task_id, "Queue full")
|
||||
enqueued = sched.enqueue(task_id, "cover_letter", job_id, params_json)
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
|
|
@ -125,11 +89,7 @@ class TaskScheduler:
|
|||
run_task_fn: RunTaskFn,
|
||||
task_types: frozenset[str],
|
||||
vram_budgets: dict[str, float],
|
||||
available_vram_gb: Optional[float] = None,
|
||||
max_queue_depth: int = _DEFAULT_MAX_QUEUE_DEPTH,
|
||||
coordinator_url: str = "http://localhost:7700",
|
||||
service_name: str = "peregrine",
|
||||
lease_priority: int = 2,
|
||||
) -> None:
|
||||
self._db_path = db_path
|
||||
self._run_task = run_task_fn
|
||||
|
|
@ -137,54 +97,22 @@ class TaskScheduler:
|
|||
self._budgets: dict[str, float] = dict(vram_budgets)
|
||||
self._max_queue_depth = max_queue_depth
|
||||
|
||||
self._coordinator_url = coordinator_url.rstrip("/")
|
||||
self._service_name = service_name
|
||||
self._lease_priority = lease_priority
|
||||
|
||||
self._lock = threading.Lock()
|
||||
self._wake = threading.Event()
|
||||
self._stop = threading.Event()
|
||||
self._queues: dict[str, deque[TaskSpec]] = {}
|
||||
self._active: dict[str, threading.Thread] = {}
|
||||
self._reserved_vram: float = 0.0
|
||||
self._thread: Optional[threading.Thread] = None
|
||||
|
||||
self._available_vram: float = (
|
||||
available_vram_gb
|
||||
if available_vram_gb is not None
|
||||
else detect_available_vram_gb()
|
||||
)
|
||||
|
||||
for t in self._task_types:
|
||||
if t not in self._budgets:
|
||||
logger.warning(
|
||||
"No VRAM budget defined for task type %r — "
|
||||
"defaulting to 0.0 GB (no VRAM gating for this type)",
|
||||
t,
|
||||
)
|
||||
|
||||
self._load_queued_tasks()
|
||||
|
||||
def enqueue(
|
||||
self,
|
||||
task_id: int,
|
||||
task_type: str,
|
||||
job_id: int,
|
||||
params: Optional[str],
|
||||
) -> bool:
|
||||
"""Add a task to the scheduler queue.
|
||||
|
||||
Returns True if enqueued successfully.
|
||||
Returns False if the queue is full — caller should mark the task failed.
|
||||
"""
|
||||
def enqueue(self, task_id: int, task_type: str, job_id: int, params: Optional[str]) -> bool:
|
||||
with self._lock:
|
||||
q = self._queues.setdefault(task_type, deque())
|
||||
if len(q) >= self._max_queue_depth:
|
||||
logger.warning(
|
||||
"Queue depth limit for %s (max=%d) — task %d dropped",
|
||||
task_type,
|
||||
self._max_queue_depth,
|
||||
task_id,
|
||||
task_type, self._max_queue_depth, task_id,
|
||||
)
|
||||
return False
|
||||
q.append(TaskSpec(task_id, job_id, params))
|
||||
|
|
@ -192,28 +120,19 @@ class TaskScheduler:
|
|||
return True
|
||||
|
||||
def start(self) -> None:
|
||||
"""Start the background scheduler loop thread. Call once after construction."""
|
||||
self._thread = threading.Thread(
|
||||
target=self._scheduler_loop, name="task-scheduler", daemon=True
|
||||
)
|
||||
self._thread.start()
|
||||
# Wake the loop immediately so tasks loaded from DB at startup are dispatched
|
||||
with self._lock:
|
||||
if any(self._queues.values()):
|
||||
self._wake.set()
|
||||
|
||||
def shutdown(self, timeout: float = 5.0) -> None:
|
||||
"""Signal the scheduler to stop and wait for it to exit.
|
||||
|
||||
Joins both the scheduler loop thread and any active batch worker
|
||||
threads so callers can rely on clean state (e.g. _reserved_vram == 0)
|
||||
immediately after this returns.
|
||||
"""
|
||||
self._stop.set()
|
||||
self._wake.set()
|
||||
if self._thread and self._thread.is_alive():
|
||||
self._thread.join(timeout=timeout)
|
||||
# Join active batch workers so _reserved_vram is settled on return
|
||||
with self._lock:
|
||||
workers = list(self._active.values())
|
||||
for worker in workers:
|
||||
|
|
@ -224,103 +143,25 @@ class TaskScheduler:
|
|||
self._wake.wait(timeout=30)
|
||||
self._wake.clear()
|
||||
with self._lock:
|
||||
# Reap batch threads that finished without waking us.
|
||||
# VRAM accounting is handled solely by _batch_worker's finally block;
|
||||
# the reaper only removes dead entries from _active.
|
||||
for t, thread in list(self._active.items()):
|
||||
if not thread.is_alive():
|
||||
del self._active[t]
|
||||
# Start new type batches while VRAM budget allows
|
||||
candidates = sorted(
|
||||
[
|
||||
t
|
||||
for t in self._queues
|
||||
if self._queues[t] and t not in self._active
|
||||
],
|
||||
[t for t in self._queues if self._queues[t] and t not in self._active],
|
||||
key=lambda t: len(self._queues[t]),
|
||||
reverse=True,
|
||||
)
|
||||
for task_type in candidates:
|
||||
budget = self._budgets.get(task_type, 0.0)
|
||||
# Always allow at least one batch to run
|
||||
if (
|
||||
self._reserved_vram == 0.0
|
||||
or self._reserved_vram + budget <= self._available_vram
|
||||
):
|
||||
thread = threading.Thread(
|
||||
target=self._batch_worker,
|
||||
args=(task_type,),
|
||||
name=f"batch-{task_type}",
|
||||
daemon=True,
|
||||
)
|
||||
self._active[task_type] = thread
|
||||
self._reserved_vram += budget
|
||||
thread.start()
|
||||
|
||||
def _acquire_lease(self, task_type: str) -> Optional[str]:
|
||||
"""Request a VRAM lease from the coordinator. Returns lease_id or None."""
|
||||
if httpx is None:
|
||||
return None
|
||||
budget_gb = self._budgets.get(task_type, 0.0)
|
||||
if budget_gb <= 0:
|
||||
return None
|
||||
mb = int(budget_gb * 1024)
|
||||
try:
|
||||
# Pick the GPU with the most free VRAM on the first registered node
|
||||
resp = httpx.get(f"{self._coordinator_url}/api/nodes", timeout=2.0)
|
||||
if resp.status_code != 200:
|
||||
return None
|
||||
nodes = resp.json().get("nodes", [])
|
||||
if not nodes:
|
||||
return None
|
||||
best_node = best_gpu = best_free = None
|
||||
for node in nodes:
|
||||
for gpu in node.get("gpus", []):
|
||||
free = gpu.get("vram_free_mb", 0)
|
||||
if best_free is None or free > best_free:
|
||||
best_node = node["node_id"]
|
||||
best_gpu = gpu["gpu_id"]
|
||||
best_free = free
|
||||
if best_node is None:
|
||||
return None
|
||||
lease_resp = httpx.post(
|
||||
f"{self._coordinator_url}/api/leases",
|
||||
json={
|
||||
"node_id": best_node,
|
||||
"gpu_id": best_gpu,
|
||||
"mb": mb,
|
||||
"service": self._service_name,
|
||||
"priority": self._lease_priority,
|
||||
},
|
||||
timeout=3.0,
|
||||
)
|
||||
if lease_resp.status_code == 200:
|
||||
lease_id = lease_resp.json()["lease"]["lease_id"]
|
||||
logger.debug(
|
||||
"Acquired VRAM lease %s for task_type=%s (%d MB)",
|
||||
lease_id, task_type, mb,
|
||||
)
|
||||
return lease_id
|
||||
except Exception as exc:
|
||||
logger.debug("Lease acquire failed (non-fatal): %s", exc)
|
||||
return None
|
||||
|
||||
def _release_lease(self, lease_id: str) -> None:
|
||||
"""Release a coordinator VRAM lease. Best-effort; failures are logged only."""
|
||||
if httpx is None or not lease_id:
|
||||
return
|
||||
try:
|
||||
httpx.delete(
|
||||
f"{self._coordinator_url}/api/leases/{lease_id}",
|
||||
timeout=3.0,
|
||||
)
|
||||
logger.debug("Released VRAM lease %s", lease_id)
|
||||
except Exception as exc:
|
||||
logger.debug("Lease release failed (non-fatal): %s", exc)
|
||||
thread = threading.Thread(
|
||||
target=self._batch_worker,
|
||||
args=(task_type,),
|
||||
name=f"batch-{task_type}",
|
||||
daemon=True,
|
||||
)
|
||||
self._active[task_type] = thread
|
||||
thread.start()
|
||||
|
||||
def _batch_worker(self, task_type: str) -> None:
|
||||
"""Serial consumer for one task type. Runs until the type's deque is empty."""
|
||||
lease_id: Optional[str] = self._acquire_lease(task_type)
|
||||
try:
|
||||
while True:
|
||||
with self._lock:
|
||||
|
|
@ -328,19 +169,13 @@ class TaskScheduler:
|
|||
if not q:
|
||||
break
|
||||
task = q.popleft()
|
||||
self._run_task(
|
||||
self._db_path, task.id, task_type, task.job_id, task.params
|
||||
)
|
||||
self._run_task(self._db_path, task.id, task_type, task.job_id, task.params)
|
||||
finally:
|
||||
if lease_id:
|
||||
self._release_lease(lease_id)
|
||||
with self._lock:
|
||||
self._active.pop(task_type, None)
|
||||
self._reserved_vram -= self._budgets.get(task_type, 0.0)
|
||||
self._wake.set()
|
||||
|
||||
def _load_queued_tasks(self) -> None:
|
||||
"""Reload surviving 'queued' tasks from SQLite into deques at startup."""
|
||||
if not self._task_types:
|
||||
return
|
||||
task_types_list = sorted(self._task_types)
|
||||
|
|
@ -354,68 +189,58 @@ class TaskScheduler:
|
|||
task_types_list,
|
||||
).fetchall()
|
||||
except sqlite3.OperationalError:
|
||||
# Table not yet created (first run before migrations)
|
||||
rows = []
|
||||
|
||||
for row_id, task_type, job_id, params in rows:
|
||||
q = self._queues.setdefault(task_type, deque())
|
||||
q.append(TaskSpec(row_id, job_id, params))
|
||||
|
||||
if rows:
|
||||
logger.info(
|
||||
"Scheduler: resumed %d queued task(s) from prior run", len(rows)
|
||||
)
|
||||
logger.info("Scheduler: resumed %d queued task(s) from prior run", len(rows))
|
||||
|
||||
|
||||
# ── Process-level singleton ────────────────────────────────────────────────────
|
||||
|
||||
_scheduler: Optional[TaskScheduler] = None
|
||||
_scheduler: Optional[LocalScheduler] = None
|
||||
_scheduler_lock = threading.Lock()
|
||||
|
||||
|
||||
def get_scheduler(
|
||||
db_path: Path,
|
||||
db_path: Optional[Path] = None,
|
||||
run_task_fn: Optional[RunTaskFn] = None,
|
||||
task_types: Optional[frozenset[str]] = None,
|
||||
vram_budgets: Optional[dict[str, float]] = None,
|
||||
max_queue_depth: int = _DEFAULT_MAX_QUEUE_DEPTH,
|
||||
coordinator_url: str = "http://localhost:7700",
|
||||
service_name: str = "peregrine",
|
||||
) -> TaskScheduler:
|
||||
"""Return the process-level TaskScheduler singleton.
|
||||
) -> LocalScheduler:
|
||||
"""Return the process-level LocalScheduler singleton.
|
||||
|
||||
``run_task_fn``, ``task_types``, and ``vram_budgets`` are required on the
|
||||
first call; ignored on subsequent calls (singleton already constructed).
|
||||
``run_task_fn``, ``task_types``, ``vram_budgets``, and ``db_path`` are
|
||||
required on the first call; ignored on subsequent calls.
|
||||
|
||||
VRAM detection (which may involve a network call) is performed outside the
|
||||
lock so the lock is never held across blocking I/O.
|
||||
``coordinator_url`` and ``service_name`` are accepted but ignored —
|
||||
LocalScheduler has no coordinator. They exist for API compatibility with
|
||||
OrchestratedScheduler call sites.
|
||||
"""
|
||||
global _scheduler
|
||||
if _scheduler is not None:
|
||||
return _scheduler
|
||||
# Build outside the lock — TaskScheduler.__init__ may call detect_available_vram_gb()
|
||||
# which makes an httpx network call (up to 2 s). Holding the lock during that
|
||||
# would block any concurrent caller for the full duration.
|
||||
if run_task_fn is None or task_types is None or vram_budgets is None:
|
||||
if run_task_fn is None or task_types is None or vram_budgets is None or db_path is None:
|
||||
raise ValueError(
|
||||
"run_task_fn, task_types, and vram_budgets are required "
|
||||
"db_path, run_task_fn, task_types, and vram_budgets are required "
|
||||
"on the first call to get_scheduler()"
|
||||
)
|
||||
candidate = TaskScheduler(
|
||||
candidate = LocalScheduler(
|
||||
db_path=db_path,
|
||||
run_task_fn=run_task_fn,
|
||||
task_types=task_types,
|
||||
vram_budgets=vram_budgets,
|
||||
max_queue_depth=max_queue_depth,
|
||||
coordinator_url=coordinator_url,
|
||||
service_name=service_name,
|
||||
)
|
||||
candidate.start()
|
||||
with _scheduler_lock:
|
||||
if _scheduler is None:
|
||||
_scheduler = candidate
|
||||
else:
|
||||
# Another thread beat us — shut down our candidate and use the winner.
|
||||
candidate.shutdown()
|
||||
return _scheduler
|
||||
|
||||
|
|
|
|||
|
|
@ -1,17 +1,14 @@
|
|||
"""Tests for circuitforge_core.tasks.scheduler."""
|
||||
"""Tests for TaskScheduler Protocol + LocalScheduler (MIT, no coordinator)."""
|
||||
from __future__ import annotations
|
||||
|
||||
import sqlite3
|
||||
import threading
|
||||
import time
|
||||
from pathlib import Path
|
||||
from types import ModuleType
|
||||
from typing import List
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from circuitforge_core.tasks.scheduler import (
|
||||
LocalScheduler,
|
||||
TaskScheduler,
|
||||
detect_available_vram_gb,
|
||||
get_scheduler,
|
||||
|
|
@ -19,267 +16,125 @@ from circuitforge_core.tasks.scheduler import (
|
|||
)
|
||||
|
||||
|
||||
# ── Fixtures ──────────────────────────────────────────────────────────────────
|
||||
|
||||
@pytest.fixture
|
||||
def tmp_db(tmp_path: Path) -> Path:
|
||||
"""SQLite DB with background_tasks table."""
|
||||
db = tmp_path / "test.db"
|
||||
conn = sqlite3.connect(db)
|
||||
conn.execute("""
|
||||
CREATE TABLE background_tasks (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
task_type TEXT NOT NULL,
|
||||
job_id INTEGER NOT NULL DEFAULT 0,
|
||||
status TEXT NOT NULL DEFAULT 'queued',
|
||||
params TEXT,
|
||||
error TEXT,
|
||||
created_at TEXT NOT NULL DEFAULT CURRENT_TIMESTAMP
|
||||
def db_path(tmp_path: Path) -> Path:
|
||||
p = tmp_path / "test.db"
|
||||
with sqlite3.connect(p) as conn:
|
||||
conn.execute(
|
||||
"CREATE TABLE background_tasks "
|
||||
"(id INTEGER PRIMARY KEY, task_type TEXT, job_id INTEGER, "
|
||||
"params TEXT, status TEXT DEFAULT 'queued', created_at TEXT DEFAULT '')"
|
||||
)
|
||||
""")
|
||||
conn.commit()
|
||||
conn.close()
|
||||
return db
|
||||
return p
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _reset_singleton():
|
||||
"""Always tear down the scheduler singleton between tests."""
|
||||
def clean_singleton():
|
||||
yield
|
||||
reset_scheduler()
|
||||
|
||||
|
||||
TASK_TYPES = frozenset({"fast_task"})
|
||||
BUDGETS = {"fast_task": 1.0}
|
||||
def make_run_fn(results: list):
|
||||
def run(db_path, task_id, task_type, job_id, params):
|
||||
results.append((task_type, task_id))
|
||||
time.sleep(0.01)
|
||||
return run
|
||||
|
||||
|
||||
# ── detect_available_vram_gb ──────────────────────────────────────────────────
|
||||
|
||||
def test_detect_vram_from_cfortch():
|
||||
"""Uses cf-orch free VRAM when coordinator is reachable."""
|
||||
mock_resp = MagicMock()
|
||||
mock_resp.status_code = 200
|
||||
mock_resp.json.return_value = {
|
||||
"nodes": [
|
||||
{"node_id": "local", "gpus": [{"vram_free_mb": 4096}, {"vram_free_mb": 4096}]}
|
||||
]
|
||||
}
|
||||
with patch("circuitforge_core.tasks.scheduler.httpx") as mock_httpx:
|
||||
mock_httpx.get.return_value = mock_resp
|
||||
result = detect_available_vram_gb(coordinator_url="http://localhost:7700")
|
||||
assert result == pytest.approx(8.0) # 4096 + 4096 MB → 8 GB
|
||||
def test_local_scheduler_implements_protocol():
|
||||
assert isinstance(LocalScheduler.__new__(LocalScheduler), TaskScheduler)
|
||||
|
||||
|
||||
def test_detect_vram_cforch_unavailable_falls_back_to_unlimited():
|
||||
"""Falls back to 999.0 when cf-orch is unreachable and preflight unavailable."""
|
||||
with patch("circuitforge_core.tasks.scheduler.httpx") as mock_httpx:
|
||||
mock_httpx.get.side_effect = ConnectionRefusedError()
|
||||
result = detect_available_vram_gb()
|
||||
assert result == 999.0
|
||||
def test_detect_available_vram_returns_unlimited():
|
||||
assert detect_available_vram_gb() == 999.0
|
||||
|
||||
|
||||
def test_detect_vram_cforch_empty_nodes_falls_back():
|
||||
"""If cf-orch returns no nodes with GPUs, falls back to unlimited."""
|
||||
mock_resp = MagicMock()
|
||||
mock_resp.status_code = 200
|
||||
mock_resp.json.return_value = {"nodes": []}
|
||||
with patch("circuitforge_core.tasks.scheduler.httpx") as mock_httpx:
|
||||
mock_httpx.get.return_value = mock_resp
|
||||
result = detect_available_vram_gb()
|
||||
assert result == 999.0
|
||||
|
||||
|
||||
def test_detect_vram_preflight_fallback():
|
||||
"""Falls back to preflight total VRAM when cf-orch is unreachable."""
|
||||
# Build a fake scripts.preflight module with get_gpus returning two GPUs.
|
||||
fake_scripts = ModuleType("scripts")
|
||||
fake_preflight = ModuleType("scripts.preflight")
|
||||
fake_preflight.get_gpus = lambda: [ # type: ignore[attr-defined]
|
||||
{"vram_total_gb": 8.0},
|
||||
{"vram_total_gb": 4.0},
|
||||
]
|
||||
fake_scripts.preflight = fake_preflight # type: ignore[attr-defined]
|
||||
|
||||
with patch("circuitforge_core.tasks.scheduler.httpx") as mock_httpx, \
|
||||
patch.dict(
|
||||
__import__("sys").modules,
|
||||
{"scripts": fake_scripts, "scripts.preflight": fake_preflight},
|
||||
):
|
||||
mock_httpx.get.side_effect = ConnectionRefusedError()
|
||||
result = detect_available_vram_gb()
|
||||
|
||||
assert result == pytest.approx(12.0) # 8.0 + 4.0 GB
|
||||
|
||||
|
||||
# ── TaskScheduler basic behaviour ─────────────────────────────────────────────
|
||||
|
||||
def test_enqueue_returns_true_on_success(tmp_db: Path):
|
||||
ran: List[int] = []
|
||||
|
||||
def run_fn(db_path, task_id, task_type, job_id, params):
|
||||
ran.append(task_id)
|
||||
|
||||
sched = TaskScheduler(tmp_db, run_fn, TASK_TYPES, BUDGETS, available_vram_gb=8.0)
|
||||
sched.start()
|
||||
result = sched.enqueue(1, "fast_task", 0, None)
|
||||
sched.shutdown()
|
||||
assert result is True
|
||||
|
||||
|
||||
def test_scheduler_runs_task(tmp_db: Path):
|
||||
"""Enqueued task is executed by the batch worker."""
|
||||
ran: List[int] = []
|
||||
event = threading.Event()
|
||||
|
||||
def run_fn(db_path, task_id, task_type, job_id, params):
|
||||
ran.append(task_id)
|
||||
event.set()
|
||||
|
||||
sched = TaskScheduler(tmp_db, run_fn, TASK_TYPES, BUDGETS, available_vram_gb=8.0)
|
||||
sched.start()
|
||||
sched.enqueue(42, "fast_task", 0, None)
|
||||
assert event.wait(timeout=3.0), "Task was not executed within 3 seconds"
|
||||
sched.shutdown()
|
||||
assert ran == [42]
|
||||
|
||||
|
||||
def test_enqueue_returns_false_when_queue_full(tmp_db: Path):
|
||||
"""Returns False and does not enqueue when max_queue_depth is reached."""
|
||||
gate = threading.Event()
|
||||
|
||||
def blocking_run_fn(db_path, task_id, task_type, job_id, params):
|
||||
gate.wait()
|
||||
|
||||
sched = TaskScheduler(
|
||||
tmp_db, blocking_run_fn, TASK_TYPES, BUDGETS,
|
||||
available_vram_gb=8.0, max_queue_depth=2
|
||||
def test_enqueue_and_execute(db_path):
|
||||
results = []
|
||||
sched = LocalScheduler(
|
||||
db_path=db_path,
|
||||
run_task_fn=make_run_fn(results),
|
||||
task_types=frozenset({"cover_letter"}),
|
||||
vram_budgets={"cover_letter": 0.0},
|
||||
)
|
||||
sched.start()
|
||||
results = [sched.enqueue(i, "fast_task", 0, None) for i in range(1, 10)]
|
||||
gate.set()
|
||||
sched.enqueue(1, "cover_letter", 1, None)
|
||||
time.sleep(0.3)
|
||||
sched.shutdown()
|
||||
assert not all(results), "Expected at least one enqueue to be rejected"
|
||||
assert ("cover_letter", 1) in results
|
||||
|
||||
|
||||
def test_scheduler_drains_multiple_tasks(tmp_db: Path):
|
||||
"""All enqueued tasks of the same type are run serially."""
|
||||
ran: List[int] = []
|
||||
done = threading.Event()
|
||||
TOTAL = 5
|
||||
|
||||
def run_fn(db_path, task_id, task_type, job_id, params):
|
||||
ran.append(task_id)
|
||||
if len(ran) >= TOTAL:
|
||||
done.set()
|
||||
|
||||
sched = TaskScheduler(tmp_db, run_fn, TASK_TYPES, BUDGETS, available_vram_gb=8.0)
|
||||
sched.start()
|
||||
for i in range(1, TOTAL + 1):
|
||||
sched.enqueue(i, "fast_task", 0, None)
|
||||
assert done.wait(timeout=5.0), f"Only ran {len(ran)} of {TOTAL} tasks"
|
||||
sched.shutdown()
|
||||
assert sorted(ran) == list(range(1, TOTAL + 1))
|
||||
|
||||
|
||||
def test_vram_budget_blocks_second_type(tmp_db: Path):
|
||||
"""Second task type is not started when VRAM would be exceeded."""
|
||||
type_a_started = threading.Event()
|
||||
type_b_started = threading.Event()
|
||||
gate_a = threading.Event()
|
||||
gate_b = threading.Event()
|
||||
started = []
|
||||
|
||||
def run_fn(db_path, task_id, task_type, job_id, params):
|
||||
started.append(task_type)
|
||||
if task_type == "type_a":
|
||||
type_a_started.set()
|
||||
gate_a.wait()
|
||||
else:
|
||||
type_b_started.set()
|
||||
gate_b.wait()
|
||||
|
||||
two_types = frozenset({"type_a", "type_b"})
|
||||
tight_budgets = {"type_a": 4.0, "type_b": 4.0} # 4+4 > 6 GB available
|
||||
|
||||
sched = TaskScheduler(
|
||||
tmp_db, run_fn, two_types, tight_budgets, available_vram_gb=6.0
|
||||
def test_fifo_ordering(db_path):
|
||||
results = []
|
||||
sched = LocalScheduler(
|
||||
db_path=db_path,
|
||||
run_task_fn=make_run_fn(results),
|
||||
task_types=frozenset({"t"}),
|
||||
vram_budgets={"t": 0.0},
|
||||
)
|
||||
sched.start()
|
||||
sched.enqueue(1, "type_a", 0, None)
|
||||
sched.enqueue(2, "type_b", 0, None)
|
||||
|
||||
assert type_a_started.wait(timeout=3.0), "type_a never started"
|
||||
assert not type_b_started.is_set(), "type_b should be blocked by VRAM"
|
||||
|
||||
gate_a.set()
|
||||
assert type_b_started.wait(timeout=3.0), "type_b never started after type_a finished"
|
||||
gate_b.set()
|
||||
sched.enqueue(1, "t", 1, None)
|
||||
sched.enqueue(2, "t", 1, None)
|
||||
sched.enqueue(3, "t", 1, None)
|
||||
time.sleep(0.5)
|
||||
sched.shutdown()
|
||||
assert sorted(started) == ["type_a", "type_b"]
|
||||
assert [r[1] for r in results] == [1, 2, 3]
|
||||
|
||||
|
||||
def test_get_scheduler_singleton(tmp_db: Path):
|
||||
"""get_scheduler() returns the same instance on repeated calls."""
|
||||
run_fn = MagicMock()
|
||||
s1 = get_scheduler(tmp_db, run_fn, TASK_TYPES, BUDGETS)
|
||||
s2 = get_scheduler(tmp_db) # no run_fn — should reuse existing
|
||||
def test_queue_depth_limit(db_path):
|
||||
sched = LocalScheduler(
|
||||
db_path=db_path,
|
||||
run_task_fn=make_run_fn([]),
|
||||
task_types=frozenset({"t"}),
|
||||
vram_budgets={"t": 0.0},
|
||||
max_queue_depth=2,
|
||||
)
|
||||
assert sched.enqueue(1, "t", 1, None) is True
|
||||
assert sched.enqueue(2, "t", 1, None) is True
|
||||
assert sched.enqueue(3, "t", 1, None) is False
|
||||
|
||||
|
||||
def test_get_scheduler_singleton(db_path):
|
||||
results = []
|
||||
s1 = get_scheduler(
|
||||
db_path=db_path,
|
||||
run_task_fn=make_run_fn(results),
|
||||
task_types=frozenset({"t"}),
|
||||
vram_budgets={"t": 0.0},
|
||||
)
|
||||
s2 = get_scheduler(db_path=db_path)
|
||||
assert s1 is s2
|
||||
s1.shutdown()
|
||||
|
||||
|
||||
def test_reset_scheduler_clears_singleton(tmp_db: Path):
|
||||
"""reset_scheduler() allows a new singleton to be constructed."""
|
||||
run_fn = MagicMock()
|
||||
s1 = get_scheduler(tmp_db, run_fn, TASK_TYPES, BUDGETS)
|
||||
reset_scheduler()
|
||||
s2 = get_scheduler(tmp_db, run_fn, TASK_TYPES, BUDGETS)
|
||||
assert s1 is not s2
|
||||
def test_local_scheduler_no_httpx_dependency():
|
||||
"""LocalScheduler must not import httpx — not in MIT core's hard deps."""
|
||||
import ast, inspect
|
||||
from circuitforge_core.tasks import scheduler as sched_mod
|
||||
src = inspect.getsource(sched_mod)
|
||||
tree = ast.parse(src)
|
||||
for node in ast.walk(tree):
|
||||
if isinstance(node, (ast.Import, ast.ImportFrom)):
|
||||
names = [a.name for a in getattr(node, 'names', [])]
|
||||
module = getattr(node, 'module', '') or ''
|
||||
assert 'httpx' not in names and 'httpx' not in module, \
|
||||
"LocalScheduler must not import httpx"
|
||||
|
||||
|
||||
def test_load_queued_tasks_on_startup(tmp_db: Path):
|
||||
"""Tasks with status='queued' in the DB at startup are loaded into the deque."""
|
||||
conn = sqlite3.connect(tmp_db)
|
||||
conn.execute(
|
||||
"INSERT INTO background_tasks (task_type, job_id, status) VALUES ('fast_task', 0, 'queued')"
|
||||
def test_load_queued_tasks_on_startup(db_path):
|
||||
"""Tasks with status='queued' in the DB at startup are loaded and run."""
|
||||
with sqlite3.connect(db_path) as conn:
|
||||
conn.execute(
|
||||
"INSERT INTO background_tasks (id, task_type, job_id, status) VALUES (99, 't', 1, 'queued')"
|
||||
)
|
||||
results = []
|
||||
sched = LocalScheduler(
|
||||
db_path=db_path,
|
||||
run_task_fn=make_run_fn(results),
|
||||
task_types=frozenset({"t"}),
|
||||
vram_budgets={"t": 0.0},
|
||||
)
|
||||
conn.commit()
|
||||
conn.close()
|
||||
|
||||
ran: List[int] = []
|
||||
done = threading.Event()
|
||||
|
||||
def run_fn(db_path, task_id, task_type, job_id, params):
|
||||
ran.append(task_id)
|
||||
done.set()
|
||||
|
||||
sched = TaskScheduler(tmp_db, run_fn, TASK_TYPES, BUDGETS, available_vram_gb=8.0)
|
||||
sched.start()
|
||||
assert done.wait(timeout=3.0), "Pre-loaded task was not run"
|
||||
time.sleep(0.3)
|
||||
sched.shutdown()
|
||||
assert len(ran) == 1
|
||||
|
||||
|
||||
def test_load_queued_tasks_missing_table_does_not_crash(tmp_path: Path):
|
||||
"""Scheduler does not crash if background_tasks table doesn't exist yet."""
|
||||
db = tmp_path / "empty.db"
|
||||
sqlite3.connect(db).close()
|
||||
|
||||
run_fn = MagicMock()
|
||||
sched = TaskScheduler(db, run_fn, TASK_TYPES, BUDGETS, available_vram_gb=8.0)
|
||||
sched.start()
|
||||
sched.shutdown()
|
||||
# No exception = pass
|
||||
|
||||
|
||||
def test_reserved_vram_zero_after_task_completes(tmp_db: Path):
|
||||
"""_reserved_vram returns to 0.0 after a task finishes — no double-decrement."""
|
||||
done = threading.Event()
|
||||
|
||||
def run_fn(db_path, task_id, task_type, job_id, params):
|
||||
done.set()
|
||||
|
||||
sched = TaskScheduler(tmp_db, run_fn, TASK_TYPES, BUDGETS, available_vram_gb=8.0)
|
||||
sched.start()
|
||||
sched.enqueue(1, "fast_task", 0, None)
|
||||
assert done.wait(timeout=3.0), "Task never completed"
|
||||
sched.shutdown()
|
||||
assert sched._reserved_vram == 0.0, f"Expected 0.0, got {sched._reserved_vram}"
|
||||
assert ("t", 99) in results
|
||||
|
|
|
|||
Loading…
Reference in a new issue