refactor(scheduler): shim to circuitforge_core.tasks.scheduler

VRAM detection now uses cf-orch free VRAM when coordinator is running,
making the scheduler cooperative with other cf-orch consumers.
Enqueue return value now checked — queue-full tasks are marked failed.
This commit is contained in:
pyr0ball 2026-03-31 09:27:43 -07:00
parent 818e46c17e
commit 922d91fb91
3 changed files with 183 additions and 178 deletions

View file

@ -9,10 +9,13 @@ and marks the task completed or failed.
Deduplication: only one queued/running task per (task_type, job_id) is allowed. Deduplication: only one queued/running task per (task_type, job_id) is allowed.
Different task types for the same job run concurrently (e.g. cover letter + research). Different task types for the same job run concurrently (e.g. cover letter + research).
""" """
import logging
import sqlite3 import sqlite3
import threading import threading
from pathlib import Path from pathlib import Path
log = logging.getLogger(__name__)
from scripts.db import ( from scripts.db import (
DEFAULT_DB, DEFAULT_DB,
insert_task, insert_task,
@ -20,6 +23,7 @@ from scripts.db import (
update_task_stage, update_task_stage,
update_cover_letter, update_cover_letter,
save_research, save_research,
save_optimized_resume,
) )
@ -39,9 +43,13 @@ def submit_task(db_path: Path = DEFAULT_DB, task_type: str = "",
if is_new: if is_new:
from scripts.task_scheduler import get_scheduler, LLM_TASK_TYPES from scripts.task_scheduler import get_scheduler, LLM_TASK_TYPES
if task_type in LLM_TASK_TYPES: if task_type in LLM_TASK_TYPES:
get_scheduler(db_path, run_task_fn=_run_task).enqueue( enqueued = get_scheduler(db_path, run_task_fn=_run_task).enqueue(
task_id, task_type, job_id or 0, params task_id, task_type, job_id or 0, params
) )
if not enqueued:
update_task_status(
db_path, task_id, "failed", error="Queue depth limit reached"
)
else: else:
t = threading.Thread( t = threading.Thread(
target=_run_task, target=_run_task,
@ -261,6 +269,48 @@ def _run_task(db_path: Path, task_id: int, task_type: str, job_id: int,
) )
return return
elif task_type == "resume_optimize":
import json as _json
from scripts.resume_parser import structure_resume
from scripts.resume_optimizer import (
extract_jd_signals,
prioritize_gaps,
rewrite_for_ats,
hallucination_check,
render_resume_text,
)
from scripts.user_profile import load_user_profile
description = job.get("description", "")
resume_path = load_user_profile().get("resume_path", "")
# Parse the candidate's resume
update_task_stage(db_path, task_id, "parsing resume")
resume_text = Path(resume_path).read_text(errors="replace") if resume_path else ""
resume_struct, parse_err = structure_resume(resume_text)
# Extract keyword gaps and build gap report (free tier)
update_task_stage(db_path, task_id, "extracting keyword gaps")
gaps = extract_jd_signals(description, resume_text)
prioritized = prioritize_gaps(gaps, resume_struct)
gap_report = _json.dumps(prioritized, indent=2)
# Full rewrite (paid tier only)
rewritten_text = ""
p = _json.loads(params or "{}")
if p.get("full_rewrite", False):
update_task_stage(db_path, task_id, "rewriting resume sections")
candidate_voice = load_user_profile().get("candidate_voice", "")
rewritten = rewrite_for_ats(resume_struct, prioritized, job, candidate_voice)
if hallucination_check(resume_struct, rewritten):
rewritten_text = render_resume_text(rewritten)
else:
log.warning("[task_runner] resume_optimize hallucination check failed for job %d", job_id)
save_optimized_resume(db_path, job_id=job_id,
text=rewritten_text,
gap_report=gap_report)
elif task_type == "prepare_training": elif task_type == "prepare_training":
from scripts.prepare_training_data import build_records, write_jsonl, DEFAULT_OUTPUT from scripts.prepare_training_data import build_records, write_jsonl, DEFAULT_OUTPUT
records = build_records() records = build_records()

View file

@ -1,232 +1,176 @@
# scripts/task_scheduler.py # scripts/task_scheduler.py
"""Resource-aware batch scheduler for LLM background tasks. """Peregrine LLM task scheduler — thin shim over circuitforge_core.tasks.scheduler.
Routes LLM task types through per-type deques with VRAM-aware scheduling. All scheduling logic lives in circuitforge_core. This module defines
Non-LLM tasks bypass this module routing lives in scripts/task_runner.py. Peregrine-specific task types, VRAM budgets, and config loading.
Public API: Public API (unchanged callers do not need to change):
LLM_TASK_TYPES set of task type strings routed through the scheduler LLM_TASK_TYPES frozenset of task type strings routed through the scheduler
get_scheduler() lazy singleton accessor DEFAULT_VRAM_BUDGETS dict of conservative peak VRAM estimates per task type
TaskSpec lightweight task descriptor (re-exported from core)
TaskScheduler backward-compatible wrapper around the core scheduler class
get_scheduler() returns the process-level TaskScheduler singleton
reset_scheduler() test teardown only reset_scheduler() test teardown only
""" """
from __future__ import annotations
import logging import logging
import sqlite3
import threading import threading
from collections import deque, namedtuple
from pathlib import Path from pathlib import Path
from typing import Callable, Optional from typing import Callable, Optional
# Module-level import so tests can monkeypatch scripts.task_scheduler._get_gpus from circuitforge_core.tasks.scheduler import (
try: TaskSpec, # re-export unchanged
from scripts.preflight import get_gpus as _get_gpus TaskScheduler as _CoreTaskScheduler,
except Exception: # graceful degradation if preflight unavailable )
_get_gpus = lambda: []
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
# Task types that go through the scheduler (all others spawn free threads) # ── Peregrine task types and VRAM budgets ─────────────────────────────────────
LLM_TASK_TYPES: frozenset[str] = frozenset({ LLM_TASK_TYPES: frozenset[str] = frozenset({
"cover_letter", "cover_letter",
"company_research", "company_research",
"wizard_generate", "wizard_generate",
"resume_optimize",
}) })
# Conservative peak VRAM estimates (GB) per task type. # Conservative peak VRAM estimates (GB) per task type.
# Overridable per-install via scheduler.vram_budgets in config/llm.yaml. # Overridable per-install via scheduler.vram_budgets in config/llm.yaml.
DEFAULT_VRAM_BUDGETS: dict[str, float] = { DEFAULT_VRAM_BUDGETS: dict[str, float] = {
"cover_letter": 2.5, # alex-cover-writer:latest (~2GB GGUF + headroom) "cover_letter": 2.5, # alex-cover-writer:latest (~2 GB GGUF + headroom)
"company_research": 5.0, # llama3.1:8b or vllm model "company_research": 5.0, # llama3.1:8b or vllm model
"wizard_generate": 2.5, # same model family as cover_letter "wizard_generate": 2.5, # same model family as cover_letter
"resume_optimize": 5.0, # section-by-section rewrite; same budget as research
} }
# Lightweight task descriptor stored in per-type deques _DEFAULT_MAX_QUEUE_DEPTH = 500
TaskSpec = namedtuple("TaskSpec", ["id", "job_id", "params"])
class TaskScheduler: def _load_config_overrides(db_path: Path) -> tuple[dict[str, float], int]:
"""Resource-aware LLM task batch scheduler. Use get_scheduler() — not direct construction.""" """Load VRAM budget overrides and max_queue_depth from config/llm.yaml."""
budgets = dict(DEFAULT_VRAM_BUDGETS)
def __init__(self, db_path: Path, run_task_fn: Callable) -> None: max_depth = _DEFAULT_MAX_QUEUE_DEPTH
self._db_path = db_path
self._run_task = run_task_fn
self._lock = threading.Lock()
self._wake = threading.Event()
self._stop = threading.Event()
self._queues: dict[str, deque] = {}
self._active: dict[str, threading.Thread] = {}
self._reserved_vram: float = 0.0
self._thread: Optional[threading.Thread] = None
# Load VRAM budgets: defaults + optional config overrides
self._budgets: dict[str, float] = dict(DEFAULT_VRAM_BUDGETS)
config_path = db_path.parent.parent / "config" / "llm.yaml" config_path = db_path.parent.parent / "config" / "llm.yaml"
self._max_queue_depth: int = 500
if config_path.exists(): if config_path.exists():
try: try:
import yaml import yaml
with open(config_path) as f: with open(config_path) as f:
cfg = yaml.safe_load(f) or {} cfg = yaml.safe_load(f) or {}
sched_cfg = cfg.get("scheduler", {}) sched_cfg = cfg.get("scheduler", {})
self._budgets.update(sched_cfg.get("vram_budgets", {})) budgets.update(sched_cfg.get("vram_budgets", {}))
self._max_queue_depth = sched_cfg.get("max_queue_depth", 500) max_depth = int(sched_cfg.get("max_queue_depth", max_depth))
except Exception as exc: except Exception as exc:
logger.warning("Failed to load scheduler config from %s: %s", config_path, exc) logger.warning(
"Failed to load scheduler config from %s: %s", config_path, exc
)
return budgets, max_depth
# Warn on LLM types with no budget entry after merge
# Module-level stub so tests can monkeypatch scripts.task_scheduler._get_gpus
# (existing tests monkeypatch this symbol — keep it here for backward compat).
try:
from scripts.preflight import get_gpus as _get_gpus
except Exception:
_get_gpus = lambda: [] # noqa: E731
class TaskScheduler(_CoreTaskScheduler):
"""Peregrine-specific TaskScheduler.
Extends circuitforge_core.tasks.scheduler.TaskScheduler with:
- Peregrine default VRAM budgets and task types wired into __init__
- Config loading from config/llm.yaml
- Backward-compatible two-argument __init__ signature (db_path, run_task_fn)
- _get_gpus monkeypatch support (existing tests patch this module-level symbol)
- Backward-compatible enqueue() that marks dropped tasks failed in the DB
and logs under the scripts.task_scheduler logger
Direct construction is still supported for tests; production code should
use get_scheduler() instead.
"""
def __init__(self, db_path: Path, run_task_fn: Callable) -> None:
budgets, max_depth = _load_config_overrides(db_path)
# Resolve VRAM using module-level _get_gpus so tests can monkeypatch it
try:
gpus = _get_gpus()
available_vram: float = (
sum(g["vram_total_gb"] for g in gpus) if gpus else 999.0
)
except Exception:
available_vram = 999.0
# Warn under this module's logger for any task types with no VRAM budget
# (mirrors the core warning but captures under scripts.task_scheduler
# so existing tests using caplog.at_level(logger="scripts.task_scheduler") pass)
for t in LLM_TASK_TYPES: for t in LLM_TASK_TYPES:
if t not in self._budgets: if t not in budgets:
logger.warning( logger.warning(
"No VRAM budget defined for LLM task type %r" "No VRAM budget defined for LLM task type %r"
"defaulting to 0.0 GB (unlimited concurrency for this type)", t "defaulting to 0.0 GB (unlimited concurrency for this type)", t
) )
# Detect total GPU VRAM; fall back to unlimited (999) on CPU-only systems. super().__init__(
# Uses module-level _get_gpus so tests can monkeypatch scripts.task_scheduler._get_gpus. db_path=db_path,
try: run_task_fn=run_task_fn,
gpus = _get_gpus() task_types=LLM_TASK_TYPES,
self._available_vram: float = ( vram_budgets=budgets,
sum(g["vram_total_gb"] for g in gpus) if gpus else 999.0 available_vram_gb=available_vram,
max_queue_depth=max_depth,
) )
except Exception:
self._available_vram = 999.0
# Durability: reload surviving 'queued' LLM tasks from prior run def enqueue(
self._load_queued_tasks() self,
task_id: int,
def enqueue(self, task_id: int, task_type: str, job_id: int, task_type: str,
params: Optional[str]) -> None: job_id: int,
params: Optional[str],
) -> bool:
"""Add an LLM task to the scheduler queue. """Add an LLM task to the scheduler queue.
If the queue for this type is at max_queue_depth, the task is marked When the queue is full, marks the task failed in SQLite immediately
failed in SQLite immediately (no ghost queued rows) and a warning is logged. (backward-compatible with the original Peregrine behavior) and logs a
""" warning under the scripts.task_scheduler logger.
from scripts.db import update_task_status
with self._lock: Returns True if enqueued, False if the queue was full.
q = self._queues.setdefault(task_type, deque()) """
if len(q) >= self._max_queue_depth: enqueued = super().enqueue(task_id, task_type, job_id, params)
if not enqueued:
# Log under this module's logger so existing caplog tests pass
logger.warning( logger.warning(
"Queue depth limit reached for %s (max=%d) — task %d dropped", "Queue depth limit reached for %s (max=%d) — task %d dropped",
task_type, self._max_queue_depth, task_id, task_type, self._max_queue_depth, task_id,
) )
update_task_status(self._db_path, task_id, "failed", from scripts.db import update_task_status
error="Queue depth limit reached") update_task_status(
return self._db_path, task_id, "failed", error="Queue depth limit reached"
q.append(TaskSpec(task_id, job_id, params))
self._wake.set()
def start(self) -> None:
"""Start the background scheduler loop thread. Call once after construction."""
self._thread = threading.Thread(
target=self._scheduler_loop, name="task-scheduler", daemon=True
) )
self._thread.start() return enqueued
def shutdown(self, timeout: float = 5.0) -> None:
"""Signal the scheduler to stop and wait for it to exit."""
self._stop.set()
self._wake.set() # unblock any wait()
if self._thread and self._thread.is_alive():
self._thread.join(timeout=timeout)
def _scheduler_loop(self) -> None:
"""Main scheduler daemon — wakes on enqueue or batch completion."""
while not self._stop.is_set():
self._wake.wait(timeout=30)
self._wake.clear()
with self._lock:
# Defense in depth: reap externally-killed batch threads.
# In normal operation _active.pop() runs in finally before _wake fires,
# so this reap finds nothing — no double-decrement risk.
for t, thread in list(self._active.items()):
if not thread.is_alive():
self._reserved_vram -= self._budgets.get(t, 0.0)
del self._active[t]
# Start new type batches while VRAM allows
candidates = sorted(
[t for t in self._queues if self._queues[t] and t not in self._active],
key=lambda t: len(self._queues[t]),
reverse=True,
)
for task_type in candidates:
budget = self._budgets.get(task_type, 0.0)
# Always allow at least one batch to run even if its budget
# exceeds _available_vram (prevents permanent starvation when
# a single type's budget is larger than the VRAM ceiling).
if self._reserved_vram == 0.0 or self._reserved_vram + budget <= self._available_vram:
thread = threading.Thread(
target=self._batch_worker,
args=(task_type,),
name=f"batch-{task_type}",
daemon=True,
)
self._active[task_type] = thread
self._reserved_vram += budget
thread.start()
def _batch_worker(self, task_type: str) -> None:
"""Serial consumer for one task type. Runs until the type's deque is empty."""
try:
while True:
with self._lock:
q = self._queues.get(task_type)
if not q:
break
task = q.popleft()
# _run_task is scripts.task_runner._run_task (passed at construction)
self._run_task(
self._db_path, task.id, task_type, task.job_id, task.params
)
finally:
# Always release — even if _run_task raises.
# _active.pop here prevents the scheduler loop reap from double-decrementing.
with self._lock:
self._active.pop(task_type, None)
self._reserved_vram -= self._budgets.get(task_type, 0.0)
self._wake.set()
def _load_queued_tasks(self) -> None:
"""Load pre-existing queued LLM tasks from SQLite into deques (called once in __init__)."""
llm_types = sorted(LLM_TASK_TYPES) # sorted for deterministic SQL params in logs
placeholders = ",".join("?" * len(llm_types))
conn = sqlite3.connect(self._db_path)
rows = conn.execute(
f"SELECT id, task_type, job_id, params FROM background_tasks"
f" WHERE status='queued' AND task_type IN ({placeholders})"
f" ORDER BY created_at ASC",
llm_types,
).fetchall()
conn.close()
for row_id, task_type, job_id, params in rows:
q = self._queues.setdefault(task_type, deque())
q.append(TaskSpec(row_id, job_id, params))
if rows:
logger.info("Scheduler: resumed %d queued task(s) from prior run", len(rows))
# ── Singleton ───────────────────────────────────────────────────────────────── # ── Peregrine-local singleton ──────────────────────────────────────────────────
# We manage our own singleton (not the core one) so the process-level instance
# is always a Peregrine TaskScheduler (with the enqueue() override).
_scheduler: Optional[TaskScheduler] = None _scheduler: Optional[TaskScheduler] = None
_scheduler_lock = threading.Lock() _scheduler_lock = threading.Lock()
def get_scheduler(db_path: Path, run_task_fn: Callable = None) -> TaskScheduler: def get_scheduler(
"""Return the process-level TaskScheduler singleton, constructing it if needed. db_path: Path,
run_task_fn: Optional[Callable] = None,
) -> TaskScheduler:
"""Return the process-level Peregrine TaskScheduler singleton.
run_task_fn is required on the first call; ignored on subsequent calls. run_task_fn is required on the first call; ignored on subsequent calls
Safety: inner lock + double-check prevents double-construction under races. (double-checked locking singleton already constructed).
The outer None check is a fast-path performance optimisation only.
""" """
global _scheduler global _scheduler
if _scheduler is None: # fast path — avoids lock on steady state if _scheduler is None: # fast path — no lock on steady state
with _scheduler_lock: with _scheduler_lock:
if _scheduler is None: # re-check under lock (double-checked locking) if _scheduler is None: # re-check under lock
if run_task_fn is None: if run_task_fn is None:
raise ValueError("run_task_fn required on first get_scheduler() call") raise ValueError("run_task_fn required on first get_scheduler() call")
_scheduler = TaskScheduler(db_path, run_task_fn) _scheduler = TaskScheduler(db_path, run_task_fn)

View file

@ -470,3 +470,14 @@ def test_llm_tasks_routed_to_scheduler(tmp_db):
task_runner.submit_task(tmp_db, "cover_letter", 1) task_runner.submit_task(tmp_db, "cover_letter", 1)
assert "cover_letter" in enqueue_calls assert "cover_letter" in enqueue_calls
def test_shim_exports_unchanged_api():
"""Peregrine shim must re-export LLM_TASK_TYPES, get_scheduler, reset_scheduler."""
from scripts.task_scheduler import LLM_TASK_TYPES, get_scheduler, reset_scheduler
assert "cover_letter" in LLM_TASK_TYPES
assert "company_research" in LLM_TASK_TYPES
assert "wizard_generate" in LLM_TASK_TYPES
assert "resume_optimize" in LLM_TASK_TYPES
assert callable(get_scheduler)
assert callable(reset_scheduler)