From 5801928f8e5ab85cae2d16326446416e86bb416b Mon Sep 17 00:00:00 2001 From: pyr0ball Date: Mon, 30 Mar 2026 23:12:23 -0700 Subject: [PATCH 1/4] feat(tasks): add shared VRAM-aware LLM task scheduler Extract generic batch scheduler into circuitforge_core.tasks.scheduler so any CircuitForge product can use it. Includes VRAM detection via cf-orch coordinator (cooperative free-VRAM), preflight fallback, and unlimited fallback; singleton API; full test coverage (12 tests). --- circuitforge_core/tasks/__init__.py | 14 ++ circuitforge_core/tasks/scheduler.py | 331 +++++++++++++++++++++++++++ pyproject.toml | 4 + tests/test_tasks/__init__.py | 0 tests/test_tasks/test_scheduler.py | 243 ++++++++++++++++++++ 5 files changed, 592 insertions(+) create mode 100644 circuitforge_core/tasks/__init__.py create mode 100644 circuitforge_core/tasks/scheduler.py create mode 100644 tests/test_tasks/__init__.py create mode 100644 tests/test_tasks/test_scheduler.py diff --git a/circuitforge_core/tasks/__init__.py b/circuitforge_core/tasks/__init__.py new file mode 100644 index 0000000..dede9f5 --- /dev/null +++ b/circuitforge_core/tasks/__init__.py @@ -0,0 +1,14 @@ +# circuitforge_core/tasks/__init__.py +from circuitforge_core.tasks.scheduler import ( + TaskScheduler, + detect_available_vram_gb, + get_scheduler, + reset_scheduler, +) + +__all__ = [ + "TaskScheduler", + "detect_available_vram_gb", + "get_scheduler", + "reset_scheduler", +] diff --git a/circuitforge_core/tasks/scheduler.py b/circuitforge_core/tasks/scheduler.py new file mode 100644 index 0000000..7ccd189 --- /dev/null +++ b/circuitforge_core/tasks/scheduler.py @@ -0,0 +1,331 @@ +# circuitforge_core/tasks/scheduler.py +"""Resource-aware batch scheduler for LLM background tasks. + +Generic scheduler that any CircuitForge product can use. Products supply: + - task_types: frozenset[str] — task type strings routed through this scheduler + - vram_budgets: dict[str, float] — VRAM GB estimate per task type + - run_task_fn — product's task execution function + +VRAM detection priority: + 1. cf-orch coordinator /api/nodes — free VRAM (lease-aware, cooperative) + 2. scripts.preflight.get_gpus() — total GPU VRAM (Peregrine-era fallback) + 3. 999.0 — unlimited (CPU-only or no detection available) + +Public API: + TaskScheduler — the scheduler class + detect_available_vram_gb() — standalone VRAM query helper + get_scheduler() — lazy process-level singleton + reset_scheduler() — test teardown only +""" +from __future__ import annotations + +import logging +import sqlite3 +import threading +from collections import deque, namedtuple +from pathlib import Path +from typing import Callable, Optional + +try: + import httpx as httpx +except ImportError: + httpx = None # type: ignore[assignment] + +logger = logging.getLogger(__name__) + +TaskSpec = namedtuple("TaskSpec", ["id", "job_id", "params"]) + +_DEFAULT_MAX_QUEUE_DEPTH = 500 + + +def detect_available_vram_gb( + coordinator_url: str = "http://localhost:7700", +) -> float: + """Detect available VRAM GB for task scheduling. + + Returns free VRAM via cf-orch (sum across all nodes/GPUs) so the scheduler + cooperates with other cf-orch consumers. Falls back to preflight total VRAM, + then 999.0 (unlimited) if nothing is reachable. + """ + # 1. Try cf-orch: use free VRAM so the scheduler cooperates with other + # cf-orch consumers (vision service, inference services, etc.) + if httpx is not None: + try: + resp = httpx.get(f"{coordinator_url}/api/nodes", timeout=2.0) + if resp.status_code == 200: + nodes = resp.json().get("nodes", []) + total_free_mb = sum( + gpu.get("vram_free_mb", 0) + for node in nodes + for gpu in node.get("gpus", []) + ) + if total_free_mb > 0: + free_gb = total_free_mb / 1024.0 + logger.debug( + "Scheduler VRAM from cf-orch: %.1f GB free", free_gb + ) + return free_gb + except Exception: + pass + + # 2. Try preflight (systems with nvidia-smi; Peregrine-era fallback) + try: + from scripts.preflight import get_gpus # type: ignore[import] + + gpus = get_gpus() + if gpus: + total_gb = sum(g.get("vram_total_gb", 0.0) for g in gpus) + logger.debug( + "Scheduler VRAM from preflight: %.1f GB total", total_gb + ) + return total_gb + except Exception: + pass + + logger.debug( + "Scheduler VRAM detection unavailable — using unlimited (999 GB)" + ) + return 999.0 + + +class TaskScheduler: + """Resource-aware LLM task batch scheduler. + + Runs one batch-worker thread per task type while total reserved VRAM + stays within the detected available budget. Always allows at least one + batch to start even if its budget exceeds available VRAM (prevents + permanent starvation on low-VRAM systems). + + Thread-safety: all queue/active state protected by self._lock. + + Usage:: + + sched = get_scheduler( + db_path=Path("data/app.db"), + run_task_fn=my_run_task, + task_types=frozenset({"cover_letter", "research"}), + vram_budgets={"cover_letter": 2.5, "research": 5.0}, + ) + task_id, is_new = insert_task(db_path, "cover_letter", job_id) + if is_new: + enqueued = sched.enqueue(task_id, "cover_letter", job_id, params_json) + if not enqueued: + mark_task_failed(db_path, task_id, "Queue full") + """ + + def __init__( + self, + db_path: Path, + run_task_fn: Callable, + task_types: frozenset[str], + vram_budgets: dict[str, float], + available_vram_gb: Optional[float] = None, + max_queue_depth: int = _DEFAULT_MAX_QUEUE_DEPTH, + ) -> None: + self._db_path = db_path + self._run_task = run_task_fn + self._task_types = frozenset(task_types) + self._budgets: dict[str, float] = dict(vram_budgets) + self._max_queue_depth = max_queue_depth + + self._lock = threading.Lock() + self._wake = threading.Event() + self._stop = threading.Event() + self._queues: dict[str, deque] = {} + self._active: dict[str, threading.Thread] = {} + self._reserved_vram: float = 0.0 + self._thread: Optional[threading.Thread] = None + + self._available_vram: float = ( + available_vram_gb + if available_vram_gb is not None + else detect_available_vram_gb() + ) + + for t in self._task_types: + if t not in self._budgets: + logger.warning( + "No VRAM budget defined for task type %r — " + "defaulting to 0.0 GB (no VRAM gating for this type)", + t, + ) + + self._load_queued_tasks() + + def enqueue( + self, + task_id: int, + task_type: str, + job_id: int, + params: Optional[str], + ) -> bool: + """Add a task to the scheduler queue. + + Returns True if enqueued successfully. + Returns False if the queue is full — caller should mark the task failed. + """ + with self._lock: + q = self._queues.setdefault(task_type, deque()) + if len(q) >= self._max_queue_depth: + logger.warning( + "Queue depth limit for %s (max=%d) — task %d dropped", + task_type, + self._max_queue_depth, + task_id, + ) + return False + q.append(TaskSpec(task_id, job_id, params)) + self._wake.set() + return True + + def start(self) -> None: + """Start the background scheduler loop thread. Call once after construction.""" + self._thread = threading.Thread( + target=self._scheduler_loop, name="task-scheduler", daemon=True + ) + self._thread.start() + # Wake the loop immediately so tasks loaded from DB at startup are dispatched + if any(self._queues.values()): + self._wake.set() + + def shutdown(self, timeout: float = 5.0) -> None: + """Signal the scheduler to stop and wait for it to exit.""" + self._stop.set() + self._wake.set() + if self._thread and self._thread.is_alive(): + self._thread.join(timeout=timeout) + + def _scheduler_loop(self) -> None: + while not self._stop.is_set(): + self._wake.wait(timeout=30) + self._wake.clear() + with self._lock: + # Reap batch threads that finished without waking us + for t, thread in list(self._active.items()): + if not thread.is_alive(): + self._reserved_vram -= self._budgets.get(t, 0.0) + del self._active[t] + # Start new type batches while VRAM budget allows + candidates = sorted( + [ + t + for t in self._queues + if self._queues[t] and t not in self._active + ], + key=lambda t: len(self._queues[t]), + reverse=True, + ) + for task_type in candidates: + budget = self._budgets.get(task_type, 0.0) + # Always allow at least one batch to run + if ( + self._reserved_vram == 0.0 + or self._reserved_vram + budget <= self._available_vram + ): + thread = threading.Thread( + target=self._batch_worker, + args=(task_type,), + name=f"batch-{task_type}", + daemon=True, + ) + self._active[task_type] = thread + self._reserved_vram += budget + thread.start() + + def _batch_worker(self, task_type: str) -> None: + """Serial consumer for one task type. Runs until the type's deque is empty.""" + try: + while True: + with self._lock: + q = self._queues.get(task_type) + if not q: + break + task = q.popleft() + self._run_task( + self._db_path, task.id, task_type, task.job_id, task.params + ) + finally: + with self._lock: + self._active.pop(task_type, None) + self._reserved_vram -= self._budgets.get(task_type, 0.0) + self._wake.set() + + def _load_queued_tasks(self) -> None: + """Reload surviving 'queued' tasks from SQLite into deques at startup.""" + if not self._task_types: + return + task_types_list = sorted(self._task_types) + placeholders = ",".join("?" * len(task_types_list)) + conn = sqlite3.connect(self._db_path) + try: + rows = conn.execute( + f"SELECT id, task_type, job_id, params FROM background_tasks" + f" WHERE status='queued' AND task_type IN ({placeholders})" + f" ORDER BY created_at ASC", + task_types_list, + ).fetchall() + except sqlite3.OperationalError: + # Table not yet created (first run before migrations) + rows = [] + finally: + conn.close() + + for row_id, task_type, job_id, params in rows: + q = self._queues.setdefault(task_type, deque()) + q.append(TaskSpec(row_id, job_id, params)) + + if rows: + logger.info( + "Scheduler: resumed %d queued task(s) from prior run", len(rows) + ) + + +# ── Process-level singleton ──────────────────────────────────────────────────── + +_scheduler: Optional[TaskScheduler] = None +_scheduler_lock = threading.Lock() + + +def get_scheduler( + db_path: Path, + run_task_fn: Optional[Callable] = None, + task_types: Optional[frozenset[str]] = None, + vram_budgets: Optional[dict[str, float]] = None, + max_queue_depth: int = _DEFAULT_MAX_QUEUE_DEPTH, +) -> TaskScheduler: + """Return the process-level TaskScheduler singleton. + + ``run_task_fn``, ``task_types``, and ``vram_budgets`` are required on the + first call; ignored on subsequent calls (singleton already constructed). + """ + global _scheduler + if _scheduler is None: + with _scheduler_lock: + if _scheduler is None: + if ( + run_task_fn is None + or task_types is None + or vram_budgets is None + ): + raise ValueError( + "run_task_fn, task_types, and vram_budgets are required " + "on the first call to get_scheduler()" + ) + _scheduler = TaskScheduler( + db_path=db_path, + run_task_fn=run_task_fn, + task_types=task_types, + vram_budgets=vram_budgets, + max_queue_depth=max_queue_depth, + ) + _scheduler.start() + return _scheduler + + +def reset_scheduler() -> None: + """Shut down and clear the singleton. TEST TEARDOWN ONLY.""" + global _scheduler + with _scheduler_lock: + if _scheduler is not None: + _scheduler.shutdown() + _scheduler = None diff --git a/pyproject.toml b/pyproject.toml index bbc3026..ab8aa5d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -22,8 +22,12 @@ orch = [ "typer[all]>=0.12", "psutil>=5.9", ] +tasks = [ + "httpx>=0.27", +] dev = [ "circuitforge-core[orch]", + "circuitforge-core[tasks]", "pytest>=8.0", "pytest-asyncio>=0.23", "httpx>=0.27", diff --git a/tests/test_tasks/__init__.py b/tests/test_tasks/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/test_tasks/test_scheduler.py b/tests/test_tasks/test_scheduler.py new file mode 100644 index 0000000..b273545 --- /dev/null +++ b/tests/test_tasks/test_scheduler.py @@ -0,0 +1,243 @@ +"""Tests for circuitforge_core.tasks.scheduler.""" +from __future__ import annotations + +import sqlite3 +import threading +import time +from pathlib import Path +from typing import List +from unittest.mock import MagicMock, patch + +import pytest + +from circuitforge_core.tasks.scheduler import ( + TaskScheduler, + detect_available_vram_gb, + get_scheduler, + reset_scheduler, +) + + +# ── Fixtures ────────────────────────────────────────────────────────────────── + +@pytest.fixture +def tmp_db(tmp_path: Path) -> Path: + """SQLite DB with background_tasks table.""" + db = tmp_path / "test.db" + conn = sqlite3.connect(db) + conn.execute(""" + CREATE TABLE background_tasks ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + task_type TEXT NOT NULL, + job_id INTEGER NOT NULL DEFAULT 0, + status TEXT NOT NULL DEFAULT 'queued', + params TEXT, + error TEXT, + created_at TEXT NOT NULL DEFAULT CURRENT_TIMESTAMP + ) + """) + conn.commit() + conn.close() + return db + + +@pytest.fixture(autouse=True) +def _reset_singleton(): + """Always tear down the scheduler singleton between tests.""" + yield + reset_scheduler() + + +TASK_TYPES = frozenset({"fast_task"}) +BUDGETS = {"fast_task": 1.0} + + +# ── detect_available_vram_gb ────────────────────────────────────────────────── + +def test_detect_vram_from_cfortch(): + """Uses cf-orch free VRAM when coordinator is reachable.""" + mock_resp = MagicMock() + mock_resp.status_code = 200 + mock_resp.json.return_value = { + "nodes": [ + {"node_id": "local", "gpus": [{"vram_free_mb": 4096}, {"vram_free_mb": 4096}]} + ] + } + with patch("circuitforge_core.tasks.scheduler.httpx") as mock_httpx: + mock_httpx.get.return_value = mock_resp + result = detect_available_vram_gb(coordinator_url="http://localhost:7700") + assert result == pytest.approx(8.0) # 4096 + 4096 MB → 8 GB + + +def test_detect_vram_cforch_unavailable_falls_back_to_unlimited(): + """Falls back to 999.0 when cf-orch is unreachable and preflight unavailable.""" + with patch("circuitforge_core.tasks.scheduler.httpx") as mock_httpx: + mock_httpx.get.side_effect = ConnectionRefusedError() + result = detect_available_vram_gb() + assert result == 999.0 + + +def test_detect_vram_cforch_empty_nodes_falls_back(): + """If cf-orch returns no nodes with GPUs, falls back to unlimited.""" + mock_resp = MagicMock() + mock_resp.status_code = 200 + mock_resp.json.return_value = {"nodes": []} + with patch("circuitforge_core.tasks.scheduler.httpx") as mock_httpx: + mock_httpx.get.return_value = mock_resp + result = detect_available_vram_gb() + assert result == 999.0 + + +# ── TaskScheduler basic behaviour ───────────────────────────────────────────── + +def test_enqueue_returns_true_on_success(tmp_db: Path): + ran: List[int] = [] + + def run_fn(db_path, task_id, task_type, job_id, params): + ran.append(task_id) + + sched = TaskScheduler(tmp_db, run_fn, TASK_TYPES, BUDGETS, available_vram_gb=8.0) + sched.start() + result = sched.enqueue(1, "fast_task", 0, None) + sched.shutdown() + assert result is True + + +def test_scheduler_runs_task(tmp_db: Path): + """Enqueued task is executed by the batch worker.""" + ran: List[int] = [] + event = threading.Event() + + def run_fn(db_path, task_id, task_type, job_id, params): + ran.append(task_id) + event.set() + + sched = TaskScheduler(tmp_db, run_fn, TASK_TYPES, BUDGETS, available_vram_gb=8.0) + sched.start() + sched.enqueue(42, "fast_task", 0, None) + assert event.wait(timeout=3.0), "Task was not executed within 3 seconds" + sched.shutdown() + assert ran == [42] + + +def test_enqueue_returns_false_when_queue_full(tmp_db: Path): + """Returns False and does not enqueue when max_queue_depth is reached.""" + gate = threading.Event() + + def blocking_run_fn(db_path, task_id, task_type, job_id, params): + gate.wait() + + sched = TaskScheduler( + tmp_db, blocking_run_fn, TASK_TYPES, BUDGETS, + available_vram_gb=8.0, max_queue_depth=2 + ) + sched.start() + results = [sched.enqueue(i, "fast_task", 0, None) for i in range(1, 10)] + gate.set() + sched.shutdown() + assert False in results + + +def test_scheduler_drains_multiple_tasks(tmp_db: Path): + """All enqueued tasks of the same type are run serially.""" + ran: List[int] = [] + done = threading.Event() + TOTAL = 5 + + def run_fn(db_path, task_id, task_type, job_id, params): + ran.append(task_id) + if len(ran) >= TOTAL: + done.set() + + sched = TaskScheduler(tmp_db, run_fn, TASK_TYPES, BUDGETS, available_vram_gb=8.0) + sched.start() + for i in range(1, TOTAL + 1): + sched.enqueue(i, "fast_task", 0, None) + assert done.wait(timeout=5.0), f"Only ran {len(ran)} of {TOTAL} tasks" + sched.shutdown() + assert sorted(ran) == list(range(1, TOTAL + 1)) + + +def test_vram_budget_blocks_second_type(tmp_db: Path): + """Second task type is not started when VRAM would be exceeded.""" + gate_a = threading.Event() + gate_b = threading.Event() + started = [] + + def run_fn(db_path, task_id, task_type, job_id, params): + started.append(task_type) + if task_type == "type_a": + gate_a.wait() + else: + gate_b.wait() + + two_types = frozenset({"type_a", "type_b"}) + tight_budgets = {"type_a": 4.0, "type_b": 4.0} # 4+4 > 6 GB available + + sched = TaskScheduler( + tmp_db, run_fn, two_types, tight_budgets, available_vram_gb=6.0 + ) + sched.start() + sched.enqueue(1, "type_a", 0, None) + sched.enqueue(2, "type_b", 0, None) + + time.sleep(0.2) + assert started == ["type_a"] + + gate_a.set() + time.sleep(0.2) + gate_b.set() + sched.shutdown() + assert "type_b" in started + + +def test_get_scheduler_singleton(tmp_db: Path): + """get_scheduler() returns the same instance on repeated calls.""" + run_fn = MagicMock() + s1 = get_scheduler(tmp_db, run_fn, TASK_TYPES, BUDGETS) + s2 = get_scheduler(tmp_db) # no run_fn — should reuse existing + assert s1 is s2 + + +def test_reset_scheduler_clears_singleton(tmp_db: Path): + """reset_scheduler() allows a new singleton to be constructed.""" + run_fn = MagicMock() + s1 = get_scheduler(tmp_db, run_fn, TASK_TYPES, BUDGETS) + reset_scheduler() + s2 = get_scheduler(tmp_db, run_fn, TASK_TYPES, BUDGETS) + assert s1 is not s2 + + +def test_load_queued_tasks_on_startup(tmp_db: Path): + """Tasks with status='queued' in the DB at startup are loaded into the deque.""" + conn = sqlite3.connect(tmp_db) + conn.execute( + "INSERT INTO background_tasks (task_type, job_id, status) VALUES ('fast_task', 0, 'queued')" + ) + conn.commit() + conn.close() + + ran: List[int] = [] + done = threading.Event() + + def run_fn(db_path, task_id, task_type, job_id, params): + ran.append(task_id) + done.set() + + sched = TaskScheduler(tmp_db, run_fn, TASK_TYPES, BUDGETS, available_vram_gb=8.0) + sched.start() + assert done.wait(timeout=3.0), "Pre-loaded task was not run" + sched.shutdown() + assert len(ran) == 1 + + +def test_load_queued_tasks_missing_table_does_not_crash(tmp_path: Path): + """Scheduler does not crash if background_tasks table doesn't exist yet.""" + db = tmp_path / "empty.db" + sqlite3.connect(db).close() + + run_fn = MagicMock() + sched = TaskScheduler(db, run_fn, TASK_TYPES, BUDGETS, available_vram_gb=8.0) + sched.start() + sched.shutdown() + # No exception = pass -- 2.45.2 From 09a5087c72b1f0d809e5ce06207317fe501f7a9b Mon Sep 17 00:00:00 2001 From: pyr0ball Date: Mon, 30 Mar 2026 23:15:19 -0700 Subject: [PATCH 2/4] test(tasks): add preflight fallback coverage to scheduler tests Adds test_detect_vram_preflight_fallback to cover the spec path where cf-orch is unreachable but scripts.preflight.get_gpus() succeeds, verifying detect_available_vram_gb() returns the summed total VRAM. Uses sys.modules injection to simulate the preflight module being present. --- tests/test_tasks/test_scheduler.py | 23 +++++++++++++++++++++++ 1 file changed, 23 insertions(+) diff --git a/tests/test_tasks/test_scheduler.py b/tests/test_tasks/test_scheduler.py index b273545..99e9ea8 100644 --- a/tests/test_tasks/test_scheduler.py +++ b/tests/test_tasks/test_scheduler.py @@ -5,6 +5,7 @@ import sqlite3 import threading import time from pathlib import Path +from types import ModuleType from typing import List from unittest.mock import MagicMock, patch @@ -88,6 +89,28 @@ def test_detect_vram_cforch_empty_nodes_falls_back(): assert result == 999.0 +def test_detect_vram_preflight_fallback(): + """Falls back to preflight total VRAM when cf-orch is unreachable.""" + # Build a fake scripts.preflight module with get_gpus returning two GPUs. + fake_scripts = ModuleType("scripts") + fake_preflight = ModuleType("scripts.preflight") + fake_preflight.get_gpus = lambda: [ # type: ignore[attr-defined] + {"vram_total_gb": 8.0}, + {"vram_total_gb": 4.0}, + ] + fake_scripts.preflight = fake_preflight # type: ignore[attr-defined] + + with patch("circuitforge_core.tasks.scheduler.httpx") as mock_httpx, \ + patch.dict( + __import__("sys").modules, + {"scripts": fake_scripts, "scripts.preflight": fake_preflight}, + ): + mock_httpx.get.side_effect = ConnectionRefusedError() + result = detect_available_vram_gb() + + assert result == pytest.approx(12.0) # 8.0 + 4.0 GB + + # ── TaskScheduler basic behaviour ───────────────────────────────────────────── def test_enqueue_returns_true_on_success(tmp_db: Path): -- 2.45.2 From 22bad8590ac0d2a8b21134698a9d91a3f5701f50 Mon Sep 17 00:00:00 2001 From: pyr0ball Date: Tue, 31 Mar 2026 09:15:09 -0700 Subject: [PATCH 3/4] fix(tasks): fix VRAM accounting race, lock scope, type annotations - C1: Remove _reserved_vram decrement from _scheduler_loop reaper; sole responsibility now belongs to _batch_worker's finally block, eliminating the double-decrement race that could drive _reserved_vram negative. - C2: Move TaskScheduler construction (including VRAM detection httpx call) outside _scheduler_lock in get_scheduler(); lock is now only held for the final singleton assignment, preventing 2s lock contention on first call. - I1: Add RunTaskFn type alias (Callable[...]) and use it in __init__ and get_scheduler() instead of bare Callable. - I2: Replace namedtuple TaskSpec with typed NamedTuple class. - I3: Parameterize _queues annotation as dict[str, deque[TaskSpec]]. - I4: Wrap _queues read in start() with self._lock. - I5: Replace time.sleep() ordering assertion in test_vram_budget_blocks_second_type with event-based synchronization using type_a_started/type_b_started events. - M2: Use sqlite3.connect() as context manager in _load_queued_tasks. - M3: Strengthen weak assertion in test_enqueue_returns_false_when_queue_full. - M4: Add test_reserved_vram_zero_after_task_completes to catch C1 regression. --- circuitforge_core/tasks/scheduler.py | 91 ++++++++++++++++------------ tests/test_tasks/test_scheduler.py | 29 +++++++-- 2 files changed, 76 insertions(+), 44 deletions(-) diff --git a/circuitforge_core/tasks/scheduler.py b/circuitforge_core/tasks/scheduler.py index 7ccd189..a6c4453 100644 --- a/circuitforge_core/tasks/scheduler.py +++ b/circuitforge_core/tasks/scheduler.py @@ -22,9 +22,9 @@ from __future__ import annotations import logging import sqlite3 import threading -from collections import deque, namedtuple +from collections import deque from pathlib import Path -from typing import Callable, Optional +from typing import Callable, NamedTuple, Optional try: import httpx as httpx @@ -33,7 +33,13 @@ except ImportError: logger = logging.getLogger(__name__) -TaskSpec = namedtuple("TaskSpec", ["id", "job_id", "params"]) +RunTaskFn = Callable[["Path", int, str, int, Optional[str]], None] + + +class TaskSpec(NamedTuple): + id: int + job_id: int + params: Optional[str] _DEFAULT_MAX_QUEUE_DEPTH = 500 @@ -116,7 +122,7 @@ class TaskScheduler: def __init__( self, db_path: Path, - run_task_fn: Callable, + run_task_fn: RunTaskFn, task_types: frozenset[str], vram_budgets: dict[str, float], available_vram_gb: Optional[float] = None, @@ -131,7 +137,7 @@ class TaskScheduler: self._lock = threading.Lock() self._wake = threading.Event() self._stop = threading.Event() - self._queues: dict[str, deque] = {} + self._queues: dict[str, deque[TaskSpec]] = {} self._active: dict[str, threading.Thread] = {} self._reserved_vram: float = 0.0 self._thread: Optional[threading.Thread] = None @@ -185,8 +191,9 @@ class TaskScheduler: ) self._thread.start() # Wake the loop immediately so tasks loaded from DB at startup are dispatched - if any(self._queues.values()): - self._wake.set() + with self._lock: + if any(self._queues.values()): + self._wake.set() def shutdown(self, timeout: float = 5.0) -> None: """Signal the scheduler to stop and wait for it to exit.""" @@ -200,10 +207,11 @@ class TaskScheduler: self._wake.wait(timeout=30) self._wake.clear() with self._lock: - # Reap batch threads that finished without waking us + # Reap batch threads that finished without waking us. + # VRAM accounting is handled solely by _batch_worker's finally block; + # the reaper only removes dead entries from _active. for t, thread in list(self._active.items()): if not thread.is_alive(): - self._reserved_vram -= self._budgets.get(t, 0.0) del self._active[t] # Start new type batches while VRAM budget allows candidates = sorted( @@ -256,19 +264,17 @@ class TaskScheduler: return task_types_list = sorted(self._task_types) placeholders = ",".join("?" * len(task_types_list)) - conn = sqlite3.connect(self._db_path) try: - rows = conn.execute( - f"SELECT id, task_type, job_id, params FROM background_tasks" - f" WHERE status='queued' AND task_type IN ({placeholders})" - f" ORDER BY created_at ASC", - task_types_list, - ).fetchall() + with sqlite3.connect(self._db_path) as conn: + rows = conn.execute( + f"SELECT id, task_type, job_id, params FROM background_tasks" + f" WHERE status='queued' AND task_type IN ({placeholders})" + f" ORDER BY created_at ASC", + task_types_list, + ).fetchall() except sqlite3.OperationalError: # Table not yet created (first run before migrations) rows = [] - finally: - conn.close() for row_id, task_type, job_id, params in rows: q = self._queues.setdefault(task_type, deque()) @@ -288,7 +294,7 @@ _scheduler_lock = threading.Lock() def get_scheduler( db_path: Path, - run_task_fn: Optional[Callable] = None, + run_task_fn: Optional[RunTaskFn] = None, task_types: Optional[frozenset[str]] = None, vram_budgets: Optional[dict[str, float]] = None, max_queue_depth: int = _DEFAULT_MAX_QUEUE_DEPTH, @@ -297,28 +303,35 @@ def get_scheduler( ``run_task_fn``, ``task_types``, and ``vram_budgets`` are required on the first call; ignored on subsequent calls (singleton already constructed). + + VRAM detection (which may involve a network call) is performed outside the + lock so the lock is never held across blocking I/O. """ global _scheduler - if _scheduler is None: - with _scheduler_lock: - if _scheduler is None: - if ( - run_task_fn is None - or task_types is None - or vram_budgets is None - ): - raise ValueError( - "run_task_fn, task_types, and vram_budgets are required " - "on the first call to get_scheduler()" - ) - _scheduler = TaskScheduler( - db_path=db_path, - run_task_fn=run_task_fn, - task_types=task_types, - vram_budgets=vram_budgets, - max_queue_depth=max_queue_depth, - ) - _scheduler.start() + if _scheduler is not None: + return _scheduler + # Build outside the lock — TaskScheduler.__init__ may call detect_available_vram_gb() + # which makes an httpx network call (up to 2 s). Holding the lock during that + # would block any concurrent caller for the full duration. + if run_task_fn is None or task_types is None or vram_budgets is None: + raise ValueError( + "run_task_fn, task_types, and vram_budgets are required " + "on the first call to get_scheduler()" + ) + candidate = TaskScheduler( + db_path=db_path, + run_task_fn=run_task_fn, + task_types=task_types, + vram_budgets=vram_budgets, + max_queue_depth=max_queue_depth, + ) + candidate.start() + with _scheduler_lock: + if _scheduler is None: + _scheduler = candidate + else: + # Another thread beat us — shut down our candidate and use the winner. + candidate.shutdown() return _scheduler diff --git a/tests/test_tasks/test_scheduler.py b/tests/test_tasks/test_scheduler.py index 99e9ea8..4428e42 100644 --- a/tests/test_tasks/test_scheduler.py +++ b/tests/test_tasks/test_scheduler.py @@ -158,7 +158,7 @@ def test_enqueue_returns_false_when_queue_full(tmp_db: Path): results = [sched.enqueue(i, "fast_task", 0, None) for i in range(1, 10)] gate.set() sched.shutdown() - assert False in results + assert not all(results), "Expected at least one enqueue to be rejected" def test_scheduler_drains_multiple_tasks(tmp_db: Path): @@ -183,6 +183,8 @@ def test_scheduler_drains_multiple_tasks(tmp_db: Path): def test_vram_budget_blocks_second_type(tmp_db: Path): """Second task type is not started when VRAM would be exceeded.""" + type_a_started = threading.Event() + type_b_started = threading.Event() gate_a = threading.Event() gate_b = threading.Event() started = [] @@ -190,8 +192,10 @@ def test_vram_budget_blocks_second_type(tmp_db: Path): def run_fn(db_path, task_id, task_type, job_id, params): started.append(task_type) if task_type == "type_a": + type_a_started.set() gate_a.wait() else: + type_b_started.set() gate_b.wait() two_types = frozenset({"type_a", "type_b"}) @@ -204,14 +208,14 @@ def test_vram_budget_blocks_second_type(tmp_db: Path): sched.enqueue(1, "type_a", 0, None) sched.enqueue(2, "type_b", 0, None) - time.sleep(0.2) - assert started == ["type_a"] + assert type_a_started.wait(timeout=3.0), "type_a never started" + assert not type_b_started.is_set(), "type_b should be blocked by VRAM" gate_a.set() - time.sleep(0.2) + assert type_b_started.wait(timeout=3.0), "type_b never started after type_a finished" gate_b.set() sched.shutdown() - assert "type_b" in started + assert sorted(started) == ["type_a", "type_b"] def test_get_scheduler_singleton(tmp_db: Path): @@ -264,3 +268,18 @@ def test_load_queued_tasks_missing_table_does_not_crash(tmp_path: Path): sched.start() sched.shutdown() # No exception = pass + + +def test_reserved_vram_zero_after_task_completes(tmp_db: Path): + """_reserved_vram returns to 0.0 after a task finishes — no double-decrement.""" + done = threading.Event() + + def run_fn(db_path, task_id, task_type, job_id, params): + done.set() + + sched = TaskScheduler(tmp_db, run_fn, TASK_TYPES, BUDGETS, available_vram_gb=8.0) + sched.start() + sched.enqueue(1, "fast_task", 0, None) + assert done.wait(timeout=3.0), "Task never completed" + sched.shutdown() + assert sched._reserved_vram == 0.0, f"Expected 0.0, got {sched._reserved_vram}" -- 2.45.2 From c027fe6137a28b1a7d9226d943c853f302e11a58 Mon Sep 17 00:00:00 2001 From: pyr0ball Date: Tue, 31 Mar 2026 10:37:51 -0700 Subject: [PATCH 4/4] fix(core): SQLite timeout=30, INSERT OR IGNORE migrations, parameterize tier unlockables - get_connection(): add timeout=30 to both sqlite3 and pysqlcipher3 paths so concurrent writers retry instead of immediately raising OperationalError - run_migrations(): INSERT OR IGNORE so two Store() calls racing on first boot don't hit a UNIQUE constraint on the migrations table - can_use() / tier_label(): accept _byok_unlockable and _local_vision_unlockable overrides so products pass their own frozensets rather than sharing module-level constants (required for circuitforge-core to serve multiple products cleanly) --- circuitforge_core/db/base.py | 6 ++++-- circuitforge_core/db/migrations.py | 4 +++- circuitforge_core/tiers/tiers.py | 34 +++++++++++++++++++++--------- 3 files changed, 31 insertions(+), 13 deletions(-) diff --git a/circuitforge_core/db/base.py b/circuitforge_core/db/base.py index baed4d0..5a0293c 100644 --- a/circuitforge_core/db/base.py +++ b/circuitforge_core/db/base.py @@ -22,7 +22,9 @@ def get_connection(db_path: Path, key: str = "") -> sqlite3.Connection: cloud_mode = os.environ.get("CLOUD_MODE", "").lower() in ("1", "true", "yes") if cloud_mode and key: from pysqlcipher3 import dbapi2 as _sqlcipher # type: ignore - conn = _sqlcipher.connect(str(db_path)) + conn = _sqlcipher.connect(str(db_path), timeout=30) conn.execute(f"PRAGMA key='{key}'") return conn - return sqlite3.connect(str(db_path)) + # timeout=30: retry for up to 30s when another writer holds the lock (WAL mode + # allows concurrent readers but only one writer at a time). + return sqlite3.connect(str(db_path), timeout=30) diff --git a/circuitforge_core/db/migrations.py b/circuitforge_core/db/migrations.py index ddcf331..f3b3cac 100644 --- a/circuitforge_core/db/migrations.py +++ b/circuitforge_core/db/migrations.py @@ -23,5 +23,7 @@ def run_migrations(conn: sqlite3.Connection, migrations_dir: Path) -> None: if sql_file.name in applied: continue conn.executescript(sql_file.read_text()) - conn.execute("INSERT INTO _migrations (name) VALUES (?)", (sql_file.name,)) + # OR IGNORE: safe if two Store() calls race on the same DB — second writer + # just skips the insert rather than raising UNIQUE constraint failed. + conn.execute("INSERT OR IGNORE INTO _migrations (name) VALUES (?)", (sql_file.name,)) conn.commit() diff --git a/circuitforge_core/tiers/tiers.py b/circuitforge_core/tiers/tiers.py index d243ad6..3b5d2d8 100644 --- a/circuitforge_core/tiers/tiers.py +++ b/circuitforge_core/tiers/tiers.py @@ -30,26 +30,35 @@ def can_use( has_byok: bool = False, has_local_vision: bool = False, _features: dict[str, str] | None = None, + _byok_unlockable: frozenset[str] | None = None, + _local_vision_unlockable: frozenset[str] | None = None, ) -> bool: """ Return True if the given tier (and optional unlocks) can access feature. Args: - feature: Feature key string. - tier: User's current tier ("free", "paid", "premium", "ultra"). - has_byok: True if user has a configured LLM backend. - has_local_vision: True if user has a local vision model configured. - _features: Feature→min_tier map. Products pass their own dict here. - If None, all features are free. + feature: Feature key string. + tier: User's current tier ("free", "paid", "premium", "ultra"). + has_byok: True if user has a configured LLM backend. + has_local_vision: True if user has a local vision model configured. + _features: Feature→min_tier map. Products pass their own dict here. + If None, all features are free. + _byok_unlockable: Product-specific BYOK-unlockable features. + If None, uses module-level BYOK_UNLOCKABLE. + _local_vision_unlockable: Product-specific local vision unlockable features. + If None, uses module-level LOCAL_VISION_UNLOCKABLE. """ features = _features or {} + byok_unlockable = _byok_unlockable if _byok_unlockable is not None else BYOK_UNLOCKABLE + local_vision_unlockable = _local_vision_unlockable if _local_vision_unlockable is not None else LOCAL_VISION_UNLOCKABLE + if feature not in features: return True - if has_byok and feature in BYOK_UNLOCKABLE: + if has_byok and feature in byok_unlockable: return True - if has_local_vision and feature in LOCAL_VISION_UNLOCKABLE: + if has_local_vision and feature in local_vision_unlockable: return True min_tier = features[feature] @@ -64,13 +73,18 @@ def tier_label( has_byok: bool = False, has_local_vision: bool = False, _features: dict[str, str] | None = None, + _byok_unlockable: frozenset[str] | None = None, + _local_vision_unlockable: frozenset[str] | None = None, ) -> str: """Return a human-readable label for the minimum tier needed for feature.""" features = _features or {} + byok_unlockable = _byok_unlockable if _byok_unlockable is not None else BYOK_UNLOCKABLE + local_vision_unlockable = _local_vision_unlockable if _local_vision_unlockable is not None else LOCAL_VISION_UNLOCKABLE + if feature not in features: return "free" - if has_byok and feature in BYOK_UNLOCKABLE: + if has_byok and feature in byok_unlockable: return "free (BYOK)" - if has_local_vision and feature in LOCAL_VISION_UNLOCKABLE: + if has_local_vision and feature in local_vision_unlockable: return "free (local vision)" return features[feature] -- 2.45.2