fix(tasks): fix VRAM accounting race, lock scope, type annotations
- C1: Remove _reserved_vram decrement from _scheduler_loop reaper; sole responsibility now belongs to _batch_worker's finally block, eliminating the double-decrement race that could drive _reserved_vram negative. - C2: Move TaskScheduler construction (including VRAM detection httpx call) outside _scheduler_lock in get_scheduler(); lock is now only held for the final singleton assignment, preventing 2s lock contention on first call. - I1: Add RunTaskFn type alias (Callable[...]) and use it in __init__ and get_scheduler() instead of bare Callable. - I2: Replace namedtuple TaskSpec with typed NamedTuple class. - I3: Parameterize _queues annotation as dict[str, deque[TaskSpec]]. - I4: Wrap _queues read in start() with self._lock. - I5: Replace time.sleep() ordering assertion in test_vram_budget_blocks_second_type with event-based synchronization using type_a_started/type_b_started events. - M2: Use sqlite3.connect() as context manager in _load_queued_tasks. - M3: Strengthen weak assertion in test_enqueue_returns_false_when_queue_full. - M4: Add test_reserved_vram_zero_after_task_completes to catch C1 regression.
This commit is contained in:
parent
09a5087c72
commit
22bad8590a
2 changed files with 76 additions and 44 deletions
|
|
@ -22,9 +22,9 @@ from __future__ import annotations
|
||||||
import logging
|
import logging
|
||||||
import sqlite3
|
import sqlite3
|
||||||
import threading
|
import threading
|
||||||
from collections import deque, namedtuple
|
from collections import deque
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Callable, Optional
|
from typing import Callable, NamedTuple, Optional
|
||||||
|
|
||||||
try:
|
try:
|
||||||
import httpx as httpx
|
import httpx as httpx
|
||||||
|
|
@ -33,7 +33,13 @@ except ImportError:
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
TaskSpec = namedtuple("TaskSpec", ["id", "job_id", "params"])
|
RunTaskFn = Callable[["Path", int, str, int, Optional[str]], None]
|
||||||
|
|
||||||
|
|
||||||
|
class TaskSpec(NamedTuple):
|
||||||
|
id: int
|
||||||
|
job_id: int
|
||||||
|
params: Optional[str]
|
||||||
|
|
||||||
_DEFAULT_MAX_QUEUE_DEPTH = 500
|
_DEFAULT_MAX_QUEUE_DEPTH = 500
|
||||||
|
|
||||||
|
|
@ -116,7 +122,7 @@ class TaskScheduler:
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
db_path: Path,
|
db_path: Path,
|
||||||
run_task_fn: Callable,
|
run_task_fn: RunTaskFn,
|
||||||
task_types: frozenset[str],
|
task_types: frozenset[str],
|
||||||
vram_budgets: dict[str, float],
|
vram_budgets: dict[str, float],
|
||||||
available_vram_gb: Optional[float] = None,
|
available_vram_gb: Optional[float] = None,
|
||||||
|
|
@ -131,7 +137,7 @@ class TaskScheduler:
|
||||||
self._lock = threading.Lock()
|
self._lock = threading.Lock()
|
||||||
self._wake = threading.Event()
|
self._wake = threading.Event()
|
||||||
self._stop = threading.Event()
|
self._stop = threading.Event()
|
||||||
self._queues: dict[str, deque] = {}
|
self._queues: dict[str, deque[TaskSpec]] = {}
|
||||||
self._active: dict[str, threading.Thread] = {}
|
self._active: dict[str, threading.Thread] = {}
|
||||||
self._reserved_vram: float = 0.0
|
self._reserved_vram: float = 0.0
|
||||||
self._thread: Optional[threading.Thread] = None
|
self._thread: Optional[threading.Thread] = None
|
||||||
|
|
@ -185,6 +191,7 @@ class TaskScheduler:
|
||||||
)
|
)
|
||||||
self._thread.start()
|
self._thread.start()
|
||||||
# Wake the loop immediately so tasks loaded from DB at startup are dispatched
|
# Wake the loop immediately so tasks loaded from DB at startup are dispatched
|
||||||
|
with self._lock:
|
||||||
if any(self._queues.values()):
|
if any(self._queues.values()):
|
||||||
self._wake.set()
|
self._wake.set()
|
||||||
|
|
||||||
|
|
@ -200,10 +207,11 @@ class TaskScheduler:
|
||||||
self._wake.wait(timeout=30)
|
self._wake.wait(timeout=30)
|
||||||
self._wake.clear()
|
self._wake.clear()
|
||||||
with self._lock:
|
with self._lock:
|
||||||
# Reap batch threads that finished without waking us
|
# Reap batch threads that finished without waking us.
|
||||||
|
# VRAM accounting is handled solely by _batch_worker's finally block;
|
||||||
|
# the reaper only removes dead entries from _active.
|
||||||
for t, thread in list(self._active.items()):
|
for t, thread in list(self._active.items()):
|
||||||
if not thread.is_alive():
|
if not thread.is_alive():
|
||||||
self._reserved_vram -= self._budgets.get(t, 0.0)
|
|
||||||
del self._active[t]
|
del self._active[t]
|
||||||
# Start new type batches while VRAM budget allows
|
# Start new type batches while VRAM budget allows
|
||||||
candidates = sorted(
|
candidates = sorted(
|
||||||
|
|
@ -256,8 +264,8 @@ class TaskScheduler:
|
||||||
return
|
return
|
||||||
task_types_list = sorted(self._task_types)
|
task_types_list = sorted(self._task_types)
|
||||||
placeholders = ",".join("?" * len(task_types_list))
|
placeholders = ",".join("?" * len(task_types_list))
|
||||||
conn = sqlite3.connect(self._db_path)
|
|
||||||
try:
|
try:
|
||||||
|
with sqlite3.connect(self._db_path) as conn:
|
||||||
rows = conn.execute(
|
rows = conn.execute(
|
||||||
f"SELECT id, task_type, job_id, params FROM background_tasks"
|
f"SELECT id, task_type, job_id, params FROM background_tasks"
|
||||||
f" WHERE status='queued' AND task_type IN ({placeholders})"
|
f" WHERE status='queued' AND task_type IN ({placeholders})"
|
||||||
|
|
@ -267,8 +275,6 @@ class TaskScheduler:
|
||||||
except sqlite3.OperationalError:
|
except sqlite3.OperationalError:
|
||||||
# Table not yet created (first run before migrations)
|
# Table not yet created (first run before migrations)
|
||||||
rows = []
|
rows = []
|
||||||
finally:
|
|
||||||
conn.close()
|
|
||||||
|
|
||||||
for row_id, task_type, job_id, params in rows:
|
for row_id, task_type, job_id, params in rows:
|
||||||
q = self._queues.setdefault(task_type, deque())
|
q = self._queues.setdefault(task_type, deque())
|
||||||
|
|
@ -288,7 +294,7 @@ _scheduler_lock = threading.Lock()
|
||||||
|
|
||||||
def get_scheduler(
|
def get_scheduler(
|
||||||
db_path: Path,
|
db_path: Path,
|
||||||
run_task_fn: Optional[Callable] = None,
|
run_task_fn: Optional[RunTaskFn] = None,
|
||||||
task_types: Optional[frozenset[str]] = None,
|
task_types: Optional[frozenset[str]] = None,
|
||||||
vram_budgets: Optional[dict[str, float]] = None,
|
vram_budgets: Optional[dict[str, float]] = None,
|
||||||
max_queue_depth: int = _DEFAULT_MAX_QUEUE_DEPTH,
|
max_queue_depth: int = _DEFAULT_MAX_QUEUE_DEPTH,
|
||||||
|
|
@ -297,28 +303,35 @@ def get_scheduler(
|
||||||
|
|
||||||
``run_task_fn``, ``task_types``, and ``vram_budgets`` are required on the
|
``run_task_fn``, ``task_types``, and ``vram_budgets`` are required on the
|
||||||
first call; ignored on subsequent calls (singleton already constructed).
|
first call; ignored on subsequent calls (singleton already constructed).
|
||||||
|
|
||||||
|
VRAM detection (which may involve a network call) is performed outside the
|
||||||
|
lock so the lock is never held across blocking I/O.
|
||||||
"""
|
"""
|
||||||
global _scheduler
|
global _scheduler
|
||||||
if _scheduler is None:
|
if _scheduler is not None:
|
||||||
with _scheduler_lock:
|
return _scheduler
|
||||||
if _scheduler is None:
|
# Build outside the lock — TaskScheduler.__init__ may call detect_available_vram_gb()
|
||||||
if (
|
# which makes an httpx network call (up to 2 s). Holding the lock during that
|
||||||
run_task_fn is None
|
# would block any concurrent caller for the full duration.
|
||||||
or task_types is None
|
if run_task_fn is None or task_types is None or vram_budgets is None:
|
||||||
or vram_budgets is None
|
|
||||||
):
|
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"run_task_fn, task_types, and vram_budgets are required "
|
"run_task_fn, task_types, and vram_budgets are required "
|
||||||
"on the first call to get_scheduler()"
|
"on the first call to get_scheduler()"
|
||||||
)
|
)
|
||||||
_scheduler = TaskScheduler(
|
candidate = TaskScheduler(
|
||||||
db_path=db_path,
|
db_path=db_path,
|
||||||
run_task_fn=run_task_fn,
|
run_task_fn=run_task_fn,
|
||||||
task_types=task_types,
|
task_types=task_types,
|
||||||
vram_budgets=vram_budgets,
|
vram_budgets=vram_budgets,
|
||||||
max_queue_depth=max_queue_depth,
|
max_queue_depth=max_queue_depth,
|
||||||
)
|
)
|
||||||
_scheduler.start()
|
candidate.start()
|
||||||
|
with _scheduler_lock:
|
||||||
|
if _scheduler is None:
|
||||||
|
_scheduler = candidate
|
||||||
|
else:
|
||||||
|
# Another thread beat us — shut down our candidate and use the winner.
|
||||||
|
candidate.shutdown()
|
||||||
return _scheduler
|
return _scheduler
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -158,7 +158,7 @@ def test_enqueue_returns_false_when_queue_full(tmp_db: Path):
|
||||||
results = [sched.enqueue(i, "fast_task", 0, None) for i in range(1, 10)]
|
results = [sched.enqueue(i, "fast_task", 0, None) for i in range(1, 10)]
|
||||||
gate.set()
|
gate.set()
|
||||||
sched.shutdown()
|
sched.shutdown()
|
||||||
assert False in results
|
assert not all(results), "Expected at least one enqueue to be rejected"
|
||||||
|
|
||||||
|
|
||||||
def test_scheduler_drains_multiple_tasks(tmp_db: Path):
|
def test_scheduler_drains_multiple_tasks(tmp_db: Path):
|
||||||
|
|
@ -183,6 +183,8 @@ def test_scheduler_drains_multiple_tasks(tmp_db: Path):
|
||||||
|
|
||||||
def test_vram_budget_blocks_second_type(tmp_db: Path):
|
def test_vram_budget_blocks_second_type(tmp_db: Path):
|
||||||
"""Second task type is not started when VRAM would be exceeded."""
|
"""Second task type is not started when VRAM would be exceeded."""
|
||||||
|
type_a_started = threading.Event()
|
||||||
|
type_b_started = threading.Event()
|
||||||
gate_a = threading.Event()
|
gate_a = threading.Event()
|
||||||
gate_b = threading.Event()
|
gate_b = threading.Event()
|
||||||
started = []
|
started = []
|
||||||
|
|
@ -190,8 +192,10 @@ def test_vram_budget_blocks_second_type(tmp_db: Path):
|
||||||
def run_fn(db_path, task_id, task_type, job_id, params):
|
def run_fn(db_path, task_id, task_type, job_id, params):
|
||||||
started.append(task_type)
|
started.append(task_type)
|
||||||
if task_type == "type_a":
|
if task_type == "type_a":
|
||||||
|
type_a_started.set()
|
||||||
gate_a.wait()
|
gate_a.wait()
|
||||||
else:
|
else:
|
||||||
|
type_b_started.set()
|
||||||
gate_b.wait()
|
gate_b.wait()
|
||||||
|
|
||||||
two_types = frozenset({"type_a", "type_b"})
|
two_types = frozenset({"type_a", "type_b"})
|
||||||
|
|
@ -204,14 +208,14 @@ def test_vram_budget_blocks_second_type(tmp_db: Path):
|
||||||
sched.enqueue(1, "type_a", 0, None)
|
sched.enqueue(1, "type_a", 0, None)
|
||||||
sched.enqueue(2, "type_b", 0, None)
|
sched.enqueue(2, "type_b", 0, None)
|
||||||
|
|
||||||
time.sleep(0.2)
|
assert type_a_started.wait(timeout=3.0), "type_a never started"
|
||||||
assert started == ["type_a"]
|
assert not type_b_started.is_set(), "type_b should be blocked by VRAM"
|
||||||
|
|
||||||
gate_a.set()
|
gate_a.set()
|
||||||
time.sleep(0.2)
|
assert type_b_started.wait(timeout=3.0), "type_b never started after type_a finished"
|
||||||
gate_b.set()
|
gate_b.set()
|
||||||
sched.shutdown()
|
sched.shutdown()
|
||||||
assert "type_b" in started
|
assert sorted(started) == ["type_a", "type_b"]
|
||||||
|
|
||||||
|
|
||||||
def test_get_scheduler_singleton(tmp_db: Path):
|
def test_get_scheduler_singleton(tmp_db: Path):
|
||||||
|
|
@ -264,3 +268,18 @@ def test_load_queued_tasks_missing_table_does_not_crash(tmp_path: Path):
|
||||||
sched.start()
|
sched.start()
|
||||||
sched.shutdown()
|
sched.shutdown()
|
||||||
# No exception = pass
|
# No exception = pass
|
||||||
|
|
||||||
|
|
||||||
|
def test_reserved_vram_zero_after_task_completes(tmp_db: Path):
|
||||||
|
"""_reserved_vram returns to 0.0 after a task finishes — no double-decrement."""
|
||||||
|
done = threading.Event()
|
||||||
|
|
||||||
|
def run_fn(db_path, task_id, task_type, job_id, params):
|
||||||
|
done.set()
|
||||||
|
|
||||||
|
sched = TaskScheduler(tmp_db, run_fn, TASK_TYPES, BUDGETS, available_vram_gb=8.0)
|
||||||
|
sched.start()
|
||||||
|
sched.enqueue(1, "fast_task", 0, None)
|
||||||
|
assert done.wait(timeout=3.0), "Task never completed"
|
||||||
|
sched.shutdown()
|
||||||
|
assert sched._reserved_vram == 0.0, f"Expected 0.0, got {sched._reserved_vram}"
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue