diff --git a/app/api.py b/app/api.py index a4c5af2..3ce74c3 100644 --- a/app/api.py +++ b/app/api.py @@ -142,6 +142,9 @@ def _normalize(item: dict) -> dict: app = FastAPI(title="Avocet API") +from app.sft import router as sft_router +app.include_router(sft_router, prefix="/api/sft") + # In-memory last-action store (single user, local tool — in-memory is fine) _last_action: dict | None = None diff --git a/app/sft.py b/app/sft.py new file mode 100644 index 0000000..80ed061 --- /dev/null +++ b/app/sft.py @@ -0,0 +1,134 @@ +"""Avocet — SFT candidate import and correction API. + +All endpoints are registered on `router` (a FastAPI APIRouter). +api.py includes this router with prefix="/api/sft". + +Module-level globals (_SFT_DATA_DIR, _SFT_CONFIG_DIR) follow the same +testability pattern as api.py — override them via set_sft_data_dir() and +set_sft_config_dir() in test fixtures. +""" +from __future__ import annotations + +import json +from pathlib import Path + +import yaml +from fastapi import APIRouter, HTTPException +from pydantic import BaseModel + +_ROOT = Path(__file__).parent.parent +_SFT_DATA_DIR: Path = _ROOT / "data" +_SFT_CONFIG_DIR: Path | None = None + +router = APIRouter() + + +# ── Testability seams ────────────────────────────────────────────────────── + +def set_sft_data_dir(path: Path) -> None: + global _SFT_DATA_DIR + _SFT_DATA_DIR = path + + +def set_sft_config_dir(path: Path | None) -> None: + global _SFT_CONFIG_DIR + _SFT_CONFIG_DIR = path + + +# ── Internal helpers ─────────────────────────────────────────────────────── + +def _config_file() -> Path: + if _SFT_CONFIG_DIR is not None: + return _SFT_CONFIG_DIR / "label_tool.yaml" + return _ROOT / "config" / "label_tool.yaml" + + +def _get_bench_results_dir() -> Path: + f = _config_file() + if not f.exists(): + return Path("/nonexistent-bench-results") + raw = yaml.safe_load(f.read_text(encoding="utf-8")) or {} + d = raw.get("sft", {}).get("bench_results_dir", "") + return Path(d) if d else Path("/nonexistent-bench-results") + + +def _candidates_file() -> Path: + return _SFT_DATA_DIR / "sft_candidates.jsonl" + + +def _approved_file() -> Path: + return _SFT_DATA_DIR / "sft_approved.jsonl" + + +def _read_jsonl(path: Path) -> list[dict]: + if not path.exists(): + return [] + records: list[dict] = [] + for line in path.read_text(encoding="utf-8").splitlines(): + line = line.strip() + if not line: + continue + try: + records.append(json.loads(line)) + except json.JSONDecodeError: + pass + return records + + +def _write_jsonl(path: Path, records: list[dict]) -> None: + path.parent.mkdir(parents=True, exist_ok=True) + content = "\n".join(json.dumps(r) for r in records) + path.write_text(content + ("\n" if records else ""), encoding="utf-8") + + +def _append_jsonl(path: Path, record: dict) -> None: + path.parent.mkdir(parents=True, exist_ok=True) + with open(path, "a", encoding="utf-8") as fh: + fh.write(json.dumps(record) + "\n") + + +def _read_candidates() -> list[dict]: + return _read_jsonl(_candidates_file()) + + +def _write_candidates(records: list[dict]) -> None: + _write_jsonl(_candidates_file(), records) + + +# ── GET /runs ────────────────────────────────────────────────────────────── + +@router.get("/runs") +def get_runs(): + """List available benchmark runs in the configured bench_results_dir.""" + from scripts.sft_import import discover_runs + bench_dir = _get_bench_results_dir() + existing = _read_candidates() + imported_run_ids = {r.get("benchmark_run_id") for r in existing} + runs = discover_runs(bench_dir) + return [ + { + "run_id": r["run_id"], + "timestamp": r["timestamp"], + "candidate_count": r["candidate_count"], + "already_imported": r["run_id"] in imported_run_ids, + } + for r in runs + ] + + +# ── POST /import ─────────────────────────────────────────────────────────── + +class ImportRequest(BaseModel): + run_id: str + + +@router.post("/import") +def post_import(req: ImportRequest): + """Import one benchmark run's sft_candidates.jsonl into the local data dir.""" + from scripts.sft_import import discover_runs, import_run + bench_dir = _get_bench_results_dir() + runs = discover_runs(bench_dir) + run = next((r for r in runs if r["run_id"] == req.run_id), None) + if run is None: + raise HTTPException(404, f"Run {req.run_id!r} not found in bench_results_dir") + return import_run(run["sft_path"], _SFT_DATA_DIR) diff --git a/tests/test_sft.py b/tests/test_sft.py new file mode 100644 index 0000000..a66da27 --- /dev/null +++ b/tests/test_sft.py @@ -0,0 +1,116 @@ +"""API integration tests for app/sft.py — /api/sft/* endpoints.""" +import json +import pytest +from fastapi.testclient import TestClient +from pathlib import Path + + +@pytest.fixture(autouse=True) +def reset_sft_globals(tmp_path): + from app import sft as sft_module + sft_module.set_sft_data_dir(tmp_path) + sft_module.set_sft_config_dir(tmp_path) + yield + sft_module.set_sft_data_dir(Path(__file__).parent.parent / "data") + sft_module.set_sft_config_dir(None) + + +@pytest.fixture +def client(): + from app.api import app + return TestClient(app) + + +def _make_record(id: str, run_id: str = "2026-04-07-143022") -> dict: + return { + "id": id, "source": "cf-orch-benchmark", + "benchmark_run_id": run_id, "timestamp": "2026-04-07T10:00:00Z", + "status": "needs_review", + "prompt_messages": [ + {"role": "system", "content": "You are a helpful assistant."}, + {"role": "user", "content": "Write a Python function that adds two numbers."}, + ], + "model_response": "def add(a, b): return a - b", + "corrected_response": None, + "quality_score": 0.2, "failure_reason": "pattern_match: 0/2 matched", + "task_id": "code-fn", "task_type": "code", + "task_name": "Code: Write a Python function", + "model_id": "Qwen/Qwen2.5-3B", "model_name": "Qwen2.5-3B", + "node_id": "heimdall", "gpu_id": 0, "tokens_per_sec": 38.4, + } + + +def _write_run(tmp_path, run_id: str, records: list[dict]) -> Path: + run_dir = tmp_path / "bench_results" / run_id + run_dir.mkdir(parents=True) + sft_path = run_dir / "sft_candidates.jsonl" + sft_path.write_text( + "\n".join(json.dumps(r) for r in records) + "\n", encoding="utf-8" + ) + return sft_path + + +def _write_config(tmp_path, bench_results_dir: Path) -> None: + import yaml + cfg = {"sft": {"bench_results_dir": str(bench_results_dir)}} + (tmp_path / "label_tool.yaml").write_text( + yaml.dump(cfg, allow_unicode=True), encoding="utf-8" + ) + + +# ── /api/sft/runs ────────────────────────────────────────────────────────── + +def test_runs_returns_empty_when_no_config(client): + r = client.get("/api/sft/runs") + assert r.status_code == 200 + assert r.json() == [] + + +def test_runs_returns_available_runs(client, tmp_path): + _write_run(tmp_path, "2026-04-07-143022", [_make_record("a"), _make_record("b")]) + _write_config(tmp_path, tmp_path / "bench_results") + r = client.get("/api/sft/runs") + assert r.status_code == 200 + data = r.json() + assert len(data) == 1 + assert data[0]["run_id"] == "2026-04-07-143022" + assert data[0]["candidate_count"] == 2 + assert data[0]["already_imported"] is False + + +def test_runs_marks_already_imported(client, tmp_path): + _write_run(tmp_path, "2026-04-07-143022", [_make_record("a")]) + _write_config(tmp_path, tmp_path / "bench_results") + from app import sft as sft_module + candidates = sft_module._candidates_file() + candidates.parent.mkdir(parents=True, exist_ok=True) + candidates.write_text( + json.dumps(_make_record("a", run_id="2026-04-07-143022")) + "\n", + encoding="utf-8" + ) + r = client.get("/api/sft/runs") + assert r.json()[0]["already_imported"] is True + + +# ── /api/sft/import ───────────────────────────────────────────────────────── + +def test_import_adds_records(client, tmp_path): + _write_run(tmp_path, "2026-04-07-143022", [_make_record("a"), _make_record("b")]) + _write_config(tmp_path, tmp_path / "bench_results") + r = client.post("/api/sft/import", json={"run_id": "2026-04-07-143022"}) + assert r.status_code == 200 + assert r.json() == {"imported": 2, "skipped": 0} + + +def test_import_is_idempotent(client, tmp_path): + _write_run(tmp_path, "2026-04-07-143022", [_make_record("a")]) + _write_config(tmp_path, tmp_path / "bench_results") + client.post("/api/sft/import", json={"run_id": "2026-04-07-143022"}) + r = client.post("/api/sft/import", json={"run_id": "2026-04-07-143022"}) + assert r.json() == {"imported": 0, "skipped": 1} + + +def test_import_unknown_run_returns_404(client, tmp_path): + _write_config(tmp_path, tmp_path / "bench_results") + r = client.post("/api/sft/import", json={"run_id": "nonexistent"}) + assert r.status_code == 404