feat(scheduler): implement scheduler loop and batch worker with VRAM-aware scheduling
This commit is contained in:
parent
4d055f6bcd
commit
3984a9c743
2 changed files with 207 additions and 0 deletions
|
|
@ -114,6 +114,78 @@ class TaskScheduler:
|
|||
|
||||
self._wake.set()
|
||||
|
||||
def start(self) -> None:
|
||||
"""Start the background scheduler loop thread. Call once after construction."""
|
||||
self._thread = threading.Thread(
|
||||
target=self._scheduler_loop, name="task-scheduler", daemon=True
|
||||
)
|
||||
self._thread.start()
|
||||
|
||||
def shutdown(self, timeout: float = 5.0) -> None:
|
||||
"""Signal the scheduler to stop and wait for it to exit."""
|
||||
self._stop.set()
|
||||
self._wake.set() # unblock any wait()
|
||||
if self._thread and self._thread.is_alive():
|
||||
self._thread.join(timeout=timeout)
|
||||
|
||||
def _scheduler_loop(self) -> None:
|
||||
"""Main scheduler daemon — wakes on enqueue or batch completion."""
|
||||
while not self._stop.is_set():
|
||||
self._wake.wait(timeout=30)
|
||||
self._wake.clear()
|
||||
|
||||
with self._lock:
|
||||
# Defense in depth: reap externally-killed batch threads.
|
||||
# In normal operation _active.pop() runs in finally before _wake fires,
|
||||
# so this reap finds nothing — no double-decrement risk.
|
||||
for t, thread in list(self._active.items()):
|
||||
if not thread.is_alive():
|
||||
self._reserved_vram -= self._budgets.get(t, 0.0)
|
||||
del self._active[t]
|
||||
|
||||
# Start new type batches while VRAM allows
|
||||
candidates = sorted(
|
||||
[t for t in self._queues if self._queues[t] and t not in self._active],
|
||||
key=lambda t: len(self._queues[t]),
|
||||
reverse=True,
|
||||
)
|
||||
for task_type in candidates:
|
||||
budget = self._budgets.get(task_type, 0.0)
|
||||
# Always allow at least one batch to run even if its budget
|
||||
# exceeds _available_vram (prevents permanent starvation when
|
||||
# a single type's budget is larger than the VRAM ceiling).
|
||||
if self._reserved_vram == 0.0 or self._reserved_vram + budget <= self._available_vram:
|
||||
thread = threading.Thread(
|
||||
target=self._batch_worker,
|
||||
args=(task_type,),
|
||||
name=f"batch-{task_type}",
|
||||
daemon=True,
|
||||
)
|
||||
self._active[task_type] = thread
|
||||
self._reserved_vram += budget
|
||||
thread.start()
|
||||
|
||||
def _batch_worker(self, task_type: str) -> None:
|
||||
"""Serial consumer for one task type. Runs until the type's deque is empty."""
|
||||
try:
|
||||
while True:
|
||||
with self._lock:
|
||||
q = self._queues.get(task_type)
|
||||
if not q:
|
||||
break
|
||||
task = q.popleft()
|
||||
# _run_task is scripts.task_runner._run_task (passed at construction)
|
||||
self._run_task(
|
||||
self._db_path, task.id, task_type, task.job_id, task.params
|
||||
)
|
||||
finally:
|
||||
# Always release — even if _run_task raises.
|
||||
# _active.pop here prevents the scheduler loop reap from double-decrementing.
|
||||
with self._lock:
|
||||
self._active.pop(task_type, None)
|
||||
self._reserved_vram -= self._budgets.get(task_type, 0.0)
|
||||
self._wake.set()
|
||||
|
||||
|
||||
# ── Singleton ─────────────────────────────────────────────────────────────────
|
||||
|
||||
|
|
|
|||
|
|
@ -1,6 +1,7 @@
|
|||
# tests/test_task_scheduler.py
|
||||
"""Tests for scripts/task_scheduler.py and related db helpers."""
|
||||
import sqlite3
|
||||
import threading
|
||||
from collections import deque
|
||||
from pathlib import Path
|
||||
|
||||
|
|
@ -195,3 +196,137 @@ def test_max_queue_depth_logs_warning(tmp_db, caplog):
|
|||
s.enqueue(task_id, "cover_letter", 1, None)
|
||||
|
||||
assert any("depth" in r.message.lower() for r in caplog.records)
|
||||
|
||||
|
||||
# ── Threading helpers ─────────────────────────────────────────────────────────
|
||||
|
||||
def _make_recording_run_task(log: list, done_event: threading.Event, expected: int):
|
||||
"""Returns a mock _run_task that records (task_id, task_type) and sets done when expected count reached."""
|
||||
def _run(db_path, task_id, task_type, job_id, params):
|
||||
log.append((task_id, task_type))
|
||||
if len(log) >= expected:
|
||||
done_event.set()
|
||||
return _run
|
||||
|
||||
|
||||
def _start_scheduler(tmp_db, run_task_fn, available_vram=999.0):
|
||||
s = TaskScheduler(tmp_db, run_task_fn)
|
||||
s._available_vram = available_vram
|
||||
s.start()
|
||||
return s
|
||||
|
||||
|
||||
# ── Tests ─────────────────────────────────────────────────────────────────────
|
||||
|
||||
def test_deepest_queue_wins_first_slot(tmp_db):
|
||||
"""Type with more queued tasks starts first when VRAM only fits one type."""
|
||||
log, done = [], threading.Event()
|
||||
|
||||
# Build scheduler but DO NOT start it yet — enqueue all tasks first
|
||||
# so the scheduler sees the full picture on its very first wake.
|
||||
run_task_fn = _make_recording_run_task(log, done, 4)
|
||||
s = TaskScheduler(tmp_db, run_task_fn)
|
||||
s._available_vram = 3.0 # fits cover_letter (2.5) but not +company_research (5.0)
|
||||
|
||||
# Enqueue cover_letter (3 tasks) and company_research (1 task) before start.
|
||||
# cover_letter has the deeper queue and must win the first batch slot.
|
||||
for i in range(3):
|
||||
s.enqueue(i + 1, "cover_letter", i + 1, None)
|
||||
s.enqueue(4, "company_research", 4, None)
|
||||
|
||||
s.start() # scheduler now sees all tasks atomically on its first iteration
|
||||
assert done.wait(timeout=5.0), "timed out — not all 4 tasks completed"
|
||||
s.shutdown()
|
||||
|
||||
assert len(log) == 4
|
||||
cl = [i for i, (_, t) in enumerate(log) if t == "cover_letter"]
|
||||
cr = [i for i, (_, t) in enumerate(log) if t == "company_research"]
|
||||
assert len(cl) == 3 and len(cr) == 1
|
||||
assert max(cl) < min(cr), "All cover_letter tasks must finish before company_research starts"
|
||||
|
||||
|
||||
def test_fifo_within_type(tmp_db):
|
||||
"""Tasks of the same type execute in arrival (FIFO) order."""
|
||||
log, done = [], threading.Event()
|
||||
s = _start_scheduler(tmp_db, _make_recording_run_task(log, done, 3))
|
||||
|
||||
for task_id in [10, 20, 30]:
|
||||
s.enqueue(task_id, "cover_letter", task_id, None)
|
||||
|
||||
assert done.wait(timeout=5.0), "timed out — not all 3 tasks completed"
|
||||
s.shutdown()
|
||||
|
||||
assert [task_id for task_id, _ in log] == [10, 20, 30]
|
||||
|
||||
|
||||
def test_concurrent_batches_when_vram_allows(tmp_db):
|
||||
"""Two type batches start simultaneously when VRAM fits both."""
|
||||
started = {"cover_letter": threading.Event(), "company_research": threading.Event()}
|
||||
all_done = threading.Event()
|
||||
log = []
|
||||
|
||||
def run_task(db_path, task_id, task_type, job_id, params):
|
||||
started[task_type].set()
|
||||
log.append(task_type)
|
||||
if len(log) >= 2:
|
||||
all_done.set()
|
||||
|
||||
# VRAM=10.0 fits both cover_letter (2.5) and company_research (5.0) simultaneously
|
||||
s = _start_scheduler(tmp_db, run_task, available_vram=10.0)
|
||||
s.enqueue(1, "cover_letter", 1, None)
|
||||
s.enqueue(2, "company_research", 2, None)
|
||||
|
||||
all_done.wait(timeout=5.0)
|
||||
s.shutdown()
|
||||
|
||||
# Both types should have started (possibly overlapping)
|
||||
assert started["cover_letter"].is_set()
|
||||
assert started["company_research"].is_set()
|
||||
|
||||
|
||||
def test_new_tasks_picked_up_mid_batch(tmp_db):
|
||||
"""A task enqueued while a batch is running is consumed in the same batch."""
|
||||
log, done = [], threading.Event()
|
||||
task1_started = threading.Event() # fires when task 1 begins executing
|
||||
task2_ready = threading.Event() # fires when task 2 has been enqueued
|
||||
|
||||
def run_task(db_path, task_id, task_type, job_id, params):
|
||||
if task_id == 1:
|
||||
task1_started.set() # signal: task 1 is now running
|
||||
task2_ready.wait(timeout=2.0) # wait for task 2 to be in the deque
|
||||
log.append(task_id)
|
||||
if len(log) >= 2:
|
||||
done.set()
|
||||
|
||||
s = _start_scheduler(tmp_db, run_task)
|
||||
s.enqueue(1, "cover_letter", 1, None)
|
||||
task1_started.wait(timeout=2.0) # wait until task 1 is actually executing
|
||||
s.enqueue(2, "cover_letter", 2, None)
|
||||
task2_ready.set() # unblock task 1 so it finishes
|
||||
|
||||
assert done.wait(timeout=5.0), "timed out — task 2 never picked up mid-batch"
|
||||
s.shutdown()
|
||||
|
||||
assert log == [1, 2]
|
||||
|
||||
|
||||
def test_worker_crash_releases_vram(tmp_db):
|
||||
"""If _run_task raises, _reserved_vram returns to 0 and scheduler continues."""
|
||||
log, done = [], threading.Event()
|
||||
|
||||
def run_task(db_path, task_id, task_type, job_id, params):
|
||||
if task_id == 1:
|
||||
raise RuntimeError("simulated failure")
|
||||
log.append(task_id)
|
||||
done.set()
|
||||
|
||||
s = _start_scheduler(tmp_db, run_task, available_vram=3.0)
|
||||
s.enqueue(1, "cover_letter", 1, None)
|
||||
s.enqueue(2, "cover_letter", 2, None)
|
||||
|
||||
assert done.wait(timeout=5.0), "timed out — task 2 never completed after task 1 crash"
|
||||
s.shutdown()
|
||||
|
||||
# Second task still ran, VRAM was released
|
||||
assert 2 in log
|
||||
assert s._reserved_vram == 0.0
|
||||
|
|
|
|||
Loading…
Reference in a new issue