avocet/app/train/train.py
pyr0ball e11db5ccd9 fix: align train job/results API envelope, config_json key, progress SSE, dashboard model_key
- GET /api/train/jobs now returns {"jobs":[...]} instead of bare array
- GET /api/train/results now returns {"results":[...]} instead of bare array
- POST /api/train/jobs body key renamed config -> config_json to match Pydantic model
- SSE log handler now handles 'progress' event type (backend never emits 'log')
- Dashboard _get_active_jobs() adds model_key to SELECT and return dict
- corrections.py docstring updated: both /api/corrections and /api/sft prefixes noted
- test_train.py assertions updated to unwrap new envelope shapes
2026-05-02 21:22:18 -07:00

339 lines
12 KiB
Python

"""Avocet -- train job queue API.
SQLite-backed job queue for finetune jobs. Replaces the ad-hoc
_running_procs dict in api.py with a persistent, inspectable queue.
Routes (all under /api/train when api.py mounts with prefix="/api/train"):
GET /jobs -- list all jobs, newest first
POST /jobs -- create a new job
GET /jobs/{id} -- get one job by id
DELETE /jobs/{id}/cancel -- cancel a queued or running job
GET /jobs/{id}/run -- SSE: run the job, stream stdout
GET /results -- list completed models with training_info.json metrics
SQLite schema:
CREATE TABLE IF NOT EXISTS jobs (
id TEXT PRIMARY KEY,
type TEXT NOT NULL, -- 'classifier' | 'llm-sft'
model_key TEXT NOT NULL,
status TEXT NOT NULL DEFAULT 'queued',
config_json TEXT NOT NULL DEFAULT '{}',
created_at TEXT NOT NULL,
started_at TEXT,
completed_at TEXT,
error TEXT
)
Testability seam: _DB_PATH global, override via set_db_path().
"""
from __future__ import annotations
import json
import logging
import os
import sqlite3
import subprocess as _subprocess
import uuid
from contextlib import contextmanager
from datetime import datetime, timezone
from pathlib import Path
from typing import Any, Generator
import yaml
from fastapi import APIRouter, HTTPException
from fastapi.responses import StreamingResponse
from pydantic import BaseModel
logger = logging.getLogger(__name__)
_ROOT = Path(__file__).parent.parent.parent
_DB_PATH: Path = _ROOT / "data" / "train_jobs.db"
_MODELS_DIR: Path = _ROOT / "models"
_CONFIG_DIR: Path | None = None # override in tests via set_config_dir()
_running_procs: dict[str, Any] = {}
router = APIRouter()
# -- Testability seams -------------------------------------------------
def set_db_path(path: Path) -> None:
global _DB_PATH
_DB_PATH = path
def set_models_dir(path: Path) -> None:
global _MODELS_DIR
_MODELS_DIR = path
def set_config_dir(path: "Path | None") -> None:
global _CONFIG_DIR
_CONFIG_DIR = path
# -- Config helpers ----------------------------------------------------
def _config_file() -> Path:
if _CONFIG_DIR is not None:
return _CONFIG_DIR / "label_tool.yaml"
return _ROOT / "config" / "label_tool.yaml"
def _load_train_config() -> dict:
"""Read python_bin from label_tool.yaml.
Priority (highest to lowest):
1. label_tool.yaml train: python_bin
2. label_tool.yaml cforch: python_bin
3. Hardcoded default (classifiers conda env)
"""
_DEFAULT_PYTHON_BIN = "/devl/miniconda3/envs/job-seeker-classifiers/bin/python"
f = _config_file()
train_cfg: dict = {}
cforch_cfg: dict = {}
if f.exists():
try:
raw = yaml.safe_load(f.read_text(encoding="utf-8")) or {}
train_cfg = raw.get("train", {}) or {}
cforch_cfg = raw.get("cforch", {}) or {}
except yaml.YAMLError as exc:
logger.warning("Failed to parse train config %s: %s", f, exc)
return {
"python_bin": train_cfg.get(
"python_bin",
cforch_cfg.get("python_bin", _DEFAULT_PYTHON_BIN),
),
}
# -- Database helpers --------------------------------------------------
@contextmanager
def _db() -> Generator[sqlite3.Connection, None, None]:
conn = sqlite3.connect(str(_DB_PATH))
conn.row_factory = sqlite3.Row
try:
yield conn
conn.commit()
finally:
conn.close()
def _init_db() -> None:
"""Create jobs table if it does not exist. Called lazily per request."""
_DB_PATH.parent.mkdir(parents=True, exist_ok=True)
with _db() as conn:
conn.execute("""
CREATE TABLE IF NOT EXISTS jobs (
id TEXT PRIMARY KEY,
type TEXT NOT NULL,
model_key TEXT NOT NULL,
status TEXT NOT NULL DEFAULT 'queued',
config_json TEXT NOT NULL DEFAULT '{}',
created_at TEXT NOT NULL,
started_at TEXT,
completed_at TEXT,
error TEXT
)
""")
def _row_to_dict(row: sqlite3.Row) -> dict:
return {k: row[k] for k in row.keys()}
# -- GPU selection (copied from api.py) --------------------------------
def _best_cuda_device() -> str:
"""Return index of GPU with most free VRAM, or empty string."""
try:
out = _subprocess.check_output(
["nvidia-smi", "--query-gpu=index,memory.free",
"--format=csv,noheader,nounits"],
text=True, timeout=5,
)
best_idx, best_free = "", 0
for line in out.strip().splitlines():
parts = line.strip().split(", ")
if len(parts) == 2:
idx, free = parts[0].strip(), int(parts[1].strip())
if free > best_free:
best_free, best_idx = free, idx
return best_idx
except Exception:
return ""
# -- Pydantic models ---------------------------------------------------
class CreateJobRequest(BaseModel):
type: str # "classifier" | "llm-sft"
model_key: str # e.g. "deberta-small"
config_json: dict = {}
# -- Routes ------------------------------------------------------------
@router.get("/jobs")
def list_jobs() -> dict:
_init_db()
with _db() as conn:
rows = conn.execute("SELECT * FROM jobs ORDER BY created_at DESC").fetchall()
return {"jobs": [_row_to_dict(r) for r in rows]}
@router.post("/jobs")
def create_job(req: CreateJobRequest) -> dict:
if req.type not in ("classifier", "llm-sft"):
raise HTTPException(400, f"Unknown job type: {req.type!r}. Must be 'classifier' or 'llm-sft'.")
_init_db()
job_id = str(uuid.uuid4())
now = datetime.now(timezone.utc).isoformat()
with _db() as conn:
conn.execute(
"INSERT INTO jobs (id, type, model_key, status, config_json, created_at) "
"VALUES (?, ?, ?, 'queued', ?, ?)",
(job_id, req.type, req.model_key, json.dumps(req.config_json), now),
)
return {"id": job_id, "type": req.type, "model_key": req.model_key,
"status": "queued", "config_json": req.config_json,
"created_at": now, "started_at": None, "completed_at": None, "error": None}
@router.get("/jobs/{job_id}")
def get_job(job_id: str) -> dict:
_init_db()
with _db() as conn:
row = conn.execute("SELECT * FROM jobs WHERE id = ?", (job_id,)).fetchone()
if row is None:
raise HTTPException(404, f"Job {job_id!r} not found")
return _row_to_dict(row)
@router.delete("/jobs/{job_id}/cancel")
def cancel_job(job_id: str) -> dict:
_init_db()
with _db() as conn:
row = conn.execute("SELECT * FROM jobs WHERE id = ?", (job_id,)).fetchone()
if row is None:
raise HTTPException(404, f"Job {job_id!r} not found")
if row["status"] not in ("queued", "running"):
raise HTTPException(409, f"Job is {row['status']} -- cannot cancel")
now = datetime.now(timezone.utc).isoformat()
conn.execute("UPDATE jobs SET status='cancelled', completed_at=? WHERE id=?", (now, job_id))
proc = _running_procs.pop(job_id, None)
if proc is not None:
try:
proc.terminate()
proc.wait(timeout=3)
except _subprocess.TimeoutExpired:
try:
proc.kill()
except OSError:
pass
return {"status": "cancelled"}
@router.get("/jobs/{job_id}/run")
def run_job(job_id: str) -> StreamingResponse:
_init_db()
with _db() as conn:
row = conn.execute("SELECT * FROM jobs WHERE id = ?", (job_id,)).fetchone()
if row is None:
raise HTTPException(404, f"Job {job_id!r} not found")
if row["status"] != "queued":
raise HTTPException(409, f"Job is {row['status']} -- only queued jobs can be run")
job = _row_to_dict(row)
def generate():
cfg = _load_train_config()
python_bin = cfg["python_bin"]
config = json.loads(job["config_json"] or "{}")
model_key = job["model_key"]
epochs = config.get("epochs", 5)
if job["type"] == "classifier":
script = str(_ROOT / "scripts" / "finetune_classifier.py")
cmd = [python_bin, script, "--model", model_key, "--epochs", str(epochs)]
data_dir = _ROOT / "data"
for sf in config.get("score_files", []):
resolved = (data_dir / sf).resolve()
if resolved.is_relative_to(data_dir.resolve()):
cmd.extend(["--score", str(resolved)])
elif job["type"] == "llm-sft":
script = str(_ROOT / "scripts" / "finetune_sft.py")
cmd = [python_bin, script, "--model", model_key, "--epochs", str(epochs)]
else:
job_type = job["type"]
yield f"data: {json.dumps({'type': 'error', 'message': f'Unknown job type: {job_type}'})}\n\n"
return
proc_env = {**os.environ, "PYTORCH_ALLOC_CONF": "expandable_segments:True"}
best_gpu = _best_cuda_device()
if best_gpu:
proc_env["CUDA_VISIBLE_DEVICES"] = best_gpu
gpu_note = f"GPU {best_gpu}" if best_gpu else "CPU (no GPU found)"
yield f"data: {json.dumps({'type': 'progress', 'message': f'[train] Using {gpu_note}'})}\n\n"
now = datetime.now(timezone.utc).isoformat()
with _db() as conn:
conn.execute("UPDATE jobs SET status='running', started_at=? WHERE id=?", (now, job_id))
try:
proc = _subprocess.Popen(
cmd, stdout=_subprocess.PIPE, stderr=_subprocess.STDOUT,
text=True, bufsize=1, cwd=str(_ROOT), env=proc_env,
)
_running_procs[job_id] = proc
try:
for line in proc.stdout:
line = line.rstrip()
if line:
yield f"data: {json.dumps({'type': 'progress', 'message': line})}\n\n"
proc.wait()
finished_at = datetime.now(timezone.utc).isoformat()
if proc.returncode == 0:
with _db() as conn:
conn.execute(
"UPDATE jobs SET status='completed', completed_at=? WHERE id=?",
(finished_at, job_id))
yield f"data: {json.dumps({'type': 'complete'})}\n\n"
else:
err = f"Process exited with code {proc.returncode}"
with _db() as conn:
conn.execute(
"UPDATE jobs SET status='failed', completed_at=?, error=? WHERE id=?",
(finished_at, err, job_id))
yield f"data: {json.dumps({'type': 'error', 'message': err})}\n\n"
finally:
_running_procs.pop(job_id, None)
except Exception as exc:
err = str(exc)
finished_at = datetime.now(timezone.utc).isoformat()
with _db() as conn:
conn.execute(
"UPDATE jobs SET status='failed', completed_at=?, error=? WHERE id=?",
(finished_at, err, job_id))
yield f"data: {json.dumps({'type': 'error', 'message': err})}\n\n"
return StreamingResponse(generate(), media_type="text/event-stream",
headers={"Cache-Control": "no-cache", "X-Accel-Buffering": "no"})
@router.get("/results")
def list_results() -> dict:
if not _MODELS_DIR.exists():
return {"results": []}
results = []
for sub in _MODELS_DIR.iterdir():
if not sub.is_dir():
continue
info_path = sub / "training_info.json"
if not info_path.exists():
continue
try:
info = json.loads(info_path.read_text(encoding="utf-8"))
results.append(info)
except Exception as exc:
logger.warning("Failed to read training_info.json from %s: %s", info_path, exc)
return {"results": results}