diff --git a/scripts/task_scheduler.py b/scripts/task_scheduler.py index c9ee0b1..574d020 100644 --- a/scripts/task_scheduler.py +++ b/scripts/task_scheduler.py @@ -114,6 +114,78 @@ class TaskScheduler: self._wake.set() + def start(self) -> None: + """Start the background scheduler loop thread. Call once after construction.""" + self._thread = threading.Thread( + target=self._scheduler_loop, name="task-scheduler", daemon=True + ) + self._thread.start() + + def shutdown(self, timeout: float = 5.0) -> None: + """Signal the scheduler to stop and wait for it to exit.""" + self._stop.set() + self._wake.set() # unblock any wait() + if self._thread and self._thread.is_alive(): + self._thread.join(timeout=timeout) + + def _scheduler_loop(self) -> None: + """Main scheduler daemon — wakes on enqueue or batch completion.""" + while not self._stop.is_set(): + self._wake.wait(timeout=30) + self._wake.clear() + + with self._lock: + # Defense in depth: reap externally-killed batch threads. + # In normal operation _active.pop() runs in finally before _wake fires, + # so this reap finds nothing — no double-decrement risk. + for t, thread in list(self._active.items()): + if not thread.is_alive(): + self._reserved_vram -= self._budgets.get(t, 0.0) + del self._active[t] + + # Start new type batches while VRAM allows + candidates = sorted( + [t for t in self._queues if self._queues[t] and t not in self._active], + key=lambda t: len(self._queues[t]), + reverse=True, + ) + for task_type in candidates: + budget = self._budgets.get(task_type, 0.0) + # Always allow at least one batch to run even if its budget + # exceeds _available_vram (prevents permanent starvation when + # a single type's budget is larger than the VRAM ceiling). + if self._reserved_vram == 0.0 or self._reserved_vram + budget <= self._available_vram: + thread = threading.Thread( + target=self._batch_worker, + args=(task_type,), + name=f"batch-{task_type}", + daemon=True, + ) + self._active[task_type] = thread + self._reserved_vram += budget + thread.start() + + def _batch_worker(self, task_type: str) -> None: + """Serial consumer for one task type. Runs until the type's deque is empty.""" + try: + while True: + with self._lock: + q = self._queues.get(task_type) + if not q: + break + task = q.popleft() + # _run_task is scripts.task_runner._run_task (passed at construction) + self._run_task( + self._db_path, task.id, task_type, task.job_id, task.params + ) + finally: + # Always release — even if _run_task raises. + # _active.pop here prevents the scheduler loop reap from double-decrementing. + with self._lock: + self._active.pop(task_type, None) + self._reserved_vram -= self._budgets.get(task_type, 0.0) + self._wake.set() + # ── Singleton ───────────────────────────────────────────────────────────────── diff --git a/tests/test_task_scheduler.py b/tests/test_task_scheduler.py index 68f977d..f174c08 100644 --- a/tests/test_task_scheduler.py +++ b/tests/test_task_scheduler.py @@ -1,6 +1,7 @@ # tests/test_task_scheduler.py """Tests for scripts/task_scheduler.py and related db helpers.""" import sqlite3 +import threading from collections import deque from pathlib import Path @@ -195,3 +196,137 @@ def test_max_queue_depth_logs_warning(tmp_db, caplog): s.enqueue(task_id, "cover_letter", 1, None) assert any("depth" in r.message.lower() for r in caplog.records) + + +# ── Threading helpers ───────────────────────────────────────────────────────── + +def _make_recording_run_task(log: list, done_event: threading.Event, expected: int): + """Returns a mock _run_task that records (task_id, task_type) and sets done when expected count reached.""" + def _run(db_path, task_id, task_type, job_id, params): + log.append((task_id, task_type)) + if len(log) >= expected: + done_event.set() + return _run + + +def _start_scheduler(tmp_db, run_task_fn, available_vram=999.0): + s = TaskScheduler(tmp_db, run_task_fn) + s._available_vram = available_vram + s.start() + return s + + +# ── Tests ───────────────────────────────────────────────────────────────────── + +def test_deepest_queue_wins_first_slot(tmp_db): + """Type with more queued tasks starts first when VRAM only fits one type.""" + log, done = [], threading.Event() + + # Build scheduler but DO NOT start it yet — enqueue all tasks first + # so the scheduler sees the full picture on its very first wake. + run_task_fn = _make_recording_run_task(log, done, 4) + s = TaskScheduler(tmp_db, run_task_fn) + s._available_vram = 3.0 # fits cover_letter (2.5) but not +company_research (5.0) + + # Enqueue cover_letter (3 tasks) and company_research (1 task) before start. + # cover_letter has the deeper queue and must win the first batch slot. + for i in range(3): + s.enqueue(i + 1, "cover_letter", i + 1, None) + s.enqueue(4, "company_research", 4, None) + + s.start() # scheduler now sees all tasks atomically on its first iteration + assert done.wait(timeout=5.0), "timed out — not all 4 tasks completed" + s.shutdown() + + assert len(log) == 4 + cl = [i for i, (_, t) in enumerate(log) if t == "cover_letter"] + cr = [i for i, (_, t) in enumerate(log) if t == "company_research"] + assert len(cl) == 3 and len(cr) == 1 + assert max(cl) < min(cr), "All cover_letter tasks must finish before company_research starts" + + +def test_fifo_within_type(tmp_db): + """Tasks of the same type execute in arrival (FIFO) order.""" + log, done = [], threading.Event() + s = _start_scheduler(tmp_db, _make_recording_run_task(log, done, 3)) + + for task_id in [10, 20, 30]: + s.enqueue(task_id, "cover_letter", task_id, None) + + assert done.wait(timeout=5.0), "timed out — not all 3 tasks completed" + s.shutdown() + + assert [task_id for task_id, _ in log] == [10, 20, 30] + + +def test_concurrent_batches_when_vram_allows(tmp_db): + """Two type batches start simultaneously when VRAM fits both.""" + started = {"cover_letter": threading.Event(), "company_research": threading.Event()} + all_done = threading.Event() + log = [] + + def run_task(db_path, task_id, task_type, job_id, params): + started[task_type].set() + log.append(task_type) + if len(log) >= 2: + all_done.set() + + # VRAM=10.0 fits both cover_letter (2.5) and company_research (5.0) simultaneously + s = _start_scheduler(tmp_db, run_task, available_vram=10.0) + s.enqueue(1, "cover_letter", 1, None) + s.enqueue(2, "company_research", 2, None) + + all_done.wait(timeout=5.0) + s.shutdown() + + # Both types should have started (possibly overlapping) + assert started["cover_letter"].is_set() + assert started["company_research"].is_set() + + +def test_new_tasks_picked_up_mid_batch(tmp_db): + """A task enqueued while a batch is running is consumed in the same batch.""" + log, done = [], threading.Event() + task1_started = threading.Event() # fires when task 1 begins executing + task2_ready = threading.Event() # fires when task 2 has been enqueued + + def run_task(db_path, task_id, task_type, job_id, params): + if task_id == 1: + task1_started.set() # signal: task 1 is now running + task2_ready.wait(timeout=2.0) # wait for task 2 to be in the deque + log.append(task_id) + if len(log) >= 2: + done.set() + + s = _start_scheduler(tmp_db, run_task) + s.enqueue(1, "cover_letter", 1, None) + task1_started.wait(timeout=2.0) # wait until task 1 is actually executing + s.enqueue(2, "cover_letter", 2, None) + task2_ready.set() # unblock task 1 so it finishes + + assert done.wait(timeout=5.0), "timed out — task 2 never picked up mid-batch" + s.shutdown() + + assert log == [1, 2] + + +def test_worker_crash_releases_vram(tmp_db): + """If _run_task raises, _reserved_vram returns to 0 and scheduler continues.""" + log, done = [], threading.Event() + + def run_task(db_path, task_id, task_type, job_id, params): + if task_id == 1: + raise RuntimeError("simulated failure") + log.append(task_id) + done.set() + + s = _start_scheduler(tmp_db, run_task, available_vram=3.0) + s.enqueue(1, "cover_letter", 1, None) + s.enqueue(2, "cover_letter", 2, None) + + assert done.wait(timeout=5.0), "timed out — task 2 never completed after task 1 crash" + s.shutdown() + + # Second task still ran, VRAM was released + assert 2 in log + assert s._reserved_vram == 0.0