feat(scheduler): implement scheduler loop and batch worker with VRAM-aware scheduling

This commit is contained in:
pyr0ball 2026-03-15 04:14:56 -07:00
parent 68d257d278
commit a53a03d593
2 changed files with 207 additions and 0 deletions

View file

@ -114,6 +114,78 @@ class TaskScheduler:
self._wake.set() self._wake.set()
def start(self) -> None:
"""Start the background scheduler loop thread. Call once after construction."""
self._thread = threading.Thread(
target=self._scheduler_loop, name="task-scheduler", daemon=True
)
self._thread.start()
def shutdown(self, timeout: float = 5.0) -> None:
"""Signal the scheduler to stop and wait for it to exit."""
self._stop.set()
self._wake.set() # unblock any wait()
if self._thread and self._thread.is_alive():
self._thread.join(timeout=timeout)
def _scheduler_loop(self) -> None:
"""Main scheduler daemon — wakes on enqueue or batch completion."""
while not self._stop.is_set():
self._wake.wait(timeout=30)
self._wake.clear()
with self._lock:
# Defense in depth: reap externally-killed batch threads.
# In normal operation _active.pop() runs in finally before _wake fires,
# so this reap finds nothing — no double-decrement risk.
for t, thread in list(self._active.items()):
if not thread.is_alive():
self._reserved_vram -= self._budgets.get(t, 0.0)
del self._active[t]
# Start new type batches while VRAM allows
candidates = sorted(
[t for t in self._queues if self._queues[t] and t not in self._active],
key=lambda t: len(self._queues[t]),
reverse=True,
)
for task_type in candidates:
budget = self._budgets.get(task_type, 0.0)
# Always allow at least one batch to run even if its budget
# exceeds _available_vram (prevents permanent starvation when
# a single type's budget is larger than the VRAM ceiling).
if self._reserved_vram == 0.0 or self._reserved_vram + budget <= self._available_vram:
thread = threading.Thread(
target=self._batch_worker,
args=(task_type,),
name=f"batch-{task_type}",
daemon=True,
)
self._active[task_type] = thread
self._reserved_vram += budget
thread.start()
def _batch_worker(self, task_type: str) -> None:
"""Serial consumer for one task type. Runs until the type's deque is empty."""
try:
while True:
with self._lock:
q = self._queues.get(task_type)
if not q:
break
task = q.popleft()
# _run_task is scripts.task_runner._run_task (passed at construction)
self._run_task(
self._db_path, task.id, task_type, task.job_id, task.params
)
finally:
# Always release — even if _run_task raises.
# _active.pop here prevents the scheduler loop reap from double-decrementing.
with self._lock:
self._active.pop(task_type, None)
self._reserved_vram -= self._budgets.get(task_type, 0.0)
self._wake.set()
# ── Singleton ───────────────────────────────────────────────────────────────── # ── Singleton ─────────────────────────────────────────────────────────────────

View file

@ -1,6 +1,7 @@
# tests/test_task_scheduler.py # tests/test_task_scheduler.py
"""Tests for scripts/task_scheduler.py and related db helpers.""" """Tests for scripts/task_scheduler.py and related db helpers."""
import sqlite3 import sqlite3
import threading
from collections import deque from collections import deque
from pathlib import Path from pathlib import Path
@ -195,3 +196,137 @@ def test_max_queue_depth_logs_warning(tmp_db, caplog):
s.enqueue(task_id, "cover_letter", 1, None) s.enqueue(task_id, "cover_letter", 1, None)
assert any("depth" in r.message.lower() for r in caplog.records) assert any("depth" in r.message.lower() for r in caplog.records)
# ── Threading helpers ─────────────────────────────────────────────────────────
def _make_recording_run_task(log: list, done_event: threading.Event, expected: int):
"""Returns a mock _run_task that records (task_id, task_type) and sets done when expected count reached."""
def _run(db_path, task_id, task_type, job_id, params):
log.append((task_id, task_type))
if len(log) >= expected:
done_event.set()
return _run
def _start_scheduler(tmp_db, run_task_fn, available_vram=999.0):
s = TaskScheduler(tmp_db, run_task_fn)
s._available_vram = available_vram
s.start()
return s
# ── Tests ─────────────────────────────────────────────────────────────────────
def test_deepest_queue_wins_first_slot(tmp_db):
"""Type with more queued tasks starts first when VRAM only fits one type."""
log, done = [], threading.Event()
# Build scheduler but DO NOT start it yet — enqueue all tasks first
# so the scheduler sees the full picture on its very first wake.
run_task_fn = _make_recording_run_task(log, done, 4)
s = TaskScheduler(tmp_db, run_task_fn)
s._available_vram = 3.0 # fits cover_letter (2.5) but not +company_research (5.0)
# Enqueue cover_letter (3 tasks) and company_research (1 task) before start.
# cover_letter has the deeper queue and must win the first batch slot.
for i in range(3):
s.enqueue(i + 1, "cover_letter", i + 1, None)
s.enqueue(4, "company_research", 4, None)
s.start() # scheduler now sees all tasks atomically on its first iteration
assert done.wait(timeout=5.0), "timed out — not all 4 tasks completed"
s.shutdown()
assert len(log) == 4
cl = [i for i, (_, t) in enumerate(log) if t == "cover_letter"]
cr = [i for i, (_, t) in enumerate(log) if t == "company_research"]
assert len(cl) == 3 and len(cr) == 1
assert max(cl) < min(cr), "All cover_letter tasks must finish before company_research starts"
def test_fifo_within_type(tmp_db):
"""Tasks of the same type execute in arrival (FIFO) order."""
log, done = [], threading.Event()
s = _start_scheduler(tmp_db, _make_recording_run_task(log, done, 3))
for task_id in [10, 20, 30]:
s.enqueue(task_id, "cover_letter", task_id, None)
assert done.wait(timeout=5.0), "timed out — not all 3 tasks completed"
s.shutdown()
assert [task_id for task_id, _ in log] == [10, 20, 30]
def test_concurrent_batches_when_vram_allows(tmp_db):
"""Two type batches start simultaneously when VRAM fits both."""
started = {"cover_letter": threading.Event(), "company_research": threading.Event()}
all_done = threading.Event()
log = []
def run_task(db_path, task_id, task_type, job_id, params):
started[task_type].set()
log.append(task_type)
if len(log) >= 2:
all_done.set()
# VRAM=10.0 fits both cover_letter (2.5) and company_research (5.0) simultaneously
s = _start_scheduler(tmp_db, run_task, available_vram=10.0)
s.enqueue(1, "cover_letter", 1, None)
s.enqueue(2, "company_research", 2, None)
all_done.wait(timeout=5.0)
s.shutdown()
# Both types should have started (possibly overlapping)
assert started["cover_letter"].is_set()
assert started["company_research"].is_set()
def test_new_tasks_picked_up_mid_batch(tmp_db):
"""A task enqueued while a batch is running is consumed in the same batch."""
log, done = [], threading.Event()
task1_started = threading.Event() # fires when task 1 begins executing
task2_ready = threading.Event() # fires when task 2 has been enqueued
def run_task(db_path, task_id, task_type, job_id, params):
if task_id == 1:
task1_started.set() # signal: task 1 is now running
task2_ready.wait(timeout=2.0) # wait for task 2 to be in the deque
log.append(task_id)
if len(log) >= 2:
done.set()
s = _start_scheduler(tmp_db, run_task)
s.enqueue(1, "cover_letter", 1, None)
task1_started.wait(timeout=2.0) # wait until task 1 is actually executing
s.enqueue(2, "cover_letter", 2, None)
task2_ready.set() # unblock task 1 so it finishes
assert done.wait(timeout=5.0), "timed out — task 2 never picked up mid-batch"
s.shutdown()
assert log == [1, 2]
def test_worker_crash_releases_vram(tmp_db):
"""If _run_task raises, _reserved_vram returns to 0 and scheduler continues."""
log, done = [], threading.Event()
def run_task(db_path, task_id, task_type, job_id, params):
if task_id == 1:
raise RuntimeError("simulated failure")
log.append(task_id)
done.set()
s = _start_scheduler(tmp_db, run_task, available_vram=3.0)
s.enqueue(1, "cover_letter", 1, None)
s.enqueue(2, "cover_letter", 2, None)
assert done.wait(timeout=5.0), "timed out — task 2 never completed after task 1 crash"
s.shutdown()
# Second task still ran, VRAM was released
assert 2 in log
assert s._reserved_vram == 0.0