From 07c627cdb07dc5bef10544464908c5e4ea4557a1 Mon Sep 17 00:00:00 2001 From: pyr0ball Date: Sun, 15 Mar 2026 04:52:42 -0700 Subject: [PATCH] feat(task_runner): route LLM tasks through scheduler in submit_task() Replaces the spawn-per-task model for LLM task types with scheduler routing: cover_letter, company_research, and wizard_generate are now enqueued via the TaskScheduler singleton for VRAM-aware batching. Non-LLM tasks (discovery, email_sync, etc.) continue to spawn daemon threads directly. Adds autouse clean_scheduler fixture to test_task_runner.py to prevent singleton cross-test contamination. --- scripts/task_runner.py | 26 +++++++++++++------- tests/test_task_runner.py | 18 ++++++++++++-- tests/test_task_scheduler.py | 46 ++++++++++++++++++++++++++++++++++++ 3 files changed, 80 insertions(+), 10 deletions(-) diff --git a/scripts/task_runner.py b/scripts/task_runner.py index 9d02bbe..83cdc7c 100644 --- a/scripts/task_runner.py +++ b/scripts/task_runner.py @@ -26,19 +26,29 @@ from scripts.db import ( def submit_task(db_path: Path = DEFAULT_DB, task_type: str = "", job_id: int = None, params: str | None = None) -> tuple[int, bool]: - """Submit a background LLM task. + """Submit a background task. - Returns (task_id, True) if a new task was queued and a thread spawned. + LLM task types (cover_letter, company_research, wizard_generate) are routed + through the TaskScheduler for VRAM-aware batch scheduling. + All other types spawn a free daemon thread as before. + + Returns (task_id, True) if a new task was queued. Returns (existing_id, False) if an identical task is already in-flight. """ task_id, is_new = insert_task(db_path, task_type, job_id or 0, params=params) if is_new: - t = threading.Thread( - target=_run_task, - args=(db_path, task_id, task_type, job_id or 0, params), - daemon=True, - ) - t.start() + from scripts.task_scheduler import get_scheduler, LLM_TASK_TYPES + if task_type in LLM_TASK_TYPES: + get_scheduler(db_path, run_task_fn=_run_task).enqueue( + task_id, task_type, job_id or 0, params + ) + else: + t = threading.Thread( + target=_run_task, + args=(db_path, task_id, task_type, job_id or 0, params), + daemon=True, + ) + t.start() return task_id, is_new diff --git a/tests/test_task_runner.py b/tests/test_task_runner.py index 8d28226..6167a42 100644 --- a/tests/test_task_runner.py +++ b/tests/test_task_runner.py @@ -6,6 +6,14 @@ from unittest.mock import patch import sqlite3 +@pytest.fixture(autouse=True) +def clean_scheduler(): + """Reset the TaskScheduler singleton between tests to prevent cross-test contamination.""" + yield + from scripts.task_scheduler import reset_scheduler + reset_scheduler() + + def _make_db(tmp_path): from scripts.db import init_db, insert_job db = tmp_path / "test.db" @@ -143,14 +151,20 @@ def test_run_task_email_sync_file_not_found(tmp_path): def test_submit_task_actually_completes(tmp_path): - """Integration: submit_task spawns a thread that completes asynchronously.""" + """Integration: submit_task routes LLM tasks through the scheduler and they complete.""" db, job_id = _make_db(tmp_path) from scripts.db import get_task_for_job + from scripts.task_scheduler import get_scheduler + from scripts.task_runner import _run_task + + # Prime the singleton with the correct db_path before submit_task runs. + # get_scheduler() already calls start() internally. + get_scheduler(db, run_task_fn=_run_task) with patch("scripts.generate_cover_letter.generate", return_value="Cover letter text"): from scripts.task_runner import submit_task task_id, _ = submit_task(db, "cover_letter", job_id) - # Wait for thread to complete (max 5s) + # Wait for scheduler to complete the task (max 5s) for _ in range(50): task = get_task_for_job(db, "cover_letter", job_id) if task and task["status"] in ("completed", "failed"): diff --git a/tests/test_task_scheduler.py b/tests/test_task_scheduler.py index 2992463..7746ca4 100644 --- a/tests/test_task_scheduler.py +++ b/tests/test_task_scheduler.py @@ -424,3 +424,49 @@ def test_durability_preserves_fifo_order(tmp_db): loaded_ids = [t.id for t in s._queues["cover_letter"]] assert loaded_ids == ids + + +def test_non_llm_tasks_bypass_scheduler(tmp_db): + """submit_task() for non-LLM types invoke _run_task directly, not enqueue().""" + from scripts import task_runner + + # Initialize the singleton properly so submit_task routes correctly + s = get_scheduler(tmp_db, _noop_run_task) + + run_task_calls = [] + enqueue_calls = [] + + original_run_task = task_runner._run_task + original_enqueue = s.enqueue + + def recording_run_task(*args, **kwargs): + run_task_calls.append(args[2]) # task_type is 3rd arg + + def recording_enqueue(task_id, task_type, job_id, params): + enqueue_calls.append(task_type) + + import unittest.mock as mock + with mock.patch.object(task_runner, "_run_task", recording_run_task), \ + mock.patch.object(s, "enqueue", recording_enqueue): + task_runner.submit_task(tmp_db, "discovery", 0) + + # discovery goes directly to _run_task; enqueue is never called + assert "discovery" not in enqueue_calls + # The scheduler deque is untouched + assert "discovery" not in s._queues or len(s._queues["discovery"]) == 0 + + +def test_llm_tasks_routed_to_scheduler(tmp_db): + """submit_task() for LLM types calls enqueue(), not _run_task directly.""" + from scripts import task_runner + + s = get_scheduler(tmp_db, _noop_run_task) + + enqueue_calls = [] + original_enqueue = s.enqueue + + import unittest.mock as mock + with mock.patch.object(s, "enqueue", side_effect=lambda *a, **kw: enqueue_calls.append(a[1]) or original_enqueue(*a, **kw)): + task_runner.submit_task(tmp_db, "cover_letter", 1) + + assert "cover_letter" in enqueue_calls