feat: sft router — /queue, /submit, /undo endpoints

This commit is contained in:
pyr0ball 2026-04-08 14:22:06 -07:00
parent b330e84111
commit f19cab60f7
2 changed files with 230 additions and 0 deletions

View file

@ -118,3 +118,89 @@ def post_import(req: ImportRequest):
if run is None:
raise HTTPException(404, f"Run {req.run_id!r} not found in bench_results_dir")
return import_run(run["sft_path"], _SFT_DATA_DIR)
# ── GET /queue ─────────────────────────────────────────────────────────────
@router.get("/queue")
def get_queue(page: int = 1, per_page: int = 20):
"""Return paginated needs_review candidates."""
records = _read_candidates()
pending = [r for r in records if r.get("status") == "needs_review"]
start = (page - 1) * per_page
return {
"items": pending[start:start + per_page],
"total": len(pending),
"page": page,
"per_page": per_page,
}
# ── POST /submit ───────────────────────────────────────────────────────────
class SubmitRequest(BaseModel):
id: str
action: str # "correct" | "discard" | "flag"
corrected_response: str | None = None
@router.post("/submit")
def post_submit(req: SubmitRequest):
"""Record a reviewer decision for one SFT candidate."""
if req.action not in ("correct", "discard", "flag"):
raise HTTPException(422, f"Unknown action {req.action!r}")
if req.action == "correct":
if not req.corrected_response or not req.corrected_response.strip():
raise HTTPException(422, "corrected_response must be non-empty when action is 'correct'")
records = _read_candidates()
idx = next((i for i, r in enumerate(records) if r.get("id") == req.id), None)
if idx is None:
raise HTTPException(404, f"Record {req.id!r} not found")
record = records[idx]
if record.get("status") != "needs_review":
raise HTTPException(409, f"Record is not in needs_review state (current: {record.get('status')})")
if req.action == "correct":
records[idx] = {**record, "status": "approved", "corrected_response": req.corrected_response}
_write_candidates(records)
append_jsonl(_approved_file(), records[idx])
elif req.action == "discard":
records[idx] = {**record, "status": "discarded"}
_write_candidates(records)
else: # flag
records[idx] = {**record, "status": "model_rejected"}
_write_candidates(records)
return {"ok": True}
# ── POST /undo ─────────────────────────────────────────────────────────────
class UndoRequest(BaseModel):
id: str
@router.post("/undo")
def post_undo(req: UndoRequest):
"""Restore a previously actioned candidate back to needs_review."""
records = _read_candidates()
idx = next((i for i, r in enumerate(records) if r.get("id") == req.id), None)
if idx is None:
raise HTTPException(404, f"Record {req.id!r} not found")
record = records[idx]
old_status = record.get("status")
if old_status == "needs_review":
raise HTTPException(409, "Record is already in needs_review state")
records[idx] = {**record, "status": "needs_review", "corrected_response": None}
_write_candidates(records)
# If it was approved, remove from the approved file too
if old_status == "approved":
approved = read_jsonl(_approved_file())
write_jsonl(_approved_file(), [r for r in approved if r.get("id") != req.id])
return {"ok": True}

View file

