diff --git a/app/api.py b/app/api.py index f71ab50..2930571 100644 --- a/app/api.py +++ b/app/api.py @@ -22,11 +22,6 @@ _DATA_DIR: Path = _ROOT / "data" # overridable in tests via set_data_dir() _MODELS_DIR: Path = _ROOT / "models" # overridable in tests via set_models_dir() _CONFIG_DIR: Path | None = None # None = use real path -# Process registry for running jobs — used by cancel endpoints. -# Keys: "benchmark" | "finetune". Values: the live Popen object. -_running_procs: dict = {} -_cancelled_jobs: set = set() - def set_data_dir(path: Path) -> None: """Override data directory — used by tests.""" @@ -34,34 +29,6 @@ def set_data_dir(path: Path) -> None: _DATA_DIR = path -def _best_cuda_device() -> str: - """Return the index of the GPU with the most free VRAM as a string. - - Uses nvidia-smi so it works in the job-seeker env (no torch). Returns "" - if nvidia-smi is unavailable or no GPUs are found. Restricting the - training subprocess to a single GPU via CUDA_VISIBLE_DEVICES prevents - PyTorch DataParallel from replicating the model across all GPUs, which - would OOM the GPU with less headroom. - """ - try: - out = _subprocess.check_output( - ["nvidia-smi", "--query-gpu=index,memory.free", - "--format=csv,noheader,nounits"], - text=True, - timeout=5, - ) - best_idx, best_free = "", 0 - for line in out.strip().splitlines(): - parts = line.strip().split(", ") - if len(parts) == 2: - idx, free = parts[0].strip(), int(parts[1].strip()) - if free > best_free: - best_free, best_idx = free, idx - return best_idx - except Exception: - return "" - - def set_models_dir(path: Path) -> None: """Override models directory — used by tests.""" global _MODELS_DIR @@ -156,116 +123,8 @@ app.include_router(imitate_router, prefix="/api/imitate") from app.data.fetch import router as fetch_router app.include_router(fetch_router, prefix="/api") - -from fastapi.responses import StreamingResponse - - -# --------------------------------------------------------------------------- -# Finetune endpoints -# --------------------------------------------------------------------------- - -@app.get("/api/finetune/status") -def get_finetune_status(): - """Scan models/ for training_info.json files. Returns [] if none exist.""" - models_dir = _MODELS_DIR - if not models_dir.exists(): - return [] - results = [] - for sub in models_dir.iterdir(): - if not sub.is_dir(): - continue - info_path = sub / "training_info.json" - if not info_path.exists(): - continue - try: - info = json.loads(info_path.read_text(encoding="utf-8")) - results.append(info) - except Exception: - pass - return results - - -@app.get("/api/finetune/run") -def run_finetune_endpoint( - model: str = "deberta-small", - epochs: int = 5, - score: list[str] = Query(default=[]), -): - """Spawn finetune_classifier.py and stream stdout as SSE progress events.""" - python_bin = "/devl/miniconda3/envs/job-seeker-classifiers/bin/python" - script = str(_ROOT / "scripts" / "finetune_classifier.py") - cmd = [python_bin, script, "--model", model, "--epochs", str(epochs)] - data_root = _DATA_DIR.resolve() - for score_file in score: - resolved = (_DATA_DIR / score_file).resolve() - if not str(resolved).startswith(str(data_root)): - raise HTTPException(400, f"Invalid score path: {score_file!r}") - cmd.extend(["--score", str(resolved)]) - - # Pick the GPU with the most free VRAM. Setting CUDA_VISIBLE_DEVICES to a - # single device prevents DataParallel from replicating the model across all - # GPUs, which would force a full copy onto the more memory-constrained device. - proc_env = {**os.environ, "PYTORCH_ALLOC_CONF": "expandable_segments:True"} - best_gpu = _best_cuda_device() - if best_gpu: - proc_env["CUDA_VISIBLE_DEVICES"] = best_gpu - - gpu_note = f"GPU {best_gpu}" if best_gpu else "CPU (no GPU found)" - - def generate(): - yield f"data: {json.dumps({'type': 'progress', 'message': f'[api] Using {gpu_note} (most free VRAM)'})}\n\n" - try: - proc = _subprocess.Popen( - cmd, - stdout=_subprocess.PIPE, - stderr=_subprocess.STDOUT, - text=True, - bufsize=1, - cwd=str(_ROOT), - env=proc_env, - ) - _running_procs["finetune"] = proc - _cancelled_jobs.discard("finetune") # clear any stale flag from a prior run - try: - for line in proc.stdout: - line = line.rstrip() - if line: - yield f"data: {json.dumps({'type': 'progress', 'message': line})}\n\n" - proc.wait() - if proc.returncode == 0: - yield f"data: {json.dumps({'type': 'complete'})}\n\n" - elif "finetune" in _cancelled_jobs: - _cancelled_jobs.discard("finetune") - yield f"data: {json.dumps({'type': 'cancelled'})}\n\n" - else: - yield f"data: {json.dumps({'type': 'error', 'message': f'Process exited with code {proc.returncode}'})}\n\n" - finally: - _running_procs.pop("finetune", None) - except Exception as exc: - yield f"data: {json.dumps({'type': 'error', 'message': str(exc)})}\n\n" - - return StreamingResponse( - generate(), - media_type="text/event-stream", - headers={"Cache-Control": "no-cache", "X-Accel-Buffering": "no"}, - ) - - - -@app.post("/api/finetune/cancel") -def cancel_finetune(): - """Kill the running fine-tune subprocess. 404 if none is running.""" - proc = _running_procs.get("finetune") - if proc is None: - raise HTTPException(404, "No finetune is running") - _cancelled_jobs.add("finetune") - proc.terminate() - try: - proc.wait(timeout=3) - except _subprocess.TimeoutExpired: - proc.kill() - return {"status": "cancelled"} - +from app.train.train import router as train_router +app.include_router(train_router, prefix="/api/train") # Static SPA — MUST be last (catches all unmatched paths) diff --git a/app/train/__init__.py b/app/train/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/app/train/train.py b/app/train/train.py new file mode 100644 index 0000000..9dfedd4 --- /dev/null +++ b/app/train/train.py @@ -0,0 +1,287 @@ +"""Avocet -- train job queue API. + +SQLite-backed job queue for finetune jobs. Replaces the ad-hoc +_running_procs dict in api.py with a persistent, inspectable queue. + +Routes (all under /api/train when api.py mounts with prefix="/api/train"): + GET /jobs -- list all jobs, newest first + POST /jobs -- create a new job + GET /jobs/{id} -- get one job by id + DELETE /jobs/{id}/cancel -- cancel a queued or running job + GET /jobs/{id}/run -- SSE: run the job, stream stdout + GET /results -- list completed models with training_info.json metrics + +SQLite schema: + CREATE TABLE IF NOT EXISTS jobs ( + id TEXT PRIMARY KEY, + type TEXT NOT NULL, -- 'classifier' | 'llm-sft' + model_key TEXT NOT NULL, + status TEXT NOT NULL DEFAULT 'queued', + config_json TEXT NOT NULL DEFAULT '{}', + created_at TEXT NOT NULL, + started_at TEXT, + completed_at TEXT, + error TEXT + ) + +Testability seam: _DB_PATH global, override via set_db_path(). +""" +from __future__ import annotations + +import json +import os +import sqlite3 +import subprocess as _subprocess +import uuid +from contextlib import contextmanager +from datetime import datetime, timezone +from pathlib import Path +from typing import Any, Generator + +from fastapi import APIRouter, HTTPException +from fastapi.responses import StreamingResponse +from pydantic import BaseModel + +_ROOT = Path(__file__).parent.parent.parent +_DB_PATH: Path = _ROOT / "data" / "train_jobs.db" +_MODELS_DIR: Path = _ROOT / "models" +_running_procs: dict[str, Any] = {} + +router = APIRouter() + + +# -- Testability seams ------------------------------------------------- + +def set_db_path(path: Path) -> None: + global _DB_PATH + _DB_PATH = path + +def set_models_dir(path: Path) -> None: + global _MODELS_DIR + _MODELS_DIR = path + + +# -- Database helpers -------------------------------------------------- + +@contextmanager +def _db() -> Generator[sqlite3.Connection, None, None]: + conn = sqlite3.connect(str(_DB_PATH)) + conn.row_factory = sqlite3.Row + try: + yield conn + conn.commit() + finally: + conn.close() + + +def _init_db() -> None: + """Create jobs table if it does not exist. Called lazily per request.""" + _DB_PATH.parent.mkdir(parents=True, exist_ok=True) + with _db() as conn: + conn.execute(""" + CREATE TABLE IF NOT EXISTS jobs ( + id TEXT PRIMARY KEY, + type TEXT NOT NULL, + model_key TEXT NOT NULL, + status TEXT NOT NULL DEFAULT 'queued', + config_json TEXT NOT NULL DEFAULT '{}', + created_at TEXT NOT NULL, + started_at TEXT, + completed_at TEXT, + error TEXT + ) + """) + + +def _row_to_dict(row: sqlite3.Row) -> dict: + return {k: row[k] for k in row.keys()} + + +# -- GPU selection (copied from api.py) -------------------------------- + +def _best_cuda_device() -> str: + """Return index of GPU with most free VRAM, or empty string.""" + try: + out = _subprocess.check_output( + ["nvidia-smi", "--query-gpu=index,memory.free", + "--format=csv,noheader,nounits"], + text=True, timeout=5, + ) + best_idx, best_free = "", 0 + for line in out.strip().splitlines(): + parts = line.strip().split(", ") + if len(parts) == 2: + idx, free = parts[0].strip(), int(parts[1].strip()) + if free > best_free: + best_free, best_idx = free, idx + return best_idx + except Exception: + return "" + + +# -- Pydantic models --------------------------------------------------- + +class CreateJobRequest(BaseModel): + type: str # "classifier" | "llm-sft" + model_key: str # e.g. "deberta-small" + config_json: dict = {} + + +# -- Routes ------------------------------------------------------------ + +@router.get("/jobs") +def list_jobs() -> list[dict]: + _init_db() + with _db() as conn: + rows = conn.execute("SELECT * FROM jobs ORDER BY created_at DESC").fetchall() + return [_row_to_dict(r) for r in rows] + + +@router.post("/jobs") +def create_job(req: CreateJobRequest) -> dict: + if req.type not in ("classifier", "llm-sft"): + raise HTTPException(400, f"Unknown job type: {req.type!r}. Must be 'classifier' or 'llm-sft'.") + _init_db() + job_id = str(uuid.uuid4()) + now = datetime.now(timezone.utc).isoformat() + with _db() as conn: + conn.execute( + "INSERT INTO jobs (id, type, model_key, status, config_json, created_at) " + "VALUES (?, ?, ?, 'queued', ?, ?)", + (job_id, req.type, req.model_key, json.dumps(req.config_json), now), + ) + return {"id": job_id, "type": req.type, "model_key": req.model_key, + "status": "queued", "config_json": req.config_json, + "created_at": now, "started_at": None, "completed_at": None, "error": None} + + +@router.get("/jobs/{job_id}") +def get_job(job_id: str) -> dict: + _init_db() + with _db() as conn: + row = conn.execute("SELECT * FROM jobs WHERE id = ?", (job_id,)).fetchone() + if row is None: + raise HTTPException(404, f"Job {job_id!r} not found") + return _row_to_dict(row) + + +@router.delete("/jobs/{job_id}/cancel") +def cancel_job(job_id: str) -> dict: + _init_db() + with _db() as conn: + row = conn.execute("SELECT * FROM jobs WHERE id = ?", (job_id,)).fetchone() + if row is None: + raise HTTPException(404, f"Job {job_id!r} not found") + if row["status"] not in ("queued", "running"): + raise HTTPException(409, f"Job is {row['status']} -- cannot cancel") + now = datetime.now(timezone.utc).isoformat() + conn.execute("UPDATE jobs SET status='cancelled', completed_at=? WHERE id=?", (now, job_id)) + proc = _running_procs.pop(job_id, None) + if proc is not None: + try: + proc.terminate() + proc.wait(timeout=3) + except _subprocess.TimeoutExpired: + proc.kill() + return {"status": "cancelled"} + + +@router.get("/jobs/{job_id}/run") +def run_job(job_id: str) -> StreamingResponse: + _init_db() + with _db() as conn: + row = conn.execute("SELECT * FROM jobs WHERE id = ?", (job_id,)).fetchone() + if row is None: + raise HTTPException(404, f"Job {job_id!r} not found") + if row["status"] != "queued": + raise HTTPException(409, f"Job is {row['status']} -- only queued jobs can be run") + job = _row_to_dict(row) + + def generate(): + python_bin = "/devl/miniconda3/envs/job-seeker-classifiers/bin/python" + config = json.loads(job["config_json"] or "{}") + model_key = job["model_key"] + epochs = config.get("epochs", 5) + + if job["type"] == "classifier": + script = str(_ROOT / "scripts" / "finetune_classifier.py") + cmd = [python_bin, script, "--model", model_key, "--epochs", str(epochs)] + data_dir = _ROOT / "data" + for sf in config.get("score_files", []): + resolved = (data_dir / sf).resolve() + if str(resolved).startswith(str(data_dir.resolve())): + cmd.extend(["--score", str(resolved)]) + elif job["type"] == "llm-sft": + script = str(_ROOT / "scripts" / "finetune_sft.py") + cmd = [python_bin, script, "--model", model_key, "--epochs", str(epochs)] + else: + yield f"data: {json.dumps({'type': 'error', 'message': f'Unknown job type: {job['type']}'})}\n\n" + return + + proc_env = {**os.environ, "PYTORCH_ALLOC_CONF": "expandable_segments:True"} + best_gpu = _best_cuda_device() + if best_gpu: + proc_env["CUDA_VISIBLE_DEVICES"] = best_gpu + + gpu_note = f"GPU {best_gpu}" if best_gpu else "CPU (no GPU found)" + yield f"data: {json.dumps({'type': 'progress', 'message': f'[train] Using {gpu_note}'})}\n\n" + + now = datetime.now(timezone.utc).isoformat() + with _db() as conn: + conn.execute("UPDATE jobs SET status='running', started_at=? WHERE id=?", (now, job_id)) + + try: + proc = _subprocess.Popen( + cmd, stdout=_subprocess.PIPE, stderr=_subprocess.STDOUT, + text=True, bufsize=1, cwd=str(_ROOT), env=proc_env, + ) + _running_procs[job_id] = proc + try: + for line in proc.stdout: + line = line.rstrip() + if line: + yield f"data: {json.dumps({'type': 'progress', 'message': line})}\n\n" + proc.wait() + finished_at = datetime.now(timezone.utc).isoformat() + if proc.returncode == 0: + with _db() as conn: + conn.execute( + "UPDATE jobs SET status='completed', completed_at=? WHERE id=?", + (finished_at, job_id)) + yield f"data: {json.dumps({'type': 'complete'})}\n\n" + else: + err = f"Process exited with code {proc.returncode}" + with _db() as conn: + conn.execute( + "UPDATE jobs SET status='failed', completed_at=?, error=? WHERE id=?", + (finished_at, err, job_id)) + yield f"data: {json.dumps({'type': 'error', 'message': err})}\n\n" + finally: + _running_procs.pop(job_id, None) + except Exception as exc: + err = str(exc) + with _db() as conn: + conn.execute("UPDATE jobs SET status='failed', error=? WHERE id=?", (err, job_id)) + yield f"data: {json.dumps({'type': 'error', 'message': err})}\n\n" + + return StreamingResponse(generate(), media_type="text/event-stream", + headers={"Cache-Control": "no-cache", "X-Accel-Buffering": "no"}) + + +@router.get("/results") +def list_results() -> list[dict]: + if not _MODELS_DIR.exists(): + return [] + results = [] + for sub in _MODELS_DIR.iterdir(): + if not sub.is_dir(): + continue + info_path = sub / "training_info.json" + if not info_path.exists(): + continue + try: + info = json.loads(info_path.read_text(encoding="utf-8")) + results.append(info) + except Exception: + pass + return results diff --git a/tests/test_api.py b/tests/test_api.py index 693098d..45b0ee9 100644 --- a/tests/test_api.py +++ b/tests/test_api.py @@ -1,29 +1,33 @@ -import json +"""Smoke tests for the app factory (app/api.py). +Detailed route tests live in test_data_label.py, test_data_fetch.py, +test_data_corrections.py, test_train.py, and test_dashboard.py. +""" import pytest -from app import api as api_module # noqa: F401 - - -@pytest.fixture(autouse=True) -def reset_globals(tmp_path): - from app import api - from app.data import label as label_module - from app.data import fetch as fetch_module - api.set_data_dir(tmp_path) - label_module.set_data_dir(tmp_path) - label_module.set_config_dir(tmp_path) - label_module.reset_last_action() - fetch_module.set_data_dir(tmp_path) - fetch_module.set_config_dir(tmp_path) - yield - label_module.reset_last_action() +from fastapi.testclient import TestClient def test_import(): from app import api # noqa: F401 -from fastapi.testclient import TestClient +def test_app_has_required_routes(): + from app.api import app + paths = {r.path for r in app.routes} + # Label routes + assert "/api/queue" in paths + assert "/api/label" in paths + assert "/api/skip" in paths + assert "/api/discard" in paths + assert "/api/label/undo" in paths + assert "/api/config/labels" in paths + assert "/api/stats" in paths + # Fetch routes + assert "/api/accounts/test" in paths + assert "/api/fetch/stream" in paths + # Train routes + assert "/api/train/jobs" in paths + assert "/api/train/results" in paths @pytest.fixture @@ -32,470 +36,8 @@ def client(): return TestClient(app) -@pytest.fixture -def queue_with_items(): - """Write 3 test emails to the queue file.""" - from app import api as api_module - items = [ - {"id": f"id{i}", "subject": f"Subject {i}", "body": f"Body {i}", - "from": "test@example.com", "date": "2026-03-01", "source": "imap:test"} - for i in range(3) - ] - queue_path = api_module._DATA_DIR / "email_label_queue.jsonl" - queue_path.write_text("\n".join(json.dumps(x) for x in items) + "\n") - return items - - -def test_queue_returns_items(client, queue_with_items): - r = client.get("/api/queue?limit=2") - assert r.status_code == 200 - data = r.json() - assert len(data["items"]) == 2 - assert data["total"] == 3 - - -def test_queue_empty_when_no_file(client): +def test_queue_endpoint_reachable(client): r = client.get("/api/queue") assert r.status_code == 200 - assert r.json() == {"items": [], "total": 0} - - -def test_label_appends_to_score(client, queue_with_items): - from app import api as api_module - r = client.post("/api/label", json={"id": "id0", "label": "interview_scheduled"}) - assert r.status_code == 200 - records = api_module._read_jsonl(api_module._score_file()) - assert len(records) == 1 - assert records[0]["id"] == "id0" - assert records[0]["label"] == "interview_scheduled" - assert "labeled_at" in records[0] - -def test_label_removes_from_queue(client, queue_with_items): - from app import api as api_module - client.post("/api/label", json={"id": "id0", "label": "rejected"}) - queue = api_module._read_jsonl(api_module._queue_file()) - assert not any(x["id"] == "id0" for x in queue) - -def test_label_unknown_id_returns_404(client, queue_with_items): - r = client.post("/api/label", json={"id": "unknown", "label": "neutral"}) - assert r.status_code == 404 - -def test_skip_moves_to_back(client, queue_with_items): - from app import api as api_module - r = client.post("/api/skip", json={"id": "id0"}) - assert r.status_code == 200 - queue = api_module._read_jsonl(api_module._queue_file()) - assert queue[-1]["id"] == "id0" - assert queue[0]["id"] == "id1" - -def test_skip_unknown_id_returns_404(client, queue_with_items): - r = client.post("/api/skip", json={"id": "nope"}) - assert r.status_code == 404 - - -# --- Part A: POST /api/discard --- - -def test_discard_writes_to_discarded_file(client, queue_with_items): - from app import api as api_module - r = client.post("/api/discard", json={"id": "id1"}) - assert r.status_code == 200 - discarded = api_module._read_jsonl(api_module._discarded_file()) - assert len(discarded) == 1 - assert discarded[0]["id"] == "id1" - assert discarded[0]["label"] == "__discarded__" - -def test_discard_removes_from_queue(client, queue_with_items): - from app import api as api_module - client.post("/api/discard", json={"id": "id1"}) - queue = api_module._read_jsonl(api_module._queue_file()) - assert not any(x["id"] == "id1" for x in queue) - - -# --- Part B: DELETE /api/label/undo --- - -def test_undo_label_removes_from_score(client, queue_with_items): - from app import api as api_module - client.post("/api/label", json={"id": "id0", "label": "neutral"}) - r = client.delete("/api/label/undo") - assert r.status_code == 200 - data = r.json() - assert data["undone"]["type"] == "label" - score = api_module._read_jsonl(api_module._score_file()) - assert score == [] - # Item should be restored to front of queue - queue = api_module._read_jsonl(api_module._queue_file()) - assert queue[0]["id"] == "id0" - -def test_undo_discard_removes_from_discarded(client, queue_with_items): - from app import api as api_module - client.post("/api/discard", json={"id": "id0"}) - r = client.delete("/api/label/undo") - assert r.status_code == 200 - discarded = api_module._read_jsonl(api_module._discarded_file()) - assert discarded == [] - -def test_undo_skip_restores_to_front(client, queue_with_items): - from app import api as api_module - client.post("/api/skip", json={"id": "id0"}) - r = client.delete("/api/label/undo") - assert r.status_code == 200 - queue = api_module._read_jsonl(api_module._queue_file()) - assert queue[0]["id"] == "id0" - -def test_undo_with_no_action_returns_404(client): - r = client.delete("/api/label/undo") - assert r.status_code == 404 - - -# --- Part C: GET /api/config/labels --- - -def test_config_labels_returns_metadata(client): - r = client.get("/api/config/labels") - assert r.status_code == 200 - labels = r.json() - assert len(labels) == 10 - assert labels[0]["key"] == "1" - assert "emoji" in labels[0] - assert "color" in labels[0] - assert "name" in labels[0] - - -# ── /api/config ────────────────────────────────────────────────────────────── - -@pytest.fixture -def config_dir(tmp_path): - """Give the API a writable config directory.""" - from app import api as api_module - from app.data import label as label_module - api_module.set_config_dir(tmp_path) - label_module.set_config_dir(tmp_path) - yield tmp_path - api_module.set_config_dir(None) # reset to default - label_module.set_config_dir(None) - - -@pytest.fixture -def data_dir(): - """Expose the current _DATA_DIR set by the autouse reset_globals fixture.""" - from app import api as api_module - return api_module._DATA_DIR - - -def test_get_config_returns_empty_when_no_file(client, config_dir): - r = client.get("/api/config") - assert r.status_code == 200 - data = r.json() - assert data["accounts"] == [] - assert data["max_per_account"] == 500 - - -def test_post_config_writes_yaml(client, config_dir): - import yaml - payload = { - "accounts": [{"name": "Test", "host": "imap.test.com", "port": 993, - "use_ssl": True, "username": "u@t.com", "password": "pw", - "folder": "INBOX", "days_back": 30}], - "max_per_account": 200, - } - r = client.post("/api/config", json=payload) - assert r.status_code == 200 - assert r.json()["ok"] is True - cfg_file = config_dir / "label_tool.yaml" - assert cfg_file.exists() - saved = yaml.safe_load(cfg_file.read_text()) - assert saved["max_per_account"] == 200 - assert saved["accounts"][0]["name"] == "Test" - - -def test_get_config_round_trips(client, config_dir): - payload = {"accounts": [{"name": "R", "host": "h", "port": 993, "use_ssl": True, - "username": "u", "password": "p", "folder": "INBOX", - "days_back": 90}], "max_per_account": 300} - client.post("/api/config", json=payload) - r = client.get("/api/config") - data = r.json() - assert data["max_per_account"] == 300 - assert data["accounts"][0]["name"] == "R" - - -# ── /api/stats ─────────────────────────────────────────────────────────────── - -@pytest.fixture -def score_with_labels(tmp_path, data_dir): - """Write a score file with 3 labels for stats tests.""" - score_path = data_dir / "email_score.jsonl" - records = [ - {"id": "a", "label": "interview_scheduled"}, - {"id": "b", "label": "interview_scheduled"}, - {"id": "c", "label": "rejected"}, - ] - score_path.write_text("\n".join(json.dumps(r) for r in records) + "\n") - return records - - -def test_stats_returns_counts(client, score_with_labels): - r = client.get("/api/stats") - assert r.status_code == 200 - data = r.json() - assert data["total"] == 3 - assert data["counts"]["interview_scheduled"] == 2 - assert data["counts"]["rejected"] == 1 - - -def test_stats_empty_when_no_file(client, data_dir): - r = client.get("/api/stats") - assert r.status_code == 200 - data = r.json() - assert data["total"] == 0 - assert data["counts"] == {} - assert data["score_file_bytes"] == 0 - - -def test_stats_download_returns_file(client, score_with_labels): - r = client.get("/api/stats/download") - assert r.status_code == 200 - assert "jsonlines" in r.headers.get("content-type", "") - - -def test_stats_download_404_when_no_file(client, data_dir): - r = client.get("/api/stats/download") - assert r.status_code == 404 - - -# ── /api/accounts/test ─────────────────────────────────────────────────────── - -def test_account_test_missing_fields(client): - r = client.post("/api/accounts/test", json={"account": {"host": "", "username": "", "password": ""}}) - assert r.status_code == 200 - data = r.json() - assert data["ok"] is False - assert "required" in data["message"].lower() - - -def test_account_test_success(client): - from unittest.mock import MagicMock, patch - mock_conn = MagicMock() - mock_conn.select.return_value = ("OK", [b"99"]) - with patch("app.data.fetch.imaplib.IMAP4_SSL", return_value=mock_conn): - r = client.post("/api/accounts/test", json={"account": { - "host": "imap.example.com", "port": 993, "use_ssl": True, - "username": "u@example.com", "password": "pw", "folder": "INBOX", - }}) - assert r.status_code == 200 - data = r.json() - assert data["ok"] is True - assert data["count"] == 99 - - -# ── /api/fetch/stream (SSE) ────────────────────────────────────────────────── - -def _parse_sse(content: bytes) -> list[dict]: - """Parse SSE response body into list of event dicts.""" - events = [] - for line in content.decode().splitlines(): - if line.startswith("data: "): - events.append(json.loads(line[6:])) - return events - - -def test_fetch_stream_no_accounts_configured(client, config_dir): - """With no config, stream should immediately complete with 0 added.""" - r = client.get("/api/fetch/stream?accounts=NoSuchAccount&days_back=30&limit=10") - assert r.status_code == 200 - events = _parse_sse(r.content) - complete = next((e for e in events if e["type"] == "complete"), None) - assert complete is not None - assert complete["total_added"] == 0 - - -def test_fetch_stream_with_mock_imap(client, config_dir, data_dir): - """With one configured account, stream should yield start/done/complete events.""" - import yaml - from unittest.mock import MagicMock, patch - - # Write a config with one account - cfg = {"accounts": [{"name": "Mock", "host": "h", "port": 993, "use_ssl": True, - "username": "u", "password": "p", "folder": "INBOX", - "days_back": 30}], "max_per_account": 50} - (config_dir / "label_tool.yaml").write_text(yaml.dump(cfg)) - - raw_msg = (b"Subject: Interview\r\nFrom: a@b.com\r\n" - b"Date: Mon, 1 Mar 2026 12:00:00 +0000\r\n\r\nBody") - mock_conn = MagicMock() - mock_conn.search.return_value = ("OK", [b"1"]) - mock_conn.fetch.return_value = ("OK", [(b"1 (RFC822 {N})", raw_msg)]) - - with patch("app.data.fetch.imaplib.IMAP4_SSL", return_value=mock_conn): - r = client.get("/api/fetch/stream?accounts=Mock&days_back=30&limit=50") - - assert r.status_code == 200 - events = _parse_sse(r.content) - types = [e["type"] for e in events] - assert "start" in types - assert "done" in types - assert "complete" in types - - -# ---- /api/finetune/status tests ---- - -def test_finetune_status_returns_empty_when_no_models_dir(client): - """GET /api/finetune/status must return [] if models/ does not exist.""" - r = client.get("/api/finetune/status") - assert r.status_code == 200 - assert r.json() == [] - - -def test_finetune_status_returns_training_info(client, tmp_path): - """GET /api/finetune/status must return one entry per training_info.json found.""" - import json as _json - from app import api as api_module - - models_dir = tmp_path / "models" / "avocet-deberta-small" - models_dir.mkdir(parents=True) - info = { - "name": "avocet-deberta-small", - "base_model_id": "cross-encoder/nli-deberta-v3-small", - "val_macro_f1": 0.712, - "timestamp": "2026-03-15T12:00:00Z", - "sample_count": 401, - } - (models_dir / "training_info.json").write_text(_json.dumps(info)) - - api_module.set_models_dir(tmp_path / "models") - try: - r = client.get("/api/finetune/status") - assert r.status_code == 200 - data = r.json() - assert any(d["name"] == "avocet-deberta-small" for d in data) - finally: - api_module.set_models_dir(api_module._ROOT / "models") - - -def test_finetune_run_streams_sse_events(client): - """GET /api/finetune/run must return text/event-stream content type.""" - from unittest.mock import patch, MagicMock - - mock_proc = MagicMock() - mock_proc.stdout = iter(["Training epoch 1\n", "Done\n"]) - mock_proc.returncode = 0 - mock_proc.wait = MagicMock() - - with patch("app.api._subprocess.Popen",return_value=mock_proc): - r = client.get("/api/finetune/run?model=deberta-small&epochs=1") - - assert r.status_code == 200 - assert "text/event-stream" in r.headers.get("content-type", "") - - -def test_finetune_run_emits_complete_on_success(client): - """GET /api/finetune/run must emit a complete event on clean exit.""" - from unittest.mock import patch, MagicMock - - mock_proc = MagicMock() - mock_proc.stdout = iter(["progress line\n"]) - mock_proc.returncode = 0 - mock_proc.wait = MagicMock() - - with patch("app.api._subprocess.Popen",return_value=mock_proc): - r = client.get("/api/finetune/run?model=deberta-small&epochs=1") - - assert '{"type": "complete"}' in r.text - - -def test_finetune_run_emits_error_on_nonzero_exit(client): - """GET /api/finetune/run must emit an error event on non-zero exit.""" - from unittest.mock import patch, MagicMock - - mock_proc = MagicMock() - mock_proc.stdout = iter([]) - mock_proc.returncode = 1 - mock_proc.wait = MagicMock() - - with patch("app.api._subprocess.Popen",return_value=mock_proc): - r = client.get("/api/finetune/run?model=deberta-small&epochs=1") - - assert '"type": "error"' in r.text - - -def test_finetune_run_passes_score_files_to_subprocess(client): - """GET /api/finetune/run?score=file1&score=file2 must pass --score args to subprocess.""" - from unittest.mock import patch, MagicMock - - captured_cmd = [] - - def mock_popen(cmd, **kwargs): - captured_cmd.extend(cmd) - m = MagicMock() - m.stdout = iter([]) - m.returncode = 0 - m.wait = MagicMock() - return m - - with patch("app.api._subprocess.Popen",side_effect=mock_popen): - client.get("/api/finetune/run?model=deberta-small&epochs=1&score=run1.jsonl&score=run2.jsonl") - - assert "--score" in captured_cmd - assert captured_cmd.count("--score") == 2 - # Paths are resolved to absolute — check filenames are present as substrings - assert any("run1.jsonl" in arg for arg in captured_cmd) - assert any("run2.jsonl" in arg for arg in captured_cmd) - - -# ---- Cancel endpoint tests ---- - - -def test_finetune_cancel_returns_404_when_not_running(client): - """POST /api/finetune/cancel must return 404 if no finetune is running.""" - from app import api as api_module - api_module._running_procs.pop("finetune", None) - r = client.post("/api/finetune/cancel") - assert r.status_code == 404 - - - -def test_finetune_cancel_terminates_running_process(client): - """POST /api/finetune/cancel must call terminate() on the running process.""" - from unittest.mock import MagicMock - from app import api as api_module - - mock_proc = MagicMock() - mock_proc.wait = MagicMock() - api_module._running_procs["finetune"] = mock_proc - - try: - r = client.post("/api/finetune/cancel") - assert r.status_code == 200 - assert r.json()["status"] == "cancelled" - mock_proc.terminate.assert_called_once() - finally: - api_module._running_procs.pop("finetune", None) - api_module._cancelled_jobs.discard("finetune") - - - -def test_finetune_run_emits_cancelled_event(client): - """GET /api/finetune/run must emit cancelled (not error) when job was cancelled.""" - from unittest.mock import patch, MagicMock - from app import api as api_module - - mock_proc = MagicMock() - mock_proc.stdout = iter([]) - mock_proc.returncode = -15 # SIGTERM - - def mock_wait(): - # Simulate cancel being called while the process is running (after discard clears stale flag) - api_module._cancelled_jobs.add("finetune") - - mock_proc.wait = mock_wait - - def mock_popen(cmd, **kwargs): - return mock_proc - - try: - with patch("app.api._subprocess.Popen",side_effect=mock_popen): - r = client.get("/api/finetune/run?model=deberta-small&epochs=1") - assert '{"type": "cancelled"}' in r.text - assert '"type": "error"' not in r.text - finally: - api_module._cancelled_jobs.discard("finetune") - + assert "items" in r.json() + assert "total" in r.json() diff --git a/tests/test_train.py b/tests/test_train.py new file mode 100644 index 0000000..756fcd9 --- /dev/null +++ b/tests/test_train.py @@ -0,0 +1,187 @@ +"""Tests for app/train/train.py -- /api/train/* endpoints.""" +import json +import pytest +from fastapi.testclient import TestClient +from unittest.mock import MagicMock, patch + + +@pytest.fixture(autouse=True) +def reset_globals(tmp_path): + from app.train import train as train_module + train_module.set_db_path(tmp_path / "train_jobs.db") + train_module.set_models_dir(tmp_path / "models") + train_module._running_procs.clear() + yield + train_module._running_procs.clear() + + +@pytest.fixture +def client(): + from app.api import app + return TestClient(app) + + +def _parse_sse(content: bytes) -> list[dict]: + events = [] + for line in content.decode().splitlines(): + if line.startswith("data: "): + events.append(json.loads(line[6:])) + return events + + +def test_list_jobs_empty(client): + r = client.get("/api/train/jobs") + assert r.status_code == 200 + assert r.json() == [] + + +def test_create_job_returns_queued_record(client): + r = client.post("/api/train/jobs", + json={"type": "classifier", "model_key": "deberta-small", + "config_json": {"epochs": 3}}) + assert r.status_code == 200 + data = r.json() + assert data["status"] == "queued" + assert data["type"] == "classifier" + assert data["model_key"] == "deberta-small" + assert "id" in data + + +def test_create_job_invalid_type_returns_400(client): + r = client.post("/api/train/jobs", + json={"type": "unknown-type", "model_key": "deberta-small"}) + assert r.status_code == 400 + + +def test_create_job_appears_in_list(client): + client.post("/api/train/jobs", json={"type": "classifier", "model_key": "deberta-small"}) + r = client.get("/api/train/jobs") + assert r.status_code == 200 + assert len(r.json()) == 1 + + +def test_get_job_returns_record(client): + r = client.post("/api/train/jobs", json={"type": "classifier", "model_key": "deberta-small"}) + job_id = r.json()["id"] + r2 = client.get(f"/api/train/jobs/{job_id}") + assert r2.status_code == 200 + assert r2.json()["id"] == job_id + + +def test_get_job_404_for_unknown(client): + r = client.get("/api/train/jobs/no-such-id") + assert r.status_code == 404 + + +def test_cancel_queued_job(client): + r = client.post("/api/train/jobs", json={"type": "classifier", "model_key": "deberta-small"}) + job_id = r.json()["id"] + r2 = client.delete(f"/api/train/jobs/{job_id}/cancel") + assert r2.status_code == 200 + assert r2.json()["status"] == "cancelled" + r3 = client.get(f"/api/train/jobs/{job_id}") + assert r3.json()["status"] == "cancelled" + + +def test_cancel_completed_job_returns_409(client): + from app.train import train as train_module + train_module._init_db() + with train_module._db() as conn: + conn.execute( + "INSERT INTO jobs (id, type, model_key, status, config_json, created_at) " + "VALUES ('abc', 'classifier', 'deberta-small', 'completed', '{}', '2026-05-01T00:00:00Z')" + ) + r = client.delete("/api/train/jobs/abc/cancel") + assert r.status_code == 409 + + +def test_cancel_terminates_running_proc(client): + from app.train import train as train_module + mock_proc = MagicMock() + mock_proc.wait = MagicMock() + r = client.post("/api/train/jobs", json={"type": "classifier", "model_key": "deberta-small"}) + job_id = r.json()["id"] + train_module._running_procs[job_id] = mock_proc + with train_module._db() as conn: + conn.execute("UPDATE jobs SET status='running' WHERE id=?", (job_id,)) + r2 = client.delete(f"/api/train/jobs/{job_id}/cancel") + assert r2.status_code == 200 + mock_proc.terminate.assert_called_once() + + +def test_run_job_streams_sse(client): + r = client.post("/api/train/jobs", json={"type": "classifier", "model_key": "deberta-small"}) + job_id = r.json()["id"] + mock_proc = MagicMock() + mock_proc.stdout = iter(["Epoch 1\n", "Done\n"]) + mock_proc.returncode = 0 + mock_proc.wait = MagicMock() + with patch("app.train.train._subprocess.Popen", return_value=mock_proc): + r2 = client.get(f"/api/train/jobs/{job_id}/run") + assert r2.status_code == 200 + assert "text/event-stream" in r2.headers.get("content-type", "") + events = _parse_sse(r2.content) + assert any(e["type"] == "complete" for e in events) + + +def test_run_job_marks_completed_in_db(client): + from app.train import train as train_module + r = client.post("/api/train/jobs", json={"type": "classifier", "model_key": "deberta-small"}) + job_id = r.json()["id"] + mock_proc = MagicMock() + mock_proc.stdout = iter([]) + mock_proc.returncode = 0 + mock_proc.wait = MagicMock() + with patch("app.train.train._subprocess.Popen", return_value=mock_proc): + client.get(f"/api/train/jobs/{job_id}/run") + r2 = client.get(f"/api/train/jobs/{job_id}") + assert r2.json()["status"] == "completed" + + +def test_run_job_marks_failed_on_nonzero_exit(client): + r = client.post("/api/train/jobs", json={"type": "classifier", "model_key": "deberta-small"}) + job_id = r.json()["id"] + mock_proc = MagicMock() + mock_proc.stdout = iter([]) + mock_proc.returncode = 1 + mock_proc.wait = MagicMock() + with patch("app.train.train._subprocess.Popen", return_value=mock_proc): + client.get(f"/api/train/jobs/{job_id}/run") + r2 = client.get(f"/api/train/jobs/{job_id}") + assert r2.json()["status"] == "failed" + + +def test_run_nonqueued_job_returns_409(client): + from app.train import train as train_module + train_module._init_db() + with train_module._db() as conn: + conn.execute( + "INSERT INTO jobs (id, type, model_key, status, config_json, created_at) " + "VALUES ('xyz', 'classifier', 'deberta-small', 'running', '{}', '2026-05-01T00:00:00Z')" + ) + r = client.get("/api/train/jobs/xyz/run") + assert r.status_code == 409 + + +def test_run_unknown_job_returns_404(client): + r = client.get("/api/train/jobs/no-such/run") + assert r.status_code == 404 + + +def test_results_empty_when_no_models_dir(client): + r = client.get("/api/train/results") + assert r.status_code == 200 + assert r.json() == [] + + +def test_results_returns_training_info(client, tmp_path): + from app.train import train as train_module + models_dir = tmp_path / "models" / "avocet-deberta-small" + models_dir.mkdir(parents=True) + train_module.set_models_dir(tmp_path / "models") + info = {"name": "avocet-deberta-small", "val_macro_f1": 0.712, "sample_count": 401} + (models_dir / "training_info.json").write_text(json.dumps(info)) + r = client.get("/api/train/results") + assert r.status_code == 200 + data = r.json() + assert any(d["name"] == "avocet-deberta-small" for d in data)