From dfd2f0214eca72bab847274aa1b5f7ea1467417a Mon Sep 17 00:00:00 2001 From: pyr0ball Date: Sun, 15 Mar 2026 04:24:11 -0700 Subject: [PATCH] =?UTF-8?q?feat(scheduler):=20add=20durability=20=E2=80=94?= =?UTF-8?q?=20re-queue=20surviving=20LLM=20tasks=20on=20startup?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- scripts/task_scheduler.py | 23 ++++++++++++++++ tests/test_task_scheduler.py | 53 ++++++++++++++++++++++++++++++++++++ 2 files changed, 76 insertions(+) diff --git a/scripts/task_scheduler.py b/scripts/task_scheduler.py index 307ba7f..baca6a8 100644 --- a/scripts/task_scheduler.py +++ b/scripts/task_scheduler.py @@ -91,6 +91,9 @@ class TaskScheduler: except Exception: self._available_vram = 999.0 + # Durability: reload surviving 'queued' LLM tasks from prior run + self._load_queued_tasks() + def enqueue(self, task_id: int, task_type: str, job_id: int, params: Optional[str]) -> None: """Add an LLM task to the scheduler queue. @@ -186,6 +189,26 @@ class TaskScheduler: self._reserved_vram -= self._budgets.get(task_type, 0.0) self._wake.set() + def _load_queued_tasks(self) -> None: + """Load pre-existing queued LLM tasks from SQLite into deques (called once in __init__).""" + llm_types = sorted(LLM_TASK_TYPES) # sorted for deterministic SQL params in logs + placeholders = ",".join("?" * len(llm_types)) + conn = sqlite3.connect(self._db_path) + rows = conn.execute( + f"SELECT id, task_type, job_id, params FROM background_tasks" + f" WHERE status='queued' AND task_type IN ({placeholders})" + f" ORDER BY created_at ASC", + llm_types, + ).fetchall() + conn.close() + + for row_id, task_type, job_id, params in rows: + q = self._queues.setdefault(task_type, deque()) + q.append(TaskSpec(row_id, job_id, params)) + + if rows: + logger.info("Scheduler: resumed %d queued task(s) from prior run", len(rows)) + # ── Singleton ───────────────────────────────────────────────────────────────── diff --git a/tests/test_task_scheduler.py b/tests/test_task_scheduler.py index 4128467..2992463 100644 --- a/tests/test_task_scheduler.py +++ b/tests/test_task_scheduler.py @@ -371,3 +371,56 @@ def test_reset_scheduler_cleans_up(tmp_db): # After reset, get_scheduler creates a fresh instance s2 = get_scheduler(tmp_db, _noop_run_task) assert s2 is not s + + +def test_durability_loads_queued_llm_tasks_on_startup(tmp_db): + """Scheduler loads pre-existing queued LLM tasks into deques at construction.""" + from scripts.db import insert_task + + # Pre-insert queued rows simulating a prior run + id1, _ = insert_task(tmp_db, "cover_letter", 1) + id2, _ = insert_task(tmp_db, "company_research", 2) + + s = TaskScheduler(tmp_db, _noop_run_task) + + assert len(s._queues.get("cover_letter", [])) == 1 + assert s._queues["cover_letter"][0].id == id1 + assert len(s._queues.get("company_research", [])) == 1 + assert s._queues["company_research"][0].id == id2 + + +def test_durability_excludes_non_llm_queued_tasks(tmp_db): + """Non-LLM queued tasks are not loaded into the scheduler deques.""" + from scripts.db import insert_task + + insert_task(tmp_db, "discovery", 0) + insert_task(tmp_db, "email_sync", 0) + + s = TaskScheduler(tmp_db, _noop_run_task) + + assert "discovery" not in s._queues or len(s._queues["discovery"]) == 0 + assert "email_sync" not in s._queues or len(s._queues["email_sync"]) == 0 + + +def test_durability_preserves_fifo_order(tmp_db): + """Queued tasks are loaded in created_at (FIFO) order.""" + conn = sqlite3.connect(tmp_db) + # Insert with explicit timestamps to control order + conn.execute( + "INSERT INTO background_tasks (task_type, job_id, params, status, created_at)" + " VALUES (?,?,?,?,?)", ("cover_letter", 1, None, "queued", "2026-01-01 10:00:00") + ) + conn.execute( + "INSERT INTO background_tasks (task_type, job_id, params, status, created_at)" + " VALUES (?,?,?,?,?)", ("cover_letter", 2, None, "queued", "2026-01-01 09:00:00") + ) + conn.commit() + ids = [r[0] for r in conn.execute( + "SELECT id FROM background_tasks ORDER BY created_at ASC" + ).fetchall()] + conn.close() + + s = TaskScheduler(tmp_db, _noop_run_task) + + loaded_ids = [t.id for t in s._queues["cover_letter"]] + assert loaded_ids == ids