@ -116,3 +116,147 @@ def test_import_unknown_run_returns_404(client, tmp_path):
_write_config(tmp_path, tmp_path / "bench_results")
r = client.post("/api/sft/import", json={"run_id": "nonexistent"})
assert r.status_code == 404
# ── /api/sft/queue ──────────────────────────────────────────────────────────
def _populate_candidates(tmp_path, records: list[dict]) -> None:
from app import sft as sft_module
path = sft_module._candidates_file()
path.parent.mkdir(parents=True, exist_ok=True)
path.write_text(
"\n".join(json.dumps(r) for r in records) + "\n", encoding="utf-8"
)
def test_queue_returns_needs_review_only(client, tmp_path):
records = [
_make_record("a"), # needs_review
{**_make_record("b"), "status": "approved"}, # should not appear
{**_make_record("c"), "status": "discarded"}, # should not appear
]
_populate_candidates(tmp_path, records)
r = client.get("/api/sft/queue")
assert r.status_code == 200
data = r.json()
assert data["total"] == 1
assert len(data["items"]) == 1
assert data["items"][0]["id"] == "a"
def test_queue_pagination(client, tmp_path):
records = [_make_record(str(i)) for i in range(25)]
_populate_candidates(tmp_path, records)
r = client.get("/api/sft/queue?page=1&per_page=10")
data = r.json()
assert data["total"] == 25
assert len(data["items"]) == 10
r2 = client.get("/api/sft/queue?page=3&per_page=10")
assert len(r2.json()["items"]) == 5
def test_queue_empty_when_no_file(client):
r = client.get("/api/sft/queue")
assert r.status_code == 200
assert r.json() == {"items": [], "total": 0, "page": 1, "per_page": 20}
# ── /api/sft/submit ─────────────────────────────────────────────────────────
def test_submit_correct_sets_approved(client, tmp_path):
_populate_candidates(tmp_path, [_make_record("a")])
r = client.post("/api/sft/submit", json={
"id": "a", "action": "correct",
"corrected_response": "def add(a, b): return a + b",
})
assert r.status_code == 200
from app import sft as sft_module
records = sft_module._read_candidates()
assert records[0]["status"] == "approved"
assert records[0]["corrected_response"] == "def add(a, b): return a + b"
def test_submit_correct_also_appends_to_approved_file(client, tmp_path):
_populate_candidates(tmp_path, [_make_record("a")])
client.post("/api/sft/submit", json={
"id": "a", "action": "correct",
"corrected_response": "def add(a, b): return a + b",
})
from app import sft as sft_module
from app.utils import read_jsonl
approved = read_jsonl(sft_module._approved_file())
assert len(approved) == 1
assert approved[0]["id"] == "a"
def test_submit_discard_sets_discarded(client, tmp_path):
_populate_candidates(tmp_path, [_make_record("a")])
r = client.post("/api/sft/submit", json={"id": "a", "action": "discard"})
assert r.status_code == 200
from app import sft as sft_module
assert sft_module._read_candidates()[0]["status"] == "discarded"
def test_submit_flag_sets_model_rejected(client, tmp_path):
_populate_candidates(tmp_path, [_make_record("a")])
r = client.post("/api/sft/submit", json={"id": "a", "action": "flag"})
assert r.status_code == 200
from app import sft as sft_module
assert sft_module._read_candidates()[0]["status"] == "model_rejected"
def test_submit_correct_empty_response_returns_422(client, tmp_path):
_populate_candidates(tmp_path, [_make_record("a")])
r = client.post("/api/sft/submit", json={
"id": "a", "action": "correct", "corrected_response": " ",
})
assert r.status_code == 422
def test_submit_correct_null_response_returns_422(client, tmp_path):
_populate_candidates(tmp_path, [_make_record("a")])
r = client.post("/api/sft/submit", json={
"id": "a", "action": "correct", "corrected_response": None,
})
assert r.status_code == 422
def test_submit_unknown_id_returns_404(client, tmp_path):
r = client.post("/api/sft/submit", json={"id": "nope", "action": "discard"})
assert r.status_code == 404
def test_submit_already_approved_returns_409(client, tmp_path):
_populate_candidates(tmp_path, [{**_make_record("a"), "status": "approved"}])
r = client.post("/api/sft/submit", json={"id": "a", "action": "discard"})
assert r.status_code == 409
# ── /api/sft/undo ────────────────────────────────────────────────────────────
def test_undo_restores_discarded_to_needs_review(client, tmp_path):
_populate_candidates(tmp_path, [_make_record("a")])
client.post("/api/sft/submit", json={"id": "a", "action": "discard"})
r = client.post("/api/sft/undo", json={"id": "a"})
assert r.status_code == 200
from app import sft as sft_module
assert sft_module._read_candidates()[0]["status"] == "needs_review"
def test_undo_removes_approved_from_approved_file(client, tmp_path):
_populate_candidates(tmp_path, [_make_record("a")])
client.post("/api/sft/submit", json={
"id": "a", "action": "correct",
"corrected_response": "def add(a, b): return a + b",
})
client.post("/api/sft/undo", json={"id": "a"})
from app import sft as sft_module
from app.utils import read_jsonl
approved = read_jsonl(sft_module._approved_file())
assert not any(r["id"] == "a" for r in approved)
def test_undo_already_needs_review_returns_409(client, tmp_path):
_populate_candidates(tmp_path, [_make_record("a")])
r = client.post("/api/sft/undo", json={"id": "a"})
assert r.status_code == 409