feat: build SQLite-backed train job queue in app/train/train.py
Replaces the ad-hoc _running_procs dict in api.py with a persistent, inspectable SQLite job queue. Removes old /api/finetune/* routes and _best_cuda_device from api.py. Adds /api/train/* routes (list, create, get, cancel, run SSE, results). 16 new tests all passing.
This commit is contained in:
parent
d432026fd7
commit
766fbafa02
5 changed files with 501 additions and 626 deletions
145
app/api.py
145
app/api.py
|
|
@ -22,11 +22,6 @@ _DATA_DIR: Path = _ROOT / "data" # overridable in tests via set_data_dir()
|
||||||
_MODELS_DIR: Path = _ROOT / "models" # overridable in tests via set_models_dir()
|
_MODELS_DIR: Path = _ROOT / "models" # overridable in tests via set_models_dir()
|
||||||
_CONFIG_DIR: Path | None = None # None = use real path
|
_CONFIG_DIR: Path | None = None # None = use real path
|
||||||
|
|
||||||
# Process registry for running jobs — used by cancel endpoints.
|
|
||||||
# Keys: "benchmark" | "finetune". Values: the live Popen object.
|
|
||||||
_running_procs: dict = {}
|
|
||||||
_cancelled_jobs: set = set()
|
|
||||||
|
|
||||||
|
|
||||||
def set_data_dir(path: Path) -> None:
|
def set_data_dir(path: Path) -> None:
|
||||||
"""Override data directory — used by tests."""
|
"""Override data directory — used by tests."""
|
||||||
|
|
@ -34,34 +29,6 @@ def set_data_dir(path: Path) -> None:
|
||||||
_DATA_DIR = path
|
_DATA_DIR = path
|
||||||
|
|
||||||
|
|
||||||
def _best_cuda_device() -> str:
|
|
||||||
"""Return the index of the GPU with the most free VRAM as a string.
|
|
||||||
|
|
||||||
Uses nvidia-smi so it works in the job-seeker env (no torch). Returns ""
|
|
||||||
if nvidia-smi is unavailable or no GPUs are found. Restricting the
|
|
||||||
training subprocess to a single GPU via CUDA_VISIBLE_DEVICES prevents
|
|
||||||
PyTorch DataParallel from replicating the model across all GPUs, which
|
|
||||||
would OOM the GPU with less headroom.
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
out = _subprocess.check_output(
|
|
||||||
["nvidia-smi", "--query-gpu=index,memory.free",
|
|
||||||
"--format=csv,noheader,nounits"],
|
|
||||||
text=True,
|
|
||||||
timeout=5,
|
|
||||||
)
|
|
||||||
best_idx, best_free = "", 0
|
|
||||||
for line in out.strip().splitlines():
|
|
||||||
parts = line.strip().split(", ")
|
|
||||||
if len(parts) == 2:
|
|
||||||
idx, free = parts[0].strip(), int(parts[1].strip())
|
|
||||||
if free > best_free:
|
|
||||||
best_free, best_idx = free, idx
|
|
||||||
return best_idx
|
|
||||||
except Exception:
|
|
||||||
return ""
|
|
||||||
|
|
||||||
|
|
||||||
def set_models_dir(path: Path) -> None:
|
def set_models_dir(path: Path) -> None:
|
||||||
"""Override models directory — used by tests."""
|
"""Override models directory — used by tests."""
|
||||||
global _MODELS_DIR
|
global _MODELS_DIR
|
||||||
|
|
@ -156,116 +123,8 @@ app.include_router(imitate_router, prefix="/api/imitate")
|
||||||
from app.data.fetch import router as fetch_router
|
from app.data.fetch import router as fetch_router
|
||||||
app.include_router(fetch_router, prefix="/api")
|
app.include_router(fetch_router, prefix="/api")
|
||||||
|
|
||||||
|
from app.train.train import router as train_router
|
||||||
from fastapi.responses import StreamingResponse
|
app.include_router(train_router, prefix="/api/train")
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
# Finetune endpoints
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
|
|
||||||
@app.get("/api/finetune/status")
|
|
||||||
def get_finetune_status():
|
|
||||||
"""Scan models/ for training_info.json files. Returns [] if none exist."""
|
|
||||||
models_dir = _MODELS_DIR
|
|
||||||
if not models_dir.exists():
|
|
||||||
return []
|
|
||||||
results = []
|
|
||||||
for sub in models_dir.iterdir():
|
|
||||||
if not sub.is_dir():
|
|
||||||
continue
|
|
||||||
info_path = sub / "training_info.json"
|
|
||||||
if not info_path.exists():
|
|
||||||
continue
|
|
||||||
try:
|
|
||||||
info = json.loads(info_path.read_text(encoding="utf-8"))
|
|
||||||
results.append(info)
|
|
||||||
except Exception:
|
|
||||||
pass
|
|
||||||
return results
|
|
||||||
|
|
||||||
|
|
||||||
@app.get("/api/finetune/run")
|
|
||||||
def run_finetune_endpoint(
|
|
||||||
model: str = "deberta-small",
|
|
||||||
epochs: int = 5,
|
|
||||||
score: list[str] = Query(default=[]),
|
|
||||||
):
|
|
||||||
"""Spawn finetune_classifier.py and stream stdout as SSE progress events."""
|
|
||||||
python_bin = "/devl/miniconda3/envs/job-seeker-classifiers/bin/python"
|
|
||||||
script = str(_ROOT / "scripts" / "finetune_classifier.py")
|
|
||||||
cmd = [python_bin, script, "--model", model, "--epochs", str(epochs)]
|
|
||||||
data_root = _DATA_DIR.resolve()
|
|
||||||
for score_file in score:
|
|
||||||
resolved = (_DATA_DIR / score_file).resolve()
|
|
||||||
if not str(resolved).startswith(str(data_root)):
|
|
||||||
raise HTTPException(400, f"Invalid score path: {score_file!r}")
|
|
||||||
cmd.extend(["--score", str(resolved)])
|
|
||||||
|
|
||||||
# Pick the GPU with the most free VRAM. Setting CUDA_VISIBLE_DEVICES to a
|
|
||||||
# single device prevents DataParallel from replicating the model across all
|
|
||||||
# GPUs, which would force a full copy onto the more memory-constrained device.
|
|
||||||
proc_env = {**os.environ, "PYTORCH_ALLOC_CONF": "expandable_segments:True"}
|
|
||||||
best_gpu = _best_cuda_device()
|
|
||||||
if best_gpu:
|
|
||||||
proc_env["CUDA_VISIBLE_DEVICES"] = best_gpu
|
|
||||||
|
|
||||||
gpu_note = f"GPU {best_gpu}" if best_gpu else "CPU (no GPU found)"
|
|
||||||
|
|
||||||
def generate():
|
|
||||||
yield f"data: {json.dumps({'type': 'progress', 'message': f'[api] Using {gpu_note} (most free VRAM)'})}\n\n"
|
|
||||||
try:
|
|
||||||
proc = _subprocess.Popen(
|
|
||||||
cmd,
|
|
||||||
stdout=_subprocess.PIPE,
|
|
||||||
stderr=_subprocess.STDOUT,
|
|
||||||
text=True,
|
|
||||||
bufsize=1,
|
|
||||||
cwd=str(_ROOT),
|
|
||||||
env=proc_env,
|
|
||||||
)
|
|
||||||
_running_procs["finetune"] = proc
|
|
||||||
_cancelled_jobs.discard("finetune") # clear any stale flag from a prior run
|
|
||||||
try:
|
|
||||||
for line in proc.stdout:
|
|
||||||
line = line.rstrip()
|
|
||||||
if line:
|
|
||||||
yield f"data: {json.dumps({'type': 'progress', 'message': line})}\n\n"
|
|
||||||
proc.wait()
|
|
||||||
if proc.returncode == 0:
|
|
||||||
yield f"data: {json.dumps({'type': 'complete'})}\n\n"
|
|
||||||
elif "finetune" in _cancelled_jobs:
|
|
||||||
_cancelled_jobs.discard("finetune")
|
|
||||||
yield f"data: {json.dumps({'type': 'cancelled'})}\n\n"
|
|
||||||
else:
|
|
||||||
yield f"data: {json.dumps({'type': 'error', 'message': f'Process exited with code {proc.returncode}'})}\n\n"
|
|
||||||
finally:
|
|
||||||
_running_procs.pop("finetune", None)
|
|
||||||
except Exception as exc:
|
|
||||||
yield f"data: {json.dumps({'type': 'error', 'message': str(exc)})}\n\n"
|
|
||||||
|
|
||||||
return StreamingResponse(
|
|
||||||
generate(),
|
|
||||||
media_type="text/event-stream",
|
|
||||||
headers={"Cache-Control": "no-cache", "X-Accel-Buffering": "no"},
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
@app.post("/api/finetune/cancel")
|
|
||||||
def cancel_finetune():
|
|
||||||
"""Kill the running fine-tune subprocess. 404 if none is running."""
|
|
||||||
proc = _running_procs.get("finetune")
|
|
||||||
if proc is None:
|
|
||||||
raise HTTPException(404, "No finetune is running")
|
|
||||||
_cancelled_jobs.add("finetune")
|
|
||||||
proc.terminate()
|
|
||||||
try:
|
|
||||||
proc.wait(timeout=3)
|
|
||||||
except _subprocess.TimeoutExpired:
|
|
||||||
proc.kill()
|
|
||||||
return {"status": "cancelled"}
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
# Static SPA — MUST be last (catches all unmatched paths)
|
# Static SPA — MUST be last (catches all unmatched paths)
|
||||||
|
|
|
||||||
0
app/train/__init__.py
Normal file
0
app/train/__init__.py
Normal file
287
app/train/train.py
Normal file
287
app/train/train.py
Normal file
|
|
@ -0,0 +1,287 @@
|
||||||
|
"""Avocet -- train job queue API.
|
||||||
|
|
||||||
|
SQLite-backed job queue for finetune jobs. Replaces the ad-hoc
|
||||||
|
_running_procs dict in api.py with a persistent, inspectable queue.
|
||||||
|
|
||||||
|
Routes (all under /api/train when api.py mounts with prefix="/api/train"):
|
||||||
|
GET /jobs -- list all jobs, newest first
|
||||||
|
POST /jobs -- create a new job
|
||||||
|
GET /jobs/{id} -- get one job by id
|
||||||
|
DELETE /jobs/{id}/cancel -- cancel a queued or running job
|
||||||
|
GET /jobs/{id}/run -- SSE: run the job, stream stdout
|
||||||
|
GET /results -- list completed models with training_info.json metrics
|
||||||
|
|
||||||
|
SQLite schema:
|
||||||
|
CREATE TABLE IF NOT EXISTS jobs (
|
||||||
|
id TEXT PRIMARY KEY,
|
||||||
|
type TEXT NOT NULL, -- 'classifier' | 'llm-sft'
|
||||||
|
model_key TEXT NOT NULL,
|
||||||
|
status TEXT NOT NULL DEFAULT 'queued',
|
||||||
|
config_json TEXT NOT NULL DEFAULT '{}',
|
||||||
|
created_at TEXT NOT NULL,
|
||||||
|
started_at TEXT,
|
||||||
|
completed_at TEXT,
|
||||||
|
error TEXT
|
||||||
|
)
|
||||||
|
|
||||||
|
Testability seam: _DB_PATH global, override via set_db_path().
|
||||||
|
"""
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
import sqlite3
|
||||||
|
import subprocess as _subprocess
|
||||||
|
import uuid
|
||||||
|
from contextlib import contextmanager
|
||||||
|
from datetime import datetime, timezone
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Any, Generator
|
||||||
|
|
||||||
|
from fastapi import APIRouter, HTTPException
|
||||||
|
from fastapi.responses import StreamingResponse
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
_ROOT = Path(__file__).parent.parent.parent
|
||||||
|
_DB_PATH: Path = _ROOT / "data" / "train_jobs.db"
|
||||||
|
_MODELS_DIR: Path = _ROOT / "models"
|
||||||
|
_running_procs: dict[str, Any] = {}
|
||||||
|
|
||||||
|
router = APIRouter()
|
||||||
|
|
||||||
|
|
||||||
|
# -- Testability seams -------------------------------------------------
|
||||||
|
|
||||||
|
def set_db_path(path: Path) -> None:
|
||||||
|
global _DB_PATH
|
||||||
|
_DB_PATH = path
|
||||||
|
|
||||||
|
def set_models_dir(path: Path) -> None:
|
||||||
|
global _MODELS_DIR
|
||||||
|
_MODELS_DIR = path
|
||||||
|
|
||||||
|
|
||||||
|
# -- Database helpers --------------------------------------------------
|
||||||
|
|
||||||
|
@contextmanager
|
||||||
|
def _db() -> Generator[sqlite3.Connection, None, None]:
|
||||||
|
conn = sqlite3.connect(str(_DB_PATH))
|
||||||
|
conn.row_factory = sqlite3.Row
|
||||||
|
try:
|
||||||
|
yield conn
|
||||||
|
conn.commit()
|
||||||
|
finally:
|
||||||
|
conn.close()
|
||||||
|
|
||||||
|
|
||||||
|
def _init_db() -> None:
|
||||||
|
"""Create jobs table if it does not exist. Called lazily per request."""
|
||||||
|
_DB_PATH.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
with _db() as conn:
|
||||||
|
conn.execute("""
|
||||||
|
CREATE TABLE IF NOT EXISTS jobs (
|
||||||
|
id TEXT PRIMARY KEY,
|
||||||
|
type TEXT NOT NULL,
|
||||||
|
model_key TEXT NOT NULL,
|
||||||
|
status TEXT NOT NULL DEFAULT 'queued',
|
||||||
|
config_json TEXT NOT NULL DEFAULT '{}',
|
||||||
|
created_at TEXT NOT NULL,
|
||||||
|
started_at TEXT,
|
||||||
|
completed_at TEXT,
|
||||||
|
error TEXT
|
||||||
|
)
|
||||||
|
""")
|
||||||
|
|
||||||
|
|
||||||
|
def _row_to_dict(row: sqlite3.Row) -> dict:
|
||||||
|
return {k: row[k] for k in row.keys()}
|
||||||
|
|
||||||
|
|
||||||
|
# -- GPU selection (copied from api.py) --------------------------------
|
||||||
|
|
||||||
|
def _best_cuda_device() -> str:
|
||||||
|
"""Return index of GPU with most free VRAM, or empty string."""
|
||||||
|
try:
|
||||||
|
out = _subprocess.check_output(
|
||||||
|
["nvidia-smi", "--query-gpu=index,memory.free",
|
||||||
|
"--format=csv,noheader,nounits"],
|
||||||
|
text=True, timeout=5,
|
||||||
|
)
|
||||||
|
best_idx, best_free = "", 0
|
||||||
|
for line in out.strip().splitlines():
|
||||||
|
parts = line.strip().split(", ")
|
||||||
|
if len(parts) == 2:
|
||||||
|
idx, free = parts[0].strip(), int(parts[1].strip())
|
||||||
|
if free > best_free:
|
||||||
|
best_free, best_idx = free, idx
|
||||||
|
return best_idx
|
||||||
|
except Exception:
|
||||||
|
return ""
|
||||||
|
|
||||||
|
|
||||||
|
# -- Pydantic models ---------------------------------------------------
|
||||||
|
|
||||||
|
class CreateJobRequest(BaseModel):
|
||||||
|
type: str # "classifier" | "llm-sft"
|
||||||
|
model_key: str # e.g. "deberta-small"
|
||||||
|
config_json: dict = {}
|
||||||
|
|
||||||
|
|
||||||
|
# -- Routes ------------------------------------------------------------
|
||||||
|
|
||||||
|
@router.get("/jobs")
|
||||||
|
def list_jobs() -> list[dict]:
|
||||||
|
_init_db()
|
||||||
|
with _db() as conn:
|
||||||
|
rows = conn.execute("SELECT * FROM jobs ORDER BY created_at DESC").fetchall()
|
||||||
|
return [_row_to_dict(r) for r in rows]
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/jobs")
|
||||||
|
def create_job(req: CreateJobRequest) -> dict:
|
||||||
|
if req.type not in ("classifier", "llm-sft"):
|
||||||
|
raise HTTPException(400, f"Unknown job type: {req.type!r}. Must be 'classifier' or 'llm-sft'.")
|
||||||
|
_init_db()
|
||||||
|
job_id = str(uuid.uuid4())
|
||||||
|
now = datetime.now(timezone.utc).isoformat()
|
||||||
|
with _db() as conn:
|
||||||
|
conn.execute(
|
||||||
|
"INSERT INTO jobs (id, type, model_key, status, config_json, created_at) "
|
||||||
|
"VALUES (?, ?, ?, 'queued', ?, ?)",
|
||||||
|
(job_id, req.type, req.model_key, json.dumps(req.config_json), now),
|
||||||
|
)
|
||||||
|
return {"id": job_id, "type": req.type, "model_key": req.model_key,
|
||||||
|
"status": "queued", "config_json": req.config_json,
|
||||||
|
"created_at": now, "started_at": None, "completed_at": None, "error": None}
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/jobs/{job_id}")
|
||||||
|
def get_job(job_id: str) -> dict:
|
||||||
|
_init_db()
|
||||||
|
with _db() as conn:
|
||||||
|
row = conn.execute("SELECT * FROM jobs WHERE id = ?", (job_id,)).fetchone()
|
||||||
|
if row is None:
|
||||||
|
raise HTTPException(404, f"Job {job_id!r} not found")
|
||||||
|
return _row_to_dict(row)
|
||||||
|
|
||||||
|
|
||||||
|
@router.delete("/jobs/{job_id}/cancel")
|
||||||
|
def cancel_job(job_id: str) -> dict:
|
||||||
|
_init_db()
|
||||||
|
with _db() as conn:
|
||||||
|
row = conn.execute("SELECT * FROM jobs WHERE id = ?", (job_id,)).fetchone()
|
||||||
|
if row is None:
|
||||||
|
raise HTTPException(404, f"Job {job_id!r} not found")
|
||||||
|
if row["status"] not in ("queued", "running"):
|
||||||
|
raise HTTPException(409, f"Job is {row['status']} -- cannot cancel")
|
||||||
|
now = datetime.now(timezone.utc).isoformat()
|
||||||
|
conn.execute("UPDATE jobs SET status='cancelled', completed_at=? WHERE id=?", (now, job_id))
|
||||||
|
proc = _running_procs.pop(job_id, None)
|
||||||
|
if proc is not None:
|
||||||
|
try:
|
||||||
|
proc.terminate()
|
||||||
|
proc.wait(timeout=3)
|
||||||
|
except _subprocess.TimeoutExpired:
|
||||||
|
proc.kill()
|
||||||
|
return {"status": "cancelled"}
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/jobs/{job_id}/run")
|
||||||
|
def run_job(job_id: str) -> StreamingResponse:
|
||||||
|
_init_db()
|
||||||
|
with _db() as conn:
|
||||||
|
row = conn.execute("SELECT * FROM jobs WHERE id = ?", (job_id,)).fetchone()
|
||||||
|
if row is None:
|
||||||
|
raise HTTPException(404, f"Job {job_id!r} not found")
|
||||||
|
if row["status"] != "queued":
|
||||||
|
raise HTTPException(409, f"Job is {row['status']} -- only queued jobs can be run")
|
||||||
|
job = _row_to_dict(row)
|
||||||
|
|
||||||
|
def generate():
|
||||||
|
python_bin = "/devl/miniconda3/envs/job-seeker-classifiers/bin/python"
|
||||||
|
config = json.loads(job["config_json"] or "{}")
|
||||||
|
model_key = job["model_key"]
|
||||||
|
epochs = config.get("epochs", 5)
|
||||||
|
|
||||||
|
if job["type"] == "classifier":
|
||||||
|
script = str(_ROOT / "scripts" / "finetune_classifier.py")
|
||||||
|
cmd = [python_bin, script, "--model", model_key, "--epochs", str(epochs)]
|
||||||
|
data_dir = _ROOT / "data"
|
||||||
|
for sf in config.get("score_files", []):
|
||||||
|
resolved = (data_dir / sf).resolve()
|
||||||
|
if str(resolved).startswith(str(data_dir.resolve())):
|
||||||
|
cmd.extend(["--score", str(resolved)])
|
||||||
|
elif job["type"] == "llm-sft":
|
||||||
|
script = str(_ROOT / "scripts" / "finetune_sft.py")
|
||||||
|
cmd = [python_bin, script, "--model", model_key, "--epochs", str(epochs)]
|
||||||
|
else:
|
||||||
|
yield f"data: {json.dumps({'type': 'error', 'message': f'Unknown job type: {job['type']}'})}\n\n"
|
||||||
|
return
|
||||||
|
|
||||||
|
proc_env = {**os.environ, "PYTORCH_ALLOC_CONF": "expandable_segments:True"}
|
||||||
|
best_gpu = _best_cuda_device()
|
||||||
|
if best_gpu:
|
||||||
|
proc_env["CUDA_VISIBLE_DEVICES"] = best_gpu
|
||||||
|
|
||||||
|
gpu_note = f"GPU {best_gpu}" if best_gpu else "CPU (no GPU found)"
|
||||||
|
yield f"data: {json.dumps({'type': 'progress', 'message': f'[train] Using {gpu_note}'})}\n\n"
|
||||||
|
|
||||||
|
now = datetime.now(timezone.utc).isoformat()
|
||||||
|
with _db() as conn:
|
||||||
|
conn.execute("UPDATE jobs SET status='running', started_at=? WHERE id=?", (now, job_id))
|
||||||
|
|
||||||
|
try:
|
||||||
|
proc = _subprocess.Popen(
|
||||||
|
cmd, stdout=_subprocess.PIPE, stderr=_subprocess.STDOUT,
|
||||||
|
text=True, bufsize=1, cwd=str(_ROOT), env=proc_env,
|
||||||
|
)
|
||||||
|
_running_procs[job_id] = proc
|
||||||
|
try:
|
||||||
|
for line in proc.stdout:
|
||||||
|
line = line.rstrip()
|
||||||
|
if line:
|
||||||
|
yield f"data: {json.dumps({'type': 'progress', 'message': line})}\n\n"
|
||||||
|
proc.wait()
|
||||||
|
finished_at = datetime.now(timezone.utc).isoformat()
|
||||||
|
if proc.returncode == 0:
|
||||||
|
with _db() as conn:
|
||||||
|
conn.execute(
|
||||||
|
"UPDATE jobs SET status='completed', completed_at=? WHERE id=?",
|
||||||
|
(finished_at, job_id))
|
||||||
|
yield f"data: {json.dumps({'type': 'complete'})}\n\n"
|
||||||
|
else:
|
||||||
|
err = f"Process exited with code {proc.returncode}"
|
||||||
|
with _db() as conn:
|
||||||
|
conn.execute(
|
||||||
|
"UPDATE jobs SET status='failed', completed_at=?, error=? WHERE id=?",
|
||||||
|
(finished_at, err, job_id))
|
||||||
|
yield f"data: {json.dumps({'type': 'error', 'message': err})}\n\n"
|
||||||
|
finally:
|
||||||
|
_running_procs.pop(job_id, None)
|
||||||
|
except Exception as exc:
|
||||||
|
err = str(exc)
|
||||||
|
with _db() as conn:
|
||||||
|
conn.execute("UPDATE jobs SET status='failed', error=? WHERE id=?", (err, job_id))
|
||||||
|
yield f"data: {json.dumps({'type': 'error', 'message': err})}\n\n"
|
||||||
|
|
||||||
|
return StreamingResponse(generate(), media_type="text/event-stream",
|
||||||
|
headers={"Cache-Control": "no-cache", "X-Accel-Buffering": "no"})
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/results")
|
||||||
|
def list_results() -> list[dict]:
|
||||||
|
if not _MODELS_DIR.exists():
|
||||||
|
return []
|
||||||
|
results = []
|
||||||
|
for sub in _MODELS_DIR.iterdir():
|
||||||
|
if not sub.is_dir():
|
||||||
|
continue
|
||||||
|
info_path = sub / "training_info.json"
|
||||||
|
if not info_path.exists():
|
||||||
|
continue
|
||||||
|
try:
|
||||||
|
info = json.loads(info_path.read_text(encoding="utf-8"))
|
||||||
|
results.append(info)
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
return results
|
||||||
|
|
@ -1,29 +1,33 @@
|
||||||
import json
|
"""Smoke tests for the app factory (app/api.py).
|
||||||
|
|
||||||
|
Detailed route tests live in test_data_label.py, test_data_fetch.py,
|
||||||
|
test_data_corrections.py, test_train.py, and test_dashboard.py.
|
||||||
|
"""
|
||||||
import pytest
|
import pytest
|
||||||
from app import api as api_module # noqa: F401
|
from fastapi.testclient import TestClient
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(autouse=True)
|
|
||||||
def reset_globals(tmp_path):
|
|
||||||
from app import api
|
|
||||||
from app.data import label as label_module
|
|
||||||
from app.data import fetch as fetch_module
|
|
||||||
api.set_data_dir(tmp_path)
|
|
||||||
label_module.set_data_dir(tmp_path)
|
|
||||||
label_module.set_config_dir(tmp_path)
|
|
||||||
label_module.reset_last_action()
|
|
||||||
fetch_module.set_data_dir(tmp_path)
|
|
||||||
fetch_module.set_config_dir(tmp_path)
|
|
||||||
yield
|
|
||||||
label_module.reset_last_action()
|
|
||||||
|
|
||||||
|
|
||||||
def test_import():
|
def test_import():
|
||||||
from app import api # noqa: F401
|
from app import api # noqa: F401
|
||||||
|
|
||||||
|
|
||||||
from fastapi.testclient import TestClient
|
def test_app_has_required_routes():
|
||||||
|
from app.api import app
|
||||||
|
paths = {r.path for r in app.routes}
|
||||||
|
# Label routes
|
||||||
|
assert "/api/queue" in paths
|
||||||
|
assert "/api/label" in paths
|
||||||
|
assert "/api/skip" in paths
|
||||||
|
assert "/api/discard" in paths
|
||||||
|
assert "/api/label/undo" in paths
|
||||||
|
assert "/api/config/labels" in paths
|
||||||
|
assert "/api/stats" in paths
|
||||||
|
# Fetch routes
|
||||||
|
assert "/api/accounts/test" in paths
|
||||||
|
assert "/api/fetch/stream" in paths
|
||||||
|
# Train routes
|
||||||
|
assert "/api/train/jobs" in paths
|
||||||
|
assert "/api/train/results" in paths
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
|
|
@ -32,470 +36,8 @@ def client():
|
||||||
return TestClient(app)
|
return TestClient(app)
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
def test_queue_endpoint_reachable(client):
|
||||||
def queue_with_items():
|
|
||||||
"""Write 3 test emails to the queue file."""
|
|
||||||
from app import api as api_module
|
|
||||||
items = [
|
|
||||||
{"id": f"id{i}", "subject": f"Subject {i}", "body": f"Body {i}",
|
|
||||||
"from": "test@example.com", "date": "2026-03-01", "source": "imap:test"}
|
|
||||||
for i in range(3)
|
|
||||||
]
|
|
||||||
queue_path = api_module._DATA_DIR / "email_label_queue.jsonl"
|
|
||||||
queue_path.write_text("\n".join(json.dumps(x) for x in items) + "\n")
|
|
||||||
return items
|
|
||||||
|
|
||||||
|
|
||||||
def test_queue_returns_items(client, queue_with_items):
|
|
||||||
r = client.get("/api/queue?limit=2")
|
|
||||||
assert r.status_code == 200
|
|
||||||
data = r.json()
|
|
||||||
assert len(data["items"]) == 2
|
|
||||||
assert data["total"] == 3
|
|
||||||
|
|
||||||
|
|
||||||
def test_queue_empty_when_no_file(client):
|
|
||||||
r = client.get("/api/queue")
|
r = client.get("/api/queue")
|
||||||
assert r.status_code == 200
|
assert r.status_code == 200
|
||||||
assert r.json() == {"items": [], "total": 0}
|
assert "items" in r.json()
|
||||||
|
assert "total" in r.json()
|
||||||
|
|
||||||
def test_label_appends_to_score(client, queue_with_items):
|
|
||||||
from app import api as api_module
|
|
||||||
r = client.post("/api/label", json={"id": "id0", "label": "interview_scheduled"})
|
|
||||||
assert r.status_code == 200
|
|
||||||
records = api_module._read_jsonl(api_module._score_file())
|
|
||||||
assert len(records) == 1
|
|
||||||
assert records[0]["id"] == "id0"
|
|
||||||
assert records[0]["label"] == "interview_scheduled"
|
|
||||||
assert "labeled_at" in records[0]
|
|
||||||
|
|
||||||
def test_label_removes_from_queue(client, queue_with_items):
|
|
||||||
from app import api as api_module
|
|
||||||
client.post("/api/label", json={"id": "id0", "label": "rejected"})
|
|
||||||
queue = api_module._read_jsonl(api_module._queue_file())
|
|
||||||
assert not any(x["id"] == "id0" for x in queue)
|
|
||||||
|
|
||||||
def test_label_unknown_id_returns_404(client, queue_with_items):
|
|
||||||
r = client.post("/api/label", json={"id": "unknown", "label": "neutral"})
|
|
||||||
assert r.status_code == 404
|
|
||||||
|
|
||||||
def test_skip_moves_to_back(client, queue_with_items):
|
|
||||||
from app import api as api_module
|
|
||||||
r = client.post("/api/skip", json={"id": "id0"})
|
|
||||||
assert r.status_code == 200
|
|
||||||
queue = api_module._read_jsonl(api_module._queue_file())
|
|
||||||
assert queue[-1]["id"] == "id0"
|
|
||||||
assert queue[0]["id"] == "id1"
|
|
||||||
|
|
||||||
def test_skip_unknown_id_returns_404(client, queue_with_items):
|
|
||||||
r = client.post("/api/skip", json={"id": "nope"})
|
|
||||||
assert r.status_code == 404
|
|
||||||
|
|
||||||
|
|
||||||
# --- Part A: POST /api/discard ---
|
|
||||||
|
|
||||||
def test_discard_writes_to_discarded_file(client, queue_with_items):
|
|
||||||
from app import api as api_module
|
|
||||||
r = client.post("/api/discard", json={"id": "id1"})
|
|
||||||
assert r.status_code == 200
|
|
||||||
discarded = api_module._read_jsonl(api_module._discarded_file())
|
|
||||||
assert len(discarded) == 1
|
|
||||||
assert discarded[0]["id"] == "id1"
|
|
||||||
assert discarded[0]["label"] == "__discarded__"
|
|
||||||
|
|
||||||
def test_discard_removes_from_queue(client, queue_with_items):
|
|
||||||
from app import api as api_module
|
|
||||||
client.post("/api/discard", json={"id": "id1"})
|
|
||||||
queue = api_module._read_jsonl(api_module._queue_file())
|
|
||||||
assert not any(x["id"] == "id1" for x in queue)
|
|
||||||
|
|
||||||
|
|
||||||
# --- Part B: DELETE /api/label/undo ---
|
|
||||||
|
|
||||||
def test_undo_label_removes_from_score(client, queue_with_items):
|
|
||||||
from app import api as api_module
|
|
||||||
client.post("/api/label", json={"id": "id0", "label": "neutral"})
|
|
||||||
r = client.delete("/api/label/undo")
|
|
||||||
assert r.status_code == 200
|
|
||||||
data = r.json()
|
|
||||||
assert data["undone"]["type"] == "label"
|
|
||||||
score = api_module._read_jsonl(api_module._score_file())
|
|
||||||
assert score == []
|
|
||||||
# Item should be restored to front of queue
|
|
||||||
queue = api_module._read_jsonl(api_module._queue_file())
|
|
||||||
assert queue[0]["id"] == "id0"
|
|
||||||
|
|
||||||
def test_undo_discard_removes_from_discarded(client, queue_with_items):
|
|
||||||
from app import api as api_module
|
|
||||||
client.post("/api/discard", json={"id": "id0"})
|
|
||||||
r = client.delete("/api/label/undo")
|
|
||||||
assert r.status_code == 200
|
|
||||||
discarded = api_module._read_jsonl(api_module._discarded_file())
|
|
||||||
assert discarded == []
|
|
||||||
|
|
||||||
def test_undo_skip_restores_to_front(client, queue_with_items):
|
|
||||||
from app import api as api_module
|
|
||||||
client.post("/api/skip", json={"id": "id0"})
|
|
||||||
r = client.delete("/api/label/undo")
|
|
||||||
assert r.status_code == 200
|
|
||||||
queue = api_module._read_jsonl(api_module._queue_file())
|
|
||||||
assert queue[0]["id"] == "id0"
|
|
||||||
|
|
||||||
def test_undo_with_no_action_returns_404(client):
|
|
||||||
r = client.delete("/api/label/undo")
|
|
||||||
assert r.status_code == 404
|
|
||||||
|
|
||||||
|
|
||||||
# --- Part C: GET /api/config/labels ---
|
|
||||||
|
|
||||||
def test_config_labels_returns_metadata(client):
|
|
||||||
r = client.get("/api/config/labels")
|
|
||||||
assert r.status_code == 200
|
|
||||||
labels = r.json()
|
|
||||||
assert len(labels) == 10
|
|
||||||
assert labels[0]["key"] == "1"
|
|
||||||
assert "emoji" in labels[0]
|
|
||||||
assert "color" in labels[0]
|
|
||||||
assert "name" in labels[0]
|
|
||||||
|
|
||||||
|
|
||||||
# ── /api/config ──────────────────────────────────────────────────────────────
|
|
||||||
|
|
||||||
@pytest.fixture
|
|
||||||
def config_dir(tmp_path):
|
|
||||||
"""Give the API a writable config directory."""
|
|
||||||
from app import api as api_module
|
|
||||||
from app.data import label as label_module
|
|
||||||
api_module.set_config_dir(tmp_path)
|
|
||||||
label_module.set_config_dir(tmp_path)
|
|
||||||
yield tmp_path
|
|
||||||
api_module.set_config_dir(None) # reset to default
|
|
||||||
label_module.set_config_dir(None)
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
|
||||||
def data_dir():
|
|
||||||
"""Expose the current _DATA_DIR set by the autouse reset_globals fixture."""
|
|
||||||
from app import api as api_module
|
|
||||||
return api_module._DATA_DIR
|
|
||||||
|
|
||||||
|
|
||||||
def test_get_config_returns_empty_when_no_file(client, config_dir):
|
|
||||||
r = client.get("/api/config")
|
|
||||||
assert r.status_code == 200
|
|
||||||
data = r.json()
|
|
||||||
assert data["accounts"] == []
|
|
||||||
assert data["max_per_account"] == 500
|
|
||||||
|
|
||||||
|
|
||||||
def test_post_config_writes_yaml(client, config_dir):
|
|
||||||
import yaml
|
|
||||||
payload = {
|
|
||||||
"accounts": [{"name": "Test", "host": "imap.test.com", "port": 993,
|
|
||||||
"use_ssl": True, "username": "u@t.com", "password": "pw",
|
|
||||||
"folder": "INBOX", "days_back": 30}],
|
|
||||||
"max_per_account": 200,
|
|
||||||
}
|
|
||||||
r = client.post("/api/config", json=payload)
|
|
||||||
assert r.status_code == 200
|
|
||||||
assert r.json()["ok"] is True
|
|
||||||
cfg_file = config_dir / "label_tool.yaml"
|
|
||||||
assert cfg_file.exists()
|
|
||||||
saved = yaml.safe_load(cfg_file.read_text())
|
|
||||||
assert saved["max_per_account"] == 200
|
|
||||||
assert saved["accounts"][0]["name"] == "Test"
|
|
||||||
|
|
||||||
|
|
||||||
def test_get_config_round_trips(client, config_dir):
|
|
||||||
payload = {"accounts": [{"name": "R", "host": "h", "port": 993, "use_ssl": True,
|
|
||||||
"username": "u", "password": "p", "folder": "INBOX",
|
|
||||||
"days_back": 90}], "max_per_account": 300}
|
|
||||||
client.post("/api/config", json=payload)
|
|
||||||
r = client.get("/api/config")
|
|
||||||
data = r.json()
|
|
||||||
assert data["max_per_account"] == 300
|
|
||||||
assert data["accounts"][0]["name"] == "R"
|
|
||||||
|
|
||||||
|
|
||||||
# ── /api/stats ───────────────────────────────────────────────────────────────
|
|
||||||
|
|
||||||
@pytest.fixture
|
|
||||||
def score_with_labels(tmp_path, data_dir):
|
|
||||||
"""Write a score file with 3 labels for stats tests."""
|
|
||||||
score_path = data_dir / "email_score.jsonl"
|
|
||||||
records = [
|
|
||||||
{"id": "a", "label": "interview_scheduled"},
|
|
||||||
{"id": "b", "label": "interview_scheduled"},
|
|
||||||
{"id": "c", "label": "rejected"},
|
|
||||||
]
|
|
||||||
score_path.write_text("\n".join(json.dumps(r) for r in records) + "\n")
|
|
||||||
return records
|
|
||||||
|
|
||||||
|
|
||||||
def test_stats_returns_counts(client, score_with_labels):
|
|
||||||
r = client.get("/api/stats")
|
|
||||||
assert r.status_code == 200
|
|
||||||
data = r.json()
|
|
||||||
assert data["total"] == 3
|
|
||||||
assert data["counts"]["interview_scheduled"] == 2
|
|
||||||
assert data["counts"]["rejected"] == 1
|
|
||||||
|
|
||||||
|
|
||||||
def test_stats_empty_when_no_file(client, data_dir):
|
|
||||||
r = client.get("/api/stats")
|
|
||||||
assert r.status_code == 200
|
|
||||||
data = r.json()
|
|
||||||
assert data["total"] == 0
|
|
||||||
assert data["counts"] == {}
|
|
||||||
assert data["score_file_bytes"] == 0
|
|
||||||
|
|
||||||
|
|
||||||
def test_stats_download_returns_file(client, score_with_labels):
|
|
||||||
r = client.get("/api/stats/download")
|
|
||||||
assert r.status_code == 200
|
|
||||||
assert "jsonlines" in r.headers.get("content-type", "")
|
|
||||||
|
|
||||||
|
|
||||||
def test_stats_download_404_when_no_file(client, data_dir):
|
|
||||||
r = client.get("/api/stats/download")
|
|
||||||
assert r.status_code == 404
|
|
||||||
|
|
||||||
|
|
||||||
# ── /api/accounts/test ───────────────────────────────────────────────────────
|
|
||||||
|
|
||||||
def test_account_test_missing_fields(client):
|
|
||||||
r = client.post("/api/accounts/test", json={"account": {"host": "", "username": "", "password": ""}})
|
|
||||||
assert r.status_code == 200
|
|
||||||
data = r.json()
|
|
||||||
assert data["ok"] is False
|
|
||||||
assert "required" in data["message"].lower()
|
|
||||||
|
|
||||||
|
|
||||||
def test_account_test_success(client):
|
|
||||||
from unittest.mock import MagicMock, patch
|
|
||||||
mock_conn = MagicMock()
|
|
||||||
mock_conn.select.return_value = ("OK", [b"99"])
|
|
||||||
with patch("app.data.fetch.imaplib.IMAP4_SSL", return_value=mock_conn):
|
|
||||||
r = client.post("/api/accounts/test", json={"account": {
|
|
||||||
"host": "imap.example.com", "port": 993, "use_ssl": True,
|
|
||||||
"username": "u@example.com", "password": "pw", "folder": "INBOX",
|
|
||||||
}})
|
|
||||||
assert r.status_code == 200
|
|
||||||
data = r.json()
|
|
||||||
assert data["ok"] is True
|
|
||||||
assert data["count"] == 99
|
|
||||||
|
|
||||||
|
|
||||||
# ── /api/fetch/stream (SSE) ──────────────────────────────────────────────────
|
|
||||||
|
|
||||||
def _parse_sse(content: bytes) -> list[dict]:
|
|
||||||
"""Parse SSE response body into list of event dicts."""
|
|
||||||
events = []
|
|
||||||
for line in content.decode().splitlines():
|
|
||||||
if line.startswith("data: "):
|
|
||||||
events.append(json.loads(line[6:]))
|
|
||||||
return events
|
|
||||||
|
|
||||||
|
|
||||||
def test_fetch_stream_no_accounts_configured(client, config_dir):
|
|
||||||
"""With no config, stream should immediately complete with 0 added."""
|
|
||||||
r = client.get("/api/fetch/stream?accounts=NoSuchAccount&days_back=30&limit=10")
|
|
||||||
assert r.status_code == 200
|
|
||||||
events = _parse_sse(r.content)
|
|
||||||
complete = next((e for e in events if e["type"] == "complete"), None)
|
|
||||||
assert complete is not None
|
|
||||||
assert complete["total_added"] == 0
|
|
||||||
|
|
||||||
|
|
||||||
def test_fetch_stream_with_mock_imap(client, config_dir, data_dir):
|
|
||||||
"""With one configured account, stream should yield start/done/complete events."""
|
|
||||||
import yaml
|
|
||||||
from unittest.mock import MagicMock, patch
|
|
||||||
|
|
||||||
# Write a config with one account
|
|
||||||
cfg = {"accounts": [{"name": "Mock", "host": "h", "port": 993, "use_ssl": True,
|
|
||||||
"username": "u", "password": "p", "folder": "INBOX",
|
|
||||||
"days_back": 30}], "max_per_account": 50}
|
|
||||||
(config_dir / "label_tool.yaml").write_text(yaml.dump(cfg))
|
|
||||||
|
|
||||||
raw_msg = (b"Subject: Interview\r\nFrom: a@b.com\r\n"
|
|
||||||
b"Date: Mon, 1 Mar 2026 12:00:00 +0000\r\n\r\nBody")
|
|
||||||
mock_conn = MagicMock()
|
|
||||||
mock_conn.search.return_value = ("OK", [b"1"])
|
|
||||||
mock_conn.fetch.return_value = ("OK", [(b"1 (RFC822 {N})", raw_msg)])
|
|
||||||
|
|
||||||
with patch("app.data.fetch.imaplib.IMAP4_SSL", return_value=mock_conn):
|
|
||||||
r = client.get("/api/fetch/stream?accounts=Mock&days_back=30&limit=50")
|
|
||||||
|
|
||||||
assert r.status_code == 200
|
|
||||||
events = _parse_sse(r.content)
|
|
||||||
types = [e["type"] for e in events]
|
|
||||||
assert "start" in types
|
|
||||||
assert "done" in types
|
|
||||||
assert "complete" in types
|
|
||||||
|
|
||||||
|
|
||||||
# ---- /api/finetune/status tests ----
|
|
||||||
|
|
||||||
def test_finetune_status_returns_empty_when_no_models_dir(client):
|
|
||||||
"""GET /api/finetune/status must return [] if models/ does not exist."""
|
|
||||||
r = client.get("/api/finetune/status")
|
|
||||||
assert r.status_code == 200
|
|
||||||
assert r.json() == []
|
|
||||||
|
|
||||||
|
|
||||||
def test_finetune_status_returns_training_info(client, tmp_path):
|
|
||||||
"""GET /api/finetune/status must return one entry per training_info.json found."""
|
|
||||||
import json as _json
|
|
||||||
from app import api as api_module
|
|
||||||
|
|
||||||
models_dir = tmp_path / "models" / "avocet-deberta-small"
|
|
||||||
models_dir.mkdir(parents=True)
|
|
||||||
info = {
|
|
||||||
"name": "avocet-deberta-small",
|
|
||||||
"base_model_id": "cross-encoder/nli-deberta-v3-small",
|
|
||||||
"val_macro_f1": 0.712,
|
|
||||||
"timestamp": "2026-03-15T12:00:00Z",
|
|
||||||
"sample_count": 401,
|
|
||||||
}
|
|
||||||
(models_dir / "training_info.json").write_text(_json.dumps(info))
|
|
||||||
|
|
||||||
api_module.set_models_dir(tmp_path / "models")
|
|
||||||
try:
|
|
||||||
r = client.get("/api/finetune/status")
|
|
||||||
assert r.status_code == 200
|
|
||||||
data = r.json()
|
|
||||||
assert any(d["name"] == "avocet-deberta-small" for d in data)
|
|
||||||
finally:
|
|
||||||
api_module.set_models_dir(api_module._ROOT / "models")
|
|
||||||
|
|
||||||
|
|
||||||
def test_finetune_run_streams_sse_events(client):
|
|
||||||
"""GET /api/finetune/run must return text/event-stream content type."""
|
|
||||||
from unittest.mock import patch, MagicMock
|
|
||||||
|
|
||||||
mock_proc = MagicMock()
|
|
||||||
mock_proc.stdout = iter(["Training epoch 1\n", "Done\n"])
|
|
||||||
mock_proc.returncode = 0
|
|
||||||
mock_proc.wait = MagicMock()
|
|
||||||
|
|
||||||
with patch("app.api._subprocess.Popen",return_value=mock_proc):
|
|
||||||
r = client.get("/api/finetune/run?model=deberta-small&epochs=1")
|
|
||||||
|
|
||||||
assert r.status_code == 200
|
|
||||||
assert "text/event-stream" in r.headers.get("content-type", "")
|
|
||||||
|
|
||||||
|
|
||||||
def test_finetune_run_emits_complete_on_success(client):
|
|
||||||
"""GET /api/finetune/run must emit a complete event on clean exit."""
|
|
||||||
from unittest.mock import patch, MagicMock
|
|
||||||
|
|
||||||
mock_proc = MagicMock()
|
|
||||||
mock_proc.stdout = iter(["progress line\n"])
|
|
||||||
mock_proc.returncode = 0
|
|
||||||
mock_proc.wait = MagicMock()
|
|
||||||
|
|
||||||
with patch("app.api._subprocess.Popen",return_value=mock_proc):
|
|
||||||
r = client.get("/api/finetune/run?model=deberta-small&epochs=1")
|
|
||||||
|
|
||||||
assert '{"type": "complete"}' in r.text
|
|
||||||
|
|
||||||
|
|
||||||
def test_finetune_run_emits_error_on_nonzero_exit(client):
|
|
||||||
"""GET /api/finetune/run must emit an error event on non-zero exit."""
|
|
||||||
from unittest.mock import patch, MagicMock
|
|
||||||
|
|
||||||
mock_proc = MagicMock()
|
|
||||||
mock_proc.stdout = iter([])
|
|
||||||
mock_proc.returncode = 1
|
|
||||||
mock_proc.wait = MagicMock()
|
|
||||||
|
|
||||||
with patch("app.api._subprocess.Popen",return_value=mock_proc):
|
|
||||||
r = client.get("/api/finetune/run?model=deberta-small&epochs=1")
|
|
||||||
|
|
||||||
assert '"type": "error"' in r.text
|
|
||||||
|
|
||||||
|
|
||||||
def test_finetune_run_passes_score_files_to_subprocess(client):
|
|
||||||
"""GET /api/finetune/run?score=file1&score=file2 must pass --score args to subprocess."""
|
|
||||||
from unittest.mock import patch, MagicMock
|
|
||||||
|
|
||||||
captured_cmd = []
|
|
||||||
|
|
||||||
def mock_popen(cmd, **kwargs):
|
|
||||||
captured_cmd.extend(cmd)
|
|
||||||
m = MagicMock()
|
|
||||||
m.stdout = iter([])
|
|
||||||
m.returncode = 0
|
|
||||||
m.wait = MagicMock()
|
|
||||||
return m
|
|
||||||
|
|
||||||
with patch("app.api._subprocess.Popen",side_effect=mock_popen):
|
|
||||||
client.get("/api/finetune/run?model=deberta-small&epochs=1&score=run1.jsonl&score=run2.jsonl")
|
|
||||||
|
|
||||||
assert "--score" in captured_cmd
|
|
||||||
assert captured_cmd.count("--score") == 2
|
|
||||||
# Paths are resolved to absolute — check filenames are present as substrings
|
|
||||||
assert any("run1.jsonl" in arg for arg in captured_cmd)
|
|
||||||
assert any("run2.jsonl" in arg for arg in captured_cmd)
|
|
||||||
|
|
||||||
|
|
||||||
# ---- Cancel endpoint tests ----
|
|
||||||
|
|
||||||
|
|
||||||
def test_finetune_cancel_returns_404_when_not_running(client):
|
|
||||||
"""POST /api/finetune/cancel must return 404 if no finetune is running."""
|
|
||||||
from app import api as api_module
|
|
||||||
api_module._running_procs.pop("finetune", None)
|
|
||||||
r = client.post("/api/finetune/cancel")
|
|
||||||
assert r.status_code == 404
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def test_finetune_cancel_terminates_running_process(client):
|
|
||||||
"""POST /api/finetune/cancel must call terminate() on the running process."""
|
|
||||||
from unittest.mock import MagicMock
|
|
||||||
from app import api as api_module
|
|
||||||
|
|
||||||
mock_proc = MagicMock()
|
|
||||||
mock_proc.wait = MagicMock()
|
|
||||||
api_module._running_procs["finetune"] = mock_proc
|
|
||||||
|
|
||||||
try:
|
|
||||||
r = client.post("/api/finetune/cancel")
|
|
||||||
assert r.status_code == 200
|
|
||||||
assert r.json()["status"] == "cancelled"
|
|
||||||
mock_proc.terminate.assert_called_once()
|
|
||||||
finally:
|
|
||||||
api_module._running_procs.pop("finetune", None)
|
|
||||||
api_module._cancelled_jobs.discard("finetune")
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def test_finetune_run_emits_cancelled_event(client):
|
|
||||||
"""GET /api/finetune/run must emit cancelled (not error) when job was cancelled."""
|
|
||||||
from unittest.mock import patch, MagicMock
|
|
||||||
from app import api as api_module
|
|
||||||
|
|
||||||
mock_proc = MagicMock()
|
|
||||||
mock_proc.stdout = iter([])
|
|
||||||
mock_proc.returncode = -15 # SIGTERM
|
|
||||||
|
|
||||||
def mock_wait():
|
|
||||||
# Simulate cancel being called while the process is running (after discard clears stale flag)
|
|
||||||
api_module._cancelled_jobs.add("finetune")
|
|
||||||
|
|
||||||
mock_proc.wait = mock_wait
|
|
||||||
|
|
||||||
def mock_popen(cmd, **kwargs):
|
|
||||||
return mock_proc
|
|
||||||
|
|
||||||
try:
|
|
||||||
with patch("app.api._subprocess.Popen",side_effect=mock_popen):
|
|
||||||
r = client.get("/api/finetune/run?model=deberta-small&epochs=1")
|
|
||||||
assert '{"type": "cancelled"}' in r.text
|
|
||||||
assert '"type": "error"' not in r.text
|
|
||||||
finally:
|
|
||||||
api_module._cancelled_jobs.discard("finetune")
|
|
||||||
|
|
||||||
|
|
|
||||||
187
tests/test_train.py
Normal file
187
tests/test_train.py
Normal file
|
|
@ -0,0 +1,187 @@
|
||||||
|
"""Tests for app/train/train.py -- /api/train/* endpoints."""
|
||||||
|
import json
|
||||||
|
import pytest
|
||||||
|
from fastapi.testclient import TestClient
|
||||||
|
from unittest.mock import MagicMock, patch
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(autouse=True)
|
||||||
|
def reset_globals(tmp_path):
|
||||||
|
from app.train import train as train_module
|
||||||
|
train_module.set_db_path(tmp_path / "train_jobs.db")
|
||||||
|
train_module.set_models_dir(tmp_path / "models")
|
||||||
|
train_module._running_procs.clear()
|
||||||
|
yield
|
||||||
|
train_module._running_procs.clear()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def client():
|
||||||
|
from app.api import app
|
||||||
|
return TestClient(app)
|
||||||
|
|
||||||
|
|
||||||
|
def _parse_sse(content: bytes) -> list[dict]:
|
||||||
|
events = []
|
||||||
|
for line in content.decode().splitlines():
|
||||||
|
if line.startswith("data: "):
|
||||||
|
events.append(json.loads(line[6:]))
|
||||||
|
return events
|
||||||
|
|
||||||
|
|
||||||
|
def test_list_jobs_empty(client):
|
||||||
|
r = client.get("/api/train/jobs")
|
||||||
|
assert r.status_code == 200
|
||||||
|
assert r.json() == []
|
||||||
|
|
||||||
|
|
||||||
|
def test_create_job_returns_queued_record(client):
|
||||||
|
r = client.post("/api/train/jobs",
|
||||||
|
json={"type": "classifier", "model_key": "deberta-small",
|
||||||
|
"config_json": {"epochs": 3}})
|
||||||
|
assert r.status_code == 200
|
||||||
|
data = r.json()
|
||||||
|
assert data["status"] == "queued"
|
||||||
|
assert data["type"] == "classifier"
|
||||||
|
assert data["model_key"] == "deberta-small"
|
||||||
|
assert "id" in data
|
||||||
|
|
||||||
|
|
||||||
|
def test_create_job_invalid_type_returns_400(client):
|
||||||
|
r = client.post("/api/train/jobs",
|
||||||
|
json={"type": "unknown-type", "model_key": "deberta-small"})
|
||||||
|
assert r.status_code == 400
|
||||||
|
|
||||||
|
|
||||||
|
def test_create_job_appears_in_list(client):
|
||||||
|
client.post("/api/train/jobs", json={"type": "classifier", "model_key": "deberta-small"})
|
||||||
|
r = client.get("/api/train/jobs")
|
||||||
|
assert r.status_code == 200
|
||||||
|
assert len(r.json()) == 1
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_job_returns_record(client):
|
||||||
|
r = client.post("/api/train/jobs", json={"type": "classifier", "model_key": "deberta-small"})
|
||||||
|
job_id = r.json()["id"]
|
||||||
|
r2 = client.get(f"/api/train/jobs/{job_id}")
|
||||||
|
assert r2.status_code == 200
|
||||||
|
assert r2.json()["id"] == job_id
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_job_404_for_unknown(client):
|
||||||
|
r = client.get("/api/train/jobs/no-such-id")
|
||||||
|
assert r.status_code == 404
|
||||||
|
|
||||||
|
|
||||||
|
def test_cancel_queued_job(client):
|
||||||
|
r = client.post("/api/train/jobs", json={"type": "classifier", "model_key": "deberta-small"})
|
||||||
|
job_id = r.json()["id"]
|
||||||
|
r2 = client.delete(f"/api/train/jobs/{job_id}/cancel")
|
||||||
|
assert r2.status_code == 200
|
||||||
|
assert r2.json()["status"] == "cancelled"
|
||||||
|
r3 = client.get(f"/api/train/jobs/{job_id}")
|
||||||
|
assert r3.json()["status"] == "cancelled"
|
||||||
|
|
||||||
|
|
||||||
|
def test_cancel_completed_job_returns_409(client):
|
||||||
|
from app.train import train as train_module
|
||||||
|
train_module._init_db()
|
||||||
|
with train_module._db() as conn:
|
||||||
|
conn.execute(
|
||||||
|
"INSERT INTO jobs (id, type, model_key, status, config_json, created_at) "
|
||||||
|
"VALUES ('abc', 'classifier', 'deberta-small', 'completed', '{}', '2026-05-01T00:00:00Z')"
|
||||||
|
)
|
||||||
|
r = client.delete("/api/train/jobs/abc/cancel")
|
||||||
|
assert r.status_code == 409
|
||||||
|
|
||||||
|
|
||||||
|
def test_cancel_terminates_running_proc(client):
|
||||||
|
from app.train import train as train_module
|
||||||
|
mock_proc = MagicMock()
|
||||||
|
mock_proc.wait = MagicMock()
|
||||||
|
r = client.post("/api/train/jobs", json={"type": "classifier", "model_key": "deberta-small"})
|
||||||
|
job_id = r.json()["id"]
|
||||||
|
train_module._running_procs[job_id] = mock_proc
|
||||||
|
with train_module._db() as conn:
|
||||||
|
conn.execute("UPDATE jobs SET status='running' WHERE id=?", (job_id,))
|
||||||
|
r2 = client.delete(f"/api/train/jobs/{job_id}/cancel")
|
||||||
|
assert r2.status_code == 200
|
||||||
|
mock_proc.terminate.assert_called_once()
|
||||||
|
|
||||||
|
|
||||||
|
def test_run_job_streams_sse(client):
|
||||||
|
r = client.post("/api/train/jobs", json={"type": "classifier", "model_key": "deberta-small"})
|
||||||
|
job_id = r.json()["id"]
|
||||||
|
mock_proc = MagicMock()
|
||||||
|
mock_proc.stdout = iter(["Epoch 1\n", "Done\n"])
|
||||||
|
mock_proc.returncode = 0
|
||||||
|
mock_proc.wait = MagicMock()
|
||||||
|
with patch("app.train.train._subprocess.Popen", return_value=mock_proc):
|
||||||
|
r2 = client.get(f"/api/train/jobs/{job_id}/run")
|
||||||
|
assert r2.status_code == 200
|
||||||
|
assert "text/event-stream" in r2.headers.get("content-type", "")
|
||||||
|
events = _parse_sse(r2.content)
|
||||||
|
assert any(e["type"] == "complete" for e in events)
|
||||||
|
|
||||||
|
|
||||||
|
def test_run_job_marks_completed_in_db(client):
|
||||||
|
from app.train import train as train_module
|
||||||
|
r = client.post("/api/train/jobs", json={"type": "classifier", "model_key": "deberta-small"})
|
||||||
|
job_id = r.json()["id"]
|
||||||
|
mock_proc = MagicMock()
|
||||||
|
mock_proc.stdout = iter([])
|
||||||
|
mock_proc.returncode = 0
|
||||||
|
mock_proc.wait = MagicMock()
|
||||||
|
with patch("app.train.train._subprocess.Popen", return_value=mock_proc):
|
||||||
|
client.get(f"/api/train/jobs/{job_id}/run")
|
||||||
|
r2 = client.get(f"/api/train/jobs/{job_id}")
|
||||||
|
assert r2.json()["status"] == "completed"
|
||||||
|
|
||||||
|
|
||||||
|
def test_run_job_marks_failed_on_nonzero_exit(client):
|
||||||
|
r = client.post("/api/train/jobs", json={"type": "classifier", "model_key": "deberta-small"})
|
||||||
|
job_id = r.json()["id"]
|
||||||
|
mock_proc = MagicMock()
|
||||||
|
mock_proc.stdout = iter([])
|
||||||
|
mock_proc.returncode = 1
|
||||||
|
mock_proc.wait = MagicMock()
|
||||||
|
with patch("app.train.train._subprocess.Popen", return_value=mock_proc):
|
||||||
|
client.get(f"/api/train/jobs/{job_id}/run")
|
||||||
|
r2 = client.get(f"/api/train/jobs/{job_id}")
|
||||||
|
assert r2.json()["status"] == "failed"
|
||||||
|
|
||||||
|
|
||||||
|
def test_run_nonqueued_job_returns_409(client):
|
||||||
|
from app.train import train as train_module
|
||||||
|
train_module._init_db()
|
||||||
|
with train_module._db() as conn:
|
||||||
|
conn.execute(
|
||||||
|
"INSERT INTO jobs (id, type, model_key, status, config_json, created_at) "
|
||||||
|
"VALUES ('xyz', 'classifier', 'deberta-small', 'running', '{}', '2026-05-01T00:00:00Z')"
|
||||||
|
)
|
||||||
|
r = client.get("/api/train/jobs/xyz/run")
|
||||||
|
assert r.status_code == 409
|
||||||
|
|
||||||
|
|
||||||
|
def test_run_unknown_job_returns_404(client):
|
||||||
|
r = client.get("/api/train/jobs/no-such/run")
|
||||||
|
assert r.status_code == 404
|
||||||
|
|
||||||
|
|
||||||
|
def test_results_empty_when_no_models_dir(client):
|
||||||
|
r = client.get("/api/train/results")
|
||||||
|
assert r.status_code == 200
|
||||||
|
assert r.json() == []
|
||||||
|
|
||||||
|
|
||||||
|
def test_results_returns_training_info(client, tmp_path):
|
||||||
|
from app.train import train as train_module
|
||||||
|
models_dir = tmp_path / "models" / "avocet-deberta-small"
|
||||||
|
models_dir.mkdir(parents=True)
|
||||||
|
train_module.set_models_dir(tmp_path / "models")
|
||||||
|
info = {"name": "avocet-deberta-small", "val_macro_f1": 0.712, "sample_count": 401}
|
||||||
|
(models_dir / "training_info.json").write_text(json.dumps(info))
|
||||||
|
r = client.get("/api/train/results")
|
||||||
|
assert r.status_code == 200
|
||||||
|
data = r.json()
|
||||||
|
assert any(d["name"] == "avocet-deberta-small" for d in data)
|
||||||
Loading…
Reference in a new issue