feat: sft router — /export and /stats endpoints

This commit is contained in:
pyr0ball 2026-04-08 14:46:08 -07:00
parent 4ad2907ae8
commit 07807f0d05
2 changed files with 148 additions and 0 deletions

View file

@ -9,12 +9,15 @@ set_sft_config_dir() in test fixtures.
"""
from __future__ import annotations
import json
import logging
from datetime import datetime, timezone
from pathlib import Path
from typing import Literal
import yaml
from fastapi import APIRouter, HTTPException
from fastapi.responses import StreamingResponse
from pydantic import BaseModel
from app.utils import append_jsonl, read_jsonl, write_jsonl
@ -203,3 +206,68 @@ def post_undo(req: UndoRequest):
write_jsonl(_approved_file(), [r for r in approved if r.get("id") != req.id])
return {"ok": True}
# ── GET /export ─────────────────────────────────────────────────────────────
@router.get("/export")
def get_export():
"""Stream approved records as SFT-ready JSONL for download."""
approved = read_jsonl(_approved_file())
exportable = [
r for r in approved
if r.get("status") == "approved"
and r.get("corrected_response")
and str(r["corrected_response"]).strip()
]
def generate():
for r in exportable:
record = {
"messages": r.get("prompt_messages", []) + [
{"role": "assistant", "content": r["corrected_response"]}
]
}
yield json.dumps(record) + "\n"
timestamp = datetime.now(timezone.utc).strftime("%Y%m%d-%H%M%S")
return StreamingResponse(
generate(),
media_type="application/x-ndjson",
headers={
"Content-Disposition": f'attachment; filename="sft_export_{timestamp}.jsonl"'
},
)
# ── GET /stats ──────────────────────────────────────────────────────────────
@router.get("/stats")
def get_stats():
"""Return counts by status, model, and task type."""
records = _read_candidates()
by_status: dict[str, int] = {}
by_model: dict[str, int] = {}
by_task_type: dict[str, int] = {}
for r in records:
status = r.get("status", "unknown")
by_status[status] = by_status.get(status, 0) + 1
model = r.get("model_name", "unknown")
by_model[model] = by_model.get(model, 0) + 1
task_type = r.get("task_type", "unknown")
by_task_type[task_type] = by_task_type.get(task_type, 0) + 1
approved = read_jsonl(_approved_file())
export_ready = sum(
1 for r in approved
if r.get("corrected_response") and str(r["corrected_response"]).strip()
)
return {
"total": len(records),
"by_status": by_status,
"by_model": by_model,
"by_task_type": by_task_type,
"export_ready": export_ready,
}

View file

@ -260,3 +260,83 @@ def test_undo_already_needs_review_returns_409(client, tmp_path):
_populate_candidates(tmp_path, [_make_record("a")])
r = client.post("/api/sft/undo", json={"id": "a"})
assert r.status_code == 409
# ── /api/sft/export ──────────────────────────────────────────────────────────
def test_export_returns_approved_as_sft_jsonl(client, tmp_path):
from app import sft as sft_module
from app.utils import write_jsonl
approved = {
**_make_record("a"),
"status": "approved",
"corrected_response": "def add(a, b): return a + b",
"prompt_messages": [
{"role": "system", "content": "You are a coding assistant."},
{"role": "user", "content": "Write a Python add function."},
],
}
write_jsonl(sft_module._approved_file(), [approved])
_populate_candidates(tmp_path, [approved])
r = client.get("/api/sft/export")
assert r.status_code == 200
assert "application/x-ndjson" in r.headers["content-type"]
lines = [l for l in r.text.splitlines() if l.strip()]
assert len(lines) == 1
record = json.loads(lines[0])
assert record["messages"][-1] == {
"role": "assistant", "content": "def add(a, b): return a + b"
}
assert record["messages"][0]["role"] == "system"
assert record["messages"][1]["role"] == "user"
def test_export_excludes_non_approved(client, tmp_path):
from app import sft as sft_module
from app.utils import write_jsonl
records = [
{**_make_record("a"), "status": "discarded", "corrected_response": None},
{**_make_record("b"), "status": "needs_review", "corrected_response": None},
]
write_jsonl(sft_module._approved_file(), records)
r = client.get("/api/sft/export")
assert r.text.strip() == ""
def test_export_empty_when_no_approved_file(client):
r = client.get("/api/sft/export")
assert r.status_code == 200
assert r.text.strip() == ""
# ── /api/sft/stats ───────────────────────────────────────────────────────────
def test_stats_counts_by_status(client, tmp_path):
from app import sft as sft_module
from app.utils import write_jsonl
records = [
_make_record("a"),
{**_make_record("b"), "status": "approved", "corrected_response": "ok"},
{**_make_record("c"), "status": "discarded"},
{**_make_record("d"), "status": "model_rejected"},
]
_populate_candidates(tmp_path, records)
write_jsonl(sft_module._approved_file(), [records[1]])
r = client.get("/api/sft/stats")
assert r.status_code == 200
data = r.json()
assert data["total"] == 4
assert data["by_status"]["needs_review"] == 1
assert data["by_status"]["approved"] == 1
assert data["by_status"]["discarded"] == 1
assert data["by_status"]["model_rejected"] == 1
assert data["export_ready"] == 1
def test_stats_empty_when_no_data(client):
r = client.get("/api/sft/stats")
assert r.status_code == 200
data = r.json()
assert data["total"] == 0
assert data["export_ready"] == 0