From 22bad8590ac0d2a8b21134698a9d91a3f5701f50 Mon Sep 17 00:00:00 2001 From: pyr0ball Date: Tue, 31 Mar 2026 09:15:09 -0700 Subject: [PATCH] fix(tasks): fix VRAM accounting race, lock scope, type annotations - C1: Remove _reserved_vram decrement from _scheduler_loop reaper; sole responsibility now belongs to _batch_worker's finally block, eliminating the double-decrement race that could drive _reserved_vram negative. - C2: Move TaskScheduler construction (including VRAM detection httpx call) outside _scheduler_lock in get_scheduler(); lock is now only held for the final singleton assignment, preventing 2s lock contention on first call. - I1: Add RunTaskFn type alias (Callable[...]) and use it in __init__ and get_scheduler() instead of bare Callable. - I2: Replace namedtuple TaskSpec with typed NamedTuple class. - I3: Parameterize _queues annotation as dict[str, deque[TaskSpec]]. - I4: Wrap _queues read in start() with self._lock. - I5: Replace time.sleep() ordering assertion in test_vram_budget_blocks_second_type with event-based synchronization using type_a_started/type_b_started events. - M2: Use sqlite3.connect() as context manager in _load_queued_tasks. - M3: Strengthen weak assertion in test_enqueue_returns_false_when_queue_full. - M4: Add test_reserved_vram_zero_after_task_completes to catch C1 regression. --- circuitforge_core/tasks/scheduler.py | 91 ++++++++++++++++------------ tests/test_tasks/test_scheduler.py | 29 +++++++-- 2 files changed, 76 insertions(+), 44 deletions(-) diff --git a/circuitforge_core/tasks/scheduler.py b/circuitforge_core/tasks/scheduler.py index 7ccd189..a6c4453 100644 --- a/circuitforge_core/tasks/scheduler.py +++ b/circuitforge_core/tasks/scheduler.py @@ -22,9 +22,9 @@ from __future__ import annotations import logging import sqlite3 import threading -from collections import deque, namedtuple +from collections import deque from pathlib import Path -from typing import Callable, Optional +from typing import Callable, NamedTuple, Optional try: import httpx as httpx @@ -33,7 +33,13 @@ except ImportError: logger = logging.getLogger(__name__) -TaskSpec = namedtuple("TaskSpec", ["id", "job_id", "params"]) +RunTaskFn = Callable[["Path", int, str, int, Optional[str]], None] + + +class TaskSpec(NamedTuple): + id: int + job_id: int + params: Optional[str] _DEFAULT_MAX_QUEUE_DEPTH = 500 @@ -116,7 +122,7 @@ class TaskScheduler: def __init__( self, db_path: Path, - run_task_fn: Callable, + run_task_fn: RunTaskFn, task_types: frozenset[str], vram_budgets: dict[str, float], available_vram_gb: Optional[float] = None, @@ -131,7 +137,7 @@ class TaskScheduler: self._lock = threading.Lock() self._wake = threading.Event() self._stop = threading.Event() - self._queues: dict[str, deque] = {} + self._queues: dict[str, deque[TaskSpec]] = {} self._active: dict[str, threading.Thread] = {} self._reserved_vram: float = 0.0 self._thread: Optional[threading.Thread] = None @@ -185,8 +191,9 @@ class TaskScheduler: ) self._thread.start() # Wake the loop immediately so tasks loaded from DB at startup are dispatched - if any(self._queues.values()): - self._wake.set() + with self._lock: + if any(self._queues.values()): + self._wake.set() def shutdown(self, timeout: float = 5.0) -> None: """Signal the scheduler to stop and wait for it to exit.""" @@ -200,10 +207,11 @@ class TaskScheduler: self._wake.wait(timeout=30) self._wake.clear() with self._lock: - # Reap batch threads that finished without waking us + # Reap batch threads that finished without waking us. + # VRAM accounting is handled solely by _batch_worker's finally block; + # the reaper only removes dead entries from _active. for t, thread in list(self._active.items()): if not thread.is_alive(): - self._reserved_vram -= self._budgets.get(t, 0.0) del self._active[t] # Start new type batches while VRAM budget allows candidates = sorted( @@ -256,19 +264,17 @@ class TaskScheduler: return task_types_list = sorted(self._task_types) placeholders = ",".join("?" * len(task_types_list)) - conn = sqlite3.connect(self._db_path) try: - rows = conn.execute( - f"SELECT id, task_type, job_id, params FROM background_tasks" - f" WHERE status='queued' AND task_type IN ({placeholders})" - f" ORDER BY created_at ASC", - task_types_list, - ).fetchall() + with sqlite3.connect(self._db_path) as conn: + rows = conn.execute( + f"SELECT id, task_type, job_id, params FROM background_tasks" + f" WHERE status='queued' AND task_type IN ({placeholders})" + f" ORDER BY created_at ASC", + task_types_list, + ).fetchall() except sqlite3.OperationalError: # Table not yet created (first run before migrations) rows = [] - finally: - conn.close() for row_id, task_type, job_id, params in rows: q = self._queues.setdefault(task_type, deque()) @@ -288,7 +294,7 @@ _scheduler_lock = threading.Lock() def get_scheduler( db_path: Path, - run_task_fn: Optional[Callable] = None, + run_task_fn: Optional[RunTaskFn] = None, task_types: Optional[frozenset[str]] = None, vram_budgets: Optional[dict[str, float]] = None, max_queue_depth: int = _DEFAULT_MAX_QUEUE_DEPTH, @@ -297,28 +303,35 @@ def get_scheduler( ``run_task_fn``, ``task_types``, and ``vram_budgets`` are required on the first call; ignored on subsequent calls (singleton already constructed). + + VRAM detection (which may involve a network call) is performed outside the + lock so the lock is never held across blocking I/O. """ global _scheduler - if _scheduler is None: - with _scheduler_lock: - if _scheduler is None: - if ( - run_task_fn is None - or task_types is None - or vram_budgets is None - ): - raise ValueError( - "run_task_fn, task_types, and vram_budgets are required " - "on the first call to get_scheduler()" - ) - _scheduler = TaskScheduler( - db_path=db_path, - run_task_fn=run_task_fn, - task_types=task_types, - vram_budgets=vram_budgets, - max_queue_depth=max_queue_depth, - ) - _scheduler.start() + if _scheduler is not None: + return _scheduler + # Build outside the lock — TaskScheduler.__init__ may call detect_available_vram_gb() + # which makes an httpx network call (up to 2 s). Holding the lock during that + # would block any concurrent caller for the full duration. + if run_task_fn is None or task_types is None or vram_budgets is None: + raise ValueError( + "run_task_fn, task_types, and vram_budgets are required " + "on the first call to get_scheduler()" + ) + candidate = TaskScheduler( + db_path=db_path, + run_task_fn=run_task_fn, + task_types=task_types, + vram_budgets=vram_budgets, + max_queue_depth=max_queue_depth, + ) + candidate.start() + with _scheduler_lock: + if _scheduler is None: + _scheduler = candidate + else: + # Another thread beat us — shut down our candidate and use the winner. + candidate.shutdown() return _scheduler diff --git a/tests/test_tasks/test_scheduler.py b/tests/test_tasks/test_scheduler.py index 99e9ea8..4428e42 100644 --- a/tests/test_tasks/test_scheduler.py +++ b/tests/test_tasks/test_scheduler.py @@ -158,7 +158,7 @@ def test_enqueue_returns_false_when_queue_full(tmp_db: Path): results = [sched.enqueue(i, "fast_task", 0, None) for i in range(1, 10)] gate.set() sched.shutdown() - assert False in results + assert not all(results), "Expected at least one enqueue to be rejected" def test_scheduler_drains_multiple_tasks(tmp_db: Path): @@ -183,6 +183,8 @@ def test_scheduler_drains_multiple_tasks(tmp_db: Path): def test_vram_budget_blocks_second_type(tmp_db: Path): """Second task type is not started when VRAM would be exceeded.""" + type_a_started = threading.Event() + type_b_started = threading.Event() gate_a = threading.Event() gate_b = threading.Event() started = [] @@ -190,8 +192,10 @@ def test_vram_budget_blocks_second_type(tmp_db: Path): def run_fn(db_path, task_id, task_type, job_id, params): started.append(task_type) if task_type == "type_a": + type_a_started.set() gate_a.wait() else: + type_b_started.set() gate_b.wait() two_types = frozenset({"type_a", "type_b"}) @@ -204,14 +208,14 @@ def test_vram_budget_blocks_second_type(tmp_db: Path): sched.enqueue(1, "type_a", 0, None) sched.enqueue(2, "type_b", 0, None) - time.sleep(0.2) - assert started == ["type_a"] + assert type_a_started.wait(timeout=3.0), "type_a never started" + assert not type_b_started.is_set(), "type_b should be blocked by VRAM" gate_a.set() - time.sleep(0.2) + assert type_b_started.wait(timeout=3.0), "type_b never started after type_a finished" gate_b.set() sched.shutdown() - assert "type_b" in started + assert sorted(started) == ["type_a", "type_b"] def test_get_scheduler_singleton(tmp_db: Path): @@ -264,3 +268,18 @@ def test_load_queued_tasks_missing_table_does_not_crash(tmp_path: Path): sched.start() sched.shutdown() # No exception = pass + + +def test_reserved_vram_zero_after_task_completes(tmp_db: Path): + """_reserved_vram returns to 0.0 after a task finishes — no double-decrement.""" + done = threading.Event() + + def run_fn(db_path, task_id, task_type, job_id, params): + done.set() + + sched = TaskScheduler(tmp_db, run_fn, TASK_TYPES, BUDGETS, available_vram_gb=8.0) + sched.start() + sched.enqueue(1, "fast_task", 0, None) + assert done.wait(timeout=3.0), "Task never completed" + sched.shutdown() + assert sched._reserved_vram == 0.0, f"Expected 0.0, got {sched._reserved_vram}"