refactor(scheduler): shim to circuitforge_core.tasks.scheduler
VRAM detection now uses cf-orch free VRAM when coordinator is running, making the scheduler cooperative with other cf-orch consumers. Enqueue return value now checked — queue-full tasks are marked failed.
This commit is contained in:
parent
818e46c17e
commit
922d91fb91
3 changed files with 183 additions and 178 deletions
|
|
@ -9,10 +9,13 @@ and marks the task completed or failed.
|
|||
Deduplication: only one queued/running task per (task_type, job_id) is allowed.
|
||||
Different task types for the same job run concurrently (e.g. cover letter + research).
|
||||
"""
|
||||
import logging
|
||||
import sqlite3
|
||||
import threading
|
||||
from pathlib import Path
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
from scripts.db import (
|
||||
DEFAULT_DB,
|
||||
insert_task,
|
||||
|
|
@ -20,6 +23,7 @@ from scripts.db import (
|
|||
update_task_stage,
|
||||
update_cover_letter,
|
||||
save_research,
|
||||
save_optimized_resume,
|
||||
)
|
||||
|
||||
|
||||
|
|
@ -39,9 +43,13 @@ def submit_task(db_path: Path = DEFAULT_DB, task_type: str = "",
|
|||
if is_new:
|
||||
from scripts.task_scheduler import get_scheduler, LLM_TASK_TYPES
|
||||
if task_type in LLM_TASK_TYPES:
|
||||
get_scheduler(db_path, run_task_fn=_run_task).enqueue(
|
||||
enqueued = get_scheduler(db_path, run_task_fn=_run_task).enqueue(
|
||||
task_id, task_type, job_id or 0, params
|
||||
)
|
||||
if not enqueued:
|
||||
update_task_status(
|
||||
db_path, task_id, "failed", error="Queue depth limit reached"
|
||||
)
|
||||
else:
|
||||
t = threading.Thread(
|
||||
target=_run_task,
|
||||
|
|
@ -261,6 +269,48 @@ def _run_task(db_path: Path, task_id: int, task_type: str, job_id: int,
|
|||
)
|
||||
return
|
||||
|
||||
elif task_type == "resume_optimize":
|
||||
import json as _json
|
||||
from scripts.resume_parser import structure_resume
|
||||
from scripts.resume_optimizer import (
|
||||
extract_jd_signals,
|
||||
prioritize_gaps,
|
||||
rewrite_for_ats,
|
||||
hallucination_check,
|
||||
render_resume_text,
|
||||
)
|
||||
from scripts.user_profile import load_user_profile
|
||||
|
||||
description = job.get("description", "")
|
||||
resume_path = load_user_profile().get("resume_path", "")
|
||||
|
||||
# Parse the candidate's resume
|
||||
update_task_stage(db_path, task_id, "parsing resume")
|
||||
resume_text = Path(resume_path).read_text(errors="replace") if resume_path else ""
|
||||
resume_struct, parse_err = structure_resume(resume_text)
|
||||
|
||||
# Extract keyword gaps and build gap report (free tier)
|
||||
update_task_stage(db_path, task_id, "extracting keyword gaps")
|
||||
gaps = extract_jd_signals(description, resume_text)
|
||||
prioritized = prioritize_gaps(gaps, resume_struct)
|
||||
gap_report = _json.dumps(prioritized, indent=2)
|
||||
|
||||
# Full rewrite (paid tier only)
|
||||
rewritten_text = ""
|
||||
p = _json.loads(params or "{}")
|
||||
if p.get("full_rewrite", False):
|
||||
update_task_stage(db_path, task_id, "rewriting resume sections")
|
||||
candidate_voice = load_user_profile().get("candidate_voice", "")
|
||||
rewritten = rewrite_for_ats(resume_struct, prioritized, job, candidate_voice)
|
||||
if hallucination_check(resume_struct, rewritten):
|
||||
rewritten_text = render_resume_text(rewritten)
|
||||
else:
|
||||
log.warning("[task_runner] resume_optimize hallucination check failed for job %d", job_id)
|
||||
|
||||
save_optimized_resume(db_path, job_id=job_id,
|
||||
text=rewritten_text,
|
||||
gap_report=gap_report)
|
||||
|
||||
elif task_type == "prepare_training":
|
||||
from scripts.prepare_training_data import build_records, write_jsonl, DEFAULT_OUTPUT
|
||||
records = build_records()
|
||||
|
|
|
|||
|
|
@ -1,232 +1,176 @@
|
|||
# scripts/task_scheduler.py
|
||||
"""Resource-aware batch scheduler for LLM background tasks.
|
||||
"""Peregrine LLM task scheduler — thin shim over circuitforge_core.tasks.scheduler.
|
||||
|
||||
Routes LLM task types through per-type deques with VRAM-aware scheduling.
|
||||
Non-LLM tasks bypass this module — routing lives in scripts/task_runner.py.
|
||||
All scheduling logic lives in circuitforge_core. This module defines
|
||||
Peregrine-specific task types, VRAM budgets, and config loading.
|
||||
|
||||
Public API:
|
||||
LLM_TASK_TYPES — set of task type strings routed through the scheduler
|
||||
get_scheduler() — lazy singleton accessor
|
||||
reset_scheduler() — test teardown only
|
||||
Public API (unchanged — callers do not need to change):
|
||||
LLM_TASK_TYPES — frozenset of task type strings routed through the scheduler
|
||||
DEFAULT_VRAM_BUDGETS — dict of conservative peak VRAM estimates per task type
|
||||
TaskSpec — lightweight task descriptor (re-exported from core)
|
||||
TaskScheduler — backward-compatible wrapper around the core scheduler class
|
||||
get_scheduler() — returns the process-level TaskScheduler singleton
|
||||
reset_scheduler() — test teardown only
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import sqlite3
|
||||
import threading
|
||||
from collections import deque, namedtuple
|
||||
from pathlib import Path
|
||||
from typing import Callable, Optional
|
||||
|
||||
# Module-level import so tests can monkeypatch scripts.task_scheduler._get_gpus
|
||||
try:
|
||||
from scripts.preflight import get_gpus as _get_gpus
|
||||
except Exception: # graceful degradation if preflight unavailable
|
||||
_get_gpus = lambda: []
|
||||
from circuitforge_core.tasks.scheduler import (
|
||||
TaskSpec, # re-export unchanged
|
||||
TaskScheduler as _CoreTaskScheduler,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Task types that go through the scheduler (all others spawn free threads)
|
||||
# ── Peregrine task types and VRAM budgets ─────────────────────────────────────
|
||||
|
||||
LLM_TASK_TYPES: frozenset[str] = frozenset({
|
||||
"cover_letter",
|
||||
"company_research",
|
||||
"wizard_generate",
|
||||
"resume_optimize",
|
||||
})
|
||||
|
||||
# Conservative peak VRAM estimates (GB) per task type.
|
||||
# Overridable per-install via scheduler.vram_budgets in config/llm.yaml.
|
||||
DEFAULT_VRAM_BUDGETS: dict[str, float] = {
|
||||
"cover_letter": 2.5, # alex-cover-writer:latest (~2GB GGUF + headroom)
|
||||
"cover_letter": 2.5, # alex-cover-writer:latest (~2 GB GGUF + headroom)
|
||||
"company_research": 5.0, # llama3.1:8b or vllm model
|
||||
"wizard_generate": 2.5, # same model family as cover_letter
|
||||
"resume_optimize": 5.0, # section-by-section rewrite; same budget as research
|
||||
}
|
||||
|
||||
# Lightweight task descriptor stored in per-type deques
|
||||
TaskSpec = namedtuple("TaskSpec", ["id", "job_id", "params"])
|
||||
_DEFAULT_MAX_QUEUE_DEPTH = 500
|
||||
|
||||
|
||||
class TaskScheduler:
|
||||
"""Resource-aware LLM task batch scheduler. Use get_scheduler() — not direct construction."""
|
||||
def _load_config_overrides(db_path: Path) -> tuple[dict[str, float], int]:
|
||||
"""Load VRAM budget overrides and max_queue_depth from config/llm.yaml."""
|
||||
budgets = dict(DEFAULT_VRAM_BUDGETS)
|
||||
max_depth = _DEFAULT_MAX_QUEUE_DEPTH
|
||||
config_path = db_path.parent.parent / "config" / "llm.yaml"
|
||||
if config_path.exists():
|
||||
try:
|
||||
import yaml
|
||||
with open(config_path) as f:
|
||||
cfg = yaml.safe_load(f) or {}
|
||||
sched_cfg = cfg.get("scheduler", {})
|
||||
budgets.update(sched_cfg.get("vram_budgets", {}))
|
||||
max_depth = int(sched_cfg.get("max_queue_depth", max_depth))
|
||||
except Exception as exc:
|
||||
logger.warning(
|
||||
"Failed to load scheduler config from %s: %s", config_path, exc
|
||||
)
|
||||
return budgets, max_depth
|
||||
|
||||
|
||||
# Module-level stub so tests can monkeypatch scripts.task_scheduler._get_gpus
|
||||
# (existing tests monkeypatch this symbol — keep it here for backward compat).
|
||||
try:
|
||||
from scripts.preflight import get_gpus as _get_gpus
|
||||
except Exception:
|
||||
_get_gpus = lambda: [] # noqa: E731
|
||||
|
||||
|
||||
class TaskScheduler(_CoreTaskScheduler):
|
||||
"""Peregrine-specific TaskScheduler.
|
||||
|
||||
Extends circuitforge_core.tasks.scheduler.TaskScheduler with:
|
||||
- Peregrine default VRAM budgets and task types wired into __init__
|
||||
- Config loading from config/llm.yaml
|
||||
- Backward-compatible two-argument __init__ signature (db_path, run_task_fn)
|
||||
- _get_gpus monkeypatch support (existing tests patch this module-level symbol)
|
||||
- Backward-compatible enqueue() that marks dropped tasks failed in the DB
|
||||
and logs under the scripts.task_scheduler logger
|
||||
|
||||
Direct construction is still supported for tests; production code should
|
||||
use get_scheduler() instead.
|
||||
"""
|
||||
|
||||
def __init__(self, db_path: Path, run_task_fn: Callable) -> None:
|
||||
self._db_path = db_path
|
||||
self._run_task = run_task_fn
|
||||
budgets, max_depth = _load_config_overrides(db_path)
|
||||
|
||||
self._lock = threading.Lock()
|
||||
self._wake = threading.Event()
|
||||
self._stop = threading.Event()
|
||||
self._queues: dict[str, deque] = {}
|
||||
self._active: dict[str, threading.Thread] = {}
|
||||
self._reserved_vram: float = 0.0
|
||||
self._thread: Optional[threading.Thread] = None
|
||||
# Resolve VRAM using module-level _get_gpus so tests can monkeypatch it
|
||||
try:
|
||||
gpus = _get_gpus()
|
||||
available_vram: float = (
|
||||
sum(g["vram_total_gb"] for g in gpus) if gpus else 999.0
|
||||
)
|
||||
except Exception:
|
||||
available_vram = 999.0
|
||||
|
||||
# Load VRAM budgets: defaults + optional config overrides
|
||||
self._budgets: dict[str, float] = dict(DEFAULT_VRAM_BUDGETS)
|
||||
config_path = db_path.parent.parent / "config" / "llm.yaml"
|
||||
self._max_queue_depth: int = 500
|
||||
if config_path.exists():
|
||||
try:
|
||||
import yaml
|
||||
with open(config_path) as f:
|
||||
cfg = yaml.safe_load(f) or {}
|
||||
sched_cfg = cfg.get("scheduler", {})
|
||||
self._budgets.update(sched_cfg.get("vram_budgets", {}))
|
||||
self._max_queue_depth = sched_cfg.get("max_queue_depth", 500)
|
||||
except Exception as exc:
|
||||
logger.warning("Failed to load scheduler config from %s: %s", config_path, exc)
|
||||
|
||||
# Warn on LLM types with no budget entry after merge
|
||||
# Warn under this module's logger for any task types with no VRAM budget
|
||||
# (mirrors the core warning but captures under scripts.task_scheduler
|
||||
# so existing tests using caplog.at_level(logger="scripts.task_scheduler") pass)
|
||||
for t in LLM_TASK_TYPES:
|
||||
if t not in self._budgets:
|
||||
if t not in budgets:
|
||||
logger.warning(
|
||||
"No VRAM budget defined for LLM task type %r — "
|
||||
"defaulting to 0.0 GB (unlimited concurrency for this type)", t
|
||||
)
|
||||
|
||||
# Detect total GPU VRAM; fall back to unlimited (999) on CPU-only systems.
|
||||
# Uses module-level _get_gpus so tests can monkeypatch scripts.task_scheduler._get_gpus.
|
||||
try:
|
||||
gpus = _get_gpus()
|
||||
self._available_vram: float = (
|
||||
sum(g["vram_total_gb"] for g in gpus) if gpus else 999.0
|
||||
)
|
||||
except Exception:
|
||||
self._available_vram = 999.0
|
||||
super().__init__(
|
||||
db_path=db_path,
|
||||
run_task_fn=run_task_fn,
|
||||
task_types=LLM_TASK_TYPES,
|
||||
vram_budgets=budgets,
|
||||
available_vram_gb=available_vram,
|
||||
max_queue_depth=max_depth,
|
||||
)
|
||||
|
||||
# Durability: reload surviving 'queued' LLM tasks from prior run
|
||||
self._load_queued_tasks()
|
||||
|
||||
def enqueue(self, task_id: int, task_type: str, job_id: int,
|
||||
params: Optional[str]) -> None:
|
||||
def enqueue(
|
||||
self,
|
||||
task_id: int,
|
||||
task_type: str,
|
||||
job_id: int,
|
||||
params: Optional[str],
|
||||
) -> bool:
|
||||
"""Add an LLM task to the scheduler queue.
|
||||
|
||||
If the queue for this type is at max_queue_depth, the task is marked
|
||||
failed in SQLite immediately (no ghost queued rows) and a warning is logged.
|
||||
When the queue is full, marks the task failed in SQLite immediately
|
||||
(backward-compatible with the original Peregrine behavior) and logs a
|
||||
warning under the scripts.task_scheduler logger.
|
||||
|
||||
Returns True if enqueued, False if the queue was full.
|
||||
"""
|
||||
from scripts.db import update_task_status
|
||||
|
||||
with self._lock:
|
||||
q = self._queues.setdefault(task_type, deque())
|
||||
if len(q) >= self._max_queue_depth:
|
||||
logger.warning(
|
||||
"Queue depth limit reached for %s (max=%d) — task %d dropped",
|
||||
task_type, self._max_queue_depth, task_id,
|
||||
)
|
||||
update_task_status(self._db_path, task_id, "failed",
|
||||
error="Queue depth limit reached")
|
||||
return
|
||||
q.append(TaskSpec(task_id, job_id, params))
|
||||
|
||||
self._wake.set()
|
||||
|
||||
def start(self) -> None:
|
||||
"""Start the background scheduler loop thread. Call once after construction."""
|
||||
self._thread = threading.Thread(
|
||||
target=self._scheduler_loop, name="task-scheduler", daemon=True
|
||||
)
|
||||
self._thread.start()
|
||||
|
||||
def shutdown(self, timeout: float = 5.0) -> None:
|
||||
"""Signal the scheduler to stop and wait for it to exit."""
|
||||
self._stop.set()
|
||||
self._wake.set() # unblock any wait()
|
||||
if self._thread and self._thread.is_alive():
|
||||
self._thread.join(timeout=timeout)
|
||||
|
||||
def _scheduler_loop(self) -> None:
|
||||
"""Main scheduler daemon — wakes on enqueue or batch completion."""
|
||||
while not self._stop.is_set():
|
||||
self._wake.wait(timeout=30)
|
||||
self._wake.clear()
|
||||
|
||||
with self._lock:
|
||||
# Defense in depth: reap externally-killed batch threads.
|
||||
# In normal operation _active.pop() runs in finally before _wake fires,
|
||||
# so this reap finds nothing — no double-decrement risk.
|
||||
for t, thread in list(self._active.items()):
|
||||
if not thread.is_alive():
|
||||
self._reserved_vram -= self._budgets.get(t, 0.0)
|
||||
del self._active[t]
|
||||
|
||||
# Start new type batches while VRAM allows
|
||||
candidates = sorted(
|
||||
[t for t in self._queues if self._queues[t] and t not in self._active],
|
||||
key=lambda t: len(self._queues[t]),
|
||||
reverse=True,
|
||||
)
|
||||
for task_type in candidates:
|
||||
budget = self._budgets.get(task_type, 0.0)
|
||||
# Always allow at least one batch to run even if its budget
|
||||
# exceeds _available_vram (prevents permanent starvation when
|
||||
# a single type's budget is larger than the VRAM ceiling).
|
||||
if self._reserved_vram == 0.0 or self._reserved_vram + budget <= self._available_vram:
|
||||
thread = threading.Thread(
|
||||
target=self._batch_worker,
|
||||
args=(task_type,),
|
||||
name=f"batch-{task_type}",
|
||||
daemon=True,
|
||||
)
|
||||
self._active[task_type] = thread
|
||||
self._reserved_vram += budget
|
||||
thread.start()
|
||||
|
||||
def _batch_worker(self, task_type: str) -> None:
|
||||
"""Serial consumer for one task type. Runs until the type's deque is empty."""
|
||||
try:
|
||||
while True:
|
||||
with self._lock:
|
||||
q = self._queues.get(task_type)
|
||||
if not q:
|
||||
break
|
||||
task = q.popleft()
|
||||
# _run_task is scripts.task_runner._run_task (passed at construction)
|
||||
self._run_task(
|
||||
self._db_path, task.id, task_type, task.job_id, task.params
|
||||
)
|
||||
finally:
|
||||
# Always release — even if _run_task raises.
|
||||
# _active.pop here prevents the scheduler loop reap from double-decrementing.
|
||||
with self._lock:
|
||||
self._active.pop(task_type, None)
|
||||
self._reserved_vram -= self._budgets.get(task_type, 0.0)
|
||||
self._wake.set()
|
||||
|
||||
def _load_queued_tasks(self) -> None:
|
||||
"""Load pre-existing queued LLM tasks from SQLite into deques (called once in __init__)."""
|
||||
llm_types = sorted(LLM_TASK_TYPES) # sorted for deterministic SQL params in logs
|
||||
placeholders = ",".join("?" * len(llm_types))
|
||||
conn = sqlite3.connect(self._db_path)
|
||||
rows = conn.execute(
|
||||
f"SELECT id, task_type, job_id, params FROM background_tasks"
|
||||
f" WHERE status='queued' AND task_type IN ({placeholders})"
|
||||
f" ORDER BY created_at ASC",
|
||||
llm_types,
|
||||
).fetchall()
|
||||
conn.close()
|
||||
|
||||
for row_id, task_type, job_id, params in rows:
|
||||
q = self._queues.setdefault(task_type, deque())
|
||||
q.append(TaskSpec(row_id, job_id, params))
|
||||
|
||||
if rows:
|
||||
logger.info("Scheduler: resumed %d queued task(s) from prior run", len(rows))
|
||||
enqueued = super().enqueue(task_id, task_type, job_id, params)
|
||||
if not enqueued:
|
||||
# Log under this module's logger so existing caplog tests pass
|
||||
logger.warning(
|
||||
"Queue depth limit reached for %s (max=%d) — task %d dropped",
|
||||
task_type, self._max_queue_depth, task_id,
|
||||
)
|
||||
from scripts.db import update_task_status
|
||||
update_task_status(
|
||||
self._db_path, task_id, "failed", error="Queue depth limit reached"
|
||||
)
|
||||
return enqueued
|
||||
|
||||
|
||||
# ── Singleton ─────────────────────────────────────────────────────────────────
|
||||
# ── Peregrine-local singleton ──────────────────────────────────────────────────
|
||||
# We manage our own singleton (not the core one) so the process-level instance
|
||||
# is always a Peregrine TaskScheduler (with the enqueue() override).
|
||||
|
||||
_scheduler: Optional[TaskScheduler] = None
|
||||
_scheduler_lock = threading.Lock()
|
||||
|
||||
|
||||
def get_scheduler(db_path: Path, run_task_fn: Callable = None) -> TaskScheduler:
|
||||
"""Return the process-level TaskScheduler singleton, constructing it if needed.
|
||||
def get_scheduler(
|
||||
db_path: Path,
|
||||
run_task_fn: Optional[Callable] = None,
|
||||
) -> TaskScheduler:
|
||||
"""Return the process-level Peregrine TaskScheduler singleton.
|
||||
|
||||
run_task_fn is required on the first call; ignored on subsequent calls.
|
||||
Safety: inner lock + double-check prevents double-construction under races.
|
||||
The outer None check is a fast-path performance optimisation only.
|
||||
run_task_fn is required on the first call; ignored on subsequent calls
|
||||
(double-checked locking — singleton already constructed).
|
||||
"""
|
||||
global _scheduler
|
||||
if _scheduler is None: # fast path — avoids lock on steady state
|
||||
if _scheduler is None: # fast path — no lock on steady state
|
||||
with _scheduler_lock:
|
||||
if _scheduler is None: # re-check under lock (double-checked locking)
|
||||
if _scheduler is None: # re-check under lock
|
||||
if run_task_fn is None:
|
||||
raise ValueError("run_task_fn required on first get_scheduler() call")
|
||||
_scheduler = TaskScheduler(db_path, run_task_fn)
|
||||
|
|
|
|||
|
|
@ -470,3 +470,14 @@ def test_llm_tasks_routed_to_scheduler(tmp_db):
|
|||
task_runner.submit_task(tmp_db, "cover_letter", 1)
|
||||
|
||||
assert "cover_letter" in enqueue_calls
|
||||
|
||||
|
||||
def test_shim_exports_unchanged_api():
|
||||
"""Peregrine shim must re-export LLM_TASK_TYPES, get_scheduler, reset_scheduler."""
|
||||
from scripts.task_scheduler import LLM_TASK_TYPES, get_scheduler, reset_scheduler
|
||||
assert "cover_letter" in LLM_TASK_TYPES
|
||||
assert "company_research" in LLM_TASK_TYPES
|
||||
assert "wizard_generate" in LLM_TASK_TYPES
|
||||
assert "resume_optimize" in LLM_TASK_TYPES
|
||||
assert callable(get_scheduler)
|
||||
assert callable(reset_scheduler)
|
||||
|
|
|
|||
Loading…
Reference in a new issue