fix(tasks): fix VRAM accounting race, lock scope, type annotations
- C1: Remove _reserved_vram decrement from _scheduler_loop reaper; sole responsibility now belongs to _batch_worker's finally block, eliminating the double-decrement race that could drive _reserved_vram negative. - C2: Move TaskScheduler construction (including VRAM detection httpx call) outside _scheduler_lock in get_scheduler(); lock is now only held for the final singleton assignment, preventing 2s lock contention on first call. - I1: Add RunTaskFn type alias (Callable[...]) and use it in __init__ and get_scheduler() instead of bare Callable. - I2: Replace namedtuple TaskSpec with typed NamedTuple class. - I3: Parameterize _queues annotation as dict[str, deque[TaskSpec]]. - I4: Wrap _queues read in start() with self._lock. - I5: Replace time.sleep() ordering assertion in test_vram_budget_blocks_second_type with event-based synchronization using type_a_started/type_b_started events. - M2: Use sqlite3.connect() as context manager in _load_queued_tasks. - M3: Strengthen weak assertion in test_enqueue_returns_false_when_queue_full. - M4: Add test_reserved_vram_zero_after_task_completes to catch C1 regression.
This commit is contained in:
parent
09a5087c72
commit
22bad8590a
2 changed files with 76 additions and 44 deletions
|
|
@ -22,9 +22,9 @@ from __future__ import annotations
|
|||
import logging
|
||||
import sqlite3
|
||||
import threading
|
||||
from collections import deque, namedtuple
|
||||
from collections import deque
|
||||
from pathlib import Path
|
||||
from typing import Callable, Optional
|
||||
from typing import Callable, NamedTuple, Optional
|
||||
|
||||
try:
|
||||
import httpx as httpx
|
||||
|
|
@ -33,7 +33,13 @@ except ImportError:
|
|||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
TaskSpec = namedtuple("TaskSpec", ["id", "job_id", "params"])
|
||||
RunTaskFn = Callable[["Path", int, str, int, Optional[str]], None]
|
||||
|
||||
|
||||
class TaskSpec(NamedTuple):
|
||||
id: int
|
||||
job_id: int
|
||||
params: Optional[str]
|
||||
|
||||
_DEFAULT_MAX_QUEUE_DEPTH = 500
|
||||
|
||||
|
|
@ -116,7 +122,7 @@ class TaskScheduler:
|
|||
def __init__(
|
||||
self,
|
||||
db_path: Path,
|
||||
run_task_fn: Callable,
|
||||
run_task_fn: RunTaskFn,
|
||||
task_types: frozenset[str],
|
||||
vram_budgets: dict[str, float],
|
||||
available_vram_gb: Optional[float] = None,
|
||||
|
|
@ -131,7 +137,7 @@ class TaskScheduler:
|
|||
self._lock = threading.Lock()
|
||||
self._wake = threading.Event()
|
||||
self._stop = threading.Event()
|
||||
self._queues: dict[str, deque] = {}
|
||||
self._queues: dict[str, deque[TaskSpec]] = {}
|
||||
self._active: dict[str, threading.Thread] = {}
|
||||
self._reserved_vram: float = 0.0
|
||||
self._thread: Optional[threading.Thread] = None
|
||||
|
|
@ -185,8 +191,9 @@ class TaskScheduler:
|
|||
)
|
||||
self._thread.start()
|
||||
# Wake the loop immediately so tasks loaded from DB at startup are dispatched
|
||||
if any(self._queues.values()):
|
||||
self._wake.set()
|
||||
with self._lock:
|
||||
if any(self._queues.values()):
|
||||
self._wake.set()
|
||||
|
||||
def shutdown(self, timeout: float = 5.0) -> None:
|
||||
"""Signal the scheduler to stop and wait for it to exit."""
|
||||
|
|
@ -200,10 +207,11 @@ class TaskScheduler:
|
|||
self._wake.wait(timeout=30)
|
||||
self._wake.clear()
|
||||
with self._lock:
|
||||
# Reap batch threads that finished without waking us
|
||||
# Reap batch threads that finished without waking us.
|
||||
# VRAM accounting is handled solely by _batch_worker's finally block;
|
||||
# the reaper only removes dead entries from _active.
|
||||
for t, thread in list(self._active.items()):
|
||||
if not thread.is_alive():
|
||||
self._reserved_vram -= self._budgets.get(t, 0.0)
|
||||
del self._active[t]
|
||||
# Start new type batches while VRAM budget allows
|
||||
candidates = sorted(
|
||||
|
|
@ -256,19 +264,17 @@ class TaskScheduler:
|
|||
return
|
||||
task_types_list = sorted(self._task_types)
|
||||
placeholders = ",".join("?" * len(task_types_list))
|
||||
conn = sqlite3.connect(self._db_path)
|
||||
try:
|
||||
rows = conn.execute(
|
||||
f"SELECT id, task_type, job_id, params FROM background_tasks"
|
||||
f" WHERE status='queued' AND task_type IN ({placeholders})"
|
||||
f" ORDER BY created_at ASC",
|
||||
task_types_list,
|
||||
).fetchall()
|
||||
with sqlite3.connect(self._db_path) as conn:
|
||||
rows = conn.execute(
|
||||
f"SELECT id, task_type, job_id, params FROM background_tasks"
|
||||
f" WHERE status='queued' AND task_type IN ({placeholders})"
|
||||
f" ORDER BY created_at ASC",
|
||||
task_types_list,
|
||||
).fetchall()
|
||||
except sqlite3.OperationalError:
|
||||
# Table not yet created (first run before migrations)
|
||||
rows = []
|
||||
finally:
|
||||
conn.close()
|
||||
|
||||
for row_id, task_type, job_id, params in rows:
|
||||
q = self._queues.setdefault(task_type, deque())
|
||||
|
|
@ -288,7 +294,7 @@ _scheduler_lock = threading.Lock()
|
|||
|
||||
def get_scheduler(
|
||||
db_path: Path,
|
||||
run_task_fn: Optional[Callable] = None,
|
||||
run_task_fn: Optional[RunTaskFn] = None,
|
||||
task_types: Optional[frozenset[str]] = None,
|
||||
vram_budgets: Optional[dict[str, float]] = None,
|
||||
max_queue_depth: int = _DEFAULT_MAX_QUEUE_DEPTH,
|
||||
|
|
@ -297,28 +303,35 @@ def get_scheduler(
|
|||
|
||||
``run_task_fn``, ``task_types``, and ``vram_budgets`` are required on the
|
||||
first call; ignored on subsequent calls (singleton already constructed).
|
||||
|
||||
VRAM detection (which may involve a network call) is performed outside the
|
||||
lock so the lock is never held across blocking I/O.
|
||||
"""
|
||||
global _scheduler
|
||||
if _scheduler is None:
|
||||
with _scheduler_lock:
|
||||
if _scheduler is None:
|
||||
if (
|
||||
run_task_fn is None
|
||||
or task_types is None
|
||||
or vram_budgets is None
|
||||
):
|
||||
raise ValueError(
|
||||
"run_task_fn, task_types, and vram_budgets are required "
|
||||
"on the first call to get_scheduler()"
|
||||
)
|
||||
_scheduler = TaskScheduler(
|
||||
db_path=db_path,
|
||||
run_task_fn=run_task_fn,
|
||||
task_types=task_types,
|
||||
vram_budgets=vram_budgets,
|
||||
max_queue_depth=max_queue_depth,
|
||||
)
|
||||
_scheduler.start()
|
||||
if _scheduler is not None:
|
||||
return _scheduler
|
||||
# Build outside the lock — TaskScheduler.__init__ may call detect_available_vram_gb()
|
||||
# which makes an httpx network call (up to 2 s). Holding the lock during that
|
||||
# would block any concurrent caller for the full duration.
|
||||
if run_task_fn is None or task_types is None or vram_budgets is None:
|
||||
raise ValueError(
|
||||
"run_task_fn, task_types, and vram_budgets are required "
|
||||
"on the first call to get_scheduler()"
|
||||
)
|
||||
candidate = TaskScheduler(
|
||||
db_path=db_path,
|
||||
run_task_fn=run_task_fn,
|
||||
task_types=task_types,
|
||||
vram_budgets=vram_budgets,
|
||||
max_queue_depth=max_queue_depth,
|
||||
)
|
||||
candidate.start()
|
||||
with _scheduler_lock:
|
||||
if _scheduler is None:
|
||||
_scheduler = candidate
|
||||
else:
|
||||
# Another thread beat us — shut down our candidate and use the winner.
|
||||
candidate.shutdown()
|
||||
return _scheduler
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -158,7 +158,7 @@ def test_enqueue_returns_false_when_queue_full(tmp_db: Path):
|
|||
results = [sched.enqueue(i, "fast_task", 0, None) for i in range(1, 10)]
|
||||
gate.set()
|
||||
sched.shutdown()
|
||||
assert False in results
|
||||
assert not all(results), "Expected at least one enqueue to be rejected"
|
||||
|
||||
|
||||
def test_scheduler_drains_multiple_tasks(tmp_db: Path):
|
||||
|
|
@ -183,6 +183,8 @@ def test_scheduler_drains_multiple_tasks(tmp_db: Path):
|
|||
|
||||
def test_vram_budget_blocks_second_type(tmp_db: Path):
|
||||
"""Second task type is not started when VRAM would be exceeded."""
|
||||
type_a_started = threading.Event()
|
||||
type_b_started = threading.Event()
|
||||
gate_a = threading.Event()
|
||||
gate_b = threading.Event()
|
||||
started = []
|
||||
|
|
@ -190,8 +192,10 @@ def test_vram_budget_blocks_second_type(tmp_db: Path):
|
|||
def run_fn(db_path, task_id, task_type, job_id, params):
|
||||
started.append(task_type)
|
||||
if task_type == "type_a":
|
||||
type_a_started.set()
|
||||
gate_a.wait()
|
||||
else:
|
||||
type_b_started.set()
|
||||
gate_b.wait()
|
||||
|
||||
two_types = frozenset({"type_a", "type_b"})
|
||||
|
|
@ -204,14 +208,14 @@ def test_vram_budget_blocks_second_type(tmp_db: Path):
|
|||
sched.enqueue(1, "type_a", 0, None)
|
||||
sched.enqueue(2, "type_b", 0, None)
|
||||
|
||||
time.sleep(0.2)
|
||||
assert started == ["type_a"]
|
||||
assert type_a_started.wait(timeout=3.0), "type_a never started"
|
||||
assert not type_b_started.is_set(), "type_b should be blocked by VRAM"
|
||||
|
||||
gate_a.set()
|
||||
time.sleep(0.2)
|
||||
assert type_b_started.wait(timeout=3.0), "type_b never started after type_a finished"
|
||||
gate_b.set()
|
||||
sched.shutdown()
|
||||
assert "type_b" in started
|
||||
assert sorted(started) == ["type_a", "type_b"]
|
||||
|
||||
|
||||
def test_get_scheduler_singleton(tmp_db: Path):
|
||||
|
|
@ -264,3 +268,18 @@ def test_load_queued_tasks_missing_table_does_not_crash(tmp_path: Path):
|
|||
sched.start()
|
||||
sched.shutdown()
|
||||
# No exception = pass
|
||||
|
||||
|
||||
def test_reserved_vram_zero_after_task_completes(tmp_db: Path):
|
||||
"""_reserved_vram returns to 0.0 after a task finishes — no double-decrement."""
|
||||
done = threading.Event()
|
||||
|
||||
def run_fn(db_path, task_id, task_type, job_id, params):
|
||||
done.set()
|
||||
|
||||
sched = TaskScheduler(tmp_db, run_fn, TASK_TYPES, BUDGETS, available_vram_gb=8.0)
|
||||
sched.start()
|
||||
sched.enqueue(1, "fast_task", 0, None)
|
||||
assert done.wait(timeout=3.0), "Task never completed"
|
||||
sched.shutdown()
|
||||
assert sched._reserved_vram == 0.0, f"Expected 0.0, got {sched._reserved_vram}"
|
||||
|
|
|
|||
Loading…
Reference in a new issue