feat: build SQLite-backed train job queue in app/train/train.py

Replaces the ad-hoc _running_procs dict in api.py with a persistent,
inspectable SQLite job queue. Removes old /api/finetune/* routes and
_best_cuda_device from api.py. Adds /api/train/* routes (list, create,
get, cancel, run SSE, results). 16 new tests all passing.
This commit is contained in:
pyr0ball 2026-05-01 23:05:11 -07:00
parent d432026fd7
commit 766fbafa02
5 changed files with 501 additions and 626 deletions

View file

@ -22,11 +22,6 @@ _DATA_DIR: Path = _ROOT / "data" # overridable in tests via set_data_dir()
_MODELS_DIR: Path = _ROOT / "models" # overridable in tests via set_models_dir()
_CONFIG_DIR: Path | None = None # None = use real path
# Process registry for running jobs — used by cancel endpoints.
# Keys: "benchmark" | "finetune". Values: the live Popen object.
_running_procs: dict = {}
_cancelled_jobs: set = set()
def set_data_dir(path: Path) -> None:
"""Override data directory — used by tests."""
@ -34,34 +29,6 @@ def set_data_dir(path: Path) -> None:
_DATA_DIR = path
def _best_cuda_device() -> str:
"""Return the index of the GPU with the most free VRAM as a string.
Uses nvidia-smi so it works in the job-seeker env (no torch). Returns ""
if nvidia-smi is unavailable or no GPUs are found. Restricting the
training subprocess to a single GPU via CUDA_VISIBLE_DEVICES prevents
PyTorch DataParallel from replicating the model across all GPUs, which
would OOM the GPU with less headroom.
"""
try:
out = _subprocess.check_output(
["nvidia-smi", "--query-gpu=index,memory.free",
"--format=csv,noheader,nounits"],
text=True,
timeout=5,
)
best_idx, best_free = "", 0
for line in out.strip().splitlines():
parts = line.strip().split(", ")
if len(parts) == 2:
idx, free = parts[0].strip(), int(parts[1].strip())
if free > best_free:
best_free, best_idx = free, idx
return best_idx
except Exception:
return ""
def set_models_dir(path: Path) -> None:
"""Override models directory — used by tests."""
global _MODELS_DIR
@ -156,116 +123,8 @@ app.include_router(imitate_router, prefix="/api/imitate")
from app.data.fetch import router as fetch_router
app.include_router(fetch_router, prefix="/api")
from fastapi.responses import StreamingResponse
# ---------------------------------------------------------------------------
# Finetune endpoints
# ---------------------------------------------------------------------------
@app.get("/api/finetune/status")
def get_finetune_status():
"""Scan models/ for training_info.json files. Returns [] if none exist."""
models_dir = _MODELS_DIR
if not models_dir.exists():
return []
results = []
for sub in models_dir.iterdir():
if not sub.is_dir():
continue
info_path = sub / "training_info.json"
if not info_path.exists():
continue
try:
info = json.loads(info_path.read_text(encoding="utf-8"))
results.append(info)
except Exception:
pass
return results
@app.get("/api/finetune/run")
def run_finetune_endpoint(
model: str = "deberta-small",
epochs: int = 5,
score: list[str] = Query(default=[]),
):
"""Spawn finetune_classifier.py and stream stdout as SSE progress events."""
python_bin = "/devl/miniconda3/envs/job-seeker-classifiers/bin/python"
script = str(_ROOT / "scripts" / "finetune_classifier.py")
cmd = [python_bin, script, "--model", model, "--epochs", str(epochs)]
data_root = _DATA_DIR.resolve()
for score_file in score:
resolved = (_DATA_DIR / score_file).resolve()
if not str(resolved).startswith(str(data_root)):
raise HTTPException(400, f"Invalid score path: {score_file!r}")
cmd.extend(["--score", str(resolved)])
# Pick the GPU with the most free VRAM. Setting CUDA_VISIBLE_DEVICES to a
# single device prevents DataParallel from replicating the model across all
# GPUs, which would force a full copy onto the more memory-constrained device.
proc_env = {**os.environ, "PYTORCH_ALLOC_CONF": "expandable_segments:True"}
best_gpu = _best_cuda_device()
if best_gpu:
proc_env["CUDA_VISIBLE_DEVICES"] = best_gpu
gpu_note = f"GPU {best_gpu}" if best_gpu else "CPU (no GPU found)"
def generate():
yield f"data: {json.dumps({'type': 'progress', 'message': f'[api] Using {gpu_note} (most free VRAM)'})}\n\n"
try:
proc = _subprocess.Popen(
cmd,
stdout=_subprocess.PIPE,
stderr=_subprocess.STDOUT,
text=True,
bufsize=1,
cwd=str(_ROOT),
env=proc_env,
)
_running_procs["finetune"] = proc
_cancelled_jobs.discard("finetune") # clear any stale flag from a prior run
try:
for line in proc.stdout:
line = line.rstrip()
if line:
yield f"data: {json.dumps({'type': 'progress', 'message': line})}\n\n"
proc.wait()
if proc.returncode == 0:
yield f"data: {json.dumps({'type': 'complete'})}\n\n"
elif "finetune" in _cancelled_jobs:
_cancelled_jobs.discard("finetune")
yield f"data: {json.dumps({'type': 'cancelled'})}\n\n"
else:
yield f"data: {json.dumps({'type': 'error', 'message': f'Process exited with code {proc.returncode}'})}\n\n"
finally:
_running_procs.pop("finetune", None)
except Exception as exc:
yield f"data: {json.dumps({'type': 'error', 'message': str(exc)})}\n\n"
return StreamingResponse(
generate(),
media_type="text/event-stream",
headers={"Cache-Control": "no-cache", "X-Accel-Buffering": "no"},
)
@app.post("/api/finetune/cancel")
def cancel_finetune():
"""Kill the running fine-tune subprocess. 404 if none is running."""
proc = _running_procs.get("finetune")
if proc is None:
raise HTTPException(404, "No finetune is running")
_cancelled_jobs.add("finetune")
proc.terminate()
try:
proc.wait(timeout=3)
except _subprocess.TimeoutExpired:
proc.kill()
return {"status": "cancelled"}
from app.train.train import router as train_router
app.include_router(train_router, prefix="/api/train")
# Static SPA — MUST be last (catches all unmatched paths)

0
app/train/__init__.py Normal file
View file

287
app/train/train.py Normal file
View file

@ -0,0 +1,287 @@
"""Avocet -- train job queue API.
SQLite-backed job queue for finetune jobs. Replaces the ad-hoc
_running_procs dict in api.py with a persistent, inspectable queue.
Routes (all under /api/train when api.py mounts with prefix="/api/train"):
GET /jobs -- list all jobs, newest first
POST /jobs -- create a new job
GET /jobs/{id} -- get one job by id
DELETE /jobs/{id}/cancel -- cancel a queued or running job
GET /jobs/{id}/run -- SSE: run the job, stream stdout
GET /results -- list completed models with training_info.json metrics
SQLite schema:
CREATE TABLE IF NOT EXISTS jobs (
id TEXT PRIMARY KEY,
type TEXT NOT NULL, -- 'classifier' | 'llm-sft'
model_key TEXT NOT NULL,
status TEXT NOT NULL DEFAULT 'queued',
config_json TEXT NOT NULL DEFAULT '{}',
created_at TEXT NOT NULL,
started_at TEXT,
completed_at TEXT,
error TEXT
)
Testability seam: _DB_PATH global, override via set_db_path().
"""
from __future__ import annotations
import json
import os
import sqlite3
import subprocess as _subprocess
import uuid
from contextlib import contextmanager
from datetime import datetime, timezone
from pathlib import Path
from typing import Any, Generator
from fastapi import APIRouter, HTTPException
from fastapi.responses import StreamingResponse
from pydantic import BaseModel
_ROOT = Path(__file__).parent.parent.parent
_DB_PATH: Path = _ROOT / "data" / "train_jobs.db"
_MODELS_DIR: Path = _ROOT / "models"
_running_procs: dict[str, Any] = {}
router = APIRouter()
# -- Testability seams -------------------------------------------------
def set_db_path(path: Path) -> None:
global _DB_PATH
_DB_PATH = path
def set_models_dir(path: Path) -> None:
global _MODELS_DIR
_MODELS_DIR = path
# -- Database helpers --------------------------------------------------
@contextmanager
def _db() -> Generator[sqlite3.Connection, None, None]:
conn = sqlite3.connect(str(_DB_PATH))
conn.row_factory = sqlite3.Row
try:
yield conn
conn.commit()
finally:
conn.close()
def _init_db() -> None:
"""Create jobs table if it does not exist. Called lazily per request."""
_DB_PATH.parent.mkdir(parents=True, exist_ok=True)
with _db() as conn:
conn.execute("""
CREATE TABLE IF NOT EXISTS jobs (
id TEXT PRIMARY KEY,
type TEXT NOT NULL,
model_key TEXT NOT NULL,
status TEXT NOT NULL DEFAULT 'queued',
config_json TEXT NOT NULL DEFAULT '{}',
created_at TEXT NOT NULL,
started_at TEXT,
completed_at TEXT,
error TEXT
)
""")
def _row_to_dict(row: sqlite3.Row) -> dict:
return {k: row[k] for k in row.keys()}
# -- GPU selection (copied from api.py) --------------------------------
def _best_cuda_device() -> str:
"""Return index of GPU with most free VRAM, or empty string."""
try:
out = _subprocess.check_output(
["nvidia-smi", "--query-gpu=index,memory.free",
"--format=csv,noheader,nounits"],
text=True, timeout=5,
)
best_idx, best_free = "", 0
for line in out.strip().splitlines():
parts = line.strip().split(", ")
if len(parts) == 2:
idx, free = parts[0].strip(), int(parts[1].strip())
if free > best_free:
best_free, best_idx = free, idx
return best_idx
except Exception:
return ""
# -- Pydantic models ---------------------------------------------------
class CreateJobRequest(BaseModel):
type: str # "classifier" | "llm-sft"
model_key: str # e.g. "deberta-small"
config_json: dict = {}
# -- Routes ------------------------------------------------------------
@router.get("/jobs")
def list_jobs() -> list[dict]:
_init_db()
with _db() as conn:
rows = conn.execute("SELECT * FROM jobs ORDER BY created_at DESC").fetchall()
return [_row_to_dict(r) for r in rows]
@router.post("/jobs")
def create_job(req: CreateJobRequest) -> dict:
if req.type not in ("classifier", "llm-sft"):
raise HTTPException(400, f"Unknown job type: {req.type!r}. Must be 'classifier' or 'llm-sft'.")
_init_db()
job_id = str(uuid.uuid4())
now = datetime.now(timezone.utc).isoformat()
with _db() as conn:
conn.execute(
"INSERT INTO jobs (id, type, model_key, status, config_json, created_at) "
"VALUES (?, ?, ?, 'queued', ?, ?)",
(job_id, req.type, req.model_key, json.dumps(req.config_json), now),
)
return {"id": job_id, "type": req.type, "model_key": req.model_key,
"status": "queued", "config_json": req.config_json,
"created_at": now, "started_at": None, "completed_at": None, "error": None}
@router.get("/jobs/{job_id}")
def get_job(job_id: str) -> dict:
_init_db()
with _db() as conn:
row = conn.execute("SELECT * FROM jobs WHERE id = ?", (job_id,)).fetchone()
if row is None:
raise HTTPException(404, f"Job {job_id!r} not found")
return _row_to_dict(row)
@router.delete("/jobs/{job_id}/cancel")
def cancel_job(job_id: str) -> dict:
_init_db()
with _db() as conn:
row = conn.execute("SELECT * FROM jobs WHERE id = ?", (job_id,)).fetchone()
if row is None:
raise HTTPException(404, f"Job {job_id!r} not found")
if row["status"] not in ("queued", "running"):
raise HTTPException(409, f"Job is {row['status']} -- cannot cancel")
now = datetime.now(timezone.utc).isoformat()
conn.execute("UPDATE jobs SET status='cancelled', completed_at=? WHERE id=?", (now, job_id))
proc = _running_procs.pop(job_id, None)
if proc is not None:
try:
proc.terminate()
proc.wait(timeout=3)
except _subprocess.TimeoutExpired:
proc.kill()
return {"status": "cancelled"}
@router.get("/jobs/{job_id}/run")
def run_job(job_id: str) -> StreamingResponse:
_init_db()
with _db() as conn:
row = conn.execute("SELECT * FROM jobs WHERE id = ?", (job_id,)).fetchone()
if row is None:
raise HTTPException(404, f"Job {job_id!r} not found")
if row["status"] != "queued":
raise HTTPException(409, f"Job is {row['status']} -- only queued jobs can be run")
job = _row_to_dict(row)
def generate():
python_bin = "/devl/miniconda3/envs/job-seeker-classifiers/bin/python"
config = json.loads(job["config_json"] or "{}")
model_key = job["model_key"]
epochs = config.get("epochs", 5)
if job["type"] == "classifier":
script = str(_ROOT / "scripts" / "finetune_classifier.py")
cmd = [python_bin, script, "--model", model_key, "--epochs", str(epochs)]
data_dir = _ROOT / "data"
for sf in config.get("score_files", []):
resolved = (data_dir / sf).resolve()
if str(resolved).startswith(str(data_dir.resolve())):
cmd.extend(["--score", str(resolved)])
elif job["type"] == "llm-sft":
script = str(_ROOT / "scripts" / "finetune_sft.py")
cmd = [python_bin, script, "--model", model_key, "--epochs", str(epochs)]
else:
yield f"data: {json.dumps({'type': 'error', 'message': f'Unknown job type: {job['type']}'})}\n\n"
return
proc_env = {**os.environ, "PYTORCH_ALLOC_CONF": "expandable_segments:True"}
best_gpu = _best_cuda_device()
if best_gpu:
proc_env["CUDA_VISIBLE_DEVICES"] = best_gpu
gpu_note = f"GPU {best_gpu}" if best_gpu else "CPU (no GPU found)"
yield f"data: {json.dumps({'type': 'progress', 'message': f'[train] Using {gpu_note}'})}\n\n"
now = datetime.now(timezone.utc).isoformat()
with _db() as conn:
conn.execute("UPDATE jobs SET status='running', started_at=? WHERE id=?", (now, job_id))
try:
proc = _subprocess.Popen(
cmd, stdout=_subprocess.PIPE, stderr=_subprocess.STDOUT,
text=True, bufsize=1, cwd=str(_ROOT), env=proc_env,
)
_running_procs[job_id] = proc
try:
for line in proc.stdout:
line = line.rstrip()
if line:
yield f"data: {json.dumps({'type': 'progress', 'message': line})}\n\n"
proc.wait()
finished_at = datetime.now(timezone.utc).isoformat()
if proc.returncode == 0:
with _db() as conn:
conn.execute(
"UPDATE jobs SET status='completed', completed_at=? WHERE id=?",
(finished_at, job_id))
yield f"data: {json.dumps({'type': 'complete'})}\n\n"
else:
err = f"Process exited with code {proc.returncode}"
with _db() as conn:
conn.execute(
"UPDATE jobs SET status='failed', completed_at=?, error=? WHERE id=?",
(finished_at, err, job_id))
yield f"data: {json.dumps({'type': 'error', 'message': err})}\n\n"
finally:
_running_procs.pop(job_id, None)
except Exception as exc:
err = str(exc)
with _db() as conn:
conn.execute("UPDATE jobs SET status='failed', error=? WHERE id=?", (err, job_id))
yield f"data: {json.dumps({'type': 'error', 'message': err})}\n\n"
return StreamingResponse(generate(), media_type="text/event-stream",
headers={"Cache-Control": "no-cache", "X-Accel-Buffering": "no"})
@router.get("/results")
def list_results() -> list[dict]:
if not _MODELS_DIR.exists():
return []
results = []
for sub in _MODELS_DIR.iterdir():
if not sub.is_dir():
continue
info_path = sub / "training_info.json"
if not info_path.exists():
continue
try:
info = json.loads(info_path.read_text(encoding="utf-8"))
results.append(info)
except Exception:
pass
return results

View file

@ -1,29 +1,33 @@
import json
"""Smoke tests for the app factory (app/api.py).
Detailed route tests live in test_data_label.py, test_data_fetch.py,
test_data_corrections.py, test_train.py, and test_dashboard.py.
"""
import pytest
from app import api as api_module # noqa: F401
@pytest.fixture(autouse=True)
def reset_globals(tmp_path):
from app import api
from app.data import label as label_module
from app.data import fetch as fetch_module
api.set_data_dir(tmp_path)
label_module.set_data_dir(tmp_path)
label_module.set_config_dir(tmp_path)
label_module.reset_last_action()
fetch_module.set_data_dir(tmp_path)
fetch_module.set_config_dir(tmp_path)
yield
label_module.reset_last_action()
from fastapi.testclient import TestClient
def test_import():
from app import api # noqa: F401
from fastapi.testclient import TestClient
def test_app_has_required_routes():
from app.api import app
paths = {r.path for r in app.routes}
# Label routes
assert "/api/queue" in paths
assert "/api/label" in paths
assert "/api/skip" in paths
assert "/api/discard" in paths
assert "/api/label/undo" in paths
assert "/api/config/labels" in paths
assert "/api/stats" in paths
# Fetch routes
assert "/api/accounts/test" in paths
assert "/api/fetch/stream" in paths
# Train routes
assert "/api/train/jobs" in paths
assert "/api/train/results" in paths
@pytest.fixture
@ -32,470 +36,8 @@ def client():
return TestClient(app)
@pytest.fixture
def queue_with_items():
"""Write 3 test emails to the queue file."""
from app import api as api_module
items = [
{"id": f"id{i}", "subject": f"Subject {i}", "body": f"Body {i}",
"from": "test@example.com", "date": "2026-03-01", "source": "imap:test"}
for i in range(3)
]
queue_path = api_module._DATA_DIR / "email_label_queue.jsonl"
queue_path.write_text("\n".join(json.dumps(x) for x in items) + "\n")
return items
def test_queue_returns_items(client, queue_with_items):
r = client.get("/api/queue?limit=2")
assert r.status_code == 200
data = r.json()
assert len(data["items"]) == 2
assert data["total"] == 3
def test_queue_empty_when_no_file(client):
def test_queue_endpoint_reachable(client):
r = client.get("/api/queue")
assert r.status_code == 200
assert r.json() == {"items": [], "total": 0}
def test_label_appends_to_score(client, queue_with_items):
from app import api as api_module
r = client.post("/api/label", json={"id": "id0", "label": "interview_scheduled"})
assert r.status_code == 200
records = api_module._read_jsonl(api_module._score_file())
assert len(records) == 1
assert records[0]["id"] == "id0"
assert records[0]["label"] == "interview_scheduled"
assert "labeled_at" in records[0]
def test_label_removes_from_queue(client, queue_with_items):
from app import api as api_module
client.post("/api/label", json={"id": "id0", "label": "rejected"})
queue = api_module._read_jsonl(api_module._queue_file())
assert not any(x["id"] == "id0" for x in queue)
def test_label_unknown_id_returns_404(client, queue_with_items):
r = client.post("/api/label", json={"id": "unknown", "label": "neutral"})
assert r.status_code == 404
def test_skip_moves_to_back(client, queue_with_items):
from app import api as api_module
r = client.post("/api/skip", json={"id": "id0"})
assert r.status_code == 200
queue = api_module._read_jsonl(api_module._queue_file())
assert queue[-1]["id"] == "id0"
assert queue[0]["id"] == "id1"
def test_skip_unknown_id_returns_404(client, queue_with_items):
r = client.post("/api/skip", json={"id": "nope"})
assert r.status_code == 404
# --- Part A: POST /api/discard ---
def test_discard_writes_to_discarded_file(client, queue_with_items):
from app import api as api_module
r = client.post("/api/discard", json={"id": "id1"})
assert r.status_code == 200
discarded = api_module._read_jsonl(api_module._discarded_file())
assert len(discarded) == 1
assert discarded[0]["id"] == "id1"
assert discarded[0]["label"] == "__discarded__"
def test_discard_removes_from_queue(client, queue_with_items):
from app import api as api_module
client.post("/api/discard", json={"id": "id1"})
queue = api_module._read_jsonl(api_module._queue_file())
assert not any(x["id"] == "id1" for x in queue)
# --- Part B: DELETE /api/label/undo ---
def test_undo_label_removes_from_score(client, queue_with_items):
from app import api as api_module
client.post("/api/label", json={"id": "id0", "label": "neutral"})
r = client.delete("/api/label/undo")
assert r.status_code == 200
data = r.json()
assert data["undone"]["type"] == "label"
score = api_module._read_jsonl(api_module._score_file())
assert score == []
# Item should be restored to front of queue
queue = api_module._read_jsonl(api_module._queue_file())
assert queue[0]["id"] == "id0"
def test_undo_discard_removes_from_discarded(client, queue_with_items):
from app import api as api_module
client.post("/api/discard", json={"id": "id0"})
r = client.delete("/api/label/undo")
assert r.status_code == 200
discarded = api_module._read_jsonl(api_module._discarded_file())
assert discarded == []
def test_undo_skip_restores_to_front(client, queue_with_items):
from app import api as api_module
client.post("/api/skip", json={"id": "id0"})
r = client.delete("/api/label/undo")
assert r.status_code == 200
queue = api_module._read_jsonl(api_module._queue_file())
assert queue[0]["id"] == "id0"
def test_undo_with_no_action_returns_404(client):
r = client.delete("/api/label/undo")
assert r.status_code == 404
# --- Part C: GET /api/config/labels ---
def test_config_labels_returns_metadata(client):
r = client.get("/api/config/labels")
assert r.status_code == 200
labels = r.json()
assert len(labels) == 10
assert labels[0]["key"] == "1"
assert "emoji" in labels[0]
assert "color" in labels[0]
assert "name" in labels[0]
# ── /api/config ──────────────────────────────────────────────────────────────
@pytest.fixture
def config_dir(tmp_path):
"""Give the API a writable config directory."""
from app import api as api_module
from app.data import label as label_module
api_module.set_config_dir(tmp_path)
label_module.set_config_dir(tmp_path)
yield tmp_path
api_module.set_config_dir(None) # reset to default
label_module.set_config_dir(None)
@pytest.fixture
def data_dir():
"""Expose the current _DATA_DIR set by the autouse reset_globals fixture."""
from app import api as api_module
return api_module._DATA_DIR
def test_get_config_returns_empty_when_no_file(client, config_dir):
r = client.get("/api/config")
assert r.status_code == 200
data = r.json()
assert data["accounts"] == []
assert data["max_per_account"] == 500
def test_post_config_writes_yaml(client, config_dir):
import yaml
payload = {
"accounts": [{"name": "Test", "host": "imap.test.com", "port": 993,
"use_ssl": True, "username": "u@t.com", "password": "pw",
"folder": "INBOX", "days_back": 30}],
"max_per_account": 200,
}
r = client.post("/api/config", json=payload)
assert r.status_code == 200
assert r.json()["ok"] is True
cfg_file = config_dir / "label_tool.yaml"
assert cfg_file.exists()
saved = yaml.safe_load(cfg_file.read_text())
assert saved["max_per_account"] == 200
assert saved["accounts"][0]["name"] == "Test"
def test_get_config_round_trips(client, config_dir):
payload = {"accounts": [{"name": "R", "host": "h", "port": 993, "use_ssl": True,
"username": "u", "password": "p", "folder": "INBOX",
"days_back": 90}], "max_per_account": 300}
client.post("/api/config", json=payload)
r = client.get("/api/config")
data = r.json()
assert data["max_per_account"] == 300
assert data["accounts"][0]["name"] == "R"
# ── /api/stats ───────────────────────────────────────────────────────────────
@pytest.fixture
def score_with_labels(tmp_path, data_dir):
"""Write a score file with 3 labels for stats tests."""
score_path = data_dir / "email_score.jsonl"
records = [
{"id": "a", "label": "interview_scheduled"},
{"id": "b", "label": "interview_scheduled"},
{"id": "c", "label": "rejected"},
]
score_path.write_text("\n".join(json.dumps(r) for r in records) + "\n")
return records
def test_stats_returns_counts(client, score_with_labels):
r = client.get("/api/stats")
assert r.status_code == 200
data = r.json()
assert data["total"] == 3
assert data["counts"]["interview_scheduled"] == 2
assert data["counts"]["rejected"] == 1
def test_stats_empty_when_no_file(client, data_dir):
r = client.get("/api/stats")
assert r.status_code == 200
data = r.json()
assert data["total"] == 0
assert data["counts"] == {}
assert data["score_file_bytes"] == 0
def test_stats_download_returns_file(client, score_with_labels):
r = client.get("/api/stats/download")
assert r.status_code == 200
assert "jsonlines" in r.headers.get("content-type", "")
def test_stats_download_404_when_no_file(client, data_dir):
r = client.get("/api/stats/download")
assert r.status_code == 404
# ── /api/accounts/test ───────────────────────────────────────────────────────
def test_account_test_missing_fields(client):
r = client.post("/api/accounts/test", json={"account": {"host": "", "username": "", "password": ""}})
assert r.status_code == 200
data = r.json()
assert data["ok"] is False
assert "required" in data["message"].lower()
def test_account_test_success(client):
from unittest.mock import MagicMock, patch
mock_conn = MagicMock()
mock_conn.select.return_value = ("OK", [b"99"])
with patch("app.data.fetch.imaplib.IMAP4_SSL", return_value=mock_conn):
r = client.post("/api/accounts/test", json={"account": {
"host": "imap.example.com", "port": 993, "use_ssl": True,
"username": "u@example.com", "password": "pw", "folder": "INBOX",
}})
assert r.status_code == 200
data = r.json()
assert data["ok"] is True
assert data["count"] == 99
# ── /api/fetch/stream (SSE) ──────────────────────────────────────────────────
def _parse_sse(content: bytes) -> list[dict]:
"""Parse SSE response body into list of event dicts."""
events = []
for line in content.decode().splitlines():
if line.startswith("data: "):
events.append(json.loads(line[6:]))
return events
def test_fetch_stream_no_accounts_configured(client, config_dir):
"""With no config, stream should immediately complete with 0 added."""
r = client.get("/api/fetch/stream?accounts=NoSuchAccount&days_back=30&limit=10")
assert r.status_code == 200
events = _parse_sse(r.content)
complete = next((e for e in events if e["type"] == "complete"), None)
assert complete is not None
assert complete["total_added"] == 0
def test_fetch_stream_with_mock_imap(client, config_dir, data_dir):
"""With one configured account, stream should yield start/done/complete events."""
import yaml
from unittest.mock import MagicMock, patch
# Write a config with one account
cfg = {"accounts": [{"name": "Mock", "host": "h", "port": 993, "use_ssl": True,
"username": "u", "password": "p", "folder": "INBOX",
"days_back": 30}], "max_per_account": 50}
(config_dir / "label_tool.yaml").write_text(yaml.dump(cfg))
raw_msg = (b"Subject: Interview\r\nFrom: a@b.com\r\n"
b"Date: Mon, 1 Mar 2026 12:00:00 +0000\r\n\r\nBody")
mock_conn = MagicMock()
mock_conn.search.return_value = ("OK", [b"1"])
mock_conn.fetch.return_value = ("OK", [(b"1 (RFC822 {N})", raw_msg)])
with patch("app.data.fetch.imaplib.IMAP4_SSL", return_value=mock_conn):
r = client.get("/api/fetch/stream?accounts=Mock&days_back=30&limit=50")
assert r.status_code == 200
events = _parse_sse(r.content)
types = [e["type"] for e in events]
assert "start" in types
assert "done" in types
assert "complete" in types
# ---- /api/finetune/status tests ----
def test_finetune_status_returns_empty_when_no_models_dir(client):
"""GET /api/finetune/status must return [] if models/ does not exist."""
r = client.get("/api/finetune/status")
assert r.status_code == 200
assert r.json() == []
def test_finetune_status_returns_training_info(client, tmp_path):
"""GET /api/finetune/status must return one entry per training_info.json found."""
import json as _json
from app import api as api_module
models_dir = tmp_path / "models" / "avocet-deberta-small"
models_dir.mkdir(parents=True)
info = {
"name": "avocet-deberta-small",
"base_model_id": "cross-encoder/nli-deberta-v3-small",
"val_macro_f1": 0.712,
"timestamp": "2026-03-15T12:00:00Z",
"sample_count": 401,
}
(models_dir / "training_info.json").write_text(_json.dumps(info))
api_module.set_models_dir(tmp_path / "models")
try:
r = client.get("/api/finetune/status")
assert r.status_code == 200
data = r.json()
assert any(d["name"] == "avocet-deberta-small" for d in data)
finally:
api_module.set_models_dir(api_module._ROOT / "models")
def test_finetune_run_streams_sse_events(client):
"""GET /api/finetune/run must return text/event-stream content type."""
from unittest.mock import patch, MagicMock
mock_proc = MagicMock()
mock_proc.stdout = iter(["Training epoch 1\n", "Done\n"])
mock_proc.returncode = 0
mock_proc.wait = MagicMock()
with patch("app.api._subprocess.Popen",return_value=mock_proc):
r = client.get("/api/finetune/run?model=deberta-small&epochs=1")
assert r.status_code == 200
assert "text/event-stream" in r.headers.get("content-type", "")
def test_finetune_run_emits_complete_on_success(client):
"""GET /api/finetune/run must emit a complete event on clean exit."""
from unittest.mock import patch, MagicMock
mock_proc = MagicMock()
mock_proc.stdout = iter(["progress line\n"])
mock_proc.returncode = 0
mock_proc.wait = MagicMock()
with patch("app.api._subprocess.Popen",return_value=mock_proc):
r = client.get("/api/finetune/run?model=deberta-small&epochs=1")
assert '{"type": "complete"}' in r.text
def test_finetune_run_emits_error_on_nonzero_exit(client):
"""GET /api/finetune/run must emit an error event on non-zero exit."""
from unittest.mock import patch, MagicMock
mock_proc = MagicMock()
mock_proc.stdout = iter([])
mock_proc.returncode = 1
mock_proc.wait = MagicMock()
with patch("app.api._subprocess.Popen",return_value=mock_proc):
r = client.get("/api/finetune/run?model=deberta-small&epochs=1")
assert '"type": "error"' in r.text
def test_finetune_run_passes_score_files_to_subprocess(client):
"""GET /api/finetune/run?score=file1&score=file2 must pass --score args to subprocess."""
from unittest.mock import patch, MagicMock
captured_cmd = []
def mock_popen(cmd, **kwargs):
captured_cmd.extend(cmd)
m = MagicMock()
m.stdout = iter([])
m.returncode = 0
m.wait = MagicMock()
return m
with patch("app.api._subprocess.Popen",side_effect=mock_popen):
client.get("/api/finetune/run?model=deberta-small&epochs=1&score=run1.jsonl&score=run2.jsonl")
assert "--score" in captured_cmd
assert captured_cmd.count("--score") == 2
# Paths are resolved to absolute — check filenames are present as substrings
assert any("run1.jsonl" in arg for arg in captured_cmd)
assert any("run2.jsonl" in arg for arg in captured_cmd)
# ---- Cancel endpoint tests ----
def test_finetune_cancel_returns_404_when_not_running(client):
"""POST /api/finetune/cancel must return 404 if no finetune is running."""
from app import api as api_module
api_module._running_procs.pop("finetune", None)
r = client.post("/api/finetune/cancel")
assert r.status_code == 404
def test_finetune_cancel_terminates_running_process(client):
"""POST /api/finetune/cancel must call terminate() on the running process."""
from unittest.mock import MagicMock
from app import api as api_module
mock_proc = MagicMock()
mock_proc.wait = MagicMock()
api_module._running_procs["finetune"] = mock_proc
try:
r = client.post("/api/finetune/cancel")
assert r.status_code == 200
assert r.json()["status"] == "cancelled"
mock_proc.terminate.assert_called_once()
finally:
api_module._running_procs.pop("finetune", None)
api_module._cancelled_jobs.discard("finetune")
def test_finetune_run_emits_cancelled_event(client):
"""GET /api/finetune/run must emit cancelled (not error) when job was cancelled."""
from unittest.mock import patch, MagicMock
from app import api as api_module
mock_proc = MagicMock()
mock_proc.stdout = iter([])
mock_proc.returncode = -15 # SIGTERM
def mock_wait():
# Simulate cancel being called while the process is running (after discard clears stale flag)
api_module._cancelled_jobs.add("finetune")
mock_proc.wait = mock_wait
def mock_popen(cmd, **kwargs):
return mock_proc
try:
with patch("app.api._subprocess.Popen",side_effect=mock_popen):
r = client.get("/api/finetune/run?model=deberta-small&epochs=1")
assert '{"type": "cancelled"}' in r.text
assert '"type": "error"' not in r.text
finally:
api_module._cancelled_jobs.discard("finetune")
assert "items" in r.json()
assert "total" in r.json()

187
tests/test_train.py Normal file
View file

@ -0,0 +1,187 @@
"""Tests for app/train/train.py -- /api/train/* endpoints."""
import json
import pytest
from fastapi.testclient import TestClient
from unittest.mock import MagicMock, patch
@pytest.fixture(autouse=True)
def reset_globals(tmp_path):
from app.train import train as train_module
train_module.set_db_path(tmp_path / "train_jobs.db")
train_module.set_models_dir(tmp_path / "models")
train_module._running_procs.clear()
yield
train_module._running_procs.clear()
@pytest.fixture
def client():
from app.api import app
return TestClient(app)
def _parse_sse(content: bytes) -> list[dict]:
events = []
for line in content.decode().splitlines():
if line.startswith("data: "):
events.append(json.loads(line[6:]))
return events
def test_list_jobs_empty(client):
r = client.get("/api/train/jobs")
assert r.status_code == 200
assert r.json() == []
def test_create_job_returns_queued_record(client):
r = client.post("/api/train/jobs",
json={"type": "classifier", "model_key": "deberta-small",
"config_json": {"epochs": 3}})
assert r.status_code == 200
data = r.json()
assert data["status"] == "queued"
assert data["type"] == "classifier"
assert data["model_key"] == "deberta-small"
assert "id" in data
def test_create_job_invalid_type_returns_400(client):
r = client.post("/api/train/jobs",
json={"type": "unknown-type", "model_key": "deberta-small"})
assert r.status_code == 400
def test_create_job_appears_in_list(client):
client.post("/api/train/jobs", json={"type": "classifier", "model_key": "deberta-small"})
r = client.get("/api/train/jobs")
assert r.status_code == 200
assert len(r.json()) == 1
def test_get_job_returns_record(client):
r = client.post("/api/train/jobs", json={"type": "classifier", "model_key": "deberta-small"})
job_id = r.json()["id"]
r2 = client.get(f"/api/train/jobs/{job_id}")
assert r2.status_code == 200
assert r2.json()["id"] == job_id
def test_get_job_404_for_unknown(client):
r = client.get("/api/train/jobs/no-such-id")
assert r.status_code == 404
def test_cancel_queued_job(client):
r = client.post("/api/train/jobs", json={"type": "classifier", "model_key": "deberta-small"})
job_id = r.json()["id"]
r2 = client.delete(f"/api/train/jobs/{job_id}/cancel")
assert r2.status_code == 200
assert r2.json()["status"] == "cancelled"
r3 = client.get(f"/api/train/jobs/{job_id}")
assert r3.json()["status"] == "cancelled"
def test_cancel_completed_job_returns_409(client):
from app.train import train as train_module
train_module._init_db()
with train_module._db() as conn:
conn.execute(
"INSERT INTO jobs (id, type, model_key, status, config_json, created_at) "
"VALUES ('abc', 'classifier', 'deberta-small', 'completed', '{}', '2026-05-01T00:00:00Z')"
)
r = client.delete("/api/train/jobs/abc/cancel")
assert r.status_code == 409
def test_cancel_terminates_running_proc(client):
from app.train import train as train_module
mock_proc = MagicMock()
mock_proc.wait = MagicMock()
r = client.post("/api/train/jobs", json={"type": "classifier", "model_key": "deberta-small"})
job_id = r.json()["id"]
train_module._running_procs[job_id] = mock_proc
with train_module._db() as conn:
conn.execute("UPDATE jobs SET status='running' WHERE id=?", (job_id,))
r2 = client.delete(f"/api/train/jobs/{job_id}/cancel")
assert r2.status_code == 200
mock_proc.terminate.assert_called_once()
def test_run_job_streams_sse(client):
r = client.post("/api/train/jobs", json={"type": "classifier", "model_key": "deberta-small"})
job_id = r.json()["id"]
mock_proc = MagicMock()
mock_proc.stdout = iter(["Epoch 1\n", "Done\n"])
mock_proc.returncode = 0
mock_proc.wait = MagicMock()
with patch("app.train.train._subprocess.Popen", return_value=mock_proc):
r2 = client.get(f"/api/train/jobs/{job_id}/run")
assert r2.status_code == 200
assert "text/event-stream" in r2.headers.get("content-type", "")
events = _parse_sse(r2.content)
assert any(e["type"] == "complete" for e in events)
def test_run_job_marks_completed_in_db(client):
from app.train import train as train_module
r = client.post("/api/train/jobs", json={"type": "classifier", "model_key": "deberta-small"})
job_id = r.json()["id"]
mock_proc = MagicMock()
mock_proc.stdout = iter([])
mock_proc.returncode = 0
mock_proc.wait = MagicMock()
with patch("app.train.train._subprocess.Popen", return_value=mock_proc):
client.get(f"/api/train/jobs/{job_id}/run")
r2 = client.get(f"/api/train/jobs/{job_id}")
assert r2.json()["status"] == "completed"
def test_run_job_marks_failed_on_nonzero_exit(client):
r = client.post("/api/train/jobs", json={"type": "classifier", "model_key": "deberta-small"})
job_id = r.json()["id"]
mock_proc = MagicMock()
mock_proc.stdout = iter([])
mock_proc.returncode = 1
mock_proc.wait = MagicMock()
with patch("app.train.train._subprocess.Popen", return_value=mock_proc):
client.get(f"/api/train/jobs/{job_id}/run")
r2 = client.get(f"/api/train/jobs/{job_id}")
assert r2.json()["status"] == "failed"
def test_run_nonqueued_job_returns_409(client):
from app.train import train as train_module
train_module._init_db()
with train_module._db() as conn:
conn.execute(
"INSERT INTO jobs (id, type, model_key, status, config_json, created_at) "
"VALUES ('xyz', 'classifier', 'deberta-small', 'running', '{}', '2026-05-01T00:00:00Z')"
)
r = client.get("/api/train/jobs/xyz/run")
assert r.status_code == 409
def test_run_unknown_job_returns_404(client):
r = client.get("/api/train/jobs/no-such/run")
assert r.status_code == 404
def test_results_empty_when_no_models_dir(client):
r = client.get("/api/train/results")
assert r.status_code == 200
assert r.json() == []
def test_results_returns_training_info(client, tmp_path):
from app.train import train as train_module
models_dir = tmp_path / "models" / "avocet-deberta-small"
models_dir.mkdir(parents=True)
train_module.set_models_dir(tmp_path / "models")
info = {"name": "avocet-deberta-small", "val_macro_f1": 0.712, "sample_count": 401}
(models_dir / "training_info.json").write_text(json.dumps(info))
r = client.get("/api/train/results")
assert r.status_code == 200
data = r.json()
assert any(d["name"] == "avocet-deberta-small" for d in data)