feat: build SQLite-backed train job queue in app/train/train.py
Replaces the ad-hoc _running_procs dict in api.py with a persistent, inspectable SQLite job queue. Removes old /api/finetune/* routes and _best_cuda_device from api.py. Adds /api/train/* routes (list, create, get, cancel, run SSE, results). 16 new tests all passing.
This commit is contained in:
parent
d432026fd7
commit
766fbafa02
5 changed files with 501 additions and 626 deletions
145
app/api.py
145
app/api.py
|
|
@ -22,11 +22,6 @@ _DATA_DIR: Path = _ROOT / "data" # overridable in tests via set_data_dir()
|
|||
_MODELS_DIR: Path = _ROOT / "models" # overridable in tests via set_models_dir()
|
||||
_CONFIG_DIR: Path | None = None # None = use real path
|
||||
|
||||
# Process registry for running jobs — used by cancel endpoints.
|
||||
# Keys: "benchmark" | "finetune". Values: the live Popen object.
|
||||
_running_procs: dict = {}
|
||||
_cancelled_jobs: set = set()
|
||||
|
||||
|
||||
def set_data_dir(path: Path) -> None:
|
||||
"""Override data directory — used by tests."""
|
||||
|
|
@ -34,34 +29,6 @@ def set_data_dir(path: Path) -> None:
|
|||
_DATA_DIR = path
|
||||
|
||||
|
||||
def _best_cuda_device() -> str:
|
||||
"""Return the index of the GPU with the most free VRAM as a string.
|
||||
|
||||
Uses nvidia-smi so it works in the job-seeker env (no torch). Returns ""
|
||||
if nvidia-smi is unavailable or no GPUs are found. Restricting the
|
||||
training subprocess to a single GPU via CUDA_VISIBLE_DEVICES prevents
|
||||
PyTorch DataParallel from replicating the model across all GPUs, which
|
||||
would OOM the GPU with less headroom.
|
||||
"""
|
||||
try:
|
||||
out = _subprocess.check_output(
|
||||
["nvidia-smi", "--query-gpu=index,memory.free",
|
||||
"--format=csv,noheader,nounits"],
|
||||
text=True,
|
||||
timeout=5,
|
||||
)
|
||||
best_idx, best_free = "", 0
|
||||
for line in out.strip().splitlines():
|
||||
parts = line.strip().split(", ")
|
||||
if len(parts) == 2:
|
||||
idx, free = parts[0].strip(), int(parts[1].strip())
|
||||
if free > best_free:
|
||||
best_free, best_idx = free, idx
|
||||
return best_idx
|
||||
except Exception:
|
||||
return ""
|
||||
|
||||
|
||||
def set_models_dir(path: Path) -> None:
|
||||
"""Override models directory — used by tests."""
|
||||
global _MODELS_DIR
|
||||
|
|
@ -156,116 +123,8 @@ app.include_router(imitate_router, prefix="/api/imitate")
|
|||
from app.data.fetch import router as fetch_router
|
||||
app.include_router(fetch_router, prefix="/api")
|
||||
|
||||
|
||||
from fastapi.responses import StreamingResponse
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Finetune endpoints
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@app.get("/api/finetune/status")
|
||||
def get_finetune_status():
|
||||
"""Scan models/ for training_info.json files. Returns [] if none exist."""
|
||||
models_dir = _MODELS_DIR
|
||||
if not models_dir.exists():
|
||||
return []
|
||||
results = []
|
||||
for sub in models_dir.iterdir():
|
||||
if not sub.is_dir():
|
||||
continue
|
||||
info_path = sub / "training_info.json"
|
||||
if not info_path.exists():
|
||||
continue
|
||||
try:
|
||||
info = json.loads(info_path.read_text(encoding="utf-8"))
|
||||
results.append(info)
|
||||
except Exception:
|
||||
pass
|
||||
return results
|
||||
|
||||
|
||||
@app.get("/api/finetune/run")
|
||||
def run_finetune_endpoint(
|
||||
model: str = "deberta-small",
|
||||
epochs: int = 5,
|
||||
score: list[str] = Query(default=[]),
|
||||
):
|
||||
"""Spawn finetune_classifier.py and stream stdout as SSE progress events."""
|
||||
python_bin = "/devl/miniconda3/envs/job-seeker-classifiers/bin/python"
|
||||
script = str(_ROOT / "scripts" / "finetune_classifier.py")
|
||||
cmd = [python_bin, script, "--model", model, "--epochs", str(epochs)]
|
||||
data_root = _DATA_DIR.resolve()
|
||||
for score_file in score:
|
||||
resolved = (_DATA_DIR / score_file).resolve()
|
||||
if not str(resolved).startswith(str(data_root)):
|
||||
raise HTTPException(400, f"Invalid score path: {score_file!r}")
|
||||
cmd.extend(["--score", str(resolved)])
|
||||
|
||||
# Pick the GPU with the most free VRAM. Setting CUDA_VISIBLE_DEVICES to a
|
||||
# single device prevents DataParallel from replicating the model across all
|
||||
# GPUs, which would force a full copy onto the more memory-constrained device.
|
||||
proc_env = {**os.environ, "PYTORCH_ALLOC_CONF": "expandable_segments:True"}
|
||||
best_gpu = _best_cuda_device()
|
||||
if best_gpu:
|
||||
proc_env["CUDA_VISIBLE_DEVICES"] = best_gpu
|
||||
|
||||
gpu_note = f"GPU {best_gpu}" if best_gpu else "CPU (no GPU found)"
|
||||
|
||||
def generate():
|
||||
yield f"data: {json.dumps({'type': 'progress', 'message': f'[api] Using {gpu_note} (most free VRAM)'})}\n\n"
|
||||
try:
|
||||
proc = _subprocess.Popen(
|
||||
cmd,
|
||||
stdout=_subprocess.PIPE,
|
||||
stderr=_subprocess.STDOUT,
|
||||
text=True,
|
||||
bufsize=1,
|
||||
cwd=str(_ROOT),
|
||||
env=proc_env,
|
||||
)
|
||||
_running_procs["finetune"] = proc
|
||||
_cancelled_jobs.discard("finetune") # clear any stale flag from a prior run
|
||||
try:
|
||||
for line in proc.stdout:
|
||||
line = line.rstrip()
|
||||
if line:
|
||||
yield f"data: {json.dumps({'type': 'progress', 'message': line})}\n\n"
|
||||
proc.wait()
|
||||
if proc.returncode == 0:
|
||||
yield f"data: {json.dumps({'type': 'complete'})}\n\n"
|
||||
elif "finetune" in _cancelled_jobs:
|
||||
_cancelled_jobs.discard("finetune")
|
||||
yield f"data: {json.dumps({'type': 'cancelled'})}\n\n"
|
||||
else:
|
||||
yield f"data: {json.dumps({'type': 'error', 'message': f'Process exited with code {proc.returncode}'})}\n\n"
|
||||
finally:
|
||||
_running_procs.pop("finetune", None)
|
||||
except Exception as exc:
|
||||
yield f"data: {json.dumps({'type': 'error', 'message': str(exc)})}\n\n"
|
||||
|
||||
return StreamingResponse(
|
||||
generate(),
|
||||
media_type="text/event-stream",
|
||||
headers={"Cache-Control": "no-cache", "X-Accel-Buffering": "no"},
|
||||
)
|
||||
|
||||
|
||||
|
||||
@app.post("/api/finetune/cancel")
|
||||
def cancel_finetune():
|
||||
"""Kill the running fine-tune subprocess. 404 if none is running."""
|
||||
proc = _running_procs.get("finetune")
|
||||
if proc is None:
|
||||
raise HTTPException(404, "No finetune is running")
|
||||
_cancelled_jobs.add("finetune")
|
||||
proc.terminate()
|
||||
try:
|
||||
proc.wait(timeout=3)
|
||||
except _subprocess.TimeoutExpired:
|
||||
proc.kill()
|
||||
return {"status": "cancelled"}
|
||||
|
||||
from app.train.train import router as train_router
|
||||
app.include_router(train_router, prefix="/api/train")
|
||||
|
||||
|
||||
# Static SPA — MUST be last (catches all unmatched paths)
|
||||
|
|
|
|||
0
app/train/__init__.py
Normal file
0
app/train/__init__.py
Normal file
287
app/train/train.py
Normal file
287
app/train/train.py
Normal file
|
|
@ -0,0 +1,287 @@
|
|||
"""Avocet -- train job queue API.
|
||||
|
||||
SQLite-backed job queue for finetune jobs. Replaces the ad-hoc
|
||||
_running_procs dict in api.py with a persistent, inspectable queue.
|
||||
|
||||
Routes (all under /api/train when api.py mounts with prefix="/api/train"):
|
||||
GET /jobs -- list all jobs, newest first
|
||||
POST /jobs -- create a new job
|
||||
GET /jobs/{id} -- get one job by id
|
||||
DELETE /jobs/{id}/cancel -- cancel a queued or running job
|
||||
GET /jobs/{id}/run -- SSE: run the job, stream stdout
|
||||
GET /results -- list completed models with training_info.json metrics
|
||||
|
||||
SQLite schema:
|
||||
CREATE TABLE IF NOT EXISTS jobs (
|
||||
id TEXT PRIMARY KEY,
|
||||
type TEXT NOT NULL, -- 'classifier' | 'llm-sft'
|
||||
model_key TEXT NOT NULL,
|
||||
status TEXT NOT NULL DEFAULT 'queued',
|
||||
config_json TEXT NOT NULL DEFAULT '{}',
|
||||
created_at TEXT NOT NULL,
|
||||
started_at TEXT,
|
||||
completed_at TEXT,
|
||||
error TEXT
|
||||
)
|
||||
|
||||
Testability seam: _DB_PATH global, override via set_db_path().
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import os
|
||||
import sqlite3
|
||||
import subprocess as _subprocess
|
||||
import uuid
|
||||
from contextlib import contextmanager
|
||||
from datetime import datetime, timezone
|
||||
from pathlib import Path
|
||||
from typing import Any, Generator
|
||||
|
||||
from fastapi import APIRouter, HTTPException
|
||||
from fastapi.responses import StreamingResponse
|
||||
from pydantic import BaseModel
|
||||
|
||||
_ROOT = Path(__file__).parent.parent.parent
|
||||
_DB_PATH: Path = _ROOT / "data" / "train_jobs.db"
|
||||
_MODELS_DIR: Path = _ROOT / "models"
|
||||
_running_procs: dict[str, Any] = {}
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
# -- Testability seams -------------------------------------------------
|
||||
|
||||
def set_db_path(path: Path) -> None:
|
||||
global _DB_PATH
|
||||
_DB_PATH = path
|
||||
|
||||
def set_models_dir(path: Path) -> None:
|
||||
global _MODELS_DIR
|
||||
_MODELS_DIR = path
|
||||
|
||||
|
||||
# -- Database helpers --------------------------------------------------
|
||||
|
||||
@contextmanager
|
||||
def _db() -> Generator[sqlite3.Connection, None, None]:
|
||||
conn = sqlite3.connect(str(_DB_PATH))
|
||||
conn.row_factory = sqlite3.Row
|
||||
try:
|
||||
yield conn
|
||||
conn.commit()
|
||||
finally:
|
||||
conn.close()
|
||||
|
||||
|
||||
def _init_db() -> None:
|
||||
"""Create jobs table if it does not exist. Called lazily per request."""
|
||||
_DB_PATH.parent.mkdir(parents=True, exist_ok=True)
|
||||
with _db() as conn:
|
||||
conn.execute("""
|
||||
CREATE TABLE IF NOT EXISTS jobs (
|
||||
id TEXT PRIMARY KEY,
|
||||
type TEXT NOT NULL,
|
||||
model_key TEXT NOT NULL,
|
||||
status TEXT NOT NULL DEFAULT 'queued',
|
||||
config_json TEXT NOT NULL DEFAULT '{}',
|
||||
created_at TEXT NOT NULL,
|
||||
started_at TEXT,
|
||||
completed_at TEXT,
|
||||
error TEXT
|
||||
)
|
||||
""")
|
||||
|
||||
|
||||
def _row_to_dict(row: sqlite3.Row) -> dict:
|
||||
return {k: row[k] for k in row.keys()}
|
||||
|
||||
|
||||
# -- GPU selection (copied from api.py) --------------------------------
|
||||
|
||||
def _best_cuda_device() -> str:
|
||||
"""Return index of GPU with most free VRAM, or empty string."""
|
||||
try:
|
||||
out = _subprocess.check_output(
|
||||
["nvidia-smi", "--query-gpu=index,memory.free",
|
||||
"--format=csv,noheader,nounits"],
|
||||
text=True, timeout=5,
|
||||
)
|
||||
best_idx, best_free = "", 0
|
||||
for line in out.strip().splitlines():
|
||||
parts = line.strip().split(", ")
|
||||
if len(parts) == 2:
|
||||
idx, free = parts[0].strip(), int(parts[1].strip())
|
||||
if free > best_free:
|
||||
best_free, best_idx = free, idx
|
||||
return best_idx
|
||||
except Exception:
|
||||
return ""
|
||||
|
||||
|
||||
# -- Pydantic models ---------------------------------------------------
|
||||
|
||||
class CreateJobRequest(BaseModel):
|
||||
type: str # "classifier" | "llm-sft"
|
||||
model_key: str # e.g. "deberta-small"
|
||||
config_json: dict = {}
|
||||
|
||||
|
||||
# -- Routes ------------------------------------------------------------
|
||||
|
||||
@router.get("/jobs")
|
||||
def list_jobs() -> list[dict]:
|
||||
_init_db()
|
||||
with _db() as conn:
|
||||
rows = conn.execute("SELECT * FROM jobs ORDER BY created_at DESC").fetchall()
|
||||
return [_row_to_dict(r) for r in rows]
|
||||
|
||||
|
||||
@router.post("/jobs")
|
||||
def create_job(req: CreateJobRequest) -> dict:
|
||||
if req.type not in ("classifier", "llm-sft"):
|
||||
raise HTTPException(400, f"Unknown job type: {req.type!r}. Must be 'classifier' or 'llm-sft'.")
|
||||
_init_db()
|
||||
job_id = str(uuid.uuid4())
|
||||
now = datetime.now(timezone.utc).isoformat()
|
||||
with _db() as conn:
|
||||
conn.execute(
|
||||
"INSERT INTO jobs (id, type, model_key, status, config_json, created_at) "
|
||||
"VALUES (?, ?, ?, 'queued', ?, ?)",
|
||||
(job_id, req.type, req.model_key, json.dumps(req.config_json), now),
|
||||
)
|
||||
return {"id": job_id, "type": req.type, "model_key": req.model_key,
|
||||
"status": "queued", "config_json": req.config_json,
|
||||
"created_at": now, "started_at": None, "completed_at": None, "error": None}
|
||||
|
||||
|
||||
@router.get("/jobs/{job_id}")
|
||||
def get_job(job_id: str) -> dict:
|
||||
_init_db()
|
||||
with _db() as conn:
|
||||
row = conn.execute("SELECT * FROM jobs WHERE id = ?", (job_id,)).fetchone()
|
||||
if row is None:
|
||||
raise HTTPException(404, f"Job {job_id!r} not found")
|
||||
return _row_to_dict(row)
|
||||
|
||||
|
||||
@router.delete("/jobs/{job_id}/cancel")
|
||||
def cancel_job(job_id: str) -> dict:
|
||||
_init_db()
|
||||
with _db() as conn:
|
||||
row = conn.execute("SELECT * FROM jobs WHERE id = ?", (job_id,)).fetchone()
|
||||
if row is None:
|
||||
raise HTTPException(404, f"Job {job_id!r} not found")
|
||||
if row["status"] not in ("queued", "running"):
|
||||
raise HTTPException(409, f"Job is {row['status']} -- cannot cancel")
|
||||
now = datetime.now(timezone.utc).isoformat()
|
||||
conn.execute("UPDATE jobs SET status='cancelled', completed_at=? WHERE id=?", (now, job_id))
|
||||
proc = _running_procs.pop(job_id, None)
|
||||
if proc is not None:
|
||||
try:
|
||||
proc.terminate()
|
||||
proc.wait(timeout=3)
|
||||
except _subprocess.TimeoutExpired:
|
||||
proc.kill()
|
||||
return {"status": "cancelled"}
|
||||
|
||||
|
||||
@router.get("/jobs/{job_id}/run")
|
||||
def run_job(job_id: str) -> StreamingResponse:
|
||||
_init_db()
|
||||
with _db() as conn:
|
||||
row = conn.execute("SELECT * FROM jobs WHERE id = ?", (job_id,)).fetchone()
|
||||
if row is None:
|
||||
raise HTTPException(404, f"Job {job_id!r} not found")
|
||||
if row["status"] != "queued":
|
||||
raise HTTPException(409, f"Job is {row['status']} -- only queued jobs can be run")
|
||||
job = _row_to_dict(row)
|
||||
|
||||
def generate():
|
||||
python_bin = "/devl/miniconda3/envs/job-seeker-classifiers/bin/python"
|
||||
config = json.loads(job["config_json"] or "{}")
|
||||
model_key = job["model_key"]
|
||||
epochs = config.get("epochs", 5)
|
||||
|
||||
if job["type"] == "classifier":
|
||||
script = str(_ROOT / "scripts" / "finetune_classifier.py")
|
||||
cmd = [python_bin, script, "--model", model_key, "--epochs", str(epochs)]
|
||||
data_dir = _ROOT / "data"
|
||||
for sf in config.get("score_files", []):
|
||||
resolved = (data_dir / sf).resolve()
|
||||
if str(resolved).startswith(str(data_dir.resolve())):
|
||||
cmd.extend(["--score", str(resolved)])
|
||||
elif job["type"] == "llm-sft":
|
||||
script = str(_ROOT / "scripts" / "finetune_sft.py")
|
||||
cmd = [python_bin, script, "--model", model_key, "--epochs", str(epochs)]
|
||||
else:
|
||||
yield f"data: {json.dumps({'type': 'error', 'message': f'Unknown job type: {job['type']}'})}\n\n"
|
||||
return
|
||||
|
||||
proc_env = {**os.environ, "PYTORCH_ALLOC_CONF": "expandable_segments:True"}
|
||||
best_gpu = _best_cuda_device()
|
||||
if best_gpu:
|
||||
proc_env["CUDA_VISIBLE_DEVICES"] = best_gpu
|
||||
|
||||
gpu_note = f"GPU {best_gpu}" if best_gpu else "CPU (no GPU found)"
|
||||
yield f"data: {json.dumps({'type': 'progress', 'message': f'[train] Using {gpu_note}'})}\n\n"
|
||||
|
||||
now = datetime.now(timezone.utc).isoformat()
|
||||
with _db() as conn:
|
||||
conn.execute("UPDATE jobs SET status='running', started_at=? WHERE id=?", (now, job_id))
|
||||
|
||||
try:
|
||||
proc = _subprocess.Popen(
|
||||
cmd, stdout=_subprocess.PIPE, stderr=_subprocess.STDOUT,
|
||||
text=True, bufsize=1, cwd=str(_ROOT), env=proc_env,
|
||||
)
|
||||
_running_procs[job_id] = proc
|
||||
try:
|
||||
for line in proc.stdout:
|
||||
line = line.rstrip()
|
||||
if line:
|
||||
yield f"data: {json.dumps({'type': 'progress', 'message': line})}\n\n"
|
||||
proc.wait()
|
||||
finished_at = datetime.now(timezone.utc).isoformat()
|
||||
if proc.returncode == 0:
|
||||
with _db() as conn:
|
||||
conn.execute(
|
||||
"UPDATE jobs SET status='completed', completed_at=? WHERE id=?",
|
||||
(finished_at, job_id))
|
||||
yield f"data: {json.dumps({'type': 'complete'})}\n\n"
|
||||
else:
|
||||
err = f"Process exited with code {proc.returncode}"
|
||||
with _db() as conn:
|
||||
conn.execute(
|
||||
"UPDATE jobs SET status='failed', completed_at=?, error=? WHERE id=?",
|
||||
(finished_at, err, job_id))
|
||||
yield f"data: {json.dumps({'type': 'error', 'message': err})}\n\n"
|
||||
finally:
|
||||
_running_procs.pop(job_id, None)
|
||||
except Exception as exc:
|
||||
err = str(exc)
|
||||
with _db() as conn:
|
||||
conn.execute("UPDATE jobs SET status='failed', error=? WHERE id=?", (err, job_id))
|
||||
yield f"data: {json.dumps({'type': 'error', 'message': err})}\n\n"
|
||||
|
||||
return StreamingResponse(generate(), media_type="text/event-stream",
|
||||
headers={"Cache-Control": "no-cache", "X-Accel-Buffering": "no"})
|
||||
|
||||
|
||||
@router.get("/results")
|
||||
def list_results() -> list[dict]:
|
||||
if not _MODELS_DIR.exists():
|
||||
return []
|
||||
results = []
|
||||
for sub in _MODELS_DIR.iterdir():
|
||||
if not sub.is_dir():
|
||||
continue
|
||||
info_path = sub / "training_info.json"
|
||||
if not info_path.exists():
|
||||
continue
|
||||
try:
|
||||
info = json.loads(info_path.read_text(encoding="utf-8"))
|
||||
results.append(info)
|
||||
except Exception:
|
||||
pass
|
||||
return results
|
||||
|
|
@ -1,29 +1,33 @@
|
|||
import json
|
||||
"""Smoke tests for the app factory (app/api.py).
|
||||
|
||||
Detailed route tests live in test_data_label.py, test_data_fetch.py,
|
||||
test_data_corrections.py, test_train.py, and test_dashboard.py.
|
||||
"""
|
||||
import pytest
|
||||
from app import api as api_module # noqa: F401
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def reset_globals(tmp_path):
|
||||
from app import api
|
||||
from app.data import label as label_module
|
||||
from app.data import fetch as fetch_module
|
||||
api.set_data_dir(tmp_path)
|
||||
label_module.set_data_dir(tmp_path)
|
||||
label_module.set_config_dir(tmp_path)
|
||||
label_module.reset_last_action()
|
||||
fetch_module.set_data_dir(tmp_path)
|
||||
fetch_module.set_config_dir(tmp_path)
|
||||
yield
|
||||
label_module.reset_last_action()
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
|
||||
def test_import():
|
||||
from app import api # noqa: F401
|
||||
|
||||
|
||||
from fastapi.testclient import TestClient
|
||||
def test_app_has_required_routes():
|
||||
from app.api import app
|
||||
paths = {r.path for r in app.routes}
|
||||
# Label routes
|
||||
assert "/api/queue" in paths
|
||||
assert "/api/label" in paths
|
||||
assert "/api/skip" in paths
|
||||
assert "/api/discard" in paths
|
||||
assert "/api/label/undo" in paths
|
||||
assert "/api/config/labels" in paths
|
||||
assert "/api/stats" in paths
|
||||
# Fetch routes
|
||||
assert "/api/accounts/test" in paths
|
||||
assert "/api/fetch/stream" in paths
|
||||
# Train routes
|
||||
assert "/api/train/jobs" in paths
|
||||
assert "/api/train/results" in paths
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
|
|
@ -32,470 +36,8 @@ def client():
|
|||
return TestClient(app)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def queue_with_items():
|
||||
"""Write 3 test emails to the queue file."""
|
||||
from app import api as api_module
|
||||
items = [
|
||||
{"id": f"id{i}", "subject": f"Subject {i}", "body": f"Body {i}",
|
||||
"from": "test@example.com", "date": "2026-03-01", "source": "imap:test"}
|
||||
for i in range(3)
|
||||
]
|
||||
queue_path = api_module._DATA_DIR / "email_label_queue.jsonl"
|
||||
queue_path.write_text("\n".join(json.dumps(x) for x in items) + "\n")
|
||||
return items
|
||||
|
||||
|
||||
def test_queue_returns_items(client, queue_with_items):
|
||||
r = client.get("/api/queue?limit=2")
|
||||
assert r.status_code == 200
|
||||
data = r.json()
|
||||
assert len(data["items"]) == 2
|
||||
assert data["total"] == 3
|
||||
|
||||
|
||||
def test_queue_empty_when_no_file(client):
|
||||
def test_queue_endpoint_reachable(client):
|
||||
r = client.get("/api/queue")
|
||||
assert r.status_code == 200
|
||||
assert r.json() == {"items": [], "total": 0}
|
||||
|
||||
|
||||
def test_label_appends_to_score(client, queue_with_items):
|
||||
from app import api as api_module
|
||||
r = client.post("/api/label", json={"id": "id0", "label": "interview_scheduled"})
|
||||
assert r.status_code == 200
|
||||
records = api_module._read_jsonl(api_module._score_file())
|
||||
assert len(records) == 1
|
||||
assert records[0]["id"] == "id0"
|
||||
assert records[0]["label"] == "interview_scheduled"
|
||||
assert "labeled_at" in records[0]
|
||||
|
||||
def test_label_removes_from_queue(client, queue_with_items):
|
||||
from app import api as api_module
|
||||
client.post("/api/label", json={"id": "id0", "label": "rejected"})
|
||||
queue = api_module._read_jsonl(api_module._queue_file())
|
||||
assert not any(x["id"] == "id0" for x in queue)
|
||||
|
||||
def test_label_unknown_id_returns_404(client, queue_with_items):
|
||||
r = client.post("/api/label", json={"id": "unknown", "label": "neutral"})
|
||||
assert r.status_code == 404
|
||||
|
||||
def test_skip_moves_to_back(client, queue_with_items):
|
||||
from app import api as api_module
|
||||
r = client.post("/api/skip", json={"id": "id0"})
|
||||
assert r.status_code == 200
|
||||
queue = api_module._read_jsonl(api_module._queue_file())
|
||||
assert queue[-1]["id"] == "id0"
|
||||
assert queue[0]["id"] == "id1"
|
||||
|
||||
def test_skip_unknown_id_returns_404(client, queue_with_items):
|
||||
r = client.post("/api/skip", json={"id": "nope"})
|
||||
assert r.status_code == 404
|
||||
|
||||
|
||||
# --- Part A: POST /api/discard ---
|
||||
|
||||
def test_discard_writes_to_discarded_file(client, queue_with_items):
|
||||
from app import api as api_module
|
||||
r = client.post("/api/discard", json={"id": "id1"})
|
||||
assert r.status_code == 200
|
||||
discarded = api_module._read_jsonl(api_module._discarded_file())
|
||||
assert len(discarded) == 1
|
||||
assert discarded[0]["id"] == "id1"
|
||||
assert discarded[0]["label"] == "__discarded__"
|
||||
|
||||
def test_discard_removes_from_queue(client, queue_with_items):
|
||||
from app import api as api_module
|
||||
client.post("/api/discard", json={"id": "id1"})
|
||||
queue = api_module._read_jsonl(api_module._queue_file())
|
||||
assert not any(x["id"] == "id1" for x in queue)
|
||||
|
||||
|
||||
# --- Part B: DELETE /api/label/undo ---
|
||||
|
||||
def test_undo_label_removes_from_score(client, queue_with_items):
|
||||
from app import api as api_module
|
||||
client.post("/api/label", json={"id": "id0", "label": "neutral"})
|
||||
r = client.delete("/api/label/undo")
|
||||
assert r.status_code == 200
|
||||
data = r.json()
|
||||
assert data["undone"]["type"] == "label"
|
||||
score = api_module._read_jsonl(api_module._score_file())
|
||||
assert score == []
|
||||
# Item should be restored to front of queue
|
||||
queue = api_module._read_jsonl(api_module._queue_file())
|
||||
assert queue[0]["id"] == "id0"
|
||||
|
||||
def test_undo_discard_removes_from_discarded(client, queue_with_items):
|
||||
from app import api as api_module
|
||||
client.post("/api/discard", json={"id": "id0"})
|
||||
r = client.delete("/api/label/undo")
|
||||
assert r.status_code == 200
|
||||
discarded = api_module._read_jsonl(api_module._discarded_file())
|
||||
assert discarded == []
|
||||
|
||||
def test_undo_skip_restores_to_front(client, queue_with_items):
|
||||
from app import api as api_module
|
||||
client.post("/api/skip", json={"id": "id0"})
|
||||
r = client.delete("/api/label/undo")
|
||||
assert r.status_code == 200
|
||||
queue = api_module._read_jsonl(api_module._queue_file())
|
||||
assert queue[0]["id"] == "id0"
|
||||
|
||||
def test_undo_with_no_action_returns_404(client):
|
||||
r = client.delete("/api/label/undo")
|
||||
assert r.status_code == 404
|
||||
|
||||
|
||||
# --- Part C: GET /api/config/labels ---
|
||||
|
||||
def test_config_labels_returns_metadata(client):
|
||||
r = client.get("/api/config/labels")
|
||||
assert r.status_code == 200
|
||||
labels = r.json()
|
||||
assert len(labels) == 10
|
||||
assert labels[0]["key"] == "1"
|
||||
assert "emoji" in labels[0]
|
||||
assert "color" in labels[0]
|
||||
assert "name" in labels[0]
|
||||
|
||||
|
||||
# ── /api/config ──────────────────────────────────────────────────────────────
|
||||
|
||||
@pytest.fixture
|
||||
def config_dir(tmp_path):
|
||||
"""Give the API a writable config directory."""
|
||||
from app import api as api_module
|
||||
from app.data import label as label_module
|
||||
api_module.set_config_dir(tmp_path)
|
||||
label_module.set_config_dir(tmp_path)
|
||||
yield tmp_path
|
||||
api_module.set_config_dir(None) # reset to default
|
||||
label_module.set_config_dir(None)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def data_dir():
|
||||
"""Expose the current _DATA_DIR set by the autouse reset_globals fixture."""
|
||||
from app import api as api_module
|
||||
return api_module._DATA_DIR
|
||||
|
||||
|
||||
def test_get_config_returns_empty_when_no_file(client, config_dir):
|
||||
r = client.get("/api/config")
|
||||
assert r.status_code == 200
|
||||
data = r.json()
|
||||
assert data["accounts"] == []
|
||||
assert data["max_per_account"] == 500
|
||||
|
||||
|
||||
def test_post_config_writes_yaml(client, config_dir):
|
||||
import yaml
|
||||
payload = {
|
||||
"accounts": [{"name": "Test", "host": "imap.test.com", "port": 993,
|
||||
"use_ssl": True, "username": "u@t.com", "password": "pw",
|
||||
"folder": "INBOX", "days_back": 30}],
|
||||
"max_per_account": 200,
|
||||
}
|
||||
r = client.post("/api/config", json=payload)
|
||||
assert r.status_code == 200
|
||||
assert r.json()["ok"] is True
|
||||
cfg_file = config_dir / "label_tool.yaml"
|
||||
assert cfg_file.exists()
|
||||
saved = yaml.safe_load(cfg_file.read_text())
|
||||
assert saved["max_per_account"] == 200
|
||||
assert saved["accounts"][0]["name"] == "Test"
|
||||
|
||||
|
||||
def test_get_config_round_trips(client, config_dir):
|
||||
payload = {"accounts": [{"name": "R", "host": "h", "port": 993, "use_ssl": True,
|
||||
"username": "u", "password": "p", "folder": "INBOX",
|
||||
"days_back": 90}], "max_per_account": 300}
|
||||
client.post("/api/config", json=payload)
|
||||
r = client.get("/api/config")
|
||||
data = r.json()
|
||||
assert data["max_per_account"] == 300
|
||||
assert data["accounts"][0]["name"] == "R"
|
||||
|
||||
|
||||
# ── /api/stats ───────────────────────────────────────────────────────────────
|
||||
|
||||
@pytest.fixture
|
||||
def score_with_labels(tmp_path, data_dir):
|
||||
"""Write a score file with 3 labels for stats tests."""
|
||||
score_path = data_dir / "email_score.jsonl"
|
||||
records = [
|
||||
{"id": "a", "label": "interview_scheduled"},
|
||||
{"id": "b", "label": "interview_scheduled"},
|
||||
{"id": "c", "label": "rejected"},
|
||||
]
|
||||
score_path.write_text("\n".join(json.dumps(r) for r in records) + "\n")
|
||||
return records
|
||||
|
||||
|
||||
def test_stats_returns_counts(client, score_with_labels):
|
||||
r = client.get("/api/stats")
|
||||
assert r.status_code == 200
|
||||
data = r.json()
|
||||
assert data["total"] == 3
|
||||
assert data["counts"]["interview_scheduled"] == 2
|
||||
assert data["counts"]["rejected"] == 1
|
||||
|
||||
|
||||
def test_stats_empty_when_no_file(client, data_dir):
|
||||
r = client.get("/api/stats")
|
||||
assert r.status_code == 200
|
||||
data = r.json()
|
||||
assert data["total"] == 0
|
||||
assert data["counts"] == {}
|
||||
assert data["score_file_bytes"] == 0
|
||||
|
||||
|
||||
def test_stats_download_returns_file(client, score_with_labels):
|
||||
r = client.get("/api/stats/download")
|
||||
assert r.status_code == 200
|
||||
assert "jsonlines" in r.headers.get("content-type", "")
|
||||
|
||||
|
||||
def test_stats_download_404_when_no_file(client, data_dir):
|
||||
r = client.get("/api/stats/download")
|
||||
assert r.status_code == 404
|
||||
|
||||
|
||||
# ── /api/accounts/test ───────────────────────────────────────────────────────
|
||||
|
||||
def test_account_test_missing_fields(client):
|
||||
r = client.post("/api/accounts/test", json={"account": {"host": "", "username": "", "password": ""}})
|
||||
assert r.status_code == 200
|
||||
data = r.json()
|
||||
assert data["ok"] is False
|
||||
assert "required" in data["message"].lower()
|
||||
|
||||
|
||||
def test_account_test_success(client):
|
||||
from unittest.mock import MagicMock, patch
|
||||
mock_conn = MagicMock()
|
||||
mock_conn.select.return_value = ("OK", [b"99"])
|
||||
with patch("app.data.fetch.imaplib.IMAP4_SSL", return_value=mock_conn):
|
||||
r = client.post("/api/accounts/test", json={"account": {
|
||||
"host": "imap.example.com", "port": 993, "use_ssl": True,
|
||||
"username": "u@example.com", "password": "pw", "folder": "INBOX",
|
||||
}})
|
||||
assert r.status_code == 200
|
||||
data = r.json()
|
||||
assert data["ok"] is True
|
||||
assert data["count"] == 99
|
||||
|
||||
|
||||
# ── /api/fetch/stream (SSE) ──────────────────────────────────────────────────
|
||||
|
||||
def _parse_sse(content: bytes) -> list[dict]:
|
||||
"""Parse SSE response body into list of event dicts."""
|
||||
events = []
|
||||
for line in content.decode().splitlines():
|
||||
if line.startswith("data: "):
|
||||
events.append(json.loads(line[6:]))
|
||||
return events
|
||||
|
||||
|
||||
def test_fetch_stream_no_accounts_configured(client, config_dir):
|
||||
"""With no config, stream should immediately complete with 0 added."""
|
||||
r = client.get("/api/fetch/stream?accounts=NoSuchAccount&days_back=30&limit=10")
|
||||
assert r.status_code == 200
|
||||
events = _parse_sse(r.content)
|
||||
complete = next((e for e in events if e["type"] == "complete"), None)
|
||||
assert complete is not None
|
||||
assert complete["total_added"] == 0
|
||||
|
||||
|
||||
def test_fetch_stream_with_mock_imap(client, config_dir, data_dir):
|
||||
"""With one configured account, stream should yield start/done/complete events."""
|
||||
import yaml
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
# Write a config with one account
|
||||
cfg = {"accounts": [{"name": "Mock", "host": "h", "port": 993, "use_ssl": True,
|
||||
"username": "u", "password": "p", "folder": "INBOX",
|
||||
"days_back": 30}], "max_per_account": 50}
|
||||
(config_dir / "label_tool.yaml").write_text(yaml.dump(cfg))
|
||||
|
||||
raw_msg = (b"Subject: Interview\r\nFrom: a@b.com\r\n"
|
||||
b"Date: Mon, 1 Mar 2026 12:00:00 +0000\r\n\r\nBody")
|
||||
mock_conn = MagicMock()
|
||||
mock_conn.search.return_value = ("OK", [b"1"])
|
||||
mock_conn.fetch.return_value = ("OK", [(b"1 (RFC822 {N})", raw_msg)])
|
||||
|
||||
with patch("app.data.fetch.imaplib.IMAP4_SSL", return_value=mock_conn):
|
||||
r = client.get("/api/fetch/stream?accounts=Mock&days_back=30&limit=50")
|
||||
|
||||
assert r.status_code == 200
|
||||
events = _parse_sse(r.content)
|
||||
types = [e["type"] for e in events]
|
||||
assert "start" in types
|
||||
assert "done" in types
|
||||
assert "complete" in types
|
||||
|
||||
|
||||
# ---- /api/finetune/status tests ----
|
||||
|
||||
def test_finetune_status_returns_empty_when_no_models_dir(client):
|
||||
"""GET /api/finetune/status must return [] if models/ does not exist."""
|
||||
r = client.get("/api/finetune/status")
|
||||
assert r.status_code == 200
|
||||
assert r.json() == []
|
||||
|
||||
|
||||
def test_finetune_status_returns_training_info(client, tmp_path):
|
||||
"""GET /api/finetune/status must return one entry per training_info.json found."""
|
||||
import json as _json
|
||||
from app import api as api_module
|
||||
|
||||
models_dir = tmp_path / "models" / "avocet-deberta-small"
|
||||
models_dir.mkdir(parents=True)
|
||||
info = {
|
||||
"name": "avocet-deberta-small",
|
||||
"base_model_id": "cross-encoder/nli-deberta-v3-small",
|
||||
"val_macro_f1": 0.712,
|
||||
"timestamp": "2026-03-15T12:00:00Z",
|
||||
"sample_count": 401,
|
||||
}
|
||||
(models_dir / "training_info.json").write_text(_json.dumps(info))
|
||||
|
||||
api_module.set_models_dir(tmp_path / "models")
|
||||
try:
|
||||
r = client.get("/api/finetune/status")
|
||||
assert r.status_code == 200
|
||||
data = r.json()
|
||||
assert any(d["name"] == "avocet-deberta-small" for d in data)
|
||||
finally:
|
||||
api_module.set_models_dir(api_module._ROOT / "models")
|
||||
|
||||
|
||||
def test_finetune_run_streams_sse_events(client):
|
||||
"""GET /api/finetune/run must return text/event-stream content type."""
|
||||
from unittest.mock import patch, MagicMock
|
||||
|
||||
mock_proc = MagicMock()
|
||||
mock_proc.stdout = iter(["Training epoch 1\n", "Done\n"])
|
||||
mock_proc.returncode = 0
|
||||
mock_proc.wait = MagicMock()
|
||||
|
||||
with patch("app.api._subprocess.Popen",return_value=mock_proc):
|
||||
r = client.get("/api/finetune/run?model=deberta-small&epochs=1")
|
||||
|
||||
assert r.status_code == 200
|
||||
assert "text/event-stream" in r.headers.get("content-type", "")
|
||||
|
||||
|
||||
def test_finetune_run_emits_complete_on_success(client):
|
||||
"""GET /api/finetune/run must emit a complete event on clean exit."""
|
||||
from unittest.mock import patch, MagicMock
|
||||
|
||||
mock_proc = MagicMock()
|
||||
mock_proc.stdout = iter(["progress line\n"])
|
||||
mock_proc.returncode = 0
|
||||
mock_proc.wait = MagicMock()
|
||||
|
||||
with patch("app.api._subprocess.Popen",return_value=mock_proc):
|
||||
r = client.get("/api/finetune/run?model=deberta-small&epochs=1")
|
||||
|
||||
assert '{"type": "complete"}' in r.text
|
||||
|
||||
|
||||
def test_finetune_run_emits_error_on_nonzero_exit(client):
|
||||
"""GET /api/finetune/run must emit an error event on non-zero exit."""
|
||||
from unittest.mock import patch, MagicMock
|
||||
|
||||
mock_proc = MagicMock()
|
||||
mock_proc.stdout = iter([])
|
||||
mock_proc.returncode = 1
|
||||
mock_proc.wait = MagicMock()
|
||||
|
||||
with patch("app.api._subprocess.Popen",return_value=mock_proc):
|
||||
r = client.get("/api/finetune/run?model=deberta-small&epochs=1")
|
||||
|
||||
assert '"type": "error"' in r.text
|
||||
|
||||
|
||||
def test_finetune_run_passes_score_files_to_subprocess(client):
|
||||
"""GET /api/finetune/run?score=file1&score=file2 must pass --score args to subprocess."""
|
||||
from unittest.mock import patch, MagicMock
|
||||
|
||||
captured_cmd = []
|
||||
|
||||
def mock_popen(cmd, **kwargs):
|
||||
captured_cmd.extend(cmd)
|
||||
m = MagicMock()
|
||||
m.stdout = iter([])
|
||||
m.returncode = 0
|
||||
m.wait = MagicMock()
|
||||
return m
|
||||
|
||||
with patch("app.api._subprocess.Popen",side_effect=mock_popen):
|
||||
client.get("/api/finetune/run?model=deberta-small&epochs=1&score=run1.jsonl&score=run2.jsonl")
|
||||
|
||||
assert "--score" in captured_cmd
|
||||
assert captured_cmd.count("--score") == 2
|
||||
# Paths are resolved to absolute — check filenames are present as substrings
|
||||
assert any("run1.jsonl" in arg for arg in captured_cmd)
|
||||
assert any("run2.jsonl" in arg for arg in captured_cmd)
|
||||
|
||||
|
||||
# ---- Cancel endpoint tests ----
|
||||
|
||||
|
||||
def test_finetune_cancel_returns_404_when_not_running(client):
|
||||
"""POST /api/finetune/cancel must return 404 if no finetune is running."""
|
||||
from app import api as api_module
|
||||
api_module._running_procs.pop("finetune", None)
|
||||
r = client.post("/api/finetune/cancel")
|
||||
assert r.status_code == 404
|
||||
|
||||
|
||||
|
||||
def test_finetune_cancel_terminates_running_process(client):
|
||||
"""POST /api/finetune/cancel must call terminate() on the running process."""
|
||||
from unittest.mock import MagicMock
|
||||
from app import api as api_module
|
||||
|
||||
mock_proc = MagicMock()
|
||||
mock_proc.wait = MagicMock()
|
||||
api_module._running_procs["finetune"] = mock_proc
|
||||
|
||||
try:
|
||||
r = client.post("/api/finetune/cancel")
|
||||
assert r.status_code == 200
|
||||
assert r.json()["status"] == "cancelled"
|
||||
mock_proc.terminate.assert_called_once()
|
||||
finally:
|
||||
api_module._running_procs.pop("finetune", None)
|
||||
api_module._cancelled_jobs.discard("finetune")
|
||||
|
||||
|
||||
|
||||
def test_finetune_run_emits_cancelled_event(client):
|
||||
"""GET /api/finetune/run must emit cancelled (not error) when job was cancelled."""
|
||||
from unittest.mock import patch, MagicMock
|
||||
from app import api as api_module
|
||||
|
||||
mock_proc = MagicMock()
|
||||
mock_proc.stdout = iter([])
|
||||
mock_proc.returncode = -15 # SIGTERM
|
||||
|
||||
def mock_wait():
|
||||
# Simulate cancel being called while the process is running (after discard clears stale flag)
|
||||
api_module._cancelled_jobs.add("finetune")
|
||||
|
||||
mock_proc.wait = mock_wait
|
||||
|
||||
def mock_popen(cmd, **kwargs):
|
||||
return mock_proc
|
||||
|
||||
try:
|
||||
with patch("app.api._subprocess.Popen",side_effect=mock_popen):
|
||||
r = client.get("/api/finetune/run?model=deberta-small&epochs=1")
|
||||
assert '{"type": "cancelled"}' in r.text
|
||||
assert '"type": "error"' not in r.text
|
||||
finally:
|
||||
api_module._cancelled_jobs.discard("finetune")
|
||||
|
||||
assert "items" in r.json()
|
||||
assert "total" in r.json()
|
||||
|
|
|
|||
187
tests/test_train.py
Normal file
187
tests/test_train.py
Normal file
|
|
@ -0,0 +1,187 @@
|
|||
"""Tests for app/train/train.py -- /api/train/* endpoints."""
|
||||
import json
|
||||
import pytest
|
||||
from fastapi.testclient import TestClient
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def reset_globals(tmp_path):
|
||||
from app.train import train as train_module
|
||||
train_module.set_db_path(tmp_path / "train_jobs.db")
|
||||
train_module.set_models_dir(tmp_path / "models")
|
||||
train_module._running_procs.clear()
|
||||
yield
|
||||
train_module._running_procs.clear()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def client():
|
||||
from app.api import app
|
||||
return TestClient(app)
|
||||
|
||||
|
||||
def _parse_sse(content: bytes) -> list[dict]:
|
||||
events = []
|
||||
for line in content.decode().splitlines():
|
||||
if line.startswith("data: "):
|
||||
events.append(json.loads(line[6:]))
|
||||
return events
|
||||
|
||||
|
||||
def test_list_jobs_empty(client):
|
||||
r = client.get("/api/train/jobs")
|
||||
assert r.status_code == 200
|
||||
assert r.json() == []
|
||||
|
||||
|
||||
def test_create_job_returns_queued_record(client):
|
||||
r = client.post("/api/train/jobs",
|
||||
json={"type": "classifier", "model_key": "deberta-small",
|
||||
"config_json": {"epochs": 3}})
|
||||
assert r.status_code == 200
|
||||
data = r.json()
|
||||
assert data["status"] == "queued"
|
||||
assert data["type"] == "classifier"
|
||||
assert data["model_key"] == "deberta-small"
|
||||
assert "id" in data
|
||||
|
||||
|
||||
def test_create_job_invalid_type_returns_400(client):
|
||||
r = client.post("/api/train/jobs",
|
||||
json={"type": "unknown-type", "model_key": "deberta-small"})
|
||||
assert r.status_code == 400
|
||||
|
||||
|
||||
def test_create_job_appears_in_list(client):
|
||||
client.post("/api/train/jobs", json={"type": "classifier", "model_key": "deberta-small"})
|
||||
r = client.get("/api/train/jobs")
|
||||
assert r.status_code == 200
|
||||
assert len(r.json()) == 1
|
||||
|
||||
|
||||
def test_get_job_returns_record(client):
|
||||
r = client.post("/api/train/jobs", json={"type": "classifier", "model_key": "deberta-small"})
|
||||
job_id = r.json()["id"]
|
||||
r2 = client.get(f"/api/train/jobs/{job_id}")
|
||||
assert r2.status_code == 200
|
||||
assert r2.json()["id"] == job_id
|
||||
|
||||
|
||||
def test_get_job_404_for_unknown(client):
|
||||
r = client.get("/api/train/jobs/no-such-id")
|
||||
assert r.status_code == 404
|
||||
|
||||
|
||||
def test_cancel_queued_job(client):
|
||||
r = client.post("/api/train/jobs", json={"type": "classifier", "model_key": "deberta-small"})
|
||||
job_id = r.json()["id"]
|
||||
r2 = client.delete(f"/api/train/jobs/{job_id}/cancel")
|
||||
assert r2.status_code == 200
|
||||
assert r2.json()["status"] == "cancelled"
|
||||
r3 = client.get(f"/api/train/jobs/{job_id}")
|
||||
assert r3.json()["status"] == "cancelled"
|
||||
|
||||
|
||||
def test_cancel_completed_job_returns_409(client):
|
||||
from app.train import train as train_module
|
||||
train_module._init_db()
|
||||
with train_module._db() as conn:
|
||||
conn.execute(
|
||||
"INSERT INTO jobs (id, type, model_key, status, config_json, created_at) "
|
||||
"VALUES ('abc', 'classifier', 'deberta-small', 'completed', '{}', '2026-05-01T00:00:00Z')"
|
||||
)
|
||||
r = client.delete("/api/train/jobs/abc/cancel")
|
||||
assert r.status_code == 409
|
||||
|
||||
|
||||
def test_cancel_terminates_running_proc(client):
|
||||
from app.train import train as train_module
|
||||
mock_proc = MagicMock()
|
||||
mock_proc.wait = MagicMock()
|
||||
r = client.post("/api/train/jobs", json={"type": "classifier", "model_key": "deberta-small"})
|
||||
job_id = r.json()["id"]
|
||||
train_module._running_procs[job_id] = mock_proc
|
||||
with train_module._db() as conn:
|
||||
conn.execute("UPDATE jobs SET status='running' WHERE id=?", (job_id,))
|
||||
r2 = client.delete(f"/api/train/jobs/{job_id}/cancel")
|
||||
assert r2.status_code == 200
|
||||
mock_proc.terminate.assert_called_once()
|
||||
|
||||
|
||||
def test_run_job_streams_sse(client):
|
||||
r = client.post("/api/train/jobs", json={"type": "classifier", "model_key": "deberta-small"})
|
||||
job_id = r.json()["id"]
|
||||
mock_proc = MagicMock()
|
||||
mock_proc.stdout = iter(["Epoch 1\n", "Done\n"])
|
||||
mock_proc.returncode = 0
|
||||
mock_proc.wait = MagicMock()
|
||||
with patch("app.train.train._subprocess.Popen", return_value=mock_proc):
|
||||
r2 = client.get(f"/api/train/jobs/{job_id}/run")
|
||||
assert r2.status_code == 200
|
||||
assert "text/event-stream" in r2.headers.get("content-type", "")
|
||||
events = _parse_sse(r2.content)
|
||||
assert any(e["type"] == "complete" for e in events)
|
||||
|
||||
|
||||
def test_run_job_marks_completed_in_db(client):
|
||||
from app.train import train as train_module
|
||||
r = client.post("/api/train/jobs", json={"type": "classifier", "model_key": "deberta-small"})
|
||||
job_id = r.json()["id"]
|
||||
mock_proc = MagicMock()
|
||||
mock_proc.stdout = iter([])
|
||||
mock_proc.returncode = 0
|
||||
mock_proc.wait = MagicMock()
|
||||
with patch("app.train.train._subprocess.Popen", return_value=mock_proc):
|
||||
client.get(f"/api/train/jobs/{job_id}/run")
|
||||
r2 = client.get(f"/api/train/jobs/{job_id}")
|
||||
assert r2.json()["status"] == "completed"
|
||||
|
||||
|
||||
def test_run_job_marks_failed_on_nonzero_exit(client):
|
||||
r = client.post("/api/train/jobs", json={"type": "classifier", "model_key": "deberta-small"})
|
||||
job_id = r.json()["id"]
|
||||
mock_proc = MagicMock()
|
||||
mock_proc.stdout = iter([])
|
||||
mock_proc.returncode = 1
|
||||
mock_proc.wait = MagicMock()
|
||||
with patch("app.train.train._subprocess.Popen", return_value=mock_proc):
|
||||
client.get(f"/api/train/jobs/{job_id}/run")
|
||||
r2 = client.get(f"/api/train/jobs/{job_id}")
|
||||
assert r2.json()["status"] == "failed"
|
||||
|
||||
|
||||
def test_run_nonqueued_job_returns_409(client):
|
||||
from app.train import train as train_module
|
||||
train_module._init_db()
|
||||
with train_module._db() as conn:
|
||||
conn.execute(
|
||||
"INSERT INTO jobs (id, type, model_key, status, config_json, created_at) "
|
||||
"VALUES ('xyz', 'classifier', 'deberta-small', 'running', '{}', '2026-05-01T00:00:00Z')"
|
||||
)
|
||||
r = client.get("/api/train/jobs/xyz/run")
|
||||
assert r.status_code == 409
|
||||
|
||||
|
||||
def test_run_unknown_job_returns_404(client):
|
||||
r = client.get("/api/train/jobs/no-such/run")
|
||||
assert r.status_code == 404
|
||||
|
||||
|
||||
def test_results_empty_when_no_models_dir(client):
|
||||
r = client.get("/api/train/results")
|
||||
assert r.status_code == 200
|
||||
assert r.json() == []
|
||||
|
||||
|
||||
def test_results_returns_training_info(client, tmp_path):
|
||||
from app.train import train as train_module
|
||||
models_dir = tmp_path / "models" / "avocet-deberta-small"
|
||||
models_dir.mkdir(parents=True)
|
||||
train_module.set_models_dir(tmp_path / "models")
|
||||
info = {"name": "avocet-deberta-small", "val_macro_f1": 0.712, "sample_count": 401}
|
||||
(models_dir / "training_info.json").write_text(json.dumps(info))
|
||||
r = client.get("/api/train/results")
|
||||
assert r.status_code == 200
|
||||
data = r.json()
|
||||
assert any(d["name"] == "avocet-deberta-small" for d in data)
|
||||
Loading…
Reference in a new issue