feat(scheduler): implement scheduler loop and batch worker with VRAM-aware scheduling
This commit is contained in:
parent
605e820fa6
commit
84ce68af46
2 changed files with 207 additions and 0 deletions
|
|
@ -114,6 +114,78 @@ class TaskScheduler:
|
||||||
|
|
||||||
self._wake.set()
|
self._wake.set()
|
||||||
|
|
||||||
|
def start(self) -> None:
|
||||||
|
"""Start the background scheduler loop thread. Call once after construction."""
|
||||||
|
self._thread = threading.Thread(
|
||||||
|
target=self._scheduler_loop, name="task-scheduler", daemon=True
|
||||||
|
)
|
||||||
|
self._thread.start()
|
||||||
|
|
||||||
|
def shutdown(self, timeout: float = 5.0) -> None:
|
||||||
|
"""Signal the scheduler to stop and wait for it to exit."""
|
||||||
|
self._stop.set()
|
||||||
|
self._wake.set() # unblock any wait()
|
||||||
|
if self._thread and self._thread.is_alive():
|
||||||
|
self._thread.join(timeout=timeout)
|
||||||
|
|
||||||
|
def _scheduler_loop(self) -> None:
|
||||||
|
"""Main scheduler daemon — wakes on enqueue or batch completion."""
|
||||||
|
while not self._stop.is_set():
|
||||||
|
self._wake.wait(timeout=30)
|
||||||
|
self._wake.clear()
|
||||||
|
|
||||||
|
with self._lock:
|
||||||
|
# Defense in depth: reap externally-killed batch threads.
|
||||||
|
# In normal operation _active.pop() runs in finally before _wake fires,
|
||||||
|
# so this reap finds nothing — no double-decrement risk.
|
||||||
|
for t, thread in list(self._active.items()):
|
||||||
|
if not thread.is_alive():
|
||||||
|
self._reserved_vram -= self._budgets.get(t, 0.0)
|
||||||
|
del self._active[t]
|
||||||
|
|
||||||
|
# Start new type batches while VRAM allows
|
||||||
|
candidates = sorted(
|
||||||
|
[t for t in self._queues if self._queues[t] and t not in self._active],
|
||||||
|
key=lambda t: len(self._queues[t]),
|
||||||
|
reverse=True,
|
||||||
|
)
|
||||||
|
for task_type in candidates:
|
||||||
|
budget = self._budgets.get(task_type, 0.0)
|
||||||
|
# Always allow at least one batch to run even if its budget
|
||||||
|
# exceeds _available_vram (prevents permanent starvation when
|
||||||
|
# a single type's budget is larger than the VRAM ceiling).
|
||||||
|
if self._reserved_vram == 0.0 or self._reserved_vram + budget <= self._available_vram:
|
||||||
|
thread = threading.Thread(
|
||||||
|
target=self._batch_worker,
|
||||||
|
args=(task_type,),
|
||||||
|
name=f"batch-{task_type}",
|
||||||
|
daemon=True,
|
||||||
|
)
|
||||||
|
self._active[task_type] = thread
|
||||||
|
self._reserved_vram += budget
|
||||||
|
thread.start()
|
||||||
|
|
||||||
|
def _batch_worker(self, task_type: str) -> None:
|
||||||
|
"""Serial consumer for one task type. Runs until the type's deque is empty."""
|
||||||
|
try:
|
||||||
|
while True:
|
||||||
|
with self._lock:
|
||||||
|
q = self._queues.get(task_type)
|
||||||
|
if not q:
|
||||||
|
break
|
||||||
|
task = q.popleft()
|
||||||
|
# _run_task is scripts.task_runner._run_task (passed at construction)
|
||||||
|
self._run_task(
|
||||||
|
self._db_path, task.id, task_type, task.job_id, task.params
|
||||||
|
)
|
||||||
|
finally:
|
||||||
|
# Always release — even if _run_task raises.
|
||||||
|
# _active.pop here prevents the scheduler loop reap from double-decrementing.
|
||||||
|
with self._lock:
|
||||||
|
self._active.pop(task_type, None)
|
||||||
|
self._reserved_vram -= self._budgets.get(task_type, 0.0)
|
||||||
|
self._wake.set()
|
||||||
|
|
||||||
|
|
||||||
# ── Singleton ─────────────────────────────────────────────────────────────────
|
# ── Singleton ─────────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,6 +1,7 @@
|
||||||
# tests/test_task_scheduler.py
|
# tests/test_task_scheduler.py
|
||||||
"""Tests for scripts/task_scheduler.py and related db helpers."""
|
"""Tests for scripts/task_scheduler.py and related db helpers."""
|
||||||
import sqlite3
|
import sqlite3
|
||||||
|
import threading
|
||||||
from collections import deque
|
from collections import deque
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
|
|
@ -195,3 +196,137 @@ def test_max_queue_depth_logs_warning(tmp_db, caplog):
|
||||||
s.enqueue(task_id, "cover_letter", 1, None)
|
s.enqueue(task_id, "cover_letter", 1, None)
|
||||||
|
|
||||||
assert any("depth" in r.message.lower() for r in caplog.records)
|
assert any("depth" in r.message.lower() for r in caplog.records)
|
||||||
|
|
||||||
|
|
||||||
|
# ── Threading helpers ─────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
def _make_recording_run_task(log: list, done_event: threading.Event, expected: int):
|
||||||
|
"""Returns a mock _run_task that records (task_id, task_type) and sets done when expected count reached."""
|
||||||
|
def _run(db_path, task_id, task_type, job_id, params):
|
||||||
|
log.append((task_id, task_type))
|
||||||
|
if len(log) >= expected:
|
||||||
|
done_event.set()
|
||||||
|
return _run
|
||||||
|
|
||||||
|
|
||||||
|
def _start_scheduler(tmp_db, run_task_fn, available_vram=999.0):
|
||||||
|
s = TaskScheduler(tmp_db, run_task_fn)
|
||||||
|
s._available_vram = available_vram
|
||||||
|
s.start()
|
||||||
|
return s
|
||||||
|
|
||||||
|
|
||||||
|
# ── Tests ─────────────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
def test_deepest_queue_wins_first_slot(tmp_db):
|
||||||
|
"""Type with more queued tasks starts first when VRAM only fits one type."""
|
||||||
|
log, done = [], threading.Event()
|
||||||
|
|
||||||
|
# Build scheduler but DO NOT start it yet — enqueue all tasks first
|
||||||
|
# so the scheduler sees the full picture on its very first wake.
|
||||||
|
run_task_fn = _make_recording_run_task(log, done, 4)
|
||||||
|
s = TaskScheduler(tmp_db, run_task_fn)
|
||||||
|
s._available_vram = 3.0 # fits cover_letter (2.5) but not +company_research (5.0)
|
||||||
|
|
||||||
|
# Enqueue cover_letter (3 tasks) and company_research (1 task) before start.
|
||||||
|
# cover_letter has the deeper queue and must win the first batch slot.
|
||||||
|
for i in range(3):
|
||||||
|
s.enqueue(i + 1, "cover_letter", i + 1, None)
|
||||||
|
s.enqueue(4, "company_research", 4, None)
|
||||||
|
|
||||||
|
s.start() # scheduler now sees all tasks atomically on its first iteration
|
||||||
|
assert done.wait(timeout=5.0), "timed out — not all 4 tasks completed"
|
||||||
|
s.shutdown()
|
||||||
|
|
||||||
|
assert len(log) == 4
|
||||||
|
cl = [i for i, (_, t) in enumerate(log) if t == "cover_letter"]
|
||||||
|
cr = [i for i, (_, t) in enumerate(log) if t == "company_research"]
|
||||||
|
assert len(cl) == 3 and len(cr) == 1
|
||||||
|
assert max(cl) < min(cr), "All cover_letter tasks must finish before company_research starts"
|
||||||
|
|
||||||
|
|
||||||
|
def test_fifo_within_type(tmp_db):
|
||||||
|
"""Tasks of the same type execute in arrival (FIFO) order."""
|
||||||
|
log, done = [], threading.Event()
|
||||||
|
s = _start_scheduler(tmp_db, _make_recording_run_task(log, done, 3))
|
||||||
|
|
||||||
|
for task_id in [10, 20, 30]:
|
||||||
|
s.enqueue(task_id, "cover_letter", task_id, None)
|
||||||
|
|
||||||
|
assert done.wait(timeout=5.0), "timed out — not all 3 tasks completed"
|
||||||
|
s.shutdown()
|
||||||
|
|
||||||
|
assert [task_id for task_id, _ in log] == [10, 20, 30]
|
||||||
|
|
||||||
|
|
||||||
|
def test_concurrent_batches_when_vram_allows(tmp_db):
|
||||||
|
"""Two type batches start simultaneously when VRAM fits both."""
|
||||||
|
started = {"cover_letter": threading.Event(), "company_research": threading.Event()}
|
||||||
|
all_done = threading.Event()
|
||||||
|
log = []
|
||||||
|
|
||||||
|
def run_task(db_path, task_id, task_type, job_id, params):
|
||||||
|
started[task_type].set()
|
||||||
|
log.append(task_type)
|
||||||
|
if len(log) >= 2:
|
||||||
|
all_done.set()
|
||||||
|
|
||||||
|
# VRAM=10.0 fits both cover_letter (2.5) and company_research (5.0) simultaneously
|
||||||
|
s = _start_scheduler(tmp_db, run_task, available_vram=10.0)
|
||||||
|
s.enqueue(1, "cover_letter", 1, None)
|
||||||
|
s.enqueue(2, "company_research", 2, None)
|
||||||
|
|
||||||
|
all_done.wait(timeout=5.0)
|
||||||
|
s.shutdown()
|
||||||
|
|
||||||
|
# Both types should have started (possibly overlapping)
|
||||||
|
assert started["cover_letter"].is_set()
|
||||||
|
assert started["company_research"].is_set()
|
||||||
|
|
||||||
|
|
||||||
|
def test_new_tasks_picked_up_mid_batch(tmp_db):
|
||||||
|
"""A task enqueued while a batch is running is consumed in the same batch."""
|
||||||
|
log, done = [], threading.Event()
|
||||||
|
task1_started = threading.Event() # fires when task 1 begins executing
|
||||||
|
task2_ready = threading.Event() # fires when task 2 has been enqueued
|
||||||
|
|
||||||
|
def run_task(db_path, task_id, task_type, job_id, params):
|
||||||
|
if task_id == 1:
|
||||||
|
task1_started.set() # signal: task 1 is now running
|
||||||
|
task2_ready.wait(timeout=2.0) # wait for task 2 to be in the deque
|
||||||
|
log.append(task_id)
|
||||||
|
if len(log) >= 2:
|
||||||
|
done.set()
|
||||||
|
|
||||||
|
s = _start_scheduler(tmp_db, run_task)
|
||||||
|
s.enqueue(1, "cover_letter", 1, None)
|
||||||
|
task1_started.wait(timeout=2.0) # wait until task 1 is actually executing
|
||||||
|
s.enqueue(2, "cover_letter", 2, None)
|
||||||
|
task2_ready.set() # unblock task 1 so it finishes
|
||||||
|
|
||||||
|
assert done.wait(timeout=5.0), "timed out — task 2 never picked up mid-batch"
|
||||||
|
s.shutdown()
|
||||||
|
|
||||||
|
assert log == [1, 2]
|
||||||
|
|
||||||
|
|
||||||
|
def test_worker_crash_releases_vram(tmp_db):
|
||||||
|
"""If _run_task raises, _reserved_vram returns to 0 and scheduler continues."""
|
||||||
|
log, done = [], threading.Event()
|
||||||
|
|
||||||
|
def run_task(db_path, task_id, task_type, job_id, params):
|
||||||
|
if task_id == 1:
|
||||||
|
raise RuntimeError("simulated failure")
|
||||||
|
log.append(task_id)
|
||||||
|
done.set()
|
||||||
|
|
||||||
|
s = _start_scheduler(tmp_db, run_task, available_vram=3.0)
|
||||||
|
s.enqueue(1, "cover_letter", 1, None)
|
||||||
|
s.enqueue(2, "cover_letter", 2, None)
|
||||||
|
|
||||||
|
assert done.wait(timeout=5.0), "timed out — task 2 never completed after task 1 crash"
|
||||||
|
s.shutdown()
|
||||||
|
|
||||||
|
# Second task still ran, VRAM was released
|
||||||
|
assert 2 in log
|
||||||
|
assert s._reserved_vram == 0.0
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue