Merge pull request 'feat(tasks): shared VRAM-aware LLM task scheduler' (#2) from feature/shared-task-scheduler into main
This commit is contained in:
commit
563b73ce85
8 changed files with 678 additions and 13 deletions
|
|
@ -22,7 +22,9 @@ def get_connection(db_path: Path, key: str = "") -> sqlite3.Connection:
|
|||
cloud_mode = os.environ.get("CLOUD_MODE", "").lower() in ("1", "true", "yes")
|
||||
if cloud_mode and key:
|
||||
from pysqlcipher3 import dbapi2 as _sqlcipher # type: ignore
|
||||
conn = _sqlcipher.connect(str(db_path))
|
||||
conn = _sqlcipher.connect(str(db_path), timeout=30)
|
||||
conn.execute(f"PRAGMA key='{key}'")
|
||||
return conn
|
||||
return sqlite3.connect(str(db_path))
|
||||
# timeout=30: retry for up to 30s when another writer holds the lock (WAL mode
|
||||
# allows concurrent readers but only one writer at a time).
|
||||
return sqlite3.connect(str(db_path), timeout=30)
|
||||
|
|
|
|||
|
|
@ -23,5 +23,7 @@ def run_migrations(conn: sqlite3.Connection, migrations_dir: Path) -> None:
|
|||
if sql_file.name in applied:
|
||||
continue
|
||||
conn.executescript(sql_file.read_text())
|
||||
conn.execute("INSERT INTO _migrations (name) VALUES (?)", (sql_file.name,))
|
||||
# OR IGNORE: safe if two Store() calls race on the same DB — second writer
|
||||
# just skips the insert rather than raising UNIQUE constraint failed.
|
||||
conn.execute("INSERT OR IGNORE INTO _migrations (name) VALUES (?)", (sql_file.name,))
|
||||
conn.commit()
|
||||
|
|
|
|||
14
circuitforge_core/tasks/__init__.py
Normal file
14
circuitforge_core/tasks/__init__.py
Normal file
|
|
@ -0,0 +1,14 @@
|
|||
# circuitforge_core/tasks/__init__.py
|
||||
from circuitforge_core.tasks.scheduler import (
|
||||
TaskScheduler,
|
||||
detect_available_vram_gb,
|
||||
get_scheduler,
|
||||
reset_scheduler,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"TaskScheduler",
|
||||
"detect_available_vram_gb",
|
||||
"get_scheduler",
|
||||
"reset_scheduler",
|
||||
]
|
||||
344
circuitforge_core/tasks/scheduler.py
Normal file
344
circuitforge_core/tasks/scheduler.py
Normal file
|
|
@ -0,0 +1,344 @@
|
|||
# circuitforge_core/tasks/scheduler.py
|
||||
"""Resource-aware batch scheduler for LLM background tasks.
|
||||
|
||||
Generic scheduler that any CircuitForge product can use. Products supply:
|
||||
- task_types: frozenset[str] — task type strings routed through this scheduler
|
||||
- vram_budgets: dict[str, float] — VRAM GB estimate per task type
|
||||
- run_task_fn — product's task execution function
|
||||
|
||||
VRAM detection priority:
|
||||
1. cf-orch coordinator /api/nodes — free VRAM (lease-aware, cooperative)
|
||||
2. scripts.preflight.get_gpus() — total GPU VRAM (Peregrine-era fallback)
|
||||
3. 999.0 — unlimited (CPU-only or no detection available)
|
||||
|
||||
Public API:
|
||||
TaskScheduler — the scheduler class
|
||||
detect_available_vram_gb() — standalone VRAM query helper
|
||||
get_scheduler() — lazy process-level singleton
|
||||
reset_scheduler() — test teardown only
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import sqlite3
|
||||
import threading
|
||||
from collections import deque
|
||||
from pathlib import Path
|
||||
from typing import Callable, NamedTuple, Optional
|
||||
|
||||
try:
|
||||
import httpx as httpx
|
||||
except ImportError:
|
||||
httpx = None # type: ignore[assignment]
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
RunTaskFn = Callable[["Path", int, str, int, Optional[str]], None]
|
||||
|
||||
|
||||
class TaskSpec(NamedTuple):
|
||||
id: int
|
||||
job_id: int
|
||||
params: Optional[str]
|
||||
|
||||
_DEFAULT_MAX_QUEUE_DEPTH = 500
|
||||
|
||||
|
||||
def detect_available_vram_gb(
|
||||
coordinator_url: str = "http://localhost:7700",
|
||||
) -> float:
|
||||
"""Detect available VRAM GB for task scheduling.
|
||||
|
||||
Returns free VRAM via cf-orch (sum across all nodes/GPUs) so the scheduler
|
||||
cooperates with other cf-orch consumers. Falls back to preflight total VRAM,
|
||||
then 999.0 (unlimited) if nothing is reachable.
|
||||
"""
|
||||
# 1. Try cf-orch: use free VRAM so the scheduler cooperates with other
|
||||
# cf-orch consumers (vision service, inference services, etc.)
|
||||
if httpx is not None:
|
||||
try:
|
||||
resp = httpx.get(f"{coordinator_url}/api/nodes", timeout=2.0)
|
||||
if resp.status_code == 200:
|
||||
nodes = resp.json().get("nodes", [])
|
||||
total_free_mb = sum(
|
||||
gpu.get("vram_free_mb", 0)
|
||||
for node in nodes
|
||||
for gpu in node.get("gpus", [])
|
||||
)
|
||||
if total_free_mb > 0:
|
||||
free_gb = total_free_mb / 1024.0
|
||||
logger.debug(
|
||||
"Scheduler VRAM from cf-orch: %.1f GB free", free_gb
|
||||
)
|
||||
return free_gb
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# 2. Try preflight (systems with nvidia-smi; Peregrine-era fallback)
|
||||
try:
|
||||
from scripts.preflight import get_gpus # type: ignore[import]
|
||||
|
||||
gpus = get_gpus()
|
||||
if gpus:
|
||||
total_gb = sum(g.get("vram_total_gb", 0.0) for g in gpus)
|
||||
logger.debug(
|
||||
"Scheduler VRAM from preflight: %.1f GB total", total_gb
|
||||
)
|
||||
return total_gb
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
logger.debug(
|
||||
"Scheduler VRAM detection unavailable — using unlimited (999 GB)"
|
||||
)
|
||||
return 999.0
|
||||
|
||||
|
||||
class TaskScheduler:
|
||||
"""Resource-aware LLM task batch scheduler.
|
||||
|
||||
Runs one batch-worker thread per task type while total reserved VRAM
|
||||
stays within the detected available budget. Always allows at least one
|
||||
batch to start even if its budget exceeds available VRAM (prevents
|
||||
permanent starvation on low-VRAM systems).
|
||||
|
||||
Thread-safety: all queue/active state protected by self._lock.
|
||||
|
||||
Usage::
|
||||
|
||||
sched = get_scheduler(
|
||||
db_path=Path("data/app.db"),
|
||||
run_task_fn=my_run_task,
|
||||
task_types=frozenset({"cover_letter", "research"}),
|
||||
vram_budgets={"cover_letter": 2.5, "research": 5.0},
|
||||
)
|
||||
task_id, is_new = insert_task(db_path, "cover_letter", job_id)
|
||||
if is_new:
|
||||
enqueued = sched.enqueue(task_id, "cover_letter", job_id, params_json)
|
||||
if not enqueued:
|
||||
mark_task_failed(db_path, task_id, "Queue full")
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
db_path: Path,
|
||||
run_task_fn: RunTaskFn,
|
||||
task_types: frozenset[str],
|
||||
vram_budgets: dict[str, float],
|
||||
available_vram_gb: Optional[float] = None,
|
||||
max_queue_depth: int = _DEFAULT_MAX_QUEUE_DEPTH,
|
||||
) -> None:
|
||||
self._db_path = db_path
|
||||
self._run_task = run_task_fn
|
||||
self._task_types = frozenset(task_types)
|
||||
self._budgets: dict[str, float] = dict(vram_budgets)
|
||||
self._max_queue_depth = max_queue_depth
|
||||
|
||||
self._lock = threading.Lock()
|
||||
self._wake = threading.Event()
|
||||
self._stop = threading.Event()
|
||||
self._queues: dict[str, deque[TaskSpec]] = {}
|
||||
self._active: dict[str, threading.Thread] = {}
|
||||
self._reserved_vram: float = 0.0
|
||||
self._thread: Optional[threading.Thread] = None
|
||||
|
||||
self._available_vram: float = (
|
||||
available_vram_gb
|
||||
if available_vram_gb is not None
|
||||
else detect_available_vram_gb()
|
||||
)
|
||||
|
||||
for t in self._task_types:
|
||||
if t not in self._budgets:
|
||||
logger.warning(
|
||||
"No VRAM budget defined for task type %r — "
|
||||
"defaulting to 0.0 GB (no VRAM gating for this type)",
|
||||
t,
|
||||
)
|
||||
|
||||
self._load_queued_tasks()
|
||||
|
||||
def enqueue(
|
||||
self,
|
||||
task_id: int,
|
||||
task_type: str,
|
||||
job_id: int,
|
||||
params: Optional[str],
|
||||
) -> bool:
|
||||
"""Add a task to the scheduler queue.
|
||||
|
||||
Returns True if enqueued successfully.
|
||||
Returns False if the queue is full — caller should mark the task failed.
|
||||
"""
|
||||
with self._lock:
|
||||
q = self._queues.setdefault(task_type, deque())
|
||||
if len(q) >= self._max_queue_depth:
|
||||
logger.warning(
|
||||
"Queue depth limit for %s (max=%d) — task %d dropped",
|
||||
task_type,
|
||||
self._max_queue_depth,
|
||||
task_id,
|
||||
)
|
||||
return False
|
||||
q.append(TaskSpec(task_id, job_id, params))
|
||||
self._wake.set()
|
||||
return True
|
||||
|
||||
def start(self) -> None:
|
||||
"""Start the background scheduler loop thread. Call once after construction."""
|
||||
self._thread = threading.Thread(
|
||||
target=self._scheduler_loop, name="task-scheduler", daemon=True
|
||||
)
|
||||
self._thread.start()
|
||||
# Wake the loop immediately so tasks loaded from DB at startup are dispatched
|
||||
with self._lock:
|
||||
if any(self._queues.values()):
|
||||
self._wake.set()
|
||||
|
||||
def shutdown(self, timeout: float = 5.0) -> None:
|
||||
"""Signal the scheduler to stop and wait for it to exit."""
|
||||
self._stop.set()
|
||||
self._wake.set()
|
||||
if self._thread and self._thread.is_alive():
|
||||
self._thread.join(timeout=timeout)
|
||||
|
||||
def _scheduler_loop(self) -> None:
|
||||
while not self._stop.is_set():
|
||||
self._wake.wait(timeout=30)
|
||||
self._wake.clear()
|
||||
with self._lock:
|
||||
# Reap batch threads that finished without waking us.
|
||||
# VRAM accounting is handled solely by _batch_worker's finally block;
|
||||
# the reaper only removes dead entries from _active.
|
||||
for t, thread in list(self._active.items()):
|
||||
if not thread.is_alive():
|
||||
del self._active[t]
|
||||
# Start new type batches while VRAM budget allows
|
||||
candidates = sorted(
|
||||
[
|
||||
t
|
||||
for t in self._queues
|
||||
if self._queues[t] and t not in self._active
|
||||
],
|
||||
key=lambda t: len(self._queues[t]),
|
||||
reverse=True,
|
||||
)
|
||||
for task_type in candidates:
|
||||
budget = self._budgets.get(task_type, 0.0)
|
||||
# Always allow at least one batch to run
|
||||
if (
|
||||
self._reserved_vram == 0.0
|
||||
or self._reserved_vram + budget <= self._available_vram
|
||||
):
|
||||
thread = threading.Thread(
|
||||
target=self._batch_worker,
|
||||
args=(task_type,),
|
||||
name=f"batch-{task_type}",
|
||||
daemon=True,
|
||||
)
|
||||
self._active[task_type] = thread
|
||||
self._reserved_vram += budget
|
||||
thread.start()
|
||||
|
||||
def _batch_worker(self, task_type: str) -> None:
|
||||
"""Serial consumer for one task type. Runs until the type's deque is empty."""
|
||||
try:
|
||||
while True:
|
||||
with self._lock:
|
||||
q = self._queues.get(task_type)
|
||||
if not q:
|
||||
break
|
||||
task = q.popleft()
|
||||
self._run_task(
|
||||
self._db_path, task.id, task_type, task.job_id, task.params
|
||||
)
|
||||
finally:
|
||||
with self._lock:
|
||||
self._active.pop(task_type, None)
|
||||
self._reserved_vram -= self._budgets.get(task_type, 0.0)
|
||||
self._wake.set()
|
||||
|
||||
def _load_queued_tasks(self) -> None:
|
||||
"""Reload surviving 'queued' tasks from SQLite into deques at startup."""
|
||||
if not self._task_types:
|
||||
return
|
||||
task_types_list = sorted(self._task_types)
|
||||
placeholders = ",".join("?" * len(task_types_list))
|
||||
try:
|
||||
with sqlite3.connect(self._db_path) as conn:
|
||||
rows = conn.execute(
|
||||
f"SELECT id, task_type, job_id, params FROM background_tasks"
|
||||
f" WHERE status='queued' AND task_type IN ({placeholders})"
|
||||
f" ORDER BY created_at ASC",
|
||||
task_types_list,
|
||||
).fetchall()
|
||||
except sqlite3.OperationalError:
|
||||
# Table not yet created (first run before migrations)
|
||||
rows = []
|
||||
|
||||
for row_id, task_type, job_id, params in rows:
|
||||
q = self._queues.setdefault(task_type, deque())
|
||||
q.append(TaskSpec(row_id, job_id, params))
|
||||
|
||||
if rows:
|
||||
logger.info(
|
||||
"Scheduler: resumed %d queued task(s) from prior run", len(rows)
|
||||
)
|
||||
|
||||
|
||||
# ── Process-level singleton ────────────────────────────────────────────────────
|
||||
|
||||
_scheduler: Optional[TaskScheduler] = None
|
||||
_scheduler_lock = threading.Lock()
|
||||
|
||||
|
||||
def get_scheduler(
|
||||
db_path: Path,
|
||||
run_task_fn: Optional[RunTaskFn] = None,
|
||||
task_types: Optional[frozenset[str]] = None,
|
||||
vram_budgets: Optional[dict[str, float]] = None,
|
||||
max_queue_depth: int = _DEFAULT_MAX_QUEUE_DEPTH,
|
||||
) -> TaskScheduler:
|
||||
"""Return the process-level TaskScheduler singleton.
|
||||
|
||||
``run_task_fn``, ``task_types``, and ``vram_budgets`` are required on the
|
||||
first call; ignored on subsequent calls (singleton already constructed).
|
||||
|
||||
VRAM detection (which may involve a network call) is performed outside the
|
||||
lock so the lock is never held across blocking I/O.
|
||||
"""
|
||||
global _scheduler
|
||||
if _scheduler is not None:
|
||||
return _scheduler
|
||||
# Build outside the lock — TaskScheduler.__init__ may call detect_available_vram_gb()
|
||||
# which makes an httpx network call (up to 2 s). Holding the lock during that
|
||||
# would block any concurrent caller for the full duration.
|
||||
if run_task_fn is None or task_types is None or vram_budgets is None:
|
||||
raise ValueError(
|
||||
"run_task_fn, task_types, and vram_budgets are required "
|
||||
"on the first call to get_scheduler()"
|
||||
)
|
||||
candidate = TaskScheduler(
|
||||
db_path=db_path,
|
||||
run_task_fn=run_task_fn,
|
||||
task_types=task_types,
|
||||
vram_budgets=vram_budgets,
|
||||
max_queue_depth=max_queue_depth,
|
||||
)
|
||||
candidate.start()
|
||||
with _scheduler_lock:
|
||||
if _scheduler is None:
|
||||
_scheduler = candidate
|
||||
else:
|
||||
# Another thread beat us — shut down our candidate and use the winner.
|
||||
candidate.shutdown()
|
||||
return _scheduler
|
||||
|
||||
|
||||
def reset_scheduler() -> None:
|
||||
"""Shut down and clear the singleton. TEST TEARDOWN ONLY."""
|
||||
global _scheduler
|
||||
with _scheduler_lock:
|
||||
if _scheduler is not None:
|
||||
_scheduler.shutdown()
|
||||
_scheduler = None
|
||||
|
|
@ -30,26 +30,35 @@ def can_use(
|
|||
has_byok: bool = False,
|
||||
has_local_vision: bool = False,
|
||||
_features: dict[str, str] | None = None,
|
||||
_byok_unlockable: frozenset[str] | None = None,
|
||||
_local_vision_unlockable: frozenset[str] | None = None,
|
||||
) -> bool:
|
||||
"""
|
||||
Return True if the given tier (and optional unlocks) can access feature.
|
||||
|
||||
Args:
|
||||
feature: Feature key string.
|
||||
tier: User's current tier ("free", "paid", "premium", "ultra").
|
||||
has_byok: True if user has a configured LLM backend.
|
||||
has_local_vision: True if user has a local vision model configured.
|
||||
_features: Feature→min_tier map. Products pass their own dict here.
|
||||
If None, all features are free.
|
||||
feature: Feature key string.
|
||||
tier: User's current tier ("free", "paid", "premium", "ultra").
|
||||
has_byok: True if user has a configured LLM backend.
|
||||
has_local_vision: True if user has a local vision model configured.
|
||||
_features: Feature→min_tier map. Products pass their own dict here.
|
||||
If None, all features are free.
|
||||
_byok_unlockable: Product-specific BYOK-unlockable features.
|
||||
If None, uses module-level BYOK_UNLOCKABLE.
|
||||
_local_vision_unlockable: Product-specific local vision unlockable features.
|
||||
If None, uses module-level LOCAL_VISION_UNLOCKABLE.
|
||||
"""
|
||||
features = _features or {}
|
||||
byok_unlockable = _byok_unlockable if _byok_unlockable is not None else BYOK_UNLOCKABLE
|
||||
local_vision_unlockable = _local_vision_unlockable if _local_vision_unlockable is not None else LOCAL_VISION_UNLOCKABLE
|
||||
|
||||
if feature not in features:
|
||||
return True
|
||||
|
||||
if has_byok and feature in BYOK_UNLOCKABLE:
|
||||
if has_byok and feature in byok_unlockable:
|
||||
return True
|
||||
|
||||
if has_local_vision and feature in LOCAL_VISION_UNLOCKABLE:
|
||||
if has_local_vision and feature in local_vision_unlockable:
|
||||
return True
|
||||
|
||||
min_tier = features[feature]
|
||||
|
|
@ -64,13 +73,18 @@ def tier_label(
|
|||
has_byok: bool = False,
|
||||
has_local_vision: bool = False,
|
||||
_features: dict[str, str] | None = None,
|
||||
_byok_unlockable: frozenset[str] | None = None,
|
||||
_local_vision_unlockable: frozenset[str] | None = None,
|
||||
) -> str:
|
||||
"""Return a human-readable label for the minimum tier needed for feature."""
|
||||
features = _features or {}
|
||||
byok_unlockable = _byok_unlockable if _byok_unlockable is not None else BYOK_UNLOCKABLE
|
||||
local_vision_unlockable = _local_vision_unlockable if _local_vision_unlockable is not None else LOCAL_VISION_UNLOCKABLE
|
||||
|
||||
if feature not in features:
|
||||
return "free"
|
||||
if has_byok and feature in BYOK_UNLOCKABLE:
|
||||
if has_byok and feature in byok_unlockable:
|
||||
return "free (BYOK)"
|
||||
if has_local_vision and feature in LOCAL_VISION_UNLOCKABLE:
|
||||
if has_local_vision and feature in local_vision_unlockable:
|
||||
return "free (local vision)"
|
||||
return features[feature]
|
||||
|
|
|
|||
|
|
@ -22,8 +22,12 @@ orch = [
|
|||
"typer[all]>=0.12",
|
||||
"psutil>=5.9",
|
||||
]
|
||||
tasks = [
|
||||
"httpx>=0.27",
|
||||
]
|
||||
dev = [
|
||||
"circuitforge-core[orch]",
|
||||
"circuitforge-core[tasks]",
|
||||
"pytest>=8.0",
|
||||
"pytest-asyncio>=0.23",
|
||||
"httpx>=0.27",
|
||||
|
|
|
|||
0
tests/test_tasks/__init__.py
Normal file
0
tests/test_tasks/__init__.py
Normal file
285
tests/test_tasks/test_scheduler.py
Normal file
285
tests/test_tasks/test_scheduler.py
Normal file
|
|
@ -0,0 +1,285 @@
|
|||
"""Tests for circuitforge_core.tasks.scheduler."""
|
||||
from __future__ import annotations
|
||||
|
||||
import sqlite3
|
||||
import threading
|
||||
import time
|
||||
from pathlib import Path
|
||||
from types import ModuleType
|
||||
from typing import List
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from circuitforge_core.tasks.scheduler import (
|
||||
TaskScheduler,
|
||||
detect_available_vram_gb,
|
||||
get_scheduler,
|
||||
reset_scheduler,
|
||||
)
|
||||
|
||||
|
||||
# ── Fixtures ──────────────────────────────────────────────────────────────────
|
||||
|
||||
@pytest.fixture
|
||||
def tmp_db(tmp_path: Path) -> Path:
|
||||
"""SQLite DB with background_tasks table."""
|
||||
db = tmp_path / "test.db"
|
||||
conn = sqlite3.connect(db)
|
||||
conn.execute("""
|
||||
CREATE TABLE background_tasks (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
task_type TEXT NOT NULL,
|
||||
job_id INTEGER NOT NULL DEFAULT 0,
|
||||
status TEXT NOT NULL DEFAULT 'queued',
|
||||
params TEXT,
|
||||
error TEXT,
|
||||
created_at TEXT NOT NULL DEFAULT CURRENT_TIMESTAMP
|
||||
)
|
||||
""")
|
||||
conn.commit()
|
||||
conn.close()
|
||||
return db
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _reset_singleton():
|
||||
"""Always tear down the scheduler singleton between tests."""
|
||||
yield
|
||||
reset_scheduler()
|
||||
|
||||
|
||||
TASK_TYPES = frozenset({"fast_task"})
|
||||
BUDGETS = {"fast_task": 1.0}
|
||||
|
||||
|
||||
# ── detect_available_vram_gb ──────────────────────────────────────────────────
|
||||
|
||||
def test_detect_vram_from_cfortch():
|
||||
"""Uses cf-orch free VRAM when coordinator is reachable."""
|
||||
mock_resp = MagicMock()
|
||||
mock_resp.status_code = 200
|
||||
mock_resp.json.return_value = {
|
||||
"nodes": [
|
||||
{"node_id": "local", "gpus": [{"vram_free_mb": 4096}, {"vram_free_mb": 4096}]}
|
||||
]
|
||||
}
|
||||
with patch("circuitforge_core.tasks.scheduler.httpx") as mock_httpx:
|
||||
mock_httpx.get.return_value = mock_resp
|
||||
result = detect_available_vram_gb(coordinator_url="http://localhost:7700")
|
||||
assert result == pytest.approx(8.0) # 4096 + 4096 MB → 8 GB
|
||||
|
||||
|
||||
def test_detect_vram_cforch_unavailable_falls_back_to_unlimited():
|
||||
"""Falls back to 999.0 when cf-orch is unreachable and preflight unavailable."""
|
||||
with patch("circuitforge_core.tasks.scheduler.httpx") as mock_httpx:
|
||||
mock_httpx.get.side_effect = ConnectionRefusedError()
|
||||
result = detect_available_vram_gb()
|
||||
assert result == 999.0
|
||||
|
||||
|
||||
def test_detect_vram_cforch_empty_nodes_falls_back():
|
||||
"""If cf-orch returns no nodes with GPUs, falls back to unlimited."""
|
||||
mock_resp = MagicMock()
|
||||
mock_resp.status_code = 200
|
||||
mock_resp.json.return_value = {"nodes": []}
|
||||
with patch("circuitforge_core.tasks.scheduler.httpx") as mock_httpx:
|
||||
mock_httpx.get.return_value = mock_resp
|
||||
result = detect_available_vram_gb()
|
||||
assert result == 999.0
|
||||
|
||||
|
||||
def test_detect_vram_preflight_fallback():
|
||||
"""Falls back to preflight total VRAM when cf-orch is unreachable."""
|
||||
# Build a fake scripts.preflight module with get_gpus returning two GPUs.
|
||||
fake_scripts = ModuleType("scripts")
|
||||
fake_preflight = ModuleType("scripts.preflight")
|
||||
fake_preflight.get_gpus = lambda: [ # type: ignore[attr-defined]
|
||||
{"vram_total_gb": 8.0},
|
||||
{"vram_total_gb": 4.0},
|
||||
]
|
||||
fake_scripts.preflight = fake_preflight # type: ignore[attr-defined]
|
||||
|
||||
with patch("circuitforge_core.tasks.scheduler.httpx") as mock_httpx, \
|
||||
patch.dict(
|
||||
__import__("sys").modules,
|
||||
{"scripts": fake_scripts, "scripts.preflight": fake_preflight},
|
||||
):
|
||||
mock_httpx.get.side_effect = ConnectionRefusedError()
|
||||
result = detect_available_vram_gb()
|
||||
|
||||
assert result == pytest.approx(12.0) # 8.0 + 4.0 GB
|
||||
|
||||
|
||||
# ── TaskScheduler basic behaviour ─────────────────────────────────────────────
|
||||
|
||||
def test_enqueue_returns_true_on_success(tmp_db: Path):
|
||||
ran: List[int] = []
|
||||
|
||||
def run_fn(db_path, task_id, task_type, job_id, params):
|
||||
ran.append(task_id)
|
||||
|
||||
sched = TaskScheduler(tmp_db, run_fn, TASK_TYPES, BUDGETS, available_vram_gb=8.0)
|
||||
sched.start()
|
||||
result = sched.enqueue(1, "fast_task", 0, None)
|
||||
sched.shutdown()
|
||||
assert result is True
|
||||
|
||||
|
||||
def test_scheduler_runs_task(tmp_db: Path):
|
||||
"""Enqueued task is executed by the batch worker."""
|
||||
ran: List[int] = []
|
||||
event = threading.Event()
|
||||
|
||||
def run_fn(db_path, task_id, task_type, job_id, params):
|
||||
ran.append(task_id)
|
||||
event.set()
|
||||
|
||||
sched = TaskScheduler(tmp_db, run_fn, TASK_TYPES, BUDGETS, available_vram_gb=8.0)
|
||||
sched.start()
|
||||
sched.enqueue(42, "fast_task", 0, None)
|
||||
assert event.wait(timeout=3.0), "Task was not executed within 3 seconds"
|
||||
sched.shutdown()
|
||||
assert ran == [42]
|
||||
|
||||
|
||||
def test_enqueue_returns_false_when_queue_full(tmp_db: Path):
|
||||
"""Returns False and does not enqueue when max_queue_depth is reached."""
|
||||
gate = threading.Event()
|
||||
|
||||
def blocking_run_fn(db_path, task_id, task_type, job_id, params):
|
||||
gate.wait()
|
||||
|
||||
sched = TaskScheduler(
|
||||
tmp_db, blocking_run_fn, TASK_TYPES, BUDGETS,
|
||||
available_vram_gb=8.0, max_queue_depth=2
|
||||
)
|
||||
sched.start()
|
||||
results = [sched.enqueue(i, "fast_task", 0, None) for i in range(1, 10)]
|
||||
gate.set()
|
||||
sched.shutdown()
|
||||
assert not all(results), "Expected at least one enqueue to be rejected"
|
||||
|
||||
|
||||
def test_scheduler_drains_multiple_tasks(tmp_db: Path):
|
||||
"""All enqueued tasks of the same type are run serially."""
|
||||
ran: List[int] = []
|
||||
done = threading.Event()
|
||||
TOTAL = 5
|
||||
|
||||
def run_fn(db_path, task_id, task_type, job_id, params):
|
||||
ran.append(task_id)
|
||||
if len(ran) >= TOTAL:
|
||||
done.set()
|
||||
|
||||
sched = TaskScheduler(tmp_db, run_fn, TASK_TYPES, BUDGETS, available_vram_gb=8.0)
|
||||
sched.start()
|
||||
for i in range(1, TOTAL + 1):
|
||||
sched.enqueue(i, "fast_task", 0, None)
|
||||
assert done.wait(timeout=5.0), f"Only ran {len(ran)} of {TOTAL} tasks"
|
||||
sched.shutdown()
|
||||
assert sorted(ran) == list(range(1, TOTAL + 1))
|
||||
|
||||
|
||||
def test_vram_budget_blocks_second_type(tmp_db: Path):
|
||||
"""Second task type is not started when VRAM would be exceeded."""
|
||||
type_a_started = threading.Event()
|
||||
type_b_started = threading.Event()
|
||||
gate_a = threading.Event()
|
||||
gate_b = threading.Event()
|
||||
started = []
|
||||
|
||||
def run_fn(db_path, task_id, task_type, job_id, params):
|
||||
started.append(task_type)
|
||||
if task_type == "type_a":
|
||||
type_a_started.set()
|
||||
gate_a.wait()
|
||||
else:
|
||||
type_b_started.set()
|
||||
gate_b.wait()
|
||||
|
||||
two_types = frozenset({"type_a", "type_b"})
|
||||
tight_budgets = {"type_a": 4.0, "type_b": 4.0} # 4+4 > 6 GB available
|
||||
|
||||
sched = TaskScheduler(
|
||||
tmp_db, run_fn, two_types, tight_budgets, available_vram_gb=6.0
|
||||
)
|
||||
sched.start()
|
||||
sched.enqueue(1, "type_a", 0, None)
|
||||
sched.enqueue(2, "type_b", 0, None)
|
||||
|
||||
assert type_a_started.wait(timeout=3.0), "type_a never started"
|
||||
assert not type_b_started.is_set(), "type_b should be blocked by VRAM"
|
||||
|
||||
gate_a.set()
|
||||
assert type_b_started.wait(timeout=3.0), "type_b never started after type_a finished"
|
||||
gate_b.set()
|
||||
sched.shutdown()
|
||||
assert sorted(started) == ["type_a", "type_b"]
|
||||
|
||||
|
||||
def test_get_scheduler_singleton(tmp_db: Path):
|
||||
"""get_scheduler() returns the same instance on repeated calls."""
|
||||
run_fn = MagicMock()
|
||||
s1 = get_scheduler(tmp_db, run_fn, TASK_TYPES, BUDGETS)
|
||||
s2 = get_scheduler(tmp_db) # no run_fn — should reuse existing
|
||||
assert s1 is s2
|
||||
|
||||
|
||||
def test_reset_scheduler_clears_singleton(tmp_db: Path):
|
||||
"""reset_scheduler() allows a new singleton to be constructed."""
|
||||
run_fn = MagicMock()
|
||||
s1 = get_scheduler(tmp_db, run_fn, TASK_TYPES, BUDGETS)
|
||||
reset_scheduler()
|
||||
s2 = get_scheduler(tmp_db, run_fn, TASK_TYPES, BUDGETS)
|
||||
assert s1 is not s2
|
||||
|
||||
|
||||
def test_load_queued_tasks_on_startup(tmp_db: Path):
|
||||
"""Tasks with status='queued' in the DB at startup are loaded into the deque."""
|
||||
conn = sqlite3.connect(tmp_db)
|
||||
conn.execute(
|
||||
"INSERT INTO background_tasks (task_type, job_id, status) VALUES ('fast_task', 0, 'queued')"
|
||||
)
|
||||
conn.commit()
|
||||
conn.close()
|
||||
|
||||
ran: List[int] = []
|
||||
done = threading.Event()
|
||||
|
||||
def run_fn(db_path, task_id, task_type, job_id, params):
|
||||
ran.append(task_id)
|
||||
done.set()
|
||||
|
||||
sched = TaskScheduler(tmp_db, run_fn, TASK_TYPES, BUDGETS, available_vram_gb=8.0)
|
||||
sched.start()
|
||||
assert done.wait(timeout=3.0), "Pre-loaded task was not run"
|
||||
sched.shutdown()
|
||||
assert len(ran) == 1
|
||||
|
||||
|
||||
def test_load_queued_tasks_missing_table_does_not_crash(tmp_path: Path):
|
||||
"""Scheduler does not crash if background_tasks table doesn't exist yet."""
|
||||
db = tmp_path / "empty.db"
|
||||
sqlite3.connect(db).close()
|
||||
|
||||
run_fn = MagicMock()
|
||||
sched = TaskScheduler(db, run_fn, TASK_TYPES, BUDGETS, available_vram_gb=8.0)
|
||||
sched.start()
|
||||
sched.shutdown()
|
||||
# No exception = pass
|
||||
|
||||
|
||||
def test_reserved_vram_zero_after_task_completes(tmp_db: Path):
|
||||
"""_reserved_vram returns to 0.0 after a task finishes — no double-decrement."""
|
||||
done = threading.Event()
|
||||
|
||||
def run_fn(db_path, task_id, task_type, job_id, params):
|
||||
done.set()
|
||||
|
||||
sched = TaskScheduler(tmp_db, run_fn, TASK_TYPES, BUDGETS, available_vram_gb=8.0)
|
||||
sched.start()
|
||||
sched.enqueue(1, "fast_task", 0, None)
|
||||
assert done.wait(timeout=3.0), "Task never completed"
|
||||
sched.shutdown()
|
||||
assert sched._reserved_vram == 0.0, f"Expected 0.0, got {sched._reserved_vram}"
|
||||
Loading…
Reference in a new issue