feat: Corrections tab — SFT candidate import, review, and JSONL export #15
3 changed files with 253 additions and 0 deletions
|
|
@ -142,6 +142,9 @@ def _normalize(item: dict) -> dict:
|
|||
|
||||
app = FastAPI(title="Avocet API")
|
||||
|
||||
from app.sft import router as sft_router
|
||||
app.include_router(sft_router, prefix="/api/sft")
|
||||
|
||||
# In-memory last-action store (single user, local tool — in-memory is fine)
|
||||
_last_action: dict | None = None
|
||||
|
||||
|
|
|
|||
134
app/sft.py
Normal file
134
app/sft.py
Normal file
|
|
@ -0,0 +1,134 @@
|
|||
"""Avocet — SFT candidate import and correction API.
|
||||
|
||||
All endpoints are registered on `router` (a FastAPI APIRouter).
|
||||
api.py includes this router with prefix="/api/sft".
|
||||
|
||||
Module-level globals (_SFT_DATA_DIR, _SFT_CONFIG_DIR) follow the same
|
||||
testability pattern as api.py — override them via set_sft_data_dir() and
|
||||
set_sft_config_dir() in test fixtures.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from pathlib import Path
|
||||
|
||||
import yaml
|
||||
from fastapi import APIRouter, HTTPException
|
||||
from pydantic import BaseModel
|
||||
|
||||
_ROOT = Path(__file__).parent.parent
|
||||
_SFT_DATA_DIR: Path = _ROOT / "data"
|
||||
_SFT_CONFIG_DIR: Path | None = None
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
# ── Testability seams ──────────────────────────────────────────────────────
|
||||
|
||||
def set_sft_data_dir(path: Path) -> None:
|
||||
global _SFT_DATA_DIR
|
||||
_SFT_DATA_DIR = path
|
||||
|
||||
|
||||
def set_sft_config_dir(path: Path | None) -> None:
|
||||
global _SFT_CONFIG_DIR
|
||||
_SFT_CONFIG_DIR = path
|
||||
|
||||
|
||||
# ── Internal helpers ───────────────────────────────────────────────────────
|
||||
|
||||
def _config_file() -> Path:
|
||||
if _SFT_CONFIG_DIR is not None:
|
||||
return _SFT_CONFIG_DIR / "label_tool.yaml"
|
||||
return _ROOT / "config" / "label_tool.yaml"
|
||||
|
||||
|
||||
def _get_bench_results_dir() -> Path:
|
||||
f = _config_file()
|
||||
if not f.exists():
|
||||
return Path("/nonexistent-bench-results")
|
||||
raw = yaml.safe_load(f.read_text(encoding="utf-8")) or {}
|
||||
d = raw.get("sft", {}).get("bench_results_dir", "")
|
||||
return Path(d) if d else Path("/nonexistent-bench-results")
|
||||
|
||||
|
||||
def _candidates_file() -> Path:
|
||||
return _SFT_DATA_DIR / "sft_candidates.jsonl"
|
||||
|
||||
|
||||
def _approved_file() -> Path:
|
||||
return _SFT_DATA_DIR / "sft_approved.jsonl"
|
||||
|
||||
|
||||
def _read_jsonl(path: Path) -> list[dict]:
|
||||
if not path.exists():
|
||||
return []
|
||||
records: list[dict] = []
|
||||
for line in path.read_text(encoding="utf-8").splitlines():
|
||||
line = line.strip()
|
||||
if not line:
|
||||
continue
|
||||
try:
|
||||
records.append(json.loads(line))
|
||||
except json.JSONDecodeError:
|
||||
pass
|
||||
return records
|
||||
|
||||
|
||||
def _write_jsonl(path: Path, records: list[dict]) -> None:
|
||||
path.parent.mkdir(parents=True, exist_ok=True)
|
||||
content = "\n".join(json.dumps(r) for r in records)
|
||||
path.write_text(content + ("\n" if records else ""), encoding="utf-8")
|
||||
|
||||
|
||||
def _append_jsonl(path: Path, record: dict) -> None:
|
||||
path.parent.mkdir(parents=True, exist_ok=True)
|
||||
with open(path, "a", encoding="utf-8") as fh:
|
||||
fh.write(json.dumps(record) + "\n")
|
||||
|
||||
|
||||
def _read_candidates() -> list[dict]:
|
||||
return _read_jsonl(_candidates_file())
|
||||
|
||||
|
||||
def _write_candidates(records: list[dict]) -> None:
|
||||
_write_jsonl(_candidates_file(), records)
|
||||
|
||||
|
||||
# ── GET /runs ──────────────────────────────────────────────────────────────
|
||||
|
||||
@router.get("/runs")
|
||||
def get_runs():
|
||||
"""List available benchmark runs in the configured bench_results_dir."""
|
||||
from scripts.sft_import import discover_runs
|
||||
bench_dir = _get_bench_results_dir()
|
||||
existing = _read_candidates()
|
||||
imported_run_ids = {r.get("benchmark_run_id") for r in existing}
|
||||
runs = discover_runs(bench_dir)
|
||||
return [
|
||||
{
|
||||
"run_id": r["run_id"],
|
||||
"timestamp": r["timestamp"],
|
||||
"candidate_count": r["candidate_count"],
|
||||
"already_imported": r["run_id"] in imported_run_ids,
|
||||
}
|
||||
for r in runs
|
||||
]
|
||||
|
||||
|
||||
# ── POST /import ───────────────────────────────────────────────────────────
|
||||
|
||||
class ImportRequest(BaseModel):
|
||||
run_id: str
|
||||
|
||||
|
||||
@router.post("/import")
|
||||
def post_import(req: ImportRequest):
|
||||
"""Import one benchmark run's sft_candidates.jsonl into the local data dir."""
|
||||
from scripts.sft_import import discover_runs, import_run
|
||||
bench_dir = _get_bench_results_dir()
|
||||
runs = discover_runs(bench_dir)
|
||||
run = next((r for r in runs if r["run_id"] == req.run_id), None)
|
||||
if run is None:
|
||||
raise HTTPException(404, f"Run {req.run_id!r} not found in bench_results_dir")
|
||||
return import_run(run["sft_path"], _SFT_DATA_DIR)
|
||||
116
tests/test_sft.py
Normal file
116
tests/test_sft.py
Normal file
|
|
@ -0,0 +1,116 @@
|
|||
"""API integration tests for app/sft.py — /api/sft/* endpoints."""
|
||||
import json
|
||||
import pytest
|
||||
from fastapi.testclient import TestClient
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def reset_sft_globals(tmp_path):
|
||||
from app import sft as sft_module
|
||||
sft_module.set_sft_data_dir(tmp_path)
|
||||
sft_module.set_sft_config_dir(tmp_path)
|
||||
yield
|
||||
sft_module.set_sft_data_dir(Path(__file__).parent.parent / "data")
|
||||
sft_module.set_sft_config_dir(None)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def client():
|
||||
from app.api import app
|
||||
return TestClient(app)
|
||||
|
||||
|
||||
def _make_record(id: str, run_id: str = "2026-04-07-143022") -> dict:
|
||||
return {
|
||||
"id": id, "source": "cf-orch-benchmark",
|
||||
"benchmark_run_id": run_id, "timestamp": "2026-04-07T10:00:00Z",
|
||||
"status": "needs_review",
|
||||
"prompt_messages": [
|
||||
{"role": "system", "content": "You are a helpful assistant."},
|
||||
{"role": "user", "content": "Write a Python function that adds two numbers."},
|
||||
],
|
||||
"model_response": "def add(a, b): return a - b",
|
||||
"corrected_response": None,
|
||||
"quality_score": 0.2, "failure_reason": "pattern_match: 0/2 matched",
|
||||
"task_id": "code-fn", "task_type": "code",
|
||||
"task_name": "Code: Write a Python function",
|
||||
"model_id": "Qwen/Qwen2.5-3B", "model_name": "Qwen2.5-3B",
|
||||
"node_id": "heimdall", "gpu_id": 0, "tokens_per_sec": 38.4,
|
||||
}
|
||||
|
||||
|
||||
def _write_run(tmp_path, run_id: str, records: list[dict]) -> Path:
|
||||
run_dir = tmp_path / "bench_results" / run_id
|
||||
run_dir.mkdir(parents=True)
|
||||
sft_path = run_dir / "sft_candidates.jsonl"
|
||||
sft_path.write_text(
|
||||
"\n".join(json.dumps(r) for r in records) + "\n", encoding="utf-8"
|
||||
)
|
||||
return sft_path
|
||||
|
||||
|
||||
def _write_config(tmp_path, bench_results_dir: Path) -> None:
|
||||
import yaml
|
||||
cfg = {"sft": {"bench_results_dir": str(bench_results_dir)}}
|
||||
(tmp_path / "label_tool.yaml").write_text(
|
||||
yaml.dump(cfg, allow_unicode=True), encoding="utf-8"
|
||||
)
|
||||
|
||||
|
||||
# ── /api/sft/runs ──────────────────────────────────────────────────────────
|
||||
|
||||
def test_runs_returns_empty_when_no_config(client):
|
||||
r = client.get("/api/sft/runs")
|
||||
assert r.status_code == 200
|
||||
assert r.json() == []
|
||||
|
||||
|
||||
def test_runs_returns_available_runs(client, tmp_path):
|
||||
_write_run(tmp_path, "2026-04-07-143022", [_make_record("a"), _make_record("b")])
|
||||
_write_config(tmp_path, tmp_path / "bench_results")
|
||||
r = client.get("/api/sft/runs")
|
||||
assert r.status_code == 200
|
||||
data = r.json()
|
||||
assert len(data) == 1
|
||||
assert data[0]["run_id"] == "2026-04-07-143022"
|
||||
assert data[0]["candidate_count"] == 2
|
||||
assert data[0]["already_imported"] is False
|
||||
|
||||
|
||||
def test_runs_marks_already_imported(client, tmp_path):
|
||||
_write_run(tmp_path, "2026-04-07-143022", [_make_record("a")])
|
||||
_write_config(tmp_path, tmp_path / "bench_results")
|
||||
from app import sft as sft_module
|
||||
candidates = sft_module._candidates_file()
|
||||
candidates.parent.mkdir(parents=True, exist_ok=True)
|
||||
candidates.write_text(
|
||||
json.dumps(_make_record("a", run_id="2026-04-07-143022")) + "\n",
|
||||
encoding="utf-8"
|
||||
)
|
||||
r = client.get("/api/sft/runs")
|
||||
assert r.json()[0]["already_imported"] is True
|
||||
|
||||
|
||||
# ── /api/sft/import ─────────────────────────────────────────────────────────
|
||||
|
||||
def test_import_adds_records(client, tmp_path):
|
||||
_write_run(tmp_path, "2026-04-07-143022", [_make_record("a"), _make_record("b")])
|
||||
_write_config(tmp_path, tmp_path / "bench_results")
|
||||
r = client.post("/api/sft/import", json={"run_id": "2026-04-07-143022"})
|
||||
assert r.status_code == 200
|
||||
assert r.json() == {"imported": 2, "skipped": 0}
|
||||
|
||||
|
||||
def test_import_is_idempotent(client, tmp_path):
|
||||
_write_run(tmp_path, "2026-04-07-143022", [_make_record("a")])
|
||||
_write_config(tmp_path, tmp_path / "bench_results")
|
||||
client.post("/api/sft/import", json={"run_id": "2026-04-07-143022"})
|
||||
r = client.post("/api/sft/import", json={"run_id": "2026-04-07-143022"})
|
||||
assert r.json() == {"imported": 0, "skipped": 1}
|
||||
|
||||
|
||||
def test_import_unknown_run_returns_404(client, tmp_path):
|
||||
_write_config(tmp_path, tmp_path / "bench_results")
|
||||
r = client.post("/api/sft/import", json={"run_id": "nonexistent"})
|
||||
assert r.status_code == 404
|
||||
Loading…
Reference in a new issue