From 4d055f6bcde4bd42fdeee225e961ae3f3a13060d Mon Sep 17 00:00:00 2001 From: pyr0ball Date: Sun, 15 Mar 2026 04:05:22 -0700 Subject: [PATCH] feat(scheduler): implement enqueue() with depth guard and ghost-row cleanup --- scripts/task_scheduler.py | 23 +++++++++++++ tests/test_task_scheduler.py | 67 ++++++++++++++++++++++++++++++++++++ 2 files changed, 90 insertions(+) diff --git a/scripts/task_scheduler.py b/scripts/task_scheduler.py index b8871db..c9ee0b1 100644 --- a/scripts/task_scheduler.py +++ b/scripts/task_scheduler.py @@ -91,6 +91,29 @@ class TaskScheduler: except Exception: self._available_vram = 999.0 + def enqueue(self, task_id: int, task_type: str, job_id: int, + params: Optional[str]) -> None: + """Add an LLM task to the scheduler queue. + + If the queue for this type is at max_queue_depth, the task is marked + failed in SQLite immediately (no ghost queued rows) and a warning is logged. + """ + from scripts.db import update_task_status + + with self._lock: + q = self._queues.setdefault(task_type, deque()) + if len(q) >= self._max_queue_depth: + logger.warning( + "Queue depth limit reached for %s (max=%d) — task %d dropped", + task_type, self._max_queue_depth, task_id, + ) + update_task_status(self._db_path, task_id, "failed", + error="Queue depth limit reached") + return + q.append(TaskSpec(task_id, job_id, params)) + + self._wake.set() + # ── Singleton ───────────────────────────────────────────────────────────────── diff --git a/tests/test_task_scheduler.py b/tests/test_task_scheduler.py index de0dc6e..68f977d 100644 --- a/tests/test_task_scheduler.py +++ b/tests/test_task_scheduler.py @@ -1,6 +1,7 @@ # tests/test_task_scheduler.py """Tests for scripts/task_scheduler.py and related db helpers.""" import sqlite3 +from collections import deque from pathlib import Path import pytest @@ -128,3 +129,69 @@ def test_gpu_vram_summed_across_all_gpus(tmp_db, monkeypatch): monkeypatch.setattr("scripts.task_scheduler._get_gpus", lambda: fake_gpus) s = TaskScheduler(tmp_db, _noop_run_task) assert s._available_vram == 48.0 + + +def test_enqueue_adds_taskspec_to_deque(tmp_db): + """enqueue() appends a TaskSpec to the correct per-type deque.""" + s = TaskScheduler(tmp_db, _noop_run_task) + s.enqueue(1, "cover_letter", 10, None) + s.enqueue(2, "cover_letter", 11, '{"key": "val"}') + + assert len(s._queues["cover_letter"]) == 2 + assert s._queues["cover_letter"][0].id == 1 + assert s._queues["cover_letter"][1].id == 2 + + +def test_enqueue_wakes_scheduler(tmp_db): + """enqueue() sets the _wake event so the scheduler loop re-evaluates.""" + s = TaskScheduler(tmp_db, _noop_run_task) + assert not s._wake.is_set() + s.enqueue(1, "cover_letter", 10, None) + assert s._wake.is_set() + + +def test_max_queue_depth_marks_task_failed(tmp_db): + """When queue is at max_queue_depth, dropped task is marked failed in DB.""" + from scripts.db import insert_task + + s = TaskScheduler(tmp_db, _noop_run_task) + s._max_queue_depth = 2 + + # Fill the queue to the limit via direct deque manipulation (no DB rows needed) + from scripts.task_scheduler import TaskSpec + s._queues.setdefault("cover_letter", deque()) + s._queues["cover_letter"].append(TaskSpec(99, 1, None)) + s._queues["cover_letter"].append(TaskSpec(100, 2, None)) + + # Insert a real DB row for the task we're about to drop + task_id, _ = insert_task(tmp_db, "cover_letter", 3) + + # This enqueue should be rejected and the DB row marked failed + s.enqueue(task_id, "cover_letter", 3, None) + + conn = sqlite3.connect(tmp_db) + row = conn.execute( + "SELECT status, error FROM background_tasks WHERE id=?", (task_id,) + ).fetchone() + conn.close() + + assert row[0] == "failed" + assert "depth" in row[1].lower() + # Queue length unchanged + assert len(s._queues["cover_letter"]) == 2 + + +def test_max_queue_depth_logs_warning(tmp_db, caplog): + """Queue depth overflow logs a WARNING.""" + import logging + from scripts.db import insert_task + from scripts.task_scheduler import TaskSpec + + s = TaskScheduler(tmp_db, _noop_run_task) + s._max_queue_depth = 0 # immediately at limit + + task_id, _ = insert_task(tmp_db, "cover_letter", 1) + with caplog.at_level(logging.WARNING, logger="scripts.task_scheduler"): + s.enqueue(task_id, "cover_letter", 1, None) + + assert any("depth" in r.message.lower() for r in caplog.records)