feat(scheduler): implement enqueue() with depth guard and ghost-row cleanup
This commit is contained in:
parent
28e66001a3
commit
4d055f6bcd
2 changed files with 90 additions and 0 deletions
|
|
@ -91,6 +91,29 @@ class TaskScheduler:
|
||||||
except Exception:
|
except Exception:
|
||||||
self._available_vram = 999.0
|
self._available_vram = 999.0
|
||||||
|
|
||||||
|
def enqueue(self, task_id: int, task_type: str, job_id: int,
|
||||||
|
params: Optional[str]) -> None:
|
||||||
|
"""Add an LLM task to the scheduler queue.
|
||||||
|
|
||||||
|
If the queue for this type is at max_queue_depth, the task is marked
|
||||||
|
failed in SQLite immediately (no ghost queued rows) and a warning is logged.
|
||||||
|
"""
|
||||||
|
from scripts.db import update_task_status
|
||||||
|
|
||||||
|
with self._lock:
|
||||||
|
q = self._queues.setdefault(task_type, deque())
|
||||||
|
if len(q) >= self._max_queue_depth:
|
||||||
|
logger.warning(
|
||||||
|
"Queue depth limit reached for %s (max=%d) — task %d dropped",
|
||||||
|
task_type, self._max_queue_depth, task_id,
|
||||||
|
)
|
||||||
|
update_task_status(self._db_path, task_id, "failed",
|
||||||
|
error="Queue depth limit reached")
|
||||||
|
return
|
||||||
|
q.append(TaskSpec(task_id, job_id, params))
|
||||||
|
|
||||||
|
self._wake.set()
|
||||||
|
|
||||||
|
|
||||||
# ── Singleton ─────────────────────────────────────────────────────────────────
|
# ── Singleton ─────────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,6 +1,7 @@
|
||||||
# tests/test_task_scheduler.py
|
# tests/test_task_scheduler.py
|
||||||
"""Tests for scripts/task_scheduler.py and related db helpers."""
|
"""Tests for scripts/task_scheduler.py and related db helpers."""
|
||||||
import sqlite3
|
import sqlite3
|
||||||
|
from collections import deque
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
@ -128,3 +129,69 @@ def test_gpu_vram_summed_across_all_gpus(tmp_db, monkeypatch):
|
||||||
monkeypatch.setattr("scripts.task_scheduler._get_gpus", lambda: fake_gpus)
|
monkeypatch.setattr("scripts.task_scheduler._get_gpus", lambda: fake_gpus)
|
||||||
s = TaskScheduler(tmp_db, _noop_run_task)
|
s = TaskScheduler(tmp_db, _noop_run_task)
|
||||||
assert s._available_vram == 48.0
|
assert s._available_vram == 48.0
|
||||||
|
|
||||||
|
|
||||||
|
def test_enqueue_adds_taskspec_to_deque(tmp_db):
|
||||||
|
"""enqueue() appends a TaskSpec to the correct per-type deque."""
|
||||||
|
s = TaskScheduler(tmp_db, _noop_run_task)
|
||||||
|
s.enqueue(1, "cover_letter", 10, None)
|
||||||
|
s.enqueue(2, "cover_letter", 11, '{"key": "val"}')
|
||||||
|
|
||||||
|
assert len(s._queues["cover_letter"]) == 2
|
||||||
|
assert s._queues["cover_letter"][0].id == 1
|
||||||
|
assert s._queues["cover_letter"][1].id == 2
|
||||||
|
|
||||||
|
|
||||||
|
def test_enqueue_wakes_scheduler(tmp_db):
|
||||||
|
"""enqueue() sets the _wake event so the scheduler loop re-evaluates."""
|
||||||
|
s = TaskScheduler(tmp_db, _noop_run_task)
|
||||||
|
assert not s._wake.is_set()
|
||||||
|
s.enqueue(1, "cover_letter", 10, None)
|
||||||
|
assert s._wake.is_set()
|
||||||
|
|
||||||
|
|
||||||
|
def test_max_queue_depth_marks_task_failed(tmp_db):
|
||||||
|
"""When queue is at max_queue_depth, dropped task is marked failed in DB."""
|
||||||
|
from scripts.db import insert_task
|
||||||
|
|
||||||
|
s = TaskScheduler(tmp_db, _noop_run_task)
|
||||||
|
s._max_queue_depth = 2
|
||||||
|
|
||||||
|
# Fill the queue to the limit via direct deque manipulation (no DB rows needed)
|
||||||
|
from scripts.task_scheduler import TaskSpec
|
||||||
|
s._queues.setdefault("cover_letter", deque())
|
||||||
|
s._queues["cover_letter"].append(TaskSpec(99, 1, None))
|
||||||
|
s._queues["cover_letter"].append(TaskSpec(100, 2, None))
|
||||||
|
|
||||||
|
# Insert a real DB row for the task we're about to drop
|
||||||
|
task_id, _ = insert_task(tmp_db, "cover_letter", 3)
|
||||||
|
|
||||||
|
# This enqueue should be rejected and the DB row marked failed
|
||||||
|
s.enqueue(task_id, "cover_letter", 3, None)
|
||||||
|
|
||||||
|
conn = sqlite3.connect(tmp_db)
|
||||||
|
row = conn.execute(
|
||||||
|
"SELECT status, error FROM background_tasks WHERE id=?", (task_id,)
|
||||||
|
).fetchone()
|
||||||
|
conn.close()
|
||||||
|
|
||||||
|
assert row[0] == "failed"
|
||||||
|
assert "depth" in row[1].lower()
|
||||||
|
# Queue length unchanged
|
||||||
|
assert len(s._queues["cover_letter"]) == 2
|
||||||
|
|
||||||
|
|
||||||
|
def test_max_queue_depth_logs_warning(tmp_db, caplog):
|
||||||
|
"""Queue depth overflow logs a WARNING."""
|
||||||
|
import logging
|
||||||
|
from scripts.db import insert_task
|
||||||
|
from scripts.task_scheduler import TaskSpec
|
||||||
|
|
||||||
|
s = TaskScheduler(tmp_db, _noop_run_task)
|
||||||
|
s._max_queue_depth = 0 # immediately at limit
|
||||||
|
|
||||||
|
task_id, _ = insert_task(tmp_db, "cover_letter", 1)
|
||||||
|
with caplog.at_level(logging.WARNING, logger="scripts.task_scheduler"):
|
||||||
|
s.enqueue(task_id, "cover_letter", 1, None)
|
||||||
|
|
||||||
|
assert any("depth" in r.message.lower() for r in caplog.records)
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue