feat: LLM queue optimizer — resource-aware batch scheduler (closes #2) #13
10 changed files with 2564 additions and 15 deletions
4
.gitignore
vendored
4
.gitignore
vendored
|
|
@ -44,3 +44,7 @@ config/label_tool.yaml
|
|||
config/server.yaml
|
||||
|
||||
demo/data/*.db
|
||||
demo/seed_demo.py
|
||||
|
||||
# Git worktrees
|
||||
.worktrees/
|
||||
|
|
|
|||
10
app/app.py
10
app/app.py
|
|
@ -42,12 +42,12 @@ def _startup() -> None:
|
|||
2. Auto-queues re-runs for any research generated without SearXNG data,
|
||||
if SearXNG is now reachable.
|
||||
"""
|
||||
# Reset only in-flight tasks — queued tasks survive for the scheduler to resume.
|
||||
# MUST run before any submit_task() call in this function.
|
||||
from scripts.db import reset_running_tasks
|
||||
reset_running_tasks(get_db_path())
|
||||
|
||||
conn = sqlite3.connect(get_db_path())
|
||||
conn.execute(
|
||||
"UPDATE background_tasks SET status='failed', error='Interrupted by server restart',"
|
||||
" finished_at=datetime('now') WHERE status IN ('queued','running')"
|
||||
)
|
||||
conn.commit()
|
||||
|
||||
# Auto-recovery: re-run LLM-only research when SearXNG is available
|
||||
try:
|
||||
|
|
|
|||
|
|
@ -64,3 +64,14 @@ vision_fallback_order:
|
|||
# Note: 'ollama' (alex-cover-writer) intentionally excluded — research
|
||||
# must never use the fine-tuned writer model, and this also avoids evicting
|
||||
# the writer from GPU memory while a cover letter task is in flight.
|
||||
|
||||
# ── Scheduler — LLM batch queue optimizer ─────────────────────────────────────
|
||||
# The scheduler batches LLM tasks by model type to avoid GPU model switching.
|
||||
# VRAM budgets are conservative peak estimates (GB) for each task type.
|
||||
# Increase if your models are larger; decrease if tasks share GPU memory well.
|
||||
scheduler:
|
||||
vram_budgets:
|
||||
cover_letter: 2.5 # alex-cover-writer:latest (~2GB GGUF + headroom)
|
||||
company_research: 5.0 # llama3.1:8b or vllm model
|
||||
wizard_generate: 2.5 # same model family as cover_letter
|
||||
max_queue_depth: 500 # max pending tasks per type before drops (with logged warning)
|
||||
|
|
|
|||
1306
docs/superpowers/plans/2026-03-14-llm-queue-optimizer.md
Normal file
1306
docs/superpowers/plans/2026-03-14-llm-queue-optimizer.md
Normal file
File diff suppressed because it is too large
Load diff
477
docs/superpowers/specs/2026-03-14-llm-queue-optimizer-design.md
Normal file
477
docs/superpowers/specs/2026-03-14-llm-queue-optimizer-design.md
Normal file
|
|
@ -0,0 +1,477 @@
|
|||
# LLM Queue Optimizer — Design Spec
|
||||
|
||||
**Date:** 2026-03-14
|
||||
**Branch:** `feature/llm-queue-optimizer`
|
||||
**Closes:** [#2](https://git.opensourcesolarpunk.com/Circuit-Forge/peregrine/issues/2)
|
||||
**Author:** pyr0ball
|
||||
|
||||
---
|
||||
|
||||
## Problem
|
||||
|
||||
On single-GPU and CPU-only systems, the background task runner spawns a daemon thread for every task immediately on submission. When a user approves N jobs at once, N threads race to load their respective LLM models simultaneously, causing repeated model swaps and significant latency overhead.
|
||||
|
||||
The root issue is that `submit_task()` is a spawn-per-task model with no scheduling layer. SQLite's `background_tasks` table is a status log, not a consumed work queue.
|
||||
|
||||
Additionally, on restart all `queued` and `running` tasks are cleared to `failed` (inline SQL in `app.py`'s `_startup()`), discarding pending work that had not yet started executing.
|
||||
|
||||
---
|
||||
|
||||
## Goals
|
||||
|
||||
- Eliminate unnecessary model switching by batching LLM tasks by type
|
||||
- Allow concurrent model execution when VRAM permits multiple models simultaneously
|
||||
- Preserve FIFO ordering within each task type
|
||||
- Survive process restarts — `queued` tasks resume after restart; only `running` tasks (whose results are unknown) are reset to `failed`
|
||||
- Apply to all tiers (no tier gating)
|
||||
- Keep non-LLM tasks (discovery, email sync, scrape, enrich) unaffected — they continue to spawn free threads
|
||||
|
||||
---
|
||||
|
||||
## Non-Goals
|
||||
|
||||
- Changing the LLM router fallback chain
|
||||
- Adding new task types
|
||||
- Tier gating on the scheduler
|
||||
- Persistent task history in memory
|
||||
- Durability for non-LLM task types (discovery, email_sync, etc. — these do not survive restarts, same as current behavior)
|
||||
- Dynamic VRAM tracking — `_available_vram` is read once at startup and not refreshed (see Known Limitations)
|
||||
|
||||
---
|
||||
|
||||
## Architecture
|
||||
|
||||
### Task Classification
|
||||
|
||||
```python
|
||||
LLM_TASK_TYPES = {"cover_letter", "company_research", "wizard_generate"}
|
||||
```
|
||||
|
||||
The routing rule is: if `task_type in LLM_TASK_TYPES`, route through the scheduler. Everything else spawns a free thread unchanged from the current implementation. **Future task types default to bypass mode** unless explicitly added to `LLM_TASK_TYPES` — which is the safe default (bypass = current behavior).
|
||||
|
||||
`LLM_TASK_TYPES` is defined in `scripts/task_scheduler.py` and imported by `scripts/task_runner.py` for routing. This import direction (task_runner imports from task_scheduler) avoids circular imports because `task_scheduler.py` does **not** import from `task_runner.py`.
|
||||
|
||||
Current non-LLM types (all bypass scheduler): `discovery`, `email_sync`, `scrape_url`, `enrich_descriptions`, `enrich_craigslist`, `prepare_training`.
|
||||
|
||||
### Routing in `submit_task()` — No Circular Import
|
||||
|
||||
The routing split lives entirely in `submit_task()` in `task_runner.py`:
|
||||
|
||||
```python
|
||||
def submit_task(db_path, task_type, job_id=None, params=None):
|
||||
task_id, is_new = insert_task(db_path, task_type, job_id or 0, params=params)
|
||||
if is_new:
|
||||
from scripts.task_scheduler import get_scheduler, LLM_TASK_TYPES
|
||||
if task_type in LLM_TASK_TYPES:
|
||||
get_scheduler(db_path).enqueue(task_id, task_type, job_id or 0, params)
|
||||
else:
|
||||
t = threading.Thread(
|
||||
target=_run_task,
|
||||
args=(db_path, task_id, task_type, job_id or 0, params),
|
||||
daemon=True,
|
||||
)
|
||||
t.start()
|
||||
return task_id, is_new
|
||||
```
|
||||
|
||||
`TaskScheduler.enqueue()` only handles LLM task types and never imports or calls `_run_task`. This eliminates any circular import between `task_runner` and `task_scheduler`.
|
||||
|
||||
### Component Overview
|
||||
|
||||
```
|
||||
submit_task()
|
||||
│
|
||||
├── task_type in LLM_TASK_TYPES?
|
||||
│ │ yes │ no
|
||||
│ ▼ ▼
|
||||
│ get_scheduler().enqueue() spawn free thread (unchanged)
|
||||
│ │
|
||||
│ ▼
|
||||
│ per-type deque
|
||||
│ │
|
||||
│ ▼
|
||||
│ Scheduler loop (daemon thread)
|
||||
│ (wakes on enqueue or batch completion)
|
||||
│ │
|
||||
│ Sort eligible types by queue depth (desc)
|
||||
│ │
|
||||
│ For each type:
|
||||
│ reserved_vram + budget[type] ≤ available_vram?
|
||||
│ │ yes │ no
|
||||
│ ▼ ▼
|
||||
│ Start batch worker skip (wait for slot)
|
||||
│ (serial: one task at a time)
|
||||
│ │
|
||||
│ Batch worker signals done → scheduler re-evaluates
|
||||
```
|
||||
|
||||
### New File: `scripts/task_scheduler.py`
|
||||
|
||||
**State:**
|
||||
|
||||
| Attribute | Type | Purpose |
|
||||
|---|---|---|
|
||||
| `_queues` | `dict[str, deque[TaskSpec]]` | Per-type pending task deques |
|
||||
| `_active` | `dict[str, Thread]` | Currently running batch worker per type |
|
||||
| `_budgets` | `dict[str, float]` | VRAM budget per task type (GB). Loaded at construction by merging `DEFAULT_VRAM_BUDGETS` with `scheduler.vram_budgets` from `config/llm.yaml`. Config path derived from `db_path` (e.g. `db_path.parent.parent / "config/llm.yaml"`). Missing file or key → defaults used as-is. At construction, a warning is logged for any type in `LLM_TASK_TYPES` with no budget entry after the merge. |
|
||||
| `_reserved_vram` | `float` | Sum of `_budgets` values for currently active type batches |
|
||||
| `_available_vram` | `float` | Total VRAM from `get_gpus()` summed across all GPUs at construction; 999.0 on CPU-only systems. Static — not refreshed after startup (see Known Limitations). |
|
||||
| `_max_queue_depth` | `int` | Max tasks per type queue before drops. From `scheduler.max_queue_depth` in config; default 500. |
|
||||
| `_lock` | `threading.Lock` | Protects all mutable scheduler state |
|
||||
| `_wake` | `threading.Event` | Pulsed on enqueue or batch completion |
|
||||
| `_stop` | `threading.Event` | Set by `shutdown()` to terminate the loop |
|
||||
|
||||
**Default VRAM budgets (module-level constant):**
|
||||
|
||||
```python
|
||||
DEFAULT_VRAM_BUDGETS: dict[str, float] = {
|
||||
"cover_letter": 2.5, # alex-cover-writer:latest (~2GB GGUF + headroom)
|
||||
"company_research": 5.0, # llama3.1:8b or vllm model
|
||||
"wizard_generate": 2.5, # same model family as cover_letter
|
||||
}
|
||||
```
|
||||
|
||||
At construction, the scheduler validates that every type in `LLM_TASK_TYPES` has an entry
|
||||
in the merged `_budgets`. If any type is missing, a warning is logged:
|
||||
|
||||
```
|
||||
WARNING task_scheduler: No VRAM budget defined for LLM task type 'foo' — defaulting to 0.0 GB (unlimited concurrency for this type)
|
||||
```
|
||||
|
||||
**Scheduler loop:**
|
||||
|
||||
```python
|
||||
while not _stop.is_set():
|
||||
_wake.wait(timeout=30)
|
||||
_wake.clear()
|
||||
|
||||
with _lock:
|
||||
# Defense in depth: reap dead threads not yet cleaned by their finally block.
|
||||
# In the normal path, a batch worker's finally block calls _active.pop() and
|
||||
# decrements _reserved_vram BEFORE firing _wake — so by the time we scan here,
|
||||
# the entry is already gone and there is no double-decrement risk.
|
||||
# This reap only catches threads killed externally (daemon exit on shutdown).
|
||||
for t, thread in list(_active.items()):
|
||||
if not thread.is_alive():
|
||||
_reserved_vram -= _budgets.get(t, 0)
|
||||
del _active[t]
|
||||
|
||||
# Start new batches where VRAM allows
|
||||
candidates = sorted(
|
||||
[t for t in _queues if _queues[t] and t not in _active],
|
||||
key=lambda t: len(_queues[t]),
|
||||
reverse=True,
|
||||
)
|
||||
for task_type in candidates:
|
||||
budget = _budgets.get(task_type, 0)
|
||||
if _reserved_vram + budget <= _available_vram:
|
||||
thread = Thread(target=_batch_worker, args=(task_type,), daemon=True)
|
||||
_active[task_type] = thread
|
||||
_reserved_vram += budget
|
||||
thread.start()
|
||||
```
|
||||
|
||||
**Batch worker:**
|
||||
|
||||
The `finally` block is the single authoritative path for releasing `_reserved_vram` and
|
||||
removing the entry from `_active`. Because `_active.pop` runs in `finally` before
|
||||
`_wake.set()`, the scheduler loop's dead-thread scan will never find this entry —
|
||||
no double-decrement is possible in the normal execution path.
|
||||
|
||||
```python
|
||||
def _batch_worker(task_type: str) -> None:
|
||||
try:
|
||||
while True:
|
||||
with _lock:
|
||||
if not _queues[task_type]:
|
||||
break
|
||||
task = _queues[task_type].popleft()
|
||||
_run_task(db_path, task.id, task_type, task.job_id, task.params)
|
||||
finally:
|
||||
with _lock:
|
||||
_active.pop(task_type, None)
|
||||
_reserved_vram -= _budgets.get(task_type, 0)
|
||||
_wake.set()
|
||||
```
|
||||
|
||||
`_run_task` here refers to `task_runner._run_task`, passed in as a callable at
|
||||
construction (e.g. `self._run_task = run_task_fn`). The caller (`task_runner.py`)
|
||||
passes `_run_task` when constructing the scheduler, avoiding any import of `task_runner`
|
||||
from within `task_scheduler`.
|
||||
|
||||
**`enqueue()` method:**
|
||||
|
||||
`enqueue()` only accepts LLM task types. Non-LLM routing is handled in `submit_task()`
|
||||
before `enqueue()` is called (see Routing section above).
|
||||
|
||||
```python
|
||||
def enqueue(self, task_id: int, task_type: str, job_id: int, params: str | None) -> None:
|
||||
with self._lock:
|
||||
q = self._queues.setdefault(task_type, deque())
|
||||
if len(q) >= self._max_queue_depth:
|
||||
logger.warning(
|
||||
"Queue depth limit reached for %s (max=%d) — task %d dropped",
|
||||
task_type, self._max_queue_depth, task_id,
|
||||
)
|
||||
update_task_status(self._db_path, task_id, "failed",
|
||||
error="Queue depth limit reached")
|
||||
return
|
||||
q.append(TaskSpec(task_id, job_id, params))
|
||||
self._wake.set()
|
||||
```
|
||||
|
||||
When a task is dropped at the depth limit, `update_task_status()` marks it `failed` in
|
||||
SQLite immediately — the row inserted by `insert_task()` is never left as a permanent
|
||||
ghost in `queued` state.
|
||||
|
||||
**Singleton access — thread-safe initialization:**
|
||||
|
||||
```python
|
||||
_scheduler: TaskScheduler | None = None
|
||||
_scheduler_lock = threading.Lock()
|
||||
|
||||
def get_scheduler(db_path: Path) -> TaskScheduler:
|
||||
global _scheduler
|
||||
if _scheduler is None: # fast path — avoids lock on steady state
|
||||
with _scheduler_lock:
|
||||
if _scheduler is None: # re-check under lock (double-checked locking)
|
||||
_scheduler = TaskScheduler(db_path)
|
||||
_scheduler.start()
|
||||
return _scheduler
|
||||
|
||||
def reset_scheduler() -> None:
|
||||
"""Tear down and clear singleton. Test teardown only."""
|
||||
global _scheduler
|
||||
with _scheduler_lock:
|
||||
if _scheduler:
|
||||
_scheduler.shutdown()
|
||||
_scheduler = None
|
||||
```
|
||||
|
||||
The safety guarantee comes from the **inner `with _scheduler_lock:` block and re-check**,
|
||||
not from GIL atomicity. The outer `if _scheduler is None` is a performance optimization
|
||||
(avoid acquiring the lock on every `submit_task()` call once the scheduler is running).
|
||||
Two threads racing at startup will both pass the outer check, but only one will win the
|
||||
inner lock and construct the scheduler; the other will see a non-None value on its
|
||||
inner re-check and return the already-constructed instance.
|
||||
|
||||
---
|
||||
|
||||
## Required Call Ordering in `app.py`
|
||||
|
||||
`reset_running_tasks()` **must complete before** `get_scheduler()` is ever called.
|
||||
The scheduler's durability query reads `status='queued'` rows; if `reset_running_tasks()`
|
||||
has not yet run, a row stuck in `status='running'` from a prior crash would be loaded
|
||||
into the deque and re-executed, producing a duplicate result.
|
||||
|
||||
In practice, the first call to `get_scheduler()` is triggered by the `submit_task()` call
|
||||
inside `_startup()`'s SearXNG auto-recovery block — not by a user action. The ordering
|
||||
holds because `reset_running_tasks()` is called on an earlier line within the same
|
||||
`_startup()` function body. **Do not reorder these calls.**
|
||||
|
||||
```python
|
||||
@st.cache_resource
|
||||
def _startup() -> None:
|
||||
# Step 1: Reset interrupted tasks — MUST come first
|
||||
from scripts.db import reset_running_tasks
|
||||
reset_running_tasks(get_db_path())
|
||||
|
||||
# Step 2 (later in same function): SearXNG re-queue calls submit_task(),
|
||||
# which triggers get_scheduler() for the first time. Ordering is guaranteed
|
||||
# because _startup() runs synchronously and step 1 is already complete.
|
||||
conn = sqlite3.connect(get_db_path())
|
||||
# ... existing SearXNG re-queue logic using conn ...
|
||||
conn.close()
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Changes to Existing Files
|
||||
|
||||
### `scripts/task_runner.py`
|
||||
|
||||
`submit_task()` gains routing logic; `_run_task` is passed to the scheduler at first call:
|
||||
|
||||
```python
|
||||
def submit_task(db_path, task_type, job_id=None, params=None):
|
||||
task_id, is_new = insert_task(db_path, task_type, job_id or 0, params=params)
|
||||
if is_new:
|
||||
from scripts.task_scheduler import get_scheduler, LLM_TASK_TYPES
|
||||
if task_type in LLM_TASK_TYPES:
|
||||
get_scheduler(db_path, run_task_fn=_run_task).enqueue(
|
||||
task_id, task_type, job_id or 0, params
|
||||
)
|
||||
else:
|
||||
t = threading.Thread(
|
||||
target=_run_task,
|
||||
args=(db_path, task_id, task_type, job_id or 0, params),
|
||||
daemon=True,
|
||||
)
|
||||
t.start()
|
||||
return task_id, is_new
|
||||
```
|
||||
|
||||
`get_scheduler()` accepts `run_task_fn` only on first call (when constructing); subsequent
|
||||
calls ignore it (singleton already initialized). `_run_task()` and all handler branches
|
||||
remain unchanged.
|
||||
|
||||
### `scripts/db.py`
|
||||
|
||||
Add `reset_running_tasks()` alongside the existing `kill_stuck_tasks()`. Like
|
||||
`kill_stuck_tasks()`, it uses a plain `sqlite3.connect()` — consistent with the
|
||||
existing pattern in this file, and appropriate because this call happens before the
|
||||
app's connection pooling is established:
|
||||
|
||||
```python
|
||||
def reset_running_tasks(db_path: Path = DEFAULT_DB) -> int:
|
||||
"""On restart: mark in-flight tasks failed. Queued tasks survive for the scheduler."""
|
||||
conn = sqlite3.connect(db_path)
|
||||
count = conn.execute(
|
||||
"UPDATE background_tasks SET status='failed', error='Interrupted by restart',"
|
||||
" finished_at=datetime('now') WHERE status='running'"
|
||||
).rowcount
|
||||
conn.commit()
|
||||
conn.close()
|
||||
return count
|
||||
```
|
||||
|
||||
### `app/app.py`
|
||||
|
||||
Inside `_startup()`, replace the inline SQL block that wipes both `queued` and `running`
|
||||
rows with a call to `reset_running_tasks()`. The replacement must be the **first operation
|
||||
in `_startup()`** — before the SearXNG re-queue logic that calls `submit_task()`:
|
||||
|
||||
```python
|
||||
# REMOVE this block:
|
||||
conn.execute(
|
||||
"UPDATE background_tasks SET status='failed', error='Interrupted by server restart',"
|
||||
" finished_at=datetime('now') WHERE status IN ('queued','running')"
|
||||
)
|
||||
|
||||
# ADD at the top of _startup(), before any submit_task() calls:
|
||||
from scripts.db import reset_running_tasks
|
||||
reset_running_tasks(get_db_path())
|
||||
```
|
||||
|
||||
The existing `conn` used for subsequent SearXNG logic is unaffected — `reset_running_tasks()`
|
||||
opens and closes its own connection.
|
||||
|
||||
### `config/llm.yaml.example`
|
||||
|
||||
Add `scheduler:` section:
|
||||
|
||||
```yaml
|
||||
scheduler:
|
||||
vram_budgets:
|
||||
cover_letter: 2.5 # alex-cover-writer:latest (~2GB GGUF + headroom)
|
||||
company_research: 5.0 # llama3.1:8b or vllm model
|
||||
wizard_generate: 2.5 # same model family as cover_letter
|
||||
max_queue_depth: 500
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Data Model
|
||||
|
||||
No schema changes. The existing `background_tasks` table supports all scheduler needs:
|
||||
|
||||
| Column | Scheduler use |
|
||||
|---|---|
|
||||
| `task_type` | Queue routing — determines which deque receives the task |
|
||||
| `status` | `queued` → in deque; `running` → batch worker executing; `completed`/`failed` → done |
|
||||
| `created_at` | FIFO ordering within type (durability startup query sorts by this) |
|
||||
| `params` | Passed through to `_run_task()` unchanged |
|
||||
|
||||
---
|
||||
|
||||
## Durability
|
||||
|
||||
Scope: **LLM task types only** (`cover_letter`, `company_research`, `wizard_generate`).
|
||||
Non-LLM tasks do not survive restarts, same as current behavior.
|
||||
|
||||
On construction, `TaskScheduler.__init__()` queries:
|
||||
|
||||
```sql
|
||||
SELECT id, task_type, job_id, params
|
||||
FROM background_tasks
|
||||
WHERE status = 'queued'
|
||||
AND task_type IN ('cover_letter', 'company_research', 'wizard_generate')
|
||||
ORDER BY created_at ASC
|
||||
```
|
||||
|
||||
Results are pushed onto their respective deques. This query runs inside `__init__` before
|
||||
`start()` is called (before the scheduler loop thread exists), so there is no concurrency
|
||||
concern with deque population.
|
||||
|
||||
`running` rows are reset to `failed` by `reset_running_tasks()` before `get_scheduler()`
|
||||
is called — see Required Call Ordering above.
|
||||
|
||||
---
|
||||
|
||||
## Known Limitations
|
||||
|
||||
**Static `_available_vram`:** Total GPU VRAM is read from `get_gpus()` once at scheduler
|
||||
construction and never refreshed. Changes after startup — another process releasing VRAM,
|
||||
a GPU going offline, Ollama unloading a model — are not reflected. The scheduler's
|
||||
correctness depends on per-task VRAM budgets being conservative estimates of **peak model
|
||||
footprint** (not free VRAM at a given moment). On a system where Ollama and vLLM share
|
||||
the GPU, budgets should account for both models potentially resident simultaneously.
|
||||
Dynamic VRAM polling is a future enhancement.
|
||||
|
||||
---
|
||||
|
||||
## Memory Safety
|
||||
|
||||
- **`finally` block owns VRAM release** — batch worker always decrements `_reserved_vram`
|
||||
and removes its `_active` entry before firing `_wake`, even on exception. The scheduler
|
||||
loop's dead-thread scan is defense in depth for externally-killed daemons only; it cannot
|
||||
double-decrement because `_active.pop` in `finally` runs first.
|
||||
- **Max queue depth with DB cleanup** — `enqueue()` rejects tasks past `max_queue_depth`,
|
||||
logs a warning, and immediately marks the dropped task `failed` in SQLite to prevent
|
||||
permanent ghost rows in `queued` state.
|
||||
- **No in-memory history** — deques hold only pending `TaskSpec` namedtuples. Completed
|
||||
and failed state lives exclusively in SQLite. Memory footprint is `O(pending tasks)`.
|
||||
- **Thread-safe singleton** — double-checked locking with `_scheduler_lock` prevents
|
||||
double-construction. Safety comes from the inner lock + re-check; the outer `None`
|
||||
check is a performance optimization only.
|
||||
- **Missing budget warning** — any `LLM_TASK_TYPES` entry with no budget entry after
|
||||
config merge logs a warning at construction; defaults to 0.0 GB (unlimited concurrency
|
||||
for that type). This prevents silent incorrect scheduling for future task types.
|
||||
- **`reset_scheduler()`** — explicit teardown for test isolation: sets `_stop`, joins
|
||||
scheduler thread with timeout, clears module-level reference under `_scheduler_lock`.
|
||||
|
||||
---
|
||||
|
||||
## Testing (`tests/test_task_scheduler.py`)
|
||||
|
||||
All tests mock `_run_task` to avoid real LLM calls. `reset_scheduler()` is called in
|
||||
an `autouse` fixture for isolation between test cases.
|
||||
|
||||
| Test | What it verifies |
|
||||
|---|---|
|
||||
| `test_deepest_queue_wins_first_slot` | N cover_letter + M research enqueued (N > M); cover_letter batch starts first when `_available_vram` only fits one model budget, because it has the deeper queue |
|
||||
| `test_fifo_within_type` | Arrival order preserved within a type batch |
|
||||
| `test_concurrent_batches_when_vram_allows` | Two type batches start simultaneously when `_available_vram` fits both budgets combined |
|
||||
| `test_new_tasks_picked_up_mid_batch` | Task enqueued via `enqueue()` while a batch is active is consumed by the running worker in the same batch |
|
||||
| `test_worker_crash_releases_vram` | `_run_task` raises; `_reserved_vram` returns to 0; scheduler continues; no double-decrement |
|
||||
| `test_non_llm_tasks_bypass_scheduler` | `discovery`, `email_sync` etc. spawn free threads via `submit_task()`; scheduler deques untouched |
|
||||
| `test_durability_llm_tasks_on_startup` | DB has existing `queued` LLM-type rows; scheduler loads them into deques on construction |
|
||||
| `test_durability_excludes_non_llm` | `queued` non-LLM rows in DB are not loaded into deques on startup |
|
||||
| `test_running_rows_reset_before_scheduler` | `reset_running_tasks()` sets `running` → `failed`; `queued` rows untouched |
|
||||
| `test_max_queue_depth_marks_failed` | Enqueue past limit logs warning, does not add to deque, and marks task `failed` in DB |
|
||||
| `test_missing_budget_logs_warning` | Type in `LLM_TASK_TYPES` with no budget entry at construction logs a warning |
|
||||
| `test_singleton_thread_safe` | Concurrent calls to `get_scheduler()` produce exactly one scheduler instance |
|
||||
| `test_reset_scheduler_cleans_up` | `reset_scheduler()` stops loop thread; no lingering threads after call |
|
||||
|
||||
---
|
||||
|
||||
## Files Touched
|
||||
|
||||
| File | Change |
|
||||
|---|---|
|
||||
| `scripts/task_scheduler.py` | **New** — ~180 lines |
|
||||
| `scripts/task_runner.py` | `submit_task()` routing shim — ~12 lines changed |
|
||||
| `scripts/db.py` | `reset_running_tasks()` added — ~10 lines |
|
||||
| `app/app.py` | `_startup()`: inline SQL block → `reset_running_tasks()` call, placed first |
|
||||
| `config/llm.yaml.example` | Add `scheduler:` section |
|
||||
| `tests/test_task_scheduler.py` | **New** — ~240 lines |
|
||||
|
|
@ -366,6 +366,18 @@ def kill_stuck_tasks(db_path: Path = DEFAULT_DB) -> int:
|
|||
return count
|
||||
|
||||
|
||||
def reset_running_tasks(db_path: Path = DEFAULT_DB) -> int:
|
||||
"""On restart: mark in-flight tasks failed. Queued tasks survive for the scheduler."""
|
||||
conn = sqlite3.connect(db_path)
|
||||
count = conn.execute(
|
||||
"UPDATE background_tasks SET status='failed', error='Interrupted by restart',"
|
||||
" finished_at=datetime('now') WHERE status='running'"
|
||||
).rowcount
|
||||
conn.commit()
|
||||
conn.close()
|
||||
return count
|
||||
|
||||
|
||||
def purge_email_data(db_path: Path = DEFAULT_DB) -> tuple[int, int]:
|
||||
"""Delete all job_contacts rows and email-sourced pending jobs.
|
||||
Returns (contacts_deleted, jobs_deleted).
|
||||
|
|
|
|||
|
|
@ -26,19 +26,29 @@ from scripts.db import (
|
|||
def submit_task(db_path: Path = DEFAULT_DB, task_type: str = "",
|
||||
job_id: int = None,
|
||||
params: str | None = None) -> tuple[int, bool]:
|
||||
"""Submit a background LLM task.
|
||||
"""Submit a background task.
|
||||
|
||||
Returns (task_id, True) if a new task was queued and a thread spawned.
|
||||
LLM task types (cover_letter, company_research, wizard_generate) are routed
|
||||
through the TaskScheduler for VRAM-aware batch scheduling.
|
||||
All other types spawn a free daemon thread as before.
|
||||
|
||||
Returns (task_id, True) if a new task was queued.
|
||||
Returns (existing_id, False) if an identical task is already in-flight.
|
||||
"""
|
||||
task_id, is_new = insert_task(db_path, task_type, job_id or 0, params=params)
|
||||
if is_new:
|
||||
t = threading.Thread(
|
||||
target=_run_task,
|
||||
args=(db_path, task_id, task_type, job_id or 0, params),
|
||||
daemon=True,
|
||||
)
|
||||
t.start()
|
||||
from scripts.task_scheduler import get_scheduler, LLM_TASK_TYPES
|
||||
if task_type in LLM_TASK_TYPES:
|
||||
get_scheduler(db_path, run_task_fn=_run_task).enqueue(
|
||||
task_id, task_type, job_id or 0, params
|
||||
)
|
||||
else:
|
||||
t = threading.Thread(
|
||||
target=_run_task,
|
||||
args=(db_path, task_id, task_type, job_id or 0, params),
|
||||
daemon=True,
|
||||
)
|
||||
t.start()
|
||||
return task_id, is_new
|
||||
|
||||
|
||||
|
|
|
|||
243
scripts/task_scheduler.py
Normal file
243
scripts/task_scheduler.py
Normal file
|
|
@ -0,0 +1,243 @@
|
|||
# scripts/task_scheduler.py
|
||||
"""Resource-aware batch scheduler for LLM background tasks.
|
||||
|
||||
Routes LLM task types through per-type deques with VRAM-aware scheduling.
|
||||
Non-LLM tasks bypass this module — routing lives in scripts/task_runner.py.
|
||||
|
||||
Public API:
|
||||
LLM_TASK_TYPES — set of task type strings routed through the scheduler
|
||||
get_scheduler() — lazy singleton accessor
|
||||
reset_scheduler() — test teardown only
|
||||
"""
|
||||
import logging
|
||||
import sqlite3
|
||||
import threading
|
||||
from collections import deque, namedtuple
|
||||
from pathlib import Path
|
||||
from typing import Callable, Optional
|
||||
|
||||
# Module-level import so tests can monkeypatch scripts.task_scheduler._get_gpus
|
||||
try:
|
||||
from scripts.preflight import get_gpus as _get_gpus
|
||||
except Exception: # graceful degradation if preflight unavailable
|
||||
_get_gpus = lambda: []
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Task types that go through the scheduler (all others spawn free threads)
|
||||
LLM_TASK_TYPES: frozenset[str] = frozenset({
|
||||
"cover_letter",
|
||||
"company_research",
|
||||
"wizard_generate",
|
||||
})
|
||||
|
||||
# Conservative peak VRAM estimates (GB) per task type.
|
||||
# Overridable per-install via scheduler.vram_budgets in config/llm.yaml.
|
||||
DEFAULT_VRAM_BUDGETS: dict[str, float] = {
|
||||
"cover_letter": 2.5, # alex-cover-writer:latest (~2GB GGUF + headroom)
|
||||
"company_research": 5.0, # llama3.1:8b or vllm model
|
||||
"wizard_generate": 2.5, # same model family as cover_letter
|
||||
}
|
||||
|
||||
# Lightweight task descriptor stored in per-type deques
|
||||
TaskSpec = namedtuple("TaskSpec", ["id", "job_id", "params"])
|
||||
|
||||
|
||||
class TaskScheduler:
|
||||
"""Resource-aware LLM task batch scheduler. Use get_scheduler() — not direct construction."""
|
||||
|
||||
def __init__(self, db_path: Path, run_task_fn: Callable) -> None:
|
||||
self._db_path = db_path
|
||||
self._run_task = run_task_fn
|
||||
|
||||
self._lock = threading.Lock()
|
||||
self._wake = threading.Event()
|
||||
self._stop = threading.Event()
|
||||
self._queues: dict[str, deque] = {}
|
||||
self._active: dict[str, threading.Thread] = {}
|
||||
self._reserved_vram: float = 0.0
|
||||
self._thread: Optional[threading.Thread] = None
|
||||
|
||||
# Load VRAM budgets: defaults + optional config overrides
|
||||
self._budgets: dict[str, float] = dict(DEFAULT_VRAM_BUDGETS)
|
||||
config_path = db_path.parent.parent / "config" / "llm.yaml"
|
||||
self._max_queue_depth: int = 500
|
||||
if config_path.exists():
|
||||
try:
|
||||
import yaml
|
||||
with open(config_path) as f:
|
||||
cfg = yaml.safe_load(f) or {}
|
||||
sched_cfg = cfg.get("scheduler", {})
|
||||
self._budgets.update(sched_cfg.get("vram_budgets", {}))
|
||||
self._max_queue_depth = sched_cfg.get("max_queue_depth", 500)
|
||||
except Exception as exc:
|
||||
logger.warning("Failed to load scheduler config from %s: %s", config_path, exc)
|
||||
|
||||
# Warn on LLM types with no budget entry after merge
|
||||
for t in LLM_TASK_TYPES:
|
||||
if t not in self._budgets:
|
||||
logger.warning(
|
||||
"No VRAM budget defined for LLM task type %r — "
|
||||
"defaulting to 0.0 GB (unlimited concurrency for this type)", t
|
||||
)
|
||||
|
||||
# Detect total GPU VRAM; fall back to unlimited (999) on CPU-only systems.
|
||||
# Uses module-level _get_gpus so tests can monkeypatch scripts.task_scheduler._get_gpus.
|
||||
try:
|
||||
gpus = _get_gpus()
|
||||
self._available_vram: float = (
|
||||
sum(g["vram_total_gb"] for g in gpus) if gpus else 999.0
|
||||
)
|
||||
except Exception:
|
||||
self._available_vram = 999.0
|
||||
|
||||
# Durability: reload surviving 'queued' LLM tasks from prior run
|
||||
self._load_queued_tasks()
|
||||
|
||||
def enqueue(self, task_id: int, task_type: str, job_id: int,
|
||||
params: Optional[str]) -> None:
|
||||
"""Add an LLM task to the scheduler queue.
|
||||
|
||||
If the queue for this type is at max_queue_depth, the task is marked
|
||||
failed in SQLite immediately (no ghost queued rows) and a warning is logged.
|
||||
"""
|
||||
from scripts.db import update_task_status
|
||||
|
||||
with self._lock:
|
||||
q = self._queues.setdefault(task_type, deque())
|
||||
if len(q) >= self._max_queue_depth:
|
||||
logger.warning(
|
||||
"Queue depth limit reached for %s (max=%d) — task %d dropped",
|
||||
task_type, self._max_queue_depth, task_id,
|
||||
)
|
||||
update_task_status(self._db_path, task_id, "failed",
|
||||
error="Queue depth limit reached")
|
||||
return
|
||||
q.append(TaskSpec(task_id, job_id, params))
|
||||
|
||||
self._wake.set()
|
||||
|
||||
def start(self) -> None:
|
||||
"""Start the background scheduler loop thread. Call once after construction."""
|
||||
self._thread = threading.Thread(
|
||||
target=self._scheduler_loop, name="task-scheduler", daemon=True
|
||||
)
|
||||
self._thread.start()
|
||||
|
||||
def shutdown(self, timeout: float = 5.0) -> None:
|
||||
"""Signal the scheduler to stop and wait for it to exit."""
|
||||
self._stop.set()
|
||||
self._wake.set() # unblock any wait()
|
||||
if self._thread and self._thread.is_alive():
|
||||
self._thread.join(timeout=timeout)
|
||||
|
||||
def _scheduler_loop(self) -> None:
|
||||
"""Main scheduler daemon — wakes on enqueue or batch completion."""
|
||||
while not self._stop.is_set():
|
||||
self._wake.wait(timeout=30)
|
||||
self._wake.clear()
|
||||
|
||||
with self._lock:
|
||||
# Defense in depth: reap externally-killed batch threads.
|
||||
# In normal operation _active.pop() runs in finally before _wake fires,
|
||||
# so this reap finds nothing — no double-decrement risk.
|
||||
for t, thread in list(self._active.items()):
|
||||
if not thread.is_alive():
|
||||
self._reserved_vram -= self._budgets.get(t, 0.0)
|
||||
del self._active[t]
|
||||
|
||||
# Start new type batches while VRAM allows
|
||||
candidates = sorted(
|
||||
[t for t in self._queues if self._queues[t] and t not in self._active],
|
||||
key=lambda t: len(self._queues[t]),
|
||||
reverse=True,
|
||||
)
|
||||
for task_type in candidates:
|
||||
budget = self._budgets.get(task_type, 0.0)
|
||||
# Always allow at least one batch to run even if its budget
|
||||
# exceeds _available_vram (prevents permanent starvation when
|
||||
# a single type's budget is larger than the VRAM ceiling).
|
||||
if self._reserved_vram == 0.0 or self._reserved_vram + budget <= self._available_vram:
|
||||
thread = threading.Thread(
|
||||
target=self._batch_worker,
|
||||
args=(task_type,),
|
||||
name=f"batch-{task_type}",
|
||||
daemon=True,
|
||||
)
|
||||
self._active[task_type] = thread
|
||||
self._reserved_vram += budget
|
||||
thread.start()
|
||||
|
||||
def _batch_worker(self, task_type: str) -> None:
|
||||
"""Serial consumer for one task type. Runs until the type's deque is empty."""
|
||||
try:
|
||||
while True:
|
||||
with self._lock:
|
||||
q = self._queues.get(task_type)
|
||||
if not q:
|
||||
break
|
||||
task = q.popleft()
|
||||
# _run_task is scripts.task_runner._run_task (passed at construction)
|
||||
self._run_task(
|
||||
self._db_path, task.id, task_type, task.job_id, task.params
|
||||
)
|
||||
finally:
|
||||
# Always release — even if _run_task raises.
|
||||
# _active.pop here prevents the scheduler loop reap from double-decrementing.
|
||||
with self._lock:
|
||||
self._active.pop(task_type, None)
|
||||
self._reserved_vram -= self._budgets.get(task_type, 0.0)
|
||||
self._wake.set()
|
||||
|
||||
def _load_queued_tasks(self) -> None:
|
||||
"""Load pre-existing queued LLM tasks from SQLite into deques (called once in __init__)."""
|
||||
llm_types = sorted(LLM_TASK_TYPES) # sorted for deterministic SQL params in logs
|
||||
placeholders = ",".join("?" * len(llm_types))
|
||||
conn = sqlite3.connect(self._db_path)
|
||||
rows = conn.execute(
|
||||
f"SELECT id, task_type, job_id, params FROM background_tasks"
|
||||
f" WHERE status='queued' AND task_type IN ({placeholders})"
|
||||
f" ORDER BY created_at ASC",
|
||||
llm_types,
|
||||
).fetchall()
|
||||
conn.close()
|
||||
|
||||
for row_id, task_type, job_id, params in rows:
|
||||
q = self._queues.setdefault(task_type, deque())
|
||||
q.append(TaskSpec(row_id, job_id, params))
|
||||
|
||||
if rows:
|
||||
logger.info("Scheduler: resumed %d queued task(s) from prior run", len(rows))
|
||||
|
||||
|
||||
# ── Singleton ─────────────────────────────────────────────────────────────────
|
||||
|
||||
_scheduler: Optional[TaskScheduler] = None
|
||||
_scheduler_lock = threading.Lock()
|
||||
|
||||
|
||||
def get_scheduler(db_path: Path, run_task_fn: Callable = None) -> TaskScheduler:
|
||||
"""Return the process-level TaskScheduler singleton, constructing it if needed.
|
||||
|
||||
run_task_fn is required on the first call; ignored on subsequent calls.
|
||||
Safety: inner lock + double-check prevents double-construction under races.
|
||||
The outer None check is a fast-path performance optimisation only.
|
||||
"""
|
||||
global _scheduler
|
||||
if _scheduler is None: # fast path — avoids lock on steady state
|
||||
with _scheduler_lock:
|
||||
if _scheduler is None: # re-check under lock (double-checked locking)
|
||||
if run_task_fn is None:
|
||||
raise ValueError("run_task_fn required on first get_scheduler() call")
|
||||
_scheduler = TaskScheduler(db_path, run_task_fn)
|
||||
_scheduler.start()
|
||||
return _scheduler
|
||||
|
||||
|
||||
def reset_scheduler() -> None:
|
||||
"""Shut down and clear the singleton. TEST TEARDOWN ONLY."""
|
||||
global _scheduler
|
||||
with _scheduler_lock:
|
||||
if _scheduler is not None:
|
||||
_scheduler.shutdown()
|
||||
_scheduler = None
|
||||
|
|
@ -6,6 +6,14 @@ from unittest.mock import patch
|
|||
import sqlite3
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def clean_scheduler():
|
||||
"""Reset the TaskScheduler singleton between tests to prevent cross-test contamination."""
|
||||
yield
|
||||
from scripts.task_scheduler import reset_scheduler
|
||||
reset_scheduler()
|
||||
|
||||
|
||||
def _make_db(tmp_path):
|
||||
from scripts.db import init_db, insert_job
|
||||
db = tmp_path / "test.db"
|
||||
|
|
@ -143,14 +151,20 @@ def test_run_task_email_sync_file_not_found(tmp_path):
|
|||
|
||||
|
||||
def test_submit_task_actually_completes(tmp_path):
|
||||
"""Integration: submit_task spawns a thread that completes asynchronously."""
|
||||
"""Integration: submit_task routes LLM tasks through the scheduler and they complete."""
|
||||
db, job_id = _make_db(tmp_path)
|
||||
from scripts.db import get_task_for_job
|
||||
from scripts.task_scheduler import get_scheduler
|
||||
from scripts.task_runner import _run_task
|
||||
|
||||
# Prime the singleton with the correct db_path before submit_task runs.
|
||||
# get_scheduler() already calls start() internally.
|
||||
get_scheduler(db, run_task_fn=_run_task)
|
||||
|
||||
with patch("scripts.generate_cover_letter.generate", return_value="Cover letter text"):
|
||||
from scripts.task_runner import submit_task
|
||||
task_id, _ = submit_task(db, "cover_letter", job_id)
|
||||
# Wait for thread to complete (max 5s)
|
||||
# Wait for scheduler to complete the task (max 5s)
|
||||
for _ in range(50):
|
||||
task = get_task_for_job(db, "cover_letter", job_id)
|
||||
if task and task["status"] in ("completed", "failed"):
|
||||
|
|
|
|||
472
tests/test_task_scheduler.py
Normal file
472
tests/test_task_scheduler.py
Normal file
|
|
@ -0,0 +1,472 @@
|
|||
# tests/test_task_scheduler.py
|
||||
"""Tests for scripts/task_scheduler.py and related db helpers."""
|
||||
import sqlite3
|
||||
import threading
|
||||
from collections import deque
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
from scripts.db import init_db, reset_running_tasks
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def tmp_db(tmp_path):
|
||||
db = tmp_path / "test.db"
|
||||
init_db(db)
|
||||
return db
|
||||
|
||||
|
||||
def test_reset_running_tasks_resets_only_running(tmp_db):
|
||||
"""reset_running_tasks() marks running→failed but leaves queued untouched."""
|
||||
conn = sqlite3.connect(tmp_db)
|
||||
conn.execute(
|
||||
"INSERT INTO background_tasks (task_type, job_id, status) VALUES (?,?,?)",
|
||||
("cover_letter", 1, "running"),
|
||||
)
|
||||
conn.execute(
|
||||
"INSERT INTO background_tasks (task_type, job_id, status) VALUES (?,?,?)",
|
||||
("company_research", 2, "queued"),
|
||||
)
|
||||
conn.commit()
|
||||
conn.close()
|
||||
|
||||
count = reset_running_tasks(tmp_db)
|
||||
|
||||
conn = sqlite3.connect(tmp_db)
|
||||
rows = {r[0]: r[1] for r in conn.execute(
|
||||
"SELECT task_type, status FROM background_tasks"
|
||||
).fetchall()}
|
||||
conn.close()
|
||||
|
||||
assert count == 1
|
||||
assert rows["cover_letter"] == "failed"
|
||||
assert rows["company_research"] == "queued"
|
||||
|
||||
|
||||
def test_reset_running_tasks_returns_zero_when_nothing_running(tmp_db):
|
||||
"""Returns 0 when no running tasks exist."""
|
||||
conn = sqlite3.connect(tmp_db)
|
||||
conn.execute(
|
||||
"INSERT INTO background_tasks (task_type, job_id, status) VALUES (?,?,?)",
|
||||
("cover_letter", 1, "queued"),
|
||||
)
|
||||
conn.commit()
|
||||
conn.close()
|
||||
|
||||
assert reset_running_tasks(tmp_db) == 0
|
||||
|
||||
|
||||
from scripts.task_scheduler import (
|
||||
TaskScheduler, LLM_TASK_TYPES, DEFAULT_VRAM_BUDGETS,
|
||||
get_scheduler, reset_scheduler,
|
||||
)
|
||||
|
||||
|
||||
def _noop_run_task(*args, **kwargs):
|
||||
"""Stand-in for _run_task that does nothing."""
|
||||
pass
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def clean_scheduler():
|
||||
"""Reset singleton between every test."""
|
||||
yield
|
||||
reset_scheduler()
|
||||
|
||||
|
||||
def test_default_budgets_used_when_no_config(tmp_db):
|
||||
"""Scheduler falls back to DEFAULT_VRAM_BUDGETS when config key absent."""
|
||||
s = TaskScheduler(tmp_db, _noop_run_task)
|
||||
assert s._budgets == DEFAULT_VRAM_BUDGETS
|
||||
|
||||
|
||||
def test_config_budgets_override_defaults(tmp_db, tmp_path):
|
||||
"""Values in llm.yaml scheduler.vram_budgets override defaults."""
|
||||
config_dir = tmp_db.parent.parent / "config"
|
||||
config_dir.mkdir(parents=True, exist_ok=True)
|
||||
(config_dir / "llm.yaml").write_text(
|
||||
"scheduler:\n vram_budgets:\n cover_letter: 9.9\n"
|
||||
)
|
||||
s = TaskScheduler(tmp_db, _noop_run_task)
|
||||
assert s._budgets["cover_letter"] == 9.9
|
||||
# Non-overridden keys still use defaults
|
||||
assert s._budgets["company_research"] == DEFAULT_VRAM_BUDGETS["company_research"]
|
||||
|
||||
|
||||
def test_missing_budget_logs_warning(tmp_db, caplog):
|
||||
"""A type in LLM_TASK_TYPES with no budget entry logs a warning."""
|
||||
import logging
|
||||
# Temporarily add a type with no budget
|
||||
original = LLM_TASK_TYPES.copy() if hasattr(LLM_TASK_TYPES, 'copy') else set(LLM_TASK_TYPES)
|
||||
from scripts import task_scheduler as ts
|
||||
ts.LLM_TASK_TYPES = frozenset(LLM_TASK_TYPES | {"orphan_type"})
|
||||
try:
|
||||
with caplog.at_level(logging.WARNING, logger="scripts.task_scheduler"):
|
||||
s = TaskScheduler(tmp_db, _noop_run_task)
|
||||
assert any("orphan_type" in r.message for r in caplog.records)
|
||||
finally:
|
||||
ts.LLM_TASK_TYPES = frozenset(original)
|
||||
|
||||
|
||||
def test_cpu_only_system_gets_unlimited_vram(tmp_db, monkeypatch):
|
||||
"""_available_vram is 999.0 when _get_gpus() returns empty list."""
|
||||
# Patch the module-level _get_gpus in task_scheduler (not preflight)
|
||||
# so __init__'s _ts_mod._get_gpus() call picks up the mock.
|
||||
monkeypatch.setattr("scripts.task_scheduler._get_gpus", lambda: [])
|
||||
s = TaskScheduler(tmp_db, _noop_run_task)
|
||||
assert s._available_vram == 999.0
|
||||
|
||||
|
||||
def test_gpu_vram_summed_across_all_gpus(tmp_db, monkeypatch):
|
||||
"""_available_vram sums vram_total_gb across all detected GPUs."""
|
||||
fake_gpus = [
|
||||
{"name": "RTX 3090", "vram_total_gb": 24.0, "vram_free_gb": 20.0},
|
||||
{"name": "RTX 3090", "vram_total_gb": 24.0, "vram_free_gb": 18.0},
|
||||
]
|
||||
monkeypatch.setattr("scripts.task_scheduler._get_gpus", lambda: fake_gpus)
|
||||
s = TaskScheduler(tmp_db, _noop_run_task)
|
||||
assert s._available_vram == 48.0
|
||||
|
||||
|
||||
def test_enqueue_adds_taskspec_to_deque(tmp_db):
|
||||
"""enqueue() appends a TaskSpec to the correct per-type deque."""
|
||||
s = TaskScheduler(tmp_db, _noop_run_task)
|
||||
s.enqueue(1, "cover_letter", 10, None)
|
||||
s.enqueue(2, "cover_letter", 11, '{"key": "val"}')
|
||||
|
||||
assert len(s._queues["cover_letter"]) == 2
|
||||
assert s._queues["cover_letter"][0].id == 1
|
||||
assert s._queues["cover_letter"][1].id == 2
|
||||
|
||||
|
||||
def test_enqueue_wakes_scheduler(tmp_db):
|
||||
"""enqueue() sets the _wake event so the scheduler loop re-evaluates."""
|
||||
s = TaskScheduler(tmp_db, _noop_run_task)
|
||||
assert not s._wake.is_set()
|
||||
s.enqueue(1, "cover_letter", 10, None)
|
||||
assert s._wake.is_set()
|
||||
|
||||
|
||||
def test_max_queue_depth_marks_task_failed(tmp_db):
|
||||
"""When queue is at max_queue_depth, dropped task is marked failed in DB."""
|
||||
from scripts.db import insert_task
|
||||
|
||||
s = TaskScheduler(tmp_db, _noop_run_task)
|
||||
s._max_queue_depth = 2
|
||||
|
||||
# Fill the queue to the limit via direct deque manipulation (no DB rows needed)
|
||||
from scripts.task_scheduler import TaskSpec
|
||||
s._queues.setdefault("cover_letter", deque())
|
||||
s._queues["cover_letter"].append(TaskSpec(99, 1, None))
|
||||
s._queues["cover_letter"].append(TaskSpec(100, 2, None))
|
||||
|
||||
# Insert a real DB row for the task we're about to drop
|
||||
task_id, _ = insert_task(tmp_db, "cover_letter", 3)
|
||||
|
||||
# This enqueue should be rejected and the DB row marked failed
|
||||
s.enqueue(task_id, "cover_letter", 3, None)
|
||||
|
||||
conn = sqlite3.connect(tmp_db)
|
||||
row = conn.execute(
|
||||
"SELECT status, error FROM background_tasks WHERE id=?", (task_id,)
|
||||
).fetchone()
|
||||
conn.close()
|
||||
|
||||
assert row[0] == "failed"
|
||||
assert "depth" in row[1].lower()
|
||||
# Queue length unchanged
|
||||
assert len(s._queues["cover_letter"]) == 2
|
||||
|
||||
|
||||
def test_max_queue_depth_logs_warning(tmp_db, caplog):
|
||||
"""Queue depth overflow logs a WARNING."""
|
||||
import logging
|
||||
from scripts.db import insert_task
|
||||
from scripts.task_scheduler import TaskSpec
|
||||
|
||||
s = TaskScheduler(tmp_db, _noop_run_task)
|
||||
s._max_queue_depth = 0 # immediately at limit
|
||||
|
||||
task_id, _ = insert_task(tmp_db, "cover_letter", 1)
|
||||
with caplog.at_level(logging.WARNING, logger="scripts.task_scheduler"):
|
||||
s.enqueue(task_id, "cover_letter", 1, None)
|
||||
|
||||
assert any("depth" in r.message.lower() for r in caplog.records)
|
||||
|
||||
|
||||
# ── Threading helpers ─────────────────────────────────────────────────────────
|
||||
|
||||
def _make_recording_run_task(log: list, done_event: threading.Event, expected: int):
|
||||
"""Returns a mock _run_task that records (task_id, task_type) and sets done when expected count reached."""
|
||||
def _run(db_path, task_id, task_type, job_id, params):
|
||||
log.append((task_id, task_type))
|
||||
if len(log) >= expected:
|
||||
done_event.set()
|
||||
return _run
|
||||
|
||||
|
||||
def _start_scheduler(tmp_db, run_task_fn, available_vram=999.0):
|
||||
s = TaskScheduler(tmp_db, run_task_fn)
|
||||
s._available_vram = available_vram
|
||||
s.start()
|
||||
return s
|
||||
|
||||
|
||||
# ── Tests ─────────────────────────────────────────────────────────────────────
|
||||
|
||||
def test_deepest_queue_wins_first_slot(tmp_db):
|
||||
"""Type with more queued tasks starts first when VRAM only fits one type."""
|
||||
log, done = [], threading.Event()
|
||||
|
||||
# Build scheduler but DO NOT start it yet — enqueue all tasks first
|
||||
# so the scheduler sees the full picture on its very first wake.
|
||||
run_task_fn = _make_recording_run_task(log, done, 4)
|
||||
s = TaskScheduler(tmp_db, run_task_fn)
|
||||
s._available_vram = 3.0 # fits cover_letter (2.5) but not +company_research (5.0)
|
||||
|
||||
# Enqueue cover_letter (3 tasks) and company_research (1 task) before start.
|
||||
# cover_letter has the deeper queue and must win the first batch slot.
|
||||
for i in range(3):
|
||||
s.enqueue(i + 1, "cover_letter", i + 1, None)
|
||||
s.enqueue(4, "company_research", 4, None)
|
||||
|
||||
s.start() # scheduler now sees all tasks atomically on its first iteration
|
||||
assert done.wait(timeout=5.0), "timed out — not all 4 tasks completed"
|
||||
s.shutdown()
|
||||
|
||||
assert len(log) == 4
|
||||
cl = [i for i, (_, t) in enumerate(log) if t == "cover_letter"]
|
||||
cr = [i for i, (_, t) in enumerate(log) if t == "company_research"]
|
||||
assert len(cl) == 3 and len(cr) == 1
|
||||
assert max(cl) < min(cr), "All cover_letter tasks must finish before company_research starts"
|
||||
|
||||
|
||||
def test_fifo_within_type(tmp_db):
|
||||
"""Tasks of the same type execute in arrival (FIFO) order."""
|
||||
log, done = [], threading.Event()
|
||||
s = _start_scheduler(tmp_db, _make_recording_run_task(log, done, 3))
|
||||
|
||||
for task_id in [10, 20, 30]:
|
||||
s.enqueue(task_id, "cover_letter", task_id, None)
|
||||
|
||||
assert done.wait(timeout=5.0), "timed out — not all 3 tasks completed"
|
||||
s.shutdown()
|
||||
|
||||
assert [task_id for task_id, _ in log] == [10, 20, 30]
|
||||
|
||||
|
||||
def test_concurrent_batches_when_vram_allows(tmp_db):
|
||||
"""Two type batches start simultaneously when VRAM fits both."""
|
||||
started = {"cover_letter": threading.Event(), "company_research": threading.Event()}
|
||||
all_done = threading.Event()
|
||||
log = []
|
||||
|
||||
def run_task(db_path, task_id, task_type, job_id, params):
|
||||
started[task_type].set()
|
||||
log.append(task_type)
|
||||
if len(log) >= 2:
|
||||
all_done.set()
|
||||
|
||||
# VRAM=10.0 fits both cover_letter (2.5) and company_research (5.0) simultaneously
|
||||
s = _start_scheduler(tmp_db, run_task, available_vram=10.0)
|
||||
s.enqueue(1, "cover_letter", 1, None)
|
||||
s.enqueue(2, "company_research", 2, None)
|
||||
|
||||
all_done.wait(timeout=5.0)
|
||||
s.shutdown()
|
||||
|
||||
# Both types should have started (possibly overlapping)
|
||||
assert started["cover_letter"].is_set()
|
||||
assert started["company_research"].is_set()
|
||||
|
||||
|
||||
def test_new_tasks_picked_up_mid_batch(tmp_db):
|
||||
"""A task enqueued while a batch is running is consumed in the same batch."""
|
||||
log, done = [], threading.Event()
|
||||
task1_started = threading.Event() # fires when task 1 begins executing
|
||||
task2_ready = threading.Event() # fires when task 2 has been enqueued
|
||||
|
||||
def run_task(db_path, task_id, task_type, job_id, params):
|
||||
if task_id == 1:
|
||||
task1_started.set() # signal: task 1 is now running
|
||||
task2_ready.wait(timeout=2.0) # wait for task 2 to be in the deque
|
||||
log.append(task_id)
|
||||
if len(log) >= 2:
|
||||
done.set()
|
||||
|
||||
s = _start_scheduler(tmp_db, run_task)
|
||||
s.enqueue(1, "cover_letter", 1, None)
|
||||
task1_started.wait(timeout=2.0) # wait until task 1 is actually executing
|
||||
s.enqueue(2, "cover_letter", 2, None)
|
||||
task2_ready.set() # unblock task 1 so it finishes
|
||||
|
||||
assert done.wait(timeout=5.0), "timed out — task 2 never picked up mid-batch"
|
||||
s.shutdown()
|
||||
|
||||
assert log == [1, 2]
|
||||
|
||||
|
||||
def test_worker_crash_releases_vram(tmp_db):
|
||||
"""If _run_task raises, _reserved_vram returns to 0 and scheduler continues."""
|
||||
log, done = [], threading.Event()
|
||||
|
||||
def run_task(db_path, task_id, task_type, job_id, params):
|
||||
if task_id == 1:
|
||||
raise RuntimeError("simulated failure")
|
||||
log.append(task_id)
|
||||
done.set()
|
||||
|
||||
s = _start_scheduler(tmp_db, run_task, available_vram=3.0)
|
||||
s.enqueue(1, "cover_letter", 1, None)
|
||||
s.enqueue(2, "cover_letter", 2, None)
|
||||
|
||||
assert done.wait(timeout=5.0), "timed out — task 2 never completed after task 1 crash"
|
||||
s.shutdown()
|
||||
|
||||
# Second task still ran, VRAM was released
|
||||
assert 2 in log
|
||||
assert s._reserved_vram == 0.0
|
||||
|
||||
|
||||
def test_get_scheduler_returns_singleton(tmp_db):
|
||||
"""Multiple calls to get_scheduler() return the same instance."""
|
||||
s1 = get_scheduler(tmp_db, _noop_run_task)
|
||||
s2 = get_scheduler(tmp_db, _noop_run_task)
|
||||
assert s1 is s2
|
||||
|
||||
|
||||
def test_singleton_thread_safe(tmp_db):
|
||||
"""Concurrent get_scheduler() calls produce exactly one instance."""
|
||||
instances = []
|
||||
errors = []
|
||||
|
||||
def _get():
|
||||
try:
|
||||
instances.append(get_scheduler(tmp_db, _noop_run_task))
|
||||
except Exception as e:
|
||||
errors.append(e)
|
||||
|
||||
threads = [threading.Thread(target=_get) for _ in range(20)]
|
||||
for t in threads:
|
||||
t.start()
|
||||
for t in threads:
|
||||
t.join()
|
||||
|
||||
assert not errors
|
||||
assert len(set(id(s) for s in instances)) == 1 # all the same object
|
||||
|
||||
|
||||
def test_reset_scheduler_cleans_up(tmp_db):
|
||||
"""reset_scheduler() shuts down the scheduler; no threads linger."""
|
||||
s = get_scheduler(tmp_db, _noop_run_task)
|
||||
thread = s._thread
|
||||
assert thread.is_alive()
|
||||
|
||||
reset_scheduler()
|
||||
|
||||
thread.join(timeout=2.0)
|
||||
assert not thread.is_alive()
|
||||
|
||||
# After reset, get_scheduler creates a fresh instance
|
||||
s2 = get_scheduler(tmp_db, _noop_run_task)
|
||||
assert s2 is not s
|
||||
|
||||
|
||||
def test_durability_loads_queued_llm_tasks_on_startup(tmp_db):
|
||||
"""Scheduler loads pre-existing queued LLM tasks into deques at construction."""
|
||||
from scripts.db import insert_task
|
||||
|
||||
# Pre-insert queued rows simulating a prior run
|
||||
id1, _ = insert_task(tmp_db, "cover_letter", 1)
|
||||
id2, _ = insert_task(tmp_db, "company_research", 2)
|
||||
|
||||
s = TaskScheduler(tmp_db, _noop_run_task)
|
||||
|
||||
assert len(s._queues.get("cover_letter", [])) == 1
|
||||
assert s._queues["cover_letter"][0].id == id1
|
||||
assert len(s._queues.get("company_research", [])) == 1
|
||||
assert s._queues["company_research"][0].id == id2
|
||||
|
||||
|
||||
def test_durability_excludes_non_llm_queued_tasks(tmp_db):
|
||||
"""Non-LLM queued tasks are not loaded into the scheduler deques."""
|
||||
from scripts.db import insert_task
|
||||
|
||||
insert_task(tmp_db, "discovery", 0)
|
||||
insert_task(tmp_db, "email_sync", 0)
|
||||
|
||||
s = TaskScheduler(tmp_db, _noop_run_task)
|
||||
|
||||
assert "discovery" not in s._queues or len(s._queues["discovery"]) == 0
|
||||
assert "email_sync" not in s._queues or len(s._queues["email_sync"]) == 0
|
||||
|
||||
|
||||
def test_durability_preserves_fifo_order(tmp_db):
|
||||
"""Queued tasks are loaded in created_at (FIFO) order."""
|
||||
conn = sqlite3.connect(tmp_db)
|
||||
# Insert with explicit timestamps to control order
|
||||
conn.execute(
|
||||
"INSERT INTO background_tasks (task_type, job_id, params, status, created_at)"
|
||||
" VALUES (?,?,?,?,?)", ("cover_letter", 1, None, "queued", "2026-01-01 10:00:00")
|
||||
)
|
||||
conn.execute(
|
||||
"INSERT INTO background_tasks (task_type, job_id, params, status, created_at)"
|
||||
" VALUES (?,?,?,?,?)", ("cover_letter", 2, None, "queued", "2026-01-01 09:00:00")
|
||||
)
|
||||
conn.commit()
|
||||
ids = [r[0] for r in conn.execute(
|
||||
"SELECT id FROM background_tasks ORDER BY created_at ASC"
|
||||
).fetchall()]
|
||||
conn.close()
|
||||
|
||||
s = TaskScheduler(tmp_db, _noop_run_task)
|
||||
|
||||
loaded_ids = [t.id for t in s._queues["cover_letter"]]
|
||||
assert loaded_ids == ids
|
||||
|
||||
|
||||
def test_non_llm_tasks_bypass_scheduler(tmp_db):
|
||||
"""submit_task() for non-LLM types invoke _run_task directly, not enqueue()."""
|
||||
from scripts import task_runner
|
||||
|
||||
# Initialize the singleton properly so submit_task routes correctly
|
||||
s = get_scheduler(tmp_db, _noop_run_task)
|
||||
|
||||
run_task_calls = []
|
||||
enqueue_calls = []
|
||||
|
||||
original_run_task = task_runner._run_task
|
||||
original_enqueue = s.enqueue
|
||||
|
||||
def recording_run_task(*args, **kwargs):
|
||||
run_task_calls.append(args[2]) # task_type is 3rd arg
|
||||
|
||||
def recording_enqueue(task_id, task_type, job_id, params):
|
||||
enqueue_calls.append(task_type)
|
||||
|
||||
import unittest.mock as mock
|
||||
with mock.patch.object(task_runner, "_run_task", recording_run_task), \
|
||||
mock.patch.object(s, "enqueue", recording_enqueue):
|
||||
task_runner.submit_task(tmp_db, "discovery", 0)
|
||||
|
||||
# discovery goes directly to _run_task; enqueue is never called
|
||||
assert "discovery" not in enqueue_calls
|
||||
# The scheduler deque is untouched
|
||||
assert "discovery" not in s._queues or len(s._queues["discovery"]) == 0
|
||||
|
||||
|
||||
def test_llm_tasks_routed_to_scheduler(tmp_db):
|
||||
"""submit_task() for LLM types calls enqueue(), not _run_task directly."""
|
||||
from scripts import task_runner
|
||||
|
||||
s = get_scheduler(tmp_db, _noop_run_task)
|
||||
|
||||
enqueue_calls = []
|
||||
original_enqueue = s.enqueue
|
||||
|
||||
import unittest.mock as mock
|
||||
with mock.patch.object(s, "enqueue", side_effect=lambda *a, **kw: enqueue_calls.append(a[1]) or original_enqueue(*a, **kw)):
|
||||
task_runner.submit_task(tmp_db, "cover_letter", 1)
|
||||
|
||||
assert "cover_letter" in enqueue_calls
|
||||
Loading…
Reference in a new issue