From f19cab60f7bfe8ba5c6f620441bb52627b6ac55f Mon Sep 17 00:00:00 2001 From: pyr0ball Date: Wed, 8 Apr 2026 14:22:06 -0700 Subject: [PATCH] =?UTF-8?q?feat:=20sft=20router=20=E2=80=94=20/queue,=20/s?= =?UTF-8?q?ubmit,=20/undo=20endpoints?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- app/sft.py | 86 +++++++++++++++++++++++++++ tests/test_sft.py | 144 ++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 230 insertions(+) diff --git a/app/sft.py b/app/sft.py index 1609e80..71cf04e 100644 --- a/app/sft.py +++ b/app/sft.py @@ -118,3 +118,89 @@ def post_import(req: ImportRequest): if run is None: raise HTTPException(404, f"Run {req.run_id!r} not found in bench_results_dir") return import_run(run["sft_path"], _SFT_DATA_DIR) + + +# ── GET /queue ───────────────────────────────────────────────────────────── + +@router.get("/queue") +def get_queue(page: int = 1, per_page: int = 20): + """Return paginated needs_review candidates.""" + records = _read_candidates() + pending = [r for r in records if r.get("status") == "needs_review"] + start = (page - 1) * per_page + return { + "items": pending[start:start + per_page], + "total": len(pending), + "page": page, + "per_page": per_page, + } + + +# ── POST /submit ─────────────────────────────────────────────────────────── + +class SubmitRequest(BaseModel): + id: str + action: str # "correct" | "discard" | "flag" + corrected_response: str | None = None + + +@router.post("/submit") +def post_submit(req: SubmitRequest): + """Record a reviewer decision for one SFT candidate.""" + if req.action not in ("correct", "discard", "flag"): + raise HTTPException(422, f"Unknown action {req.action!r}") + if req.action == "correct": + if not req.corrected_response or not req.corrected_response.strip(): + raise HTTPException(422, "corrected_response must be non-empty when action is 'correct'") + + records = _read_candidates() + idx = next((i for i, r in enumerate(records) if r.get("id") == req.id), None) + if idx is None: + raise HTTPException(404, f"Record {req.id!r} not found") + + record = records[idx] + if record.get("status") != "needs_review": + raise HTTPException(409, f"Record is not in needs_review state (current: {record.get('status')})") + + if req.action == "correct": + records[idx] = {**record, "status": "approved", "corrected_response": req.corrected_response} + _write_candidates(records) + append_jsonl(_approved_file(), records[idx]) + elif req.action == "discard": + records[idx] = {**record, "status": "discarded"} + _write_candidates(records) + else: # flag + records[idx] = {**record, "status": "model_rejected"} + _write_candidates(records) + + return {"ok": True} + + +# ── POST /undo ───────────────────────────────────────────────────────────── + +class UndoRequest(BaseModel): + id: str + + +@router.post("/undo") +def post_undo(req: UndoRequest): + """Restore a previously actioned candidate back to needs_review.""" + records = _read_candidates() + idx = next((i for i, r in enumerate(records) if r.get("id") == req.id), None) + if idx is None: + raise HTTPException(404, f"Record {req.id!r} not found") + + record = records[idx] + old_status = record.get("status") + if old_status == "needs_review": + raise HTTPException(409, "Record is already in needs_review state") + + records[idx] = {**record, "status": "needs_review", "corrected_response": None} + _write_candidates(records) + + # If it was approved, remove from the approved file too + if old_status == "approved": + approved = read_jsonl(_approved_file()) + write_jsonl(_approved_file(), [r for r in approved if r.get("id") != req.id]) + + return {"ok": True} diff --git a/tests/test_sft.py b/tests/test_sft.py index 7e8de2f..81c634d 100644 --- a/tests/test_sft.py +++ b/tests/test_sft.py @@ -116,3 +116,147 @@ def test_import_unknown_run_returns_404(client, tmp_path): _write_config(tmp_path, tmp_path / "bench_results") r = client.post("/api/sft/import", json={"run_id": "nonexistent"}) assert r.status_code == 404 + + +# ── /api/sft/queue ────────────────────────────────────────────────────────── + +def _populate_candidates(tmp_path, records: list[dict]) -> None: + from app import sft as sft_module + path = sft_module._candidates_file() + path.parent.mkdir(parents=True, exist_ok=True) + path.write_text( + "\n".join(json.dumps(r) for r in records) + "\n", encoding="utf-8" + ) + + +def test_queue_returns_needs_review_only(client, tmp_path): + records = [ + _make_record("a"), # needs_review + {**_make_record("b"), "status": "approved"}, # should not appear + {**_make_record("c"), "status": "discarded"}, # should not appear + ] + _populate_candidates(tmp_path, records) + r = client.get("/api/sft/queue") + assert r.status_code == 200 + data = r.json() + assert data["total"] == 1 + assert len(data["items"]) == 1 + assert data["items"][0]["id"] == "a" + + +def test_queue_pagination(client, tmp_path): + records = [_make_record(str(i)) for i in range(25)] + _populate_candidates(tmp_path, records) + r = client.get("/api/sft/queue?page=1&per_page=10") + data = r.json() + assert data["total"] == 25 + assert len(data["items"]) == 10 + r2 = client.get("/api/sft/queue?page=3&per_page=10") + assert len(r2.json()["items"]) == 5 + + +def test_queue_empty_when_no_file(client): + r = client.get("/api/sft/queue") + assert r.status_code == 200 + assert r.json() == {"items": [], "total": 0, "page": 1, "per_page": 20} + + +# ── /api/sft/submit ───────────────────────────────────────────────────────── + +def test_submit_correct_sets_approved(client, tmp_path): + _populate_candidates(tmp_path, [_make_record("a")]) + r = client.post("/api/sft/submit", json={ + "id": "a", "action": "correct", + "corrected_response": "def add(a, b): return a + b", + }) + assert r.status_code == 200 + from app import sft as sft_module + records = sft_module._read_candidates() + assert records[0]["status"] == "approved" + assert records[0]["corrected_response"] == "def add(a, b): return a + b" + + +def test_submit_correct_also_appends_to_approved_file(client, tmp_path): + _populate_candidates(tmp_path, [_make_record("a")]) + client.post("/api/sft/submit", json={ + "id": "a", "action": "correct", + "corrected_response": "def add(a, b): return a + b", + }) + from app import sft as sft_module + from app.utils import read_jsonl + approved = read_jsonl(sft_module._approved_file()) + assert len(approved) == 1 + assert approved[0]["id"] == "a" + + +def test_submit_discard_sets_discarded(client, tmp_path): + _populate_candidates(tmp_path, [_make_record("a")]) + r = client.post("/api/sft/submit", json={"id": "a", "action": "discard"}) + assert r.status_code == 200 + from app import sft as sft_module + assert sft_module._read_candidates()[0]["status"] == "discarded" + + +def test_submit_flag_sets_model_rejected(client, tmp_path): + _populate_candidates(tmp_path, [_make_record("a")]) + r = client.post("/api/sft/submit", json={"id": "a", "action": "flag"}) + assert r.status_code == 200 + from app import sft as sft_module + assert sft_module._read_candidates()[0]["status"] == "model_rejected" + + +def test_submit_correct_empty_response_returns_422(client, tmp_path): + _populate_candidates(tmp_path, [_make_record("a")]) + r = client.post("/api/sft/submit", json={ + "id": "a", "action": "correct", "corrected_response": " ", + }) + assert r.status_code == 422 + + +def test_submit_correct_null_response_returns_422(client, tmp_path): + _populate_candidates(tmp_path, [_make_record("a")]) + r = client.post("/api/sft/submit", json={ + "id": "a", "action": "correct", "corrected_response": None, + }) + assert r.status_code == 422 + + +def test_submit_unknown_id_returns_404(client, tmp_path): + r = client.post("/api/sft/submit", json={"id": "nope", "action": "discard"}) + assert r.status_code == 404 + + +def test_submit_already_approved_returns_409(client, tmp_path): + _populate_candidates(tmp_path, [{**_make_record("a"), "status": "approved"}]) + r = client.post("/api/sft/submit", json={"id": "a", "action": "discard"}) + assert r.status_code == 409 + + +# ── /api/sft/undo ──────────────────────────────────────────────────────────── + +def test_undo_restores_discarded_to_needs_review(client, tmp_path): + _populate_candidates(tmp_path, [_make_record("a")]) + client.post("/api/sft/submit", json={"id": "a", "action": "discard"}) + r = client.post("/api/sft/undo", json={"id": "a"}) + assert r.status_code == 200 + from app import sft as sft_module + assert sft_module._read_candidates()[0]["status"] == "needs_review" + + +def test_undo_removes_approved_from_approved_file(client, tmp_path): + _populate_candidates(tmp_path, [_make_record("a")]) + client.post("/api/sft/submit", json={ + "id": "a", "action": "correct", + "corrected_response": "def add(a, b): return a + b", + }) + client.post("/api/sft/undo", json={"id": "a"}) + from app import sft as sft_module + from app.utils import read_jsonl + approved = read_jsonl(sft_module._approved_file()) + assert not any(r["id"] == "a" for r in approved) + + +def test_undo_already_needs_review_returns_409(client, tmp_path): + _populate_candidates(tmp_path, [_make_record("a")]) + r = client.post("/api/sft/undo", json={"id": "a"}) + assert r.status_code == 409