feat(scheduler): implement enqueue() with depth guard and ghost-row cleanup

This commit is contained in:
pyr0ball 2026-03-15 04:05:22 -07:00
parent 28e66001a3
commit 4d055f6bcd
2 changed files with 90 additions and 0 deletions

View file

@ -91,6 +91,29 @@ class TaskScheduler:
except Exception:
self._available_vram = 999.0
def enqueue(self, task_id: int, task_type: str, job_id: int,
params: Optional[str]) -> None:
"""Add an LLM task to the scheduler queue.
If the queue for this type is at max_queue_depth, the task is marked
failed in SQLite immediately (no ghost queued rows) and a warning is logged.
"""
from scripts.db import update_task_status
with self._lock:
q = self._queues.setdefault(task_type, deque())
if len(q) >= self._max_queue_depth:
logger.warning(
"Queue depth limit reached for %s (max=%d) — task %d dropped",
task_type, self._max_queue_depth, task_id,
)
update_task_status(self._db_path, task_id, "failed",
error="Queue depth limit reached")
return
q.append(TaskSpec(task_id, job_id, params))
self._wake.set()
# ── Singleton ─────────────────────────────────────────────────────────────────

View file

@ -1,6 +1,7 @@
# tests/test_task_scheduler.py
"""Tests for scripts/task_scheduler.py and related db helpers."""
import sqlite3
from collections import deque
from pathlib import Path
import pytest
@ -128,3 +129,69 @@ def test_gpu_vram_summed_across_all_gpus(tmp_db, monkeypatch):
monkeypatch.setattr("scripts.task_scheduler._get_gpus", lambda: fake_gpus)
s = TaskScheduler(tmp_db, _noop_run_task)
assert s._available_vram == 48.0
def test_enqueue_adds_taskspec_to_deque(tmp_db):
"""enqueue() appends a TaskSpec to the correct per-type deque."""
s = TaskScheduler(tmp_db, _noop_run_task)
s.enqueue(1, "cover_letter", 10, None)
s.enqueue(2, "cover_letter", 11, '{"key": "val"}')
assert len(s._queues["cover_letter"]) == 2
assert s._queues["cover_letter"][0].id == 1
assert s._queues["cover_letter"][1].id == 2
def test_enqueue_wakes_scheduler(tmp_db):
"""enqueue() sets the _wake event so the scheduler loop re-evaluates."""
s = TaskScheduler(tmp_db, _noop_run_task)
assert not s._wake.is_set()
s.enqueue(1, "cover_letter", 10, None)
assert s._wake.is_set()
def test_max_queue_depth_marks_task_failed(tmp_db):
"""When queue is at max_queue_depth, dropped task is marked failed in DB."""
from scripts.db import insert_task
s = TaskScheduler(tmp_db, _noop_run_task)
s._max_queue_depth = 2
# Fill the queue to the limit via direct deque manipulation (no DB rows needed)
from scripts.task_scheduler import TaskSpec
s._queues.setdefault("cover_letter", deque())
s._queues["cover_letter"].append(TaskSpec(99, 1, None))
s._queues["cover_letter"].append(TaskSpec(100, 2, None))
# Insert a real DB row for the task we're about to drop
task_id, _ = insert_task(tmp_db, "cover_letter", 3)
# This enqueue should be rejected and the DB row marked failed
s.enqueue(task_id, "cover_letter", 3, None)
conn = sqlite3.connect(tmp_db)
row = conn.execute(
"SELECT status, error FROM background_tasks WHERE id=?", (task_id,)
).fetchone()
conn.close()
assert row[0] == "failed"
assert "depth" in row[1].lower()
# Queue length unchanged
assert len(s._queues["cover_letter"]) == 2
def test_max_queue_depth_logs_warning(tmp_db, caplog):
"""Queue depth overflow logs a WARNING."""
import logging
from scripts.db import insert_task
from scripts.task_scheduler import TaskSpec
s = TaskScheduler(tmp_db, _noop_run_task)
s._max_queue_depth = 0 # immediately at limit
task_id, _ = insert_task(tmp_db, "cover_letter", 1)
with caplog.at_level(logging.WARNING, logger="scripts.task_scheduler"):
s.enqueue(task_id, "cover_letter", 1, None)
assert any("depth" in r.message.lower() for r in caplog.records)