feat: sft router — /export and /stats endpoints
This commit is contained in:
parent
4ad2907ae8
commit
07807f0d05
2 changed files with 148 additions and 0 deletions
68
app/sft.py
68
app/sft.py
|
|
@ -9,12 +9,15 @@ set_sft_config_dir() in test fixtures.
|
||||||
"""
|
"""
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import json
|
||||||
import logging
|
import logging
|
||||||
|
from datetime import datetime, timezone
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Literal
|
from typing import Literal
|
||||||
|
|
||||||
import yaml
|
import yaml
|
||||||
from fastapi import APIRouter, HTTPException
|
from fastapi import APIRouter, HTTPException
|
||||||
|
from fastapi.responses import StreamingResponse
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
from app.utils import append_jsonl, read_jsonl, write_jsonl
|
from app.utils import append_jsonl, read_jsonl, write_jsonl
|
||||||
|
|
@ -203,3 +206,68 @@ def post_undo(req: UndoRequest):
|
||||||
write_jsonl(_approved_file(), [r for r in approved if r.get("id") != req.id])
|
write_jsonl(_approved_file(), [r for r in approved if r.get("id") != req.id])
|
||||||
|
|
||||||
return {"ok": True}
|
return {"ok": True}
|
||||||
|
|
||||||
|
|
||||||
|
# ── GET /export ─────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
@router.get("/export")
|
||||||
|
def get_export():
|
||||||
|
"""Stream approved records as SFT-ready JSONL for download."""
|
||||||
|
approved = read_jsonl(_approved_file())
|
||||||
|
exportable = [
|
||||||
|
r for r in approved
|
||||||
|
if r.get("status") == "approved"
|
||||||
|
and r.get("corrected_response")
|
||||||
|
and str(r["corrected_response"]).strip()
|
||||||
|
]
|
||||||
|
|
||||||
|
def generate():
|
||||||
|
for r in exportable:
|
||||||
|
record = {
|
||||||
|
"messages": r.get("prompt_messages", []) + [
|
||||||
|
{"role": "assistant", "content": r["corrected_response"]}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
yield json.dumps(record) + "\n"
|
||||||
|
|
||||||
|
timestamp = datetime.now(timezone.utc).strftime("%Y%m%d-%H%M%S")
|
||||||
|
return StreamingResponse(
|
||||||
|
generate(),
|
||||||
|
media_type="application/x-ndjson",
|
||||||
|
headers={
|
||||||
|
"Content-Disposition": f'attachment; filename="sft_export_{timestamp}.jsonl"'
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# ── GET /stats ──────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
@router.get("/stats")
|
||||||
|
def get_stats():
|
||||||
|
"""Return counts by status, model, and task type."""
|
||||||
|
records = _read_candidates()
|
||||||
|
by_status: dict[str, int] = {}
|
||||||
|
by_model: dict[str, int] = {}
|
||||||
|
by_task_type: dict[str, int] = {}
|
||||||
|
|
||||||
|
for r in records:
|
||||||
|
status = r.get("status", "unknown")
|
||||||
|
by_status[status] = by_status.get(status, 0) + 1
|
||||||
|
model = r.get("model_name", "unknown")
|
||||||
|
by_model[model] = by_model.get(model, 0) + 1
|
||||||
|
task_type = r.get("task_type", "unknown")
|
||||||
|
by_task_type[task_type] = by_task_type.get(task_type, 0) + 1
|
||||||
|
|
||||||
|
approved = read_jsonl(_approved_file())
|
||||||
|
export_ready = sum(
|
||||||
|
1 for r in approved
|
||||||
|
if r.get("corrected_response") and str(r["corrected_response"]).strip()
|
||||||
|
)
|
||||||
|
|
||||||
|
return {
|
||||||
|
"total": len(records),
|
||||||
|
"by_status": by_status,
|
||||||
|
"by_model": by_model,
|
||||||
|
"by_task_type": by_task_type,
|
||||||
|
"export_ready": export_ready,
|
||||||
|
}
|
||||||
|
|
|
||||||
|
|
@ -260,3 +260,83 @@ def test_undo_already_needs_review_returns_409(client, tmp_path):
|
||||||
_populate_candidates(tmp_path, [_make_record("a")])
|
_populate_candidates(tmp_path, [_make_record("a")])
|
||||||
r = client.post("/api/sft/undo", json={"id": "a"})
|
r = client.post("/api/sft/undo", json={"id": "a"})
|
||||||
assert r.status_code == 409
|
assert r.status_code == 409
|
||||||
|
|
||||||
|
|
||||||
|
# ── /api/sft/export ──────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
def test_export_returns_approved_as_sft_jsonl(client, tmp_path):
|
||||||
|
from app import sft as sft_module
|
||||||
|
from app.utils import write_jsonl
|
||||||
|
approved = {
|
||||||
|
**_make_record("a"),
|
||||||
|
"status": "approved",
|
||||||
|
"corrected_response": "def add(a, b): return a + b",
|
||||||
|
"prompt_messages": [
|
||||||
|
{"role": "system", "content": "You are a coding assistant."},
|
||||||
|
{"role": "user", "content": "Write a Python add function."},
|
||||||
|
],
|
||||||
|
}
|
||||||
|
write_jsonl(sft_module._approved_file(), [approved])
|
||||||
|
_populate_candidates(tmp_path, [approved])
|
||||||
|
|
||||||
|
r = client.get("/api/sft/export")
|
||||||
|
assert r.status_code == 200
|
||||||
|
assert "application/x-ndjson" in r.headers["content-type"]
|
||||||
|
lines = [l for l in r.text.splitlines() if l.strip()]
|
||||||
|
assert len(lines) == 1
|
||||||
|
record = json.loads(lines[0])
|
||||||
|
assert record["messages"][-1] == {
|
||||||
|
"role": "assistant", "content": "def add(a, b): return a + b"
|
||||||
|
}
|
||||||
|
assert record["messages"][0]["role"] == "system"
|
||||||
|
assert record["messages"][1]["role"] == "user"
|
||||||
|
|
||||||
|
|
||||||
|
def test_export_excludes_non_approved(client, tmp_path):
|
||||||
|
from app import sft as sft_module
|
||||||
|
from app.utils import write_jsonl
|
||||||
|
records = [
|
||||||
|
{**_make_record("a"), "status": "discarded", "corrected_response": None},
|
||||||
|
{**_make_record("b"), "status": "needs_review", "corrected_response": None},
|
||||||
|
]
|
||||||
|
write_jsonl(sft_module._approved_file(), records)
|
||||||
|
r = client.get("/api/sft/export")
|
||||||
|
assert r.text.strip() == ""
|
||||||
|
|
||||||
|
|
||||||
|
def test_export_empty_when_no_approved_file(client):
|
||||||
|
r = client.get("/api/sft/export")
|
||||||
|
assert r.status_code == 200
|
||||||
|
assert r.text.strip() == ""
|
||||||
|
|
||||||
|
|
||||||
|
# ── /api/sft/stats ───────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
def test_stats_counts_by_status(client, tmp_path):
|
||||||
|
from app import sft as sft_module
|
||||||
|
from app.utils import write_jsonl
|
||||||
|
records = [
|
||||||
|
_make_record("a"),
|
||||||
|
{**_make_record("b"), "status": "approved", "corrected_response": "ok"},
|
||||||
|
{**_make_record("c"), "status": "discarded"},
|
||||||
|
{**_make_record("d"), "status": "model_rejected"},
|
||||||
|
]
|
||||||
|
_populate_candidates(tmp_path, records)
|
||||||
|
write_jsonl(sft_module._approved_file(), [records[1]])
|
||||||
|
r = client.get("/api/sft/stats")
|
||||||
|
assert r.status_code == 200
|
||||||
|
data = r.json()
|
||||||
|
assert data["total"] == 4
|
||||||
|
assert data["by_status"]["needs_review"] == 1
|
||||||
|
assert data["by_status"]["approved"] == 1
|
||||||
|
assert data["by_status"]["discarded"] == 1
|
||||||
|
assert data["by_status"]["model_rejected"] == 1
|
||||||
|
assert data["export_ready"] == 1
|
||||||
|
|
||||||
|
|
||||||
|
def test_stats_empty_when_no_data(client):
|
||||||
|
r = client.get("/api/sft/stats")
|
||||||
|
assert r.status_code == 200
|
||||||
|
data = r.json()
|
||||||
|
assert data["total"] == 0
|
||||||
|
assert data["export_ready"] == 0
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue