feat(tasks): shared VRAM-aware LLM task scheduler #2

Merged
pyr0ball merged 4 commits from feature/shared-task-scheduler into main 2026-03-31 10:45:21 -07:00
2 changed files with 76 additions and 44 deletions
Showing only changes of commit 22bad8590a - Show all commits

View file

@ -22,9 +22,9 @@ from __future__ import annotations
import logging
import sqlite3
import threading
from collections import deque, namedtuple
from collections import deque
from pathlib import Path
from typing import Callable, Optional
from typing import Callable, NamedTuple, Optional
try:
import httpx as httpx
@ -33,7 +33,13 @@ except ImportError:
logger = logging.getLogger(__name__)
TaskSpec = namedtuple("TaskSpec", ["id", "job_id", "params"])
RunTaskFn = Callable[["Path", int, str, int, Optional[str]], None]
class TaskSpec(NamedTuple):
id: int
job_id: int
params: Optional[str]
_DEFAULT_MAX_QUEUE_DEPTH = 500
@ -116,7 +122,7 @@ class TaskScheduler:
def __init__(
self,
db_path: Path,
run_task_fn: Callable,
run_task_fn: RunTaskFn,
task_types: frozenset[str],
vram_budgets: dict[str, float],
available_vram_gb: Optional[float] = None,
@ -131,7 +137,7 @@ class TaskScheduler:
self._lock = threading.Lock()
self._wake = threading.Event()
self._stop = threading.Event()
self._queues: dict[str, deque] = {}
self._queues: dict[str, deque[TaskSpec]] = {}
self._active: dict[str, threading.Thread] = {}
self._reserved_vram: float = 0.0
self._thread: Optional[threading.Thread] = None
@ -185,8 +191,9 @@ class TaskScheduler:
)
self._thread.start()
# Wake the loop immediately so tasks loaded from DB at startup are dispatched
if any(self._queues.values()):
self._wake.set()
with self._lock:
if any(self._queues.values()):
self._wake.set()
def shutdown(self, timeout: float = 5.0) -> None:
"""Signal the scheduler to stop and wait for it to exit."""
@ -200,10 +207,11 @@ class TaskScheduler:
self._wake.wait(timeout=30)
self._wake.clear()
with self._lock:
# Reap batch threads that finished without waking us
# Reap batch threads that finished without waking us.
# VRAM accounting is handled solely by _batch_worker's finally block;
# the reaper only removes dead entries from _active.
for t, thread in list(self._active.items()):
if not thread.is_alive():
self._reserved_vram -= self._budgets.get(t, 0.0)
del self._active[t]
# Start new type batches while VRAM budget allows
candidates = sorted(
@ -256,19 +264,17 @@ class TaskScheduler:
return
task_types_list = sorted(self._task_types)
placeholders = ",".join("?" * len(task_types_list))
conn = sqlite3.connect(self._db_path)
try:
rows = conn.execute(
f"SELECT id, task_type, job_id, params FROM background_tasks"
f" WHERE status='queued' AND task_type IN ({placeholders})"
f" ORDER BY created_at ASC",
task_types_list,
).fetchall()
with sqlite3.connect(self._db_path) as conn:
rows = conn.execute(
f"SELECT id, task_type, job_id, params FROM background_tasks"
f" WHERE status='queued' AND task_type IN ({placeholders})"
f" ORDER BY created_at ASC",
task_types_list,
).fetchall()
except sqlite3.OperationalError:
# Table not yet created (first run before migrations)
rows = []
finally:
conn.close()
for row_id, task_type, job_id, params in rows:
q = self._queues.setdefault(task_type, deque())
@ -288,7 +294,7 @@ _scheduler_lock = threading.Lock()
def get_scheduler(
db_path: Path,
run_task_fn: Optional[Callable] = None,
run_task_fn: Optional[RunTaskFn] = None,
task_types: Optional[frozenset[str]] = None,
vram_budgets: Optional[dict[str, float]] = None,
max_queue_depth: int = _DEFAULT_MAX_QUEUE_DEPTH,
@ -297,28 +303,35 @@ def get_scheduler(
``run_task_fn``, ``task_types``, and ``vram_budgets`` are required on the
first call; ignored on subsequent calls (singleton already constructed).
VRAM detection (which may involve a network call) is performed outside the
lock so the lock is never held across blocking I/O.
"""
global _scheduler
if _scheduler is None:
with _scheduler_lock:
if _scheduler is None:
if (
run_task_fn is None
or task_types is None
or vram_budgets is None
):
raise ValueError(
"run_task_fn, task_types, and vram_budgets are required "
"on the first call to get_scheduler()"
)
_scheduler = TaskScheduler(
db_path=db_path,
run_task_fn=run_task_fn,
task_types=task_types,
vram_budgets=vram_budgets,
max_queue_depth=max_queue_depth,
)
_scheduler.start()
if _scheduler is not None:
return _scheduler
# Build outside the lock — TaskScheduler.__init__ may call detect_available_vram_gb()
# which makes an httpx network call (up to 2 s). Holding the lock during that
# would block any concurrent caller for the full duration.
if run_task_fn is None or task_types is None or vram_budgets is None:
raise ValueError(
"run_task_fn, task_types, and vram_budgets are required "
"on the first call to get_scheduler()"
)
candidate = TaskScheduler(
db_path=db_path,
run_task_fn=run_task_fn,
task_types=task_types,
vram_budgets=vram_budgets,
max_queue_depth=max_queue_depth,
)
candidate.start()
with _scheduler_lock:
if _scheduler is None:
_scheduler = candidate
else:
# Another thread beat us — shut down our candidate and use the winner.
candidate.shutdown()
return _scheduler

View file

@ -158,7 +158,7 @@ def test_enqueue_returns_false_when_queue_full(tmp_db: Path):
results = [sched.enqueue(i, "fast_task", 0, None) for i in range(1, 10)]
gate.set()
sched.shutdown()
assert False in results
assert not all(results), "Expected at least one enqueue to be rejected"
def test_scheduler_drains_multiple_tasks(tmp_db: Path):
@ -183,6 +183,8 @@ def test_scheduler_drains_multiple_tasks(tmp_db: Path):
def test_vram_budget_blocks_second_type(tmp_db: Path):
"""Second task type is not started when VRAM would be exceeded."""
type_a_started = threading.Event()
type_b_started = threading.Event()
gate_a = threading.Event()
gate_b = threading.Event()
started = []
@ -190,8 +192,10 @@ def test_vram_budget_blocks_second_type(tmp_db: Path):
def run_fn(db_path, task_id, task_type, job_id, params):
started.append(task_type)
if task_type == "type_a":
type_a_started.set()
gate_a.wait()
else:
type_b_started.set()
gate_b.wait()
two_types = frozenset({"type_a", "type_b"})
@ -204,14 +208,14 @@ def test_vram_budget_blocks_second_type(tmp_db: Path):
sched.enqueue(1, "type_a", 0, None)
sched.enqueue(2, "type_b", 0, None)
time.sleep(0.2)
assert started == ["type_a"]
assert type_a_started.wait(timeout=3.0), "type_a never started"
assert not type_b_started.is_set(), "type_b should be blocked by VRAM"
gate_a.set()
time.sleep(0.2)
assert type_b_started.wait(timeout=3.0), "type_b never started after type_a finished"
gate_b.set()
sched.shutdown()
assert "type_b" in started
assert sorted(started) == ["type_a", "type_b"]
def test_get_scheduler_singleton(tmp_db: Path):
@ -264,3 +268,18 @@ def test_load_queued_tasks_missing_table_does_not_crash(tmp_path: Path):
sched.start()
sched.shutdown()
# No exception = pass
def test_reserved_vram_zero_after_task_completes(tmp_db: Path):
"""_reserved_vram returns to 0.0 after a task finishes — no double-decrement."""
done = threading.Event()
def run_fn(db_path, task_id, task_type, job_id, params):
done.set()
sched = TaskScheduler(tmp_db, run_fn, TASK_TYPES, BUDGETS, available_vram_gb=8.0)
sched.start()
sched.enqueue(1, "fast_task", 0, None)
assert done.wait(timeout=3.0), "Task never completed"
sched.shutdown()
assert sched._reserved_vram == 0.0, f"Expected 0.0, got {sched._reserved_vram}"