fix(tasks): fix VRAM accounting race, lock scope, type annotations

- C1: Remove _reserved_vram decrement from _scheduler_loop reaper; sole
  responsibility now belongs to _batch_worker's finally block, eliminating
  the double-decrement race that could drive _reserved_vram negative.
- C2: Move TaskScheduler construction (including VRAM detection httpx call)
  outside _scheduler_lock in get_scheduler(); lock is now only held for the
  final singleton assignment, preventing 2s lock contention on first call.
- I1: Add RunTaskFn type alias (Callable[...]) and use it in __init__ and
  get_scheduler() instead of bare Callable.
- I2: Replace namedtuple TaskSpec with typed NamedTuple class.
- I3: Parameterize _queues annotation as dict[str, deque[TaskSpec]].
- I4: Wrap _queues read in start() with self._lock.
- I5: Replace time.sleep() ordering assertion in test_vram_budget_blocks_second_type
  with event-based synchronization using type_a_started/type_b_started events.
- M2: Use sqlite3.connect() as context manager in _load_queued_tasks.
- M3: Strengthen weak assertion in test_enqueue_returns_false_when_queue_full.
- M4: Add test_reserved_vram_zero_after_task_completes to catch C1 regression.
This commit is contained in:
pyr0ball 2026-03-31 09:15:09 -07:00
parent 09a5087c72
commit 22bad8590a
2 changed files with 76 additions and 44 deletions

View file

