feat: sft router — /queue, /submit, /undo endpoints
This commit is contained in:
parent
b330e84111
commit
f19cab60f7
2 changed files with 230 additions and 0 deletions
86
app/sft.py
86
app/sft.py
|
|
@ -118,3 +118,89 @@ def post_import(req: ImportRequest):
|
|||
if run is None:
|
||||
raise HTTPException(404, f"Run {req.run_id!r} not found in bench_results_dir")
|
||||
return import_run(run["sft_path"], _SFT_DATA_DIR)
|
||||
|
||||
|
||||
# ── GET /queue ─────────────────────────────────────────────────────────────
|
||||
|
||||
@router.get("/queue")
|
||||
def get_queue(page: int = 1, per_page: int = 20):
|
||||
"""Return paginated needs_review candidates."""
|
||||
records = _read_candidates()
|
||||
pending = [r for r in records if r.get("status") == "needs_review"]
|
||||
start = (page - 1) * per_page
|
||||
return {
|
||||
"items": pending[start:start + per_page],
|
||||
"total": len(pending),
|
||||
"page": page,
|
||||
"per_page": per_page,
|
||||
}
|
||||
|
||||
|
||||
# ── POST /submit ───────────────────────────────────────────────────────────
|
||||
|
||||
class SubmitRequest(BaseModel):
|
||||
id: str
|
||||
action: str # "correct" | "discard" | "flag"
|
||||
corrected_response: str | None = None
|
||||
|
||||
|
||||
@router.post("/submit")
|
||||
def post_submit(req: SubmitRequest):
|
||||
"""Record a reviewer decision for one SFT candidate."""
|
||||
if req.action not in ("correct", "discard", "flag"):
|
||||
raise HTTPException(422, f"Unknown action {req.action!r}")
|
||||
if req.action == "correct":
|
||||
if not req.corrected_response or not req.corrected_response.strip():
|
||||
raise HTTPException(422, "corrected_response must be non-empty when action is 'correct'")
|
||||
|
||||
records = _read_candidates()
|
||||
idx = next((i for i, r in enumerate(records) if r.get("id") == req.id), None)
|
||||
if idx is None:
|
||||
raise HTTPException(404, f"Record {req.id!r} not found")
|
||||
|
||||
record = records[idx]
|
||||
if record.get("status") != "needs_review":
|
||||
raise HTTPException(409, f"Record is not in needs_review state (current: {record.get('status')})")
|
||||
|
||||
if req.action == "correct":
|
||||
records[idx] = {**record, "status": "approved", "corrected_response": req.corrected_response}
|
||||
_write_candidates(records)
|
||||
append_jsonl(_approved_file(), records[idx])
|
||||
elif req.action == "discard":
|
||||
records[idx] = {**record, "status": "discarded"}
|
||||
_write_candidates(records)
|
||||
else: # flag
|
||||
records[idx] = {**record, "status": "model_rejected"}
|
||||
_write_candidates(records)
|
||||
|
||||
return {"ok": True}
|
||||
|
||||
|
||||
# ── POST /undo ─────────────────────────────────────────────────────────────
|
||||
|
||||
class UndoRequest(BaseModel):
|
||||
id: str
|
||||
|
||||
|
||||
@router.post("/undo")
|
||||
def post_undo(req: UndoRequest):
|
||||
"""Restore a previously actioned candidate back to needs_review."""
|
||||
records = _read_candidates()
|
||||
idx = next((i for i, r in enumerate(records) if r.get("id") == req.id), None)
|
||||
if idx is None:
|
||||
raise HTTPException(404, f"Record {req.id!r} not found")
|
||||
|
||||
record = records[idx]
|
||||
old_status = record.get("status")
|
||||
if old_status == "needs_review":
|
||||
raise HTTPException(409, "Record is already in needs_review state")
|
||||
|
||||
records[idx] = {**record, "status": "needs_review", "corrected_response": None}
|
||||
_write_candidates(records)
|
||||
|
||||
# If it was approved, remove from the approved file too
|
||||
if old_status == "approved":
|
||||
approved = read_jsonl(_approved_file())
|
||||
write_jsonl(_approved_file(), [r for r in approved if r.get("id") != req.id])
|
||||
|
||||
return {"ok": True}
|
||||
|
|
|
|||
|
|
@ -116,3 +116,147 @@ def test_import_unknown_run_returns_404(client, tmp_path):
|
|||
_write_config(tmp_path, tmp_path / "bench_results")
|
||||
r = client.post("/api/sft/import", json={"run_id": "nonexistent"})
|
||||
assert r.status_code == 404
|
||||
|
||||
|
||||
# ── /api/sft/queue ──────────────────────────────────────────────────────────
|
||||
|
||||
def _populate_candidates(tmp_path, records: list[dict]) -> None:
|
||||
from app import sft as sft_module
|
||||
path = sft_module._candidates_file()
|
||||
path.parent.mkdir(parents=True, exist_ok=True)
|
||||
path.write_text(
|
||||
"\n".join(json.dumps(r) for r in records) + "\n", encoding="utf-8"
|
||||
)
|
||||
|
||||
|
||||
def test_queue_returns_needs_review_only(client, tmp_path):
|
||||
records = [
|
||||
_make_record("a"), # needs_review
|
||||
{**_make_record("b"), "status": "approved"}, # should not appear
|
||||
{**_make_record("c"), "status": "discarded"}, # should not appear
|
||||
]
|
||||
_populate_candidates(tmp_path, records)
|
||||
r = client.get("/api/sft/queue")
|
||||
assert r.status_code == 200
|
||||
data = r.json()
|
||||
assert data["total"] == 1
|
||||
assert len(data["items"]) == 1
|
||||
assert data["items"][0]["id"] == "a"
|
||||
|
||||
|
||||
def test_queue_pagination(client, tmp_path):
|
||||
records = [_make_record(str(i)) for i in range(25)]
|
||||
_populate_candidates(tmp_path, records)
|
||||
r = client.get("/api/sft/queue?page=1&per_page=10")
|
||||
data = r.json()
|
||||
assert data["total"] == 25
|
||||
assert len(data["items"]) == 10
|
||||
r2 = client.get("/api/sft/queue?page=3&per_page=10")
|
||||
assert len(r2.json()["items"]) == 5
|
||||
|
||||
|
||||
def test_queue_empty_when_no_file(client):
|
||||
r = client.get("/api/sft/queue")
|
||||
assert r.status_code == 200
|
||||
assert r.json() == {"items": [], "total": 0, "page": 1, "per_page": 20}
|
||||
|
||||
|
||||
# ── /api/sft/submit ─────────────────────────────────────────────────────────
|
||||
|
||||
def test_submit_correct_sets_approved(client, tmp_path):
|
||||
_populate_candidates(tmp_path, [_make_record("a")])
|
||||
r = client.post("/api/sft/submit", json={
|
||||
"id": "a", "action": "correct",
|
||||
"corrected_response": "def add(a, b): return a + b",
|
||||
})
|
||||
assert r.status_code == 200
|
||||
from app import sft as sft_module
|
||||
records = sft_module._read_candidates()
|
||||
assert records[0]["status"] == "approved"
|
||||
assert records[0]["corrected_response"] == "def add(a, b): return a + b"
|
||||
|
||||
|
||||
def test_submit_correct_also_appends_to_approved_file(client, tmp_path):
|
||||
_populate_candidates(tmp_path, [_make_record("a")])
|
||||
client.post("/api/sft/submit", json={
|
||||
"id": "a", "action": "correct",
|
||||
"corrected_response": "def add(a, b): return a + b",
|
||||
})
|
||||
from app import sft as sft_module
|
||||
from app.utils import read_jsonl
|
||||
approved = read_jsonl(sft_module._approved_file())
|
||||
assert len(approved) == 1
|
||||
assert approved[0]["id"] == "a"
|
||||
|
||||
|
||||
def test_submit_discard_sets_discarded(client, tmp_path):
|
||||
_populate_candidates(tmp_path, [_make_record("a")])
|
||||
r = client.post("/api/sft/submit", json={"id": "a", "action": "discard"})
|
||||
assert r.status_code == 200
|
||||
from app import sft as sft_module
|
||||
assert sft_module._read_candidates()[0]["status"] == "discarded"
|
||||
|
||||
|
||||
def test_submit_flag_sets_model_rejected(client, tmp_path):
|
||||
_populate_candidates(tmp_path, [_make_record("a")])
|
||||
r = client.post("/api/sft/submit", json={"id": "a", "action": "flag"})
|
||||
assert r.status_code == 200
|
||||
from app import sft as sft_module
|
||||
assert sft_module._read_candidates()[0]["status"] == "model_rejected"
|
||||
|
||||
|
||||
def test_submit_correct_empty_response_returns_422(client, tmp_path):
|
||||
_populate_candidates(tmp_path, [_make_record("a")])
|
||||
r = client.post("/api/sft/submit", json={
|
||||
"id": "a", "action": "correct", "corrected_response": " ",
|
||||
})
|
||||
assert r.status_code == 422
|
||||
|
||||
|
||||
def test_submit_correct_null_response_returns_422(client, tmp_path):
|
||||
_populate_candidates(tmp_path, [_make_record("a")])
|
||||
r = client.post("/api/sft/submit", json={
|
||||
"id": "a", "action": "correct", "corrected_response": None,
|
||||
})
|
||||
assert r.status_code == 422
|
||||
|
||||
|
||||
def test_submit_unknown_id_returns_404(client, tmp_path):
|
||||
r = client.post("/api/sft/submit", json={"id": "nope", "action": "discard"})
|
||||
assert r.status_code == 404
|
||||
|
||||
|
||||
def test_submit_already_approved_returns_409(client, tmp_path):
|
||||
_populate_candidates(tmp_path, [{**_make_record("a"), "status": "approved"}])
|
||||
r = client.post("/api/sft/submit", json={"id": "a", "action": "discard"})
|
||||
assert r.status_code == 409
|
||||
|
||||
|
||||
# ── /api/sft/undo ────────────────────────────────────────────────────────────
|
||||
|
||||
def test_undo_restores_discarded_to_needs_review(client, tmp_path):
|
||||
_populate_candidates(tmp_path, [_make_record("a")])
|
||||
client.post("/api/sft/submit", json={"id": "a", "action": "discard"})
|
||||
r = client.post("/api/sft/undo", json={"id": "a"})
|
||||
assert r.status_code == 200
|
||||
from app import sft as sft_module
|
||||
assert sft_module._read_candidates()[0]["status"] == "needs_review"
|
||||
|
||||
|
||||
def test_undo_removes_approved_from_approved_file(client, tmp_path):
|
||||
_populate_candidates(tmp_path, [_make_record("a")])
|
||||
client.post("/api/sft/submit", json={
|
||||
"id": "a", "action": "correct",
|
||||
"corrected_response": "def add(a, b): return a + b",
|
||||
})
|
||||
client.post("/api/sft/undo", json={"id": "a"})
|
||||
from app import sft as sft_module
|
||||
from app.utils import read_jsonl
|
||||
approved = read_jsonl(sft_module._approved_file())
|
||||
assert not any(r["id"] == "a" for r in approved)
|
||||
|
||||
|
||||
def test_undo_already_needs_review_returns_409(client, tmp_path):
|
||||
_populate_candidates(tmp_path, [_make_record("a")])
|
||||
r = client.post("/api/sft/undo", json={"id": "a"})
|
||||
assert r.status_code == 409
|
||||
|
|
|
|||
Loading…
Reference in a new issue