feat(task_runner): route LLM tasks through scheduler in submit_task()
Replaces the spawn-per-task model for LLM task types with scheduler routing: cover_letter, company_research, and wizard_generate are now enqueued via the TaskScheduler singleton for VRAM-aware batching. Non-LLM tasks (discovery, email_sync, etc.) continue to spawn daemon threads directly. Adds autouse clean_scheduler fixture to test_task_runner.py to prevent singleton cross-test contamination.
This commit is contained in:
parent
3e3c6f1fc5
commit
690a1ccf93
3 changed files with 80 additions and 10 deletions
|
|
@ -26,13 +26,23 @@ from scripts.db import (
|
|||
def submit_task(db_path: Path = DEFAULT_DB, task_type: str = "",
|
||||
job_id: int = None,
|
||||
params: str | None = None) -> tuple[int, bool]:
|
||||
"""Submit a background LLM task.
|
||||
"""Submit a background task.
|
||||
|
||||
Returns (task_id, True) if a new task was queued and a thread spawned.
|
||||
LLM task types (cover_letter, company_research, wizard_generate) are routed
|
||||
through the TaskScheduler for VRAM-aware batch scheduling.
|
||||
All other types spawn a free daemon thread as before.
|
||||
|
||||
Returns (task_id, True) if a new task was queued.
|
||||
Returns (existing_id, False) if an identical task is already in-flight.
|
||||
"""
|
||||
task_id, is_new = insert_task(db_path, task_type, job_id or 0, params=params)
|
||||
if is_new:
|
||||
from scripts.task_scheduler import get_scheduler, LLM_TASK_TYPES
|
||||
if task_type in LLM_TASK_TYPES:
|
||||
get_scheduler(db_path, run_task_fn=_run_task).enqueue(
|
||||
task_id, task_type, job_id or 0, params
|
||||
)
|
||||
else:
|
||||
t = threading.Thread(
|
||||
target=_run_task,
|
||||
args=(db_path, task_id, task_type, job_id or 0, params),
|
||||
|
|
|
|||
|
|
@ -6,6 +6,14 @@ from unittest.mock import patch
|
|||
import sqlite3
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def clean_scheduler():
|
||||
"""Reset the TaskScheduler singleton between tests to prevent cross-test contamination."""
|
||||
yield
|
||||
from scripts.task_scheduler import reset_scheduler
|
||||
reset_scheduler()
|
||||
|
||||
|
||||
def _make_db(tmp_path):
|
||||
from scripts.db import init_db, insert_job
|
||||
db = tmp_path / "test.db"
|
||||
|
|
@ -143,14 +151,20 @@ def test_run_task_email_sync_file_not_found(tmp_path):
|
|||
|
||||
|
||||
def test_submit_task_actually_completes(tmp_path):
|
||||
"""Integration: submit_task spawns a thread that completes asynchronously."""
|
||||
"""Integration: submit_task routes LLM tasks through the scheduler and they complete."""
|
||||
db, job_id = _make_db(tmp_path)
|
||||
from scripts.db import get_task_for_job
|
||||
from scripts.task_scheduler import get_scheduler
|
||||
from scripts.task_runner import _run_task
|
||||
|
||||
# Prime the singleton with the correct db_path before submit_task runs.
|
||||
# get_scheduler() already calls start() internally.
|
||||
get_scheduler(db, run_task_fn=_run_task)
|
||||
|
||||
with patch("scripts.generate_cover_letter.generate", return_value="Cover letter text"):
|
||||
from scripts.task_runner import submit_task
|
||||
task_id, _ = submit_task(db, "cover_letter", job_id)
|
||||
# Wait for thread to complete (max 5s)
|
||||
# Wait for scheduler to complete the task (max 5s)
|
||||
for _ in range(50):
|
||||
task = get_task_for_job(db, "cover_letter", job_id)
|
||||
if task and task["status"] in ("completed", "failed"):
|
||||
|
|
|
|||
|
|
@ -424,3 +424,49 @@ def test_durability_preserves_fifo_order(tmp_db):
|
|||
|
||||
loaded_ids = [t.id for t in s._queues["cover_letter"]]
|
||||
assert loaded_ids == ids
|
||||
|
||||
|
||||
def test_non_llm_tasks_bypass_scheduler(tmp_db):
|
||||
"""submit_task() for non-LLM types invoke _run_task directly, not enqueue()."""
|
||||
from scripts import task_runner
|
||||
|
||||
# Initialize the singleton properly so submit_task routes correctly
|
||||
s = get_scheduler(tmp_db, _noop_run_task)
|
||||
|
||||
run_task_calls = []
|
||||
enqueue_calls = []
|
||||
|
||||
original_run_task = task_runner._run_task
|
||||
original_enqueue = s.enqueue
|
||||
|
||||
def recording_run_task(*args, **kwargs):
|
||||
run_task_calls.append(args[2]) # task_type is 3rd arg
|
||||
|
||||
def recording_enqueue(task_id, task_type, job_id, params):
|
||||
enqueue_calls.append(task_type)
|
||||
|
||||
import unittest.mock as mock
|
||||
with mock.patch.object(task_runner, "_run_task", recording_run_task), \
|
||||
mock.patch.object(s, "enqueue", recording_enqueue):
|
||||
task_runner.submit_task(tmp_db, "discovery", 0)
|
||||
|
||||
# discovery goes directly to _run_task; enqueue is never called
|
||||
assert "discovery" not in enqueue_calls
|
||||
# The scheduler deque is untouched
|
||||
assert "discovery" not in s._queues or len(s._queues["discovery"]) == 0
|
||||
|
||||
|
||||
def test_llm_tasks_routed_to_scheduler(tmp_db):
|
||||
"""submit_task() for LLM types calls enqueue(), not _run_task directly."""
|
||||
from scripts import task_runner
|
||||
|
||||
s = get_scheduler(tmp_db, _noop_run_task)
|
||||
|
||||
enqueue_calls = []
|
||||
original_enqueue = s.enqueue
|
||||
|
||||
import unittest.mock as mock
|
||||
with mock.patch.object(s, "enqueue", side_effect=lambda *a, **kw: enqueue_calls.append(a[1]) or original_enqueue(*a, **kw)):
|
||||
task_runner.submit_task(tmp_db, "cover_letter", 1)
|
||||
|
||||
assert "cover_letter" in enqueue_calls
|
||||
|
|
|
|||
Loading…
Reference in a new issue