@ -22,9 +22,9 @@ from __future__ import annotations
import logging
import sqlite3
import threading
from collections import deque, namedtuple
from collections import deque
from pathlib import Path
from typing import Callable, Optional
from typing import Callable, NamedTuple, Optional
try:
import httpx as httpx
@ -33,7 +33,13 @@ except ImportError:
logger = logging.getLogger(__name__)
TaskSpec = namedtuple("TaskSpec", ["id", "job_id", "params"])
RunTaskFn = Callable[["Path", int, str, int, Optional[str]], None]
class TaskSpec(NamedTuple):
id: int
job_id: int
params: Optional[str]
_DEFAULT_MAX_QUEUE_DEPTH = 500
@ -116,7 +122,7 @@ class TaskScheduler:
def __init__(
self,
db_path: Path,
run_task_fn: Callable,
run_task_fn: RunTaskFn,
task_types: frozenset[str],
vram_budgets: dict[str, float],
available_vram_gb: Optional[float] = None,
@ -131,7 +137,7 @@ class TaskScheduler:
self._lock = threading.Lock()
self._wake = threading.Event()
self._stop = threading.Event()
self._queues: dict[str, deque] = {}
self._queues: dict[str, deque[TaskSpec]] = {}
self._active: dict[str, threading.Thread] = {}
self._reserved_vram: float = 0.0
self._thread: Optional[threading.Thread] = None
@ -185,8 +191,9 @@ class TaskScheduler:
)
self._thread.start()
# Wake the loop immediately so tasks loaded from DB at startup are dispatched
if any(self._queues.values()):
self._wake.set()
with self._lock:
if any(self._queues.values()):
self._wake.set()
def shutdown(self, timeout: float = 5.0) -> None:
"""Signal the scheduler to stop and wait for it to exit."""
@ -200,10 +207,11 @@ class TaskScheduler:
self._wake.wait(timeout=30)
self._wake.clear()
with self._lock:
# Reap batch threads that finished without waking us
# Reap batch threads that finished without waking us.
# VRAM accounting is handled solely by _batch_worker's finally block;
# the reaper only removes dead entries from _active.
for t, thread in list(self._active.items()):
if not thread.is_alive():
self._reserved_vram -= self._budgets.get(t, 0.0)
del self._active[t]
# Start new type batches while VRAM budget allows
candidates = sorted(
@ -256,19 +264,17 @@ class TaskScheduler:
return
task_types_list = sorted(self._task_types)
placeholders = ",".join("?" * len(task_types_list))
conn = sqlite3.connect(self._db_path)
try:
rows = conn.execute(
f"SELECT id, task_type, job_id, params FROM background_tasks"
f" WHERE status='queued' AND task_type IN ({placeholders})"
f" ORDER BY created_at ASC",
task_types_list,
).fetchall()
with sqlite3.connect(self._db_path) as conn:
rows = conn.execute(
f"SELECT id, task_type, job_id, params FROM background_tasks"
f" WHERE status='queued' AND task_type IN ({placeholders})"
f" ORDER BY created_at ASC",
task_types_list,
).fetchall()
except sqlite3.OperationalError:
# Table not yet created (first run before migrations)
rows = []
finally:
conn.close()
for row_id, task_type, job_id, params in rows:
q = self._queues.setdefault(task_type, deque())
@ -288,7 +294,7 @@ _scheduler_lock = threading.Lock()
def get_scheduler(
db_path: Path,
run_task_fn: Optional[Callable] = None,
run_task_fn: Optional[RunTaskFn] = None,
task_types: Optional[frozenset[str]] = None,
vram_budgets: Optional[dict[str, float]] = None,
max_queue_depth: int = _DEFAULT_MAX_QUEUE_DEPTH,
@ -297,28 +303,35 @@ def get_scheduler(
``run_task_fn``, ``task_types``, and ``vram_budgets`` are required on the
first call; ignored on subsequent calls (singleton already constructed).
VRAM detection (which may involve a network call) is performed outside the
lock so the lock is never held across blocking I/O.
"""
global _scheduler
if _scheduler is None:
with _scheduler_lock:
if _scheduler is None:
if (
run_task_fn is None
or task_types is None
or vram_budgets is None
):
raise ValueError(
"run_task_fn, task_types, and vram_budgets are required "
"on the first call to get_scheduler()"
)
_scheduler = TaskScheduler(
db_path=db_path,
run_task_fn=run_task_fn,
task_types=task_types,
vram_budgets=vram_budgets,
max_queue_depth=max_queue_depth,
)
_scheduler.start()
if _scheduler is not None:
return _scheduler
# Build outside the lock — TaskScheduler.__init__ may call detect_available_vram_gb()
# which makes an httpx network call (up to 2 s). Holding the lock during that
# would block any concurrent caller for the full duration.
if run_task_fn is None or task_types is None or vram_budgets is None:
raise ValueError(
"run_task_fn, task_types, and vram_budgets are required "
"on the first call to get_scheduler()"
)
candidate = TaskScheduler(
db_path=db_path,
run_task_fn=run_task_fn,
task_types=task_types,
vram_budgets=vram_budgets,
max_queue_depth=max_queue_depth,
)
candidate.start()
with _scheduler_lock:
if _scheduler is None:
_scheduler = candidate
else:
# Another thread beat us — shut down our candidate and use the winner.
candidate.shutdown()
return _scheduler

View file

@ -158,7 +158,7 @@ def test_enqueue_returns_false_when_queue_full(tmp_db: Path):
results = [sched.enqueue(i, "fast_task", 0, None) for i in range(1, 10)]
gate.set()
sched.shutdown()
assert False in results
assert not all(results), "Expected at least one enqueue to be rejected"
def test_scheduler_drains_multiple_tasks(tmp_db: Path):
@ -183,6 +183,8 @@ def test_scheduler_drains_multiple_tasks(tmp_db: Path):
def test_vram_budget_blocks_second_type(tmp_db: Path):
"""Second task type is not started when VRAM would be exceeded."""
type_a_started = threading.Event()
type_b_started = threading.Event()
gate_a = threading.Event()
gate_b = threading.Event()
started = []
@ -190,8 +192,10 @@ def test_vram_budget_blocks_second_type(tmp_db: Path):
def run_fn(db_path, task_id, task_type, job_id, params):
started.append(task_type)
if task_type == "type_a":
type_a_started.set()
gate_a.wait()
else:
type_b_started.set()
gate_b.wait()
two_types = frozenset({"type_a", "type_b"})
@ -204,14 +208,14 @@ def test_vram_budget_blocks_second_type(tmp_db: Path):
sched.enqueue(1, "type_a", 0, None)
sched.enqueue(2, "type_b", 0, None)
time.sleep(0.2)
assert started == ["type_a"]
assert type_a_started.wait(timeout=3.0), "type_a never started"
assert not type_b_started.is_set(), "type_b should be blocked by VRAM"
gate_a.set()
time.sleep(0.2)
assert type_b_started.wait(timeout=3.0), "type_b never started after type_a finished"
gate_b.set()
sched.shutdown()
assert "type_b" in started
assert sorted(started) == ["type_a", "type_b"]
def test_get_scheduler_singleton(tmp_db: Path):
@ -264,3 +268,18 @@ def test_load_queued_tasks_missing_table_does_not_crash(tmp_path: Path):
sched.start()
sched.shutdown()
# No exception = pass
def test_reserved_vram_zero_after_task_completes(tmp_db: Path):
"""_reserved_vram returns to 0.0 after a task finishes — no double-decrement."""
done = threading.Event()
def run_fn(db_path, task_id, task_type, job_id, params):
done.set()
sched = TaskScheduler(tmp_db, run_fn, TASK_TYPES, BUDGETS, available_vram_gb=8.0)
sched.start()
sched.enqueue(1, "fast_task", 0, None)
assert done.wait(timeout=3.0), "Task never completed"
sched.shutdown()
assert sched._reserved_vram == 0.0, f"Expected 0.0, got {sched._reserved_vram}"