feat: move SFT corrections API into app/data/corrections.py
This commit is contained in:
parent
2054866ff1
commit
99ea39fe38
3 changed files with 374 additions and 366 deletions
335
app/data/corrections.py
Normal file
335
app/data/corrections.py
Normal file
|
|
@ -0,0 +1,335 @@
|
||||||
|
"""Avocet -- SFT candidate corrections API (moved from app/sft.py).
|
||||||
|
|
||||||
|
All endpoints are registered on `router` (a FastAPI APIRouter).
|
||||||
|
api.py includes this router with prefix="/api/sft".
|
||||||
|
|
||||||
|
Module-level globals (_DATA_DIR, _CONFIG_DIR) follow the same
|
||||||
|
testability pattern as api.py -- override them via set_data_dir() and
|
||||||
|
set_config_dir() in test fixtures.
|
||||||
|
"""
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
from datetime import datetime, timezone
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Literal
|
||||||
|
|
||||||
|
import yaml
|
||||||
|
from fastapi import APIRouter, HTTPException
|
||||||
|
from fastapi.responses import StreamingResponse
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
from app.utils import append_jsonl, read_jsonl, write_jsonl
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
_ROOT = Path(__file__).parent.parent.parent
|
||||||
|
_DATA_DIR: Path = _ROOT / "data"
|
||||||
|
_CONFIG_DIR: Path | None = None
|
||||||
|
|
||||||
|
router = APIRouter()
|
||||||
|
|
||||||
|
|
||||||
|
# -- Testability seams ---------------------------------------------------------
|
||||||
|
|
||||||
|
def set_data_dir(path: Path) -> None:
|
||||||
|
global _DATA_DIR
|
||||||
|
_DATA_DIR = path
|
||||||
|
|
||||||
|
|
||||||
|
def set_config_dir(path: Path | None) -> None:
|
||||||
|
global _CONFIG_DIR
|
||||||
|
_CONFIG_DIR = path
|
||||||
|
|
||||||
|
|
||||||
|
# -- Internal helpers ----------------------------------------------------------
|
||||||
|
|
||||||
|
def _config_file() -> Path:
|
||||||
|
if _CONFIG_DIR is not None:
|
||||||
|
return _CONFIG_DIR / "label_tool.yaml"
|
||||||
|
return _ROOT / "config" / "label_tool.yaml"
|
||||||
|
|
||||||
|
|
||||||
|
_DEFAULT_BENCH_RESULTS_DIR = "/Library/Development/CircuitForge/circuitforge-orch/scripts/bench_results"
|
||||||
|
|
||||||
|
|
||||||
|
def set_default_bench_results_dir(path: str) -> None:
|
||||||
|
"""Override the default bench_results_dir -- used by tests to avoid real filesystem."""
|
||||||
|
global _DEFAULT_BENCH_RESULTS_DIR
|
||||||
|
_DEFAULT_BENCH_RESULTS_DIR = path
|
||||||
|
|
||||||
|
|
||||||
|
def _get_bench_results_dir() -> Path:
|
||||||
|
f = _config_file()
|
||||||
|
if f.exists():
|
||||||
|
try:
|
||||||
|
raw = yaml.safe_load(f.read_text(encoding="utf-8")) or {}
|
||||||
|
d = raw.get("sft", {}).get("bench_results_dir", "")
|
||||||
|
if d:
|
||||||
|
return Path(d)
|
||||||
|
except yaml.YAMLError as exc:
|
||||||
|
logger.warning("Failed to parse SFT config %s: %s", f, exc)
|
||||||
|
return Path(_DEFAULT_BENCH_RESULTS_DIR)
|
||||||
|
|
||||||
|
|
||||||
|
def _candidates_file() -> Path:
|
||||||
|
return _DATA_DIR / "sft_candidates.jsonl"
|
||||||
|
|
||||||
|
|
||||||
|
def _approved_file() -> Path:
|
||||||
|
return _DATA_DIR / "sft_approved.jsonl"
|
||||||
|
|
||||||
|
|
||||||
|
def _read_candidates() -> list[dict]:
|
||||||
|
return read_jsonl(_candidates_file())
|
||||||
|
|
||||||
|
|
||||||
|
def _write_candidates(records: list[dict]) -> None:
|
||||||
|
write_jsonl(_candidates_file(), records)
|
||||||
|
|
||||||
|
|
||||||
|
def _is_exportable(r: dict) -> bool:
|
||||||
|
"""Return True if an approved record is ready to include in SFT export."""
|
||||||
|
return (
|
||||||
|
r.get("status") == "approved"
|
||||||
|
and bool(r.get("corrected_response"))
|
||||||
|
and str(r["corrected_response"]).strip() != ""
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# -- GET /runs -----------------------------------------------------------------
|
||||||
|
|
||||||
|
@router.get("/runs")
|
||||||
|
def get_runs():
|
||||||
|
"""List available benchmark runs in the configured bench_results_dir."""
|
||||||
|
from scripts.sft_import import discover_runs
|
||||||
|
bench_dir = _get_bench_results_dir()
|
||||||
|
existing = _read_candidates()
|
||||||
|
# benchmark_run_id in each record equals the run's directory name by cf-orch convention
|
||||||
|
imported_run_ids = {
|
||||||
|
r["benchmark_run_id"]
|
||||||
|
for r in existing
|
||||||
|
if r.get("benchmark_run_id") is not None
|
||||||
|
}
|
||||||
|
runs = discover_runs(bench_dir)
|
||||||
|
return [
|
||||||
|
{
|
||||||
|
"run_id": r["run_id"],
|
||||||
|
"timestamp": r["timestamp"],
|
||||||
|
"candidate_count": r["candidate_count"],
|
||||||
|
"already_imported": r["run_id"] in imported_run_ids,
|
||||||
|
}
|
||||||
|
for r in runs
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
# -- POST /import --------------------------------------------------------------
|
||||||
|
|
||||||
|
class ImportRequest(BaseModel):
|
||||||
|
run_id: str
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/import")
|
||||||
|
def post_import(req: ImportRequest):
|
||||||
|
"""Import one benchmark run's sft_candidates.jsonl into the local data dir."""
|
||||||
|
from scripts.sft_import import discover_runs, import_run
|
||||||
|
bench_dir = _get_bench_results_dir()
|
||||||
|
runs = discover_runs(bench_dir)
|
||||||
|
run = next((r for r in runs if r["run_id"] == req.run_id), None)
|
||||||
|
if run is None:
|
||||||
|
raise HTTPException(404, f"Run {req.run_id!r} not found in bench_results_dir")
|
||||||
|
return import_run(run["sft_path"], _DATA_DIR)
|
||||||
|
|
||||||
|
|
||||||
|
# -- GET /queue ----------------------------------------------------------------
|
||||||
|
|
||||||
|
@router.get("/queue")
|
||||||
|
def get_queue(page: int = 1, per_page: int = 20):
|
||||||
|
"""Return paginated needs_review candidates."""
|
||||||
|
records = _read_candidates()
|
||||||
|
pending = [r for r in records if r.get("status") == "needs_review"]
|
||||||
|
start = (page - 1) * per_page
|
||||||
|
return {
|
||||||
|
"items": pending[start:start + per_page],
|
||||||
|
"total": len(pending),
|
||||||
|
"page": page,
|
||||||
|
"per_page": per_page,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
# -- POST /submit --------------------------------------------------------------
|
||||||
|
|
||||||
|
FailureCategory = Literal[
|
||||||
|
"scoring_artifact",
|
||||||
|
"style_violation",
|
||||||
|
"partial_answer",
|
||||||
|
"wrong_answer",
|
||||||
|
"format_error",
|
||||||
|
"hallucination",
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
class SubmitRequest(BaseModel):
|
||||||
|
id: str
|
||||||
|
action: Literal["correct", "discard", "flag"]
|
||||||
|
corrected_response: str | None = None
|
||||||
|
failure_category: FailureCategory | None = None
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/submit")
|
||||||
|
def post_submit(req: SubmitRequest):
|
||||||
|
"""Record a reviewer decision for one SFT candidate."""
|
||||||
|
if req.action == "correct":
|
||||||
|
if not req.corrected_response or not req.corrected_response.strip():
|
||||||
|
raise HTTPException(422, "corrected_response must be non-empty when action is 'correct'")
|
||||||
|
|
||||||
|
records = _read_candidates()
|
||||||
|
idx = next((i for i, r in enumerate(records) if r.get("id") == req.id), None)
|
||||||
|
if idx is None:
|
||||||
|
raise HTTPException(404, f"Record {req.id!r} not found")
|
||||||
|
|
||||||
|
record = records[idx]
|
||||||
|
if record.get("status") != "needs_review":
|
||||||
|
raise HTTPException(409, f"Record is not in needs_review state (current: {record.get('status')})")
|
||||||
|
|
||||||
|
if req.action == "correct":
|
||||||
|
records[idx] = {
|
||||||
|
**record,
|
||||||
|
"status": "approved",
|
||||||
|
"corrected_response": req.corrected_response,
|
||||||
|
"failure_category": req.failure_category,
|
||||||
|
}
|
||||||
|
_write_candidates(records)
|
||||||
|
append_jsonl(_approved_file(), records[idx])
|
||||||
|
elif req.action == "discard":
|
||||||
|
records[idx] = {**record, "status": "discarded"}
|
||||||
|
_write_candidates(records)
|
||||||
|
else: # flag
|
||||||
|
records[idx] = {**record, "status": "model_rejected"}
|
||||||
|
_write_candidates(records)
|
||||||
|
|
||||||
|
return {"ok": True}
|
||||||
|
|
||||||
|
|
||||||
|
# -- POST /undo ----------------------------------------------------------------
|
||||||
|
|
||||||
|
class UndoRequest(BaseModel):
|
||||||
|
id: str
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/undo")
|
||||||
|
def post_undo(req: UndoRequest):
|
||||||
|
"""Restore a previously actioned candidate back to needs_review."""
|
||||||
|
records = _read_candidates()
|
||||||
|
idx = next((i for i, r in enumerate(records) if r.get("id") == req.id), None)
|
||||||
|
if idx is None:
|
||||||
|
raise HTTPException(404, f"Record {req.id!r} not found")
|
||||||
|
|
||||||
|
record = records[idx]
|
||||||
|
old_status = record.get("status")
|
||||||
|
if old_status == "needs_review":
|
||||||
|
raise HTTPException(409, "Record is already in needs_review state")
|
||||||
|
|
||||||
|
records[idx] = {**record, "status": "needs_review", "corrected_response": None}
|
||||||
|
_write_candidates(records)
|
||||||
|
|
||||||
|
# If it was approved, remove from the approved file too
|
||||||
|
if old_status == "approved":
|
||||||
|
approved = read_jsonl(_approved_file())
|
||||||
|
write_jsonl(_approved_file(), [r for r in approved if r.get("id") != req.id])
|
||||||
|
|
||||||
|
return {"ok": True}
|
||||||
|
|
||||||
|
|
||||||
|
# -- GET /export ---------------------------------------------------------------
|
||||||
|
|
||||||
|
@router.get("/export")
|
||||||
|
def get_export() -> StreamingResponse:
|
||||||
|
"""Stream approved records as SFT-ready JSONL for download."""
|
||||||
|
exportable = [r for r in read_jsonl(_approved_file()) if _is_exportable(r)]
|
||||||
|
|
||||||
|
def generate():
|
||||||
|
for r in exportable:
|
||||||
|
record = {
|
||||||
|
"messages": r.get("prompt_messages", []) + [
|
||||||
|
{"role": "assistant", "content": r["corrected_response"]}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
yield json.dumps(record) + "\n"
|
||||||
|
|
||||||
|
timestamp = datetime.now(timezone.utc).strftime("%Y%m%d-%H%M%S")
|
||||||
|
return StreamingResponse(
|
||||||
|
generate(),
|
||||||
|
media_type="application/x-ndjson",
|
||||||
|
headers={
|
||||||
|
"Content-Disposition": f'attachment; filename="sft_export_{timestamp}.jsonl"'
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# -- GET /stats ----------------------------------------------------------------
|
||||||
|
|
||||||
|
@router.get("/stats")
|
||||||
|
def get_stats() -> dict[str, object]:
|
||||||
|
"""Return counts by status, model, and task type."""
|
||||||
|
records = _read_candidates()
|
||||||
|
by_status: dict[str, int] = {}
|
||||||
|
by_model: dict[str, int] = {}
|
||||||
|
by_task_type: dict[str, int] = {}
|
||||||
|
|
||||||
|
for r in records:
|
||||||
|
status = r.get("status", "unknown")
|
||||||
|
by_status[status] = by_status.get(status, 0) + 1
|
||||||
|
model = r.get("model_name", "unknown")
|
||||||
|
by_model[model] = by_model.get(model, 0) + 1
|
||||||
|
task_type = r.get("task_type", "unknown")
|
||||||
|
by_task_type[task_type] = by_task_type.get(task_type, 0) + 1
|
||||||
|
|
||||||
|
approved = read_jsonl(_approved_file())
|
||||||
|
export_ready = sum(1 for r in approved if _is_exportable(r))
|
||||||
|
|
||||||
|
return {
|
||||||
|
"total": len(records),
|
||||||
|
"by_status": by_status,
|
||||||
|
"by_model": by_model,
|
||||||
|
"by_task_type": by_task_type,
|
||||||
|
"export_ready": export_ready,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
# -- GET /config ---------------------------------------------------------------
|
||||||
|
|
||||||
|
@router.get("/config")
|
||||||
|
def get_sft_config() -> dict:
|
||||||
|
"""Return the current SFT configuration (bench_results_dir)."""
|
||||||
|
f = _config_file()
|
||||||
|
if not f.exists():
|
||||||
|
return {"bench_results_dir": ""}
|
||||||
|
try:
|
||||||
|
raw = yaml.safe_load(f.read_text(encoding="utf-8")) or {}
|
||||||
|
except yaml.YAMLError:
|
||||||
|
return {"bench_results_dir": ""}
|
||||||
|
sft_section = raw.get("sft") or {}
|
||||||
|
return {"bench_results_dir": sft_section.get("bench_results_dir", "")}
|
||||||
|
|
||||||
|
|
||||||
|
class SftConfigPayload(BaseModel):
|
||||||
|
bench_results_dir: str
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/config")
|
||||||
|
def post_sft_config(payload: SftConfigPayload) -> dict:
|
||||||
|
"""Write the bench_results_dir setting to the config file."""
|
||||||
|
f = _config_file()
|
||||||
|
f.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
try:
|
||||||
|
raw = yaml.safe_load(f.read_text(encoding="utf-8")) if f.exists() else {}
|
||||||
|
raw = raw or {}
|
||||||
|
except yaml.YAMLError:
|
||||||
|
raw = {}
|
||||||
|
raw["sft"] = {"bench_results_dir": payload.bench_results_dir}
|
||||||
|
tmp = f.with_suffix(".tmp")
|
||||||
|
tmp.write_text(yaml.dump(raw, allow_unicode=True, sort_keys=False), encoding="utf-8")
|
||||||
|
tmp.rename(f)
|
||||||
|
return {"ok": True}
|
||||||
343
app/sft.py
343
app/sft.py
|
|
@ -1,335 +1,8 @@
|
||||||
"""Avocet — SFT candidate import and correction API.
|
"""Backward-compat shim -- logic moved to app/data/corrections.py."""
|
||||||
|
from app.data.corrections import ( # noqa: F401
|
||||||
All endpoints are registered on `router` (a FastAPI APIRouter).
|
router,
|
||||||
api.py includes this router with prefix="/api/sft".
|
set_data_dir as set_sft_data_dir,
|
||||||
|
set_config_dir as set_sft_config_dir,
|
||||||
Module-level globals (_SFT_DATA_DIR, _SFT_CONFIG_DIR) follow the same
|
set_default_bench_results_dir,
|
||||||
testability pattern as api.py — override them via set_sft_data_dir() and
|
_DEFAULT_BENCH_RESULTS_DIR,
|
||||||
set_sft_config_dir() in test fixtures.
|
)
|
||||||
"""
|
|
||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
import json
|
|
||||||
import logging
|
|
||||||
from datetime import datetime, timezone
|
|
||||||
from pathlib import Path
|
|
||||||
from typing import Literal
|
|
||||||
|
|
||||||
import yaml
|
|
||||||
from fastapi import APIRouter, HTTPException
|
|
||||||
from fastapi.responses import StreamingResponse
|
|
||||||
from pydantic import BaseModel
|
|
||||||
|
|
||||||
from app.utils import append_jsonl, read_jsonl, write_jsonl
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
_ROOT = Path(__file__).parent.parent
|
|
||||||
_SFT_DATA_DIR: Path = _ROOT / "data"
|
|
||||||
_SFT_CONFIG_DIR: Path | None = None
|
|
||||||
|
|
||||||
router = APIRouter()
|
|
||||||
|
|
||||||
|
|
||||||
# ── Testability seams ──────────────────────────────────────────────────────
|
|
||||||
|
|
||||||
def set_sft_data_dir(path: Path) -> None:
|
|
||||||
global _SFT_DATA_DIR
|
|
||||||
_SFT_DATA_DIR = path
|
|
||||||
|
|
||||||
|
|
||||||
def set_sft_config_dir(path: Path | None) -> None:
|
|
||||||
global _SFT_CONFIG_DIR
|
|
||||||
_SFT_CONFIG_DIR = path
|
|
||||||
|
|
||||||
|
|
||||||
# ── Internal helpers ───────────────────────────────────────────────────────
|
|
||||||
|
|
||||||
def _config_file() -> Path:
|
|
||||||
if _SFT_CONFIG_DIR is not None:
|
|
||||||
return _SFT_CONFIG_DIR / "label_tool.yaml"
|
|
||||||
return _ROOT / "config" / "label_tool.yaml"
|
|
||||||
|
|
||||||
|
|
||||||
_DEFAULT_BENCH_RESULTS_DIR = "/Library/Development/CircuitForge/circuitforge-orch/scripts/bench_results"
|
|
||||||
|
|
||||||
|
|
||||||
def set_default_bench_results_dir(path: str) -> None:
|
|
||||||
"""Override the default bench_results_dir — used by tests to avoid real filesystem."""
|
|
||||||
global _DEFAULT_BENCH_RESULTS_DIR
|
|
||||||
_DEFAULT_BENCH_RESULTS_DIR = path
|
|
||||||
|
|
||||||
|
|
||||||
def _get_bench_results_dir() -> Path:
|
|
||||||
f = _config_file()
|
|
||||||
if f.exists():
|
|
||||||
try:
|
|
||||||
raw = yaml.safe_load(f.read_text(encoding="utf-8")) or {}
|
|
||||||
d = raw.get("sft", {}).get("bench_results_dir", "")
|
|
||||||
if d:
|
|
||||||
return Path(d)
|
|
||||||
except yaml.YAMLError as exc:
|
|
||||||
logger.warning("Failed to parse SFT config %s: %s", f, exc)
|
|
||||||
return Path(_DEFAULT_BENCH_RESULTS_DIR)
|
|
||||||
|
|
||||||
|
|
||||||
def _candidates_file() -> Path:
|
|
||||||
return _SFT_DATA_DIR / "sft_candidates.jsonl"
|
|
||||||
|
|
||||||
|
|
||||||
def _approved_file() -> Path:
|
|
||||||
return _SFT_DATA_DIR / "sft_approved.jsonl"
|
|
||||||
|
|
||||||
|
|
||||||
def _read_candidates() -> list[dict]:
|
|
||||||
return read_jsonl(_candidates_file())
|
|
||||||
|
|
||||||
|
|
||||||
def _write_candidates(records: list[dict]) -> None:
|
|
||||||
write_jsonl(_candidates_file(), records)
|
|
||||||
|
|
||||||
|
|
||||||
def _is_exportable(r: dict) -> bool:
|
|
||||||
"""Return True if an approved record is ready to include in SFT export."""
|
|
||||||
return (
|
|
||||||
r.get("status") == "approved"
|
|
||||||
and bool(r.get("corrected_response"))
|
|
||||||
and str(r["corrected_response"]).strip() != ""
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
# ── GET /runs ──────────────────────────────────────────────────────────────
|
|
||||||
|
|
||||||
@router.get("/runs")
|
|
||||||
def get_runs():
|
|
||||||
"""List available benchmark runs in the configured bench_results_dir."""
|
|
||||||
from scripts.sft_import import discover_runs
|
|
||||||
bench_dir = _get_bench_results_dir()
|
|
||||||
existing = _read_candidates()
|
|
||||||
# benchmark_run_id in each record equals the run's directory name by cf-orch convention
|
|
||||||
imported_run_ids = {
|
|
||||||
r["benchmark_run_id"]
|
|
||||||
for r in existing
|
|
||||||
if r.get("benchmark_run_id") is not None
|
|
||||||
}
|
|
||||||
runs = discover_runs(bench_dir)
|
|
||||||
return [
|
|
||||||
{
|
|
||||||
"run_id": r["run_id"],
|
|
||||||
"timestamp": r["timestamp"],
|
|
||||||
"candidate_count": r["candidate_count"],
|
|
||||||
"already_imported": r["run_id"] in imported_run_ids,
|
|
||||||
}
|
|
||||||
for r in runs
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
# ── POST /import ───────────────────────────────────────────────────────────
|
|
||||||
|
|
||||||
class ImportRequest(BaseModel):
|
|
||||||
run_id: str
|
|
||||||
|
|
||||||
|
|
||||||
@router.post("/import")
|
|
||||||
def post_import(req: ImportRequest):
|
|
||||||
"""Import one benchmark run's sft_candidates.jsonl into the local data dir."""
|
|
||||||
from scripts.sft_import import discover_runs, import_run
|
|
||||||
bench_dir = _get_bench_results_dir()
|
|
||||||
runs = discover_runs(bench_dir)
|
|
||||||
run = next((r for r in runs if r["run_id"] == req.run_id), None)
|
|
||||||
if run is None:
|
|
||||||
raise HTTPException(404, f"Run {req.run_id!r} not found in bench_results_dir")
|
|
||||||
return import_run(run["sft_path"], _SFT_DATA_DIR)
|
|
||||||
|
|
||||||
|
|
||||||
# ── GET /queue ─────────────────────────────────────────────────────────────
|
|
||||||
|
|
||||||
@router.get("/queue")
|
|
||||||
def get_queue(page: int = 1, per_page: int = 20):
|
|
||||||
"""Return paginated needs_review candidates."""
|
|
||||||
records = _read_candidates()
|
|
||||||
pending = [r for r in records if r.get("status") == "needs_review"]
|
|
||||||
start = (page - 1) * per_page
|
|
||||||
return {
|
|
||||||
"items": pending[start:start + per_page],
|
|
||||||
"total": len(pending),
|
|
||||||
"page": page,
|
|
||||||
"per_page": per_page,
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
# ── POST /submit ───────────────────────────────────────────────────────────
|
|
||||||
|
|
||||||
FailureCategory = Literal[
|
|
||||||
"scoring_artifact",
|
|
||||||
"style_violation",
|
|
||||||
"partial_answer",
|
|
||||||
"wrong_answer",
|
|
||||||
"format_error",
|
|
||||||
"hallucination",
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
class SubmitRequest(BaseModel):
|
|
||||||
id: str
|
|
||||||
action: Literal["correct", "discard", "flag"]
|
|
||||||
corrected_response: str | None = None
|
|
||||||
failure_category: FailureCategory | None = None
|
|
||||||
|
|
||||||
|
|
||||||
@router.post("/submit")
|
|
||||||
def post_submit(req: SubmitRequest):
|
|
||||||
"""Record a reviewer decision for one SFT candidate."""
|
|
||||||
if req.action == "correct":
|
|
||||||
if not req.corrected_response or not req.corrected_response.strip():
|
|
||||||
raise HTTPException(422, "corrected_response must be non-empty when action is 'correct'")
|
|
||||||
|
|
||||||
records = _read_candidates()
|
|
||||||
idx = next((i for i, r in enumerate(records) if r.get("id") == req.id), None)
|
|
||||||
if idx is None:
|
|
||||||
raise HTTPException(404, f"Record {req.id!r} not found")
|
|
||||||
|
|
||||||
record = records[idx]
|
|
||||||
if record.get("status") != "needs_review":
|
|
||||||
raise HTTPException(409, f"Record is not in needs_review state (current: {record.get('status')})")
|
|
||||||
|
|
||||||
if req.action == "correct":
|
|
||||||
records[idx] = {
|
|
||||||
**record,
|
|
||||||
"status": "approved",
|
|
||||||
"corrected_response": req.corrected_response,
|
|
||||||
"failure_category": req.failure_category,
|
|
||||||
}
|
|
||||||
_write_candidates(records)
|
|
||||||
append_jsonl(_approved_file(), records[idx])
|
|
||||||
elif req.action == "discard":
|
|
||||||
records[idx] = {**record, "status": "discarded"}
|
|
||||||
_write_candidates(records)
|
|
||||||
else: # flag
|
|
||||||
records[idx] = {**record, "status": "model_rejected"}
|
|
||||||
_write_candidates(records)
|
|
||||||
|
|
||||||
return {"ok": True}
|
|
||||||
|
|
||||||
|
|
||||||
# ── POST /undo ─────────────────────────────────────────────────────────────
|
|
||||||
|
|
||||||
class UndoRequest(BaseModel):
|
|
||||||
id: str
|
|
||||||
|
|
||||||
|
|
||||||
@router.post("/undo")
|
|
||||||
def post_undo(req: UndoRequest):
|
|
||||||
"""Restore a previously actioned candidate back to needs_review."""
|
|
||||||
records = _read_candidates()
|
|
||||||
idx = next((i for i, r in enumerate(records) if r.get("id") == req.id), None)
|
|
||||||
if idx is None:
|
|
||||||
raise HTTPException(404, f"Record {req.id!r} not found")
|
|
||||||
|
|
||||||
record = records[idx]
|
|
||||||
old_status = record.get("status")
|
|
||||||
if old_status == "needs_review":
|
|
||||||
raise HTTPException(409, "Record is already in needs_review state")
|
|
||||||
|
|
||||||
records[idx] = {**record, "status": "needs_review", "corrected_response": None}
|
|
||||||
_write_candidates(records)
|
|
||||||
|
|
||||||
# If it was approved, remove from the approved file too
|
|
||||||
if old_status == "approved":
|
|
||||||
approved = read_jsonl(_approved_file())
|
|
||||||
write_jsonl(_approved_file(), [r for r in approved if r.get("id") != req.id])
|
|
||||||
|
|
||||||
return {"ok": True}
|
|
||||||
|
|
||||||
|
|
||||||
# ── GET /export ─────────────────────────────────────────────────────────────
|
|
||||||
|
|
||||||
@router.get("/export")
|
|
||||||
def get_export() -> StreamingResponse:
|
|
||||||
"""Stream approved records as SFT-ready JSONL for download."""
|
|
||||||
exportable = [r for r in read_jsonl(_approved_file()) if _is_exportable(r)]
|
|
||||||
|
|
||||||
def generate():
|
|
||||||
for r in exportable:
|
|
||||||
record = {
|
|
||||||
"messages": r.get("prompt_messages", []) + [
|
|
||||||
{"role": "assistant", "content": r["corrected_response"]}
|
|
||||||
]
|
|
||||||
}
|
|
||||||
yield json.dumps(record) + "\n"
|
|
||||||
|
|
||||||
timestamp = datetime.now(timezone.utc).strftime("%Y%m%d-%H%M%S")
|
|
||||||
return StreamingResponse(
|
|
||||||
generate(),
|
|
||||||
media_type="application/x-ndjson",
|
|
||||||
headers={
|
|
||||||
"Content-Disposition": f'attachment; filename="sft_export_{timestamp}.jsonl"'
|
|
||||||
},
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
# ── GET /stats ──────────────────────────────────────────────────────────────
|
|
||||||
|
|
||||||
@router.get("/stats")
|
|
||||||
def get_stats() -> dict[str, object]:
|
|
||||||
"""Return counts by status, model, and task type."""
|
|
||||||
records = _read_candidates()
|
|
||||||
by_status: dict[str, int] = {}
|
|
||||||
by_model: dict[str, int] = {}
|
|
||||||
by_task_type: dict[str, int] = {}
|
|
||||||
|
|
||||||
for r in records:
|
|
||||||
status = r.get("status", "unknown")
|
|
||||||
by_status[status] = by_status.get(status, 0) + 1
|
|
||||||
model = r.get("model_name", "unknown")
|
|
||||||
by_model[model] = by_model.get(model, 0) + 1
|
|
||||||
task_type = r.get("task_type", "unknown")
|
|
||||||
by_task_type[task_type] = by_task_type.get(task_type, 0) + 1
|
|
||||||
|
|
||||||
approved = read_jsonl(_approved_file())
|
|
||||||
export_ready = sum(1 for r in approved if _is_exportable(r))
|
|
||||||
|
|
||||||
return {
|
|
||||||
"total": len(records),
|
|
||||||
"by_status": by_status,
|
|
||||||
"by_model": by_model,
|
|
||||||
"by_task_type": by_task_type,
|
|
||||||
"export_ready": export_ready,
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
# ── GET /config ─────────────────────────────────────────────────────────────
|
|
||||||
|
|
||||||
@router.get("/config")
|
|
||||||
def get_sft_config() -> dict:
|
|
||||||
"""Return the current SFT configuration (bench_results_dir)."""
|
|
||||||
f = _config_file()
|
|
||||||
if not f.exists():
|
|
||||||
return {"bench_results_dir": ""}
|
|
||||||
try:
|
|
||||||
raw = yaml.safe_load(f.read_text(encoding="utf-8")) or {}
|
|
||||||
except yaml.YAMLError:
|
|
||||||
return {"bench_results_dir": ""}
|
|
||||||
sft_section = raw.get("sft") or {}
|
|
||||||
return {"bench_results_dir": sft_section.get("bench_results_dir", "")}
|
|
||||||
|
|
||||||
|
|
||||||
class SftConfigPayload(BaseModel):
|
|
||||||
bench_results_dir: str
|
|
||||||
|
|
||||||
|
|
||||||
@router.post("/config")
|
|
||||||
def post_sft_config(payload: SftConfigPayload) -> dict:
|
|
||||||
"""Write the bench_results_dir setting to the config file."""
|
|
||||||
f = _config_file()
|
|
||||||
f.parent.mkdir(parents=True, exist_ok=True)
|
|
||||||
try:
|
|
||||||
raw = yaml.safe_load(f.read_text(encoding="utf-8")) if f.exists() else {}
|
|
||||||
raw = raw or {}
|
|
||||||
except yaml.YAMLError:
|
|
||||||
raw = {}
|
|
||||||
raw["sft"] = {"bench_results_dir": payload.bench_results_dir}
|
|
||||||
tmp = f.with_suffix(".tmp")
|
|
||||||
tmp.write_text(yaml.dump(raw, allow_unicode=True, sort_keys=False), encoding="utf-8")
|
|
||||||
tmp.rename(f)
|
|
||||||
return {"ok": True}
|
|
||||||
|
|
|
||||||
|
|
@ -1,4 +1,4 @@
|
||||||
"""API integration tests for app/sft.py — /api/sft/* endpoints."""
|
"""API integration tests for app/sft.py -- /api/sft/* endpoints."""
|
||||||
import json
|
import json
|
||||||
import pytest
|
import pytest
|
||||||
from fastapi.testclient import TestClient
|
from fastapi.testclient import TestClient
|
||||||
|
|
@ -7,17 +7,17 @@ from pathlib import Path
|
||||||
|
|
||||||
@pytest.fixture(autouse=True)
|
@pytest.fixture(autouse=True)
|
||||||
def reset_sft_globals(tmp_path):
|
def reset_sft_globals(tmp_path):
|
||||||
from app import sft as sft_module
|
from app.data import corrections as corr_module
|
||||||
_prev_data = sft_module._SFT_DATA_DIR
|
_prev_data = corr_module._DATA_DIR
|
||||||
_prev_cfg = sft_module._SFT_CONFIG_DIR
|
_prev_cfg = corr_module._CONFIG_DIR
|
||||||
_prev_default = sft_module._DEFAULT_BENCH_RESULTS_DIR
|
_prev_default = corr_module._DEFAULT_BENCH_RESULTS_DIR
|
||||||
sft_module.set_sft_data_dir(tmp_path)
|
corr_module.set_data_dir(tmp_path)
|
||||||
sft_module.set_sft_config_dir(tmp_path)
|
corr_module.set_config_dir(tmp_path)
|
||||||
sft_module.set_default_bench_results_dir(str(tmp_path / "bench_results"))
|
corr_module.set_default_bench_results_dir(str(tmp_path / "bench_results"))
|
||||||
yield
|
yield
|
||||||
sft_module.set_sft_data_dir(_prev_data)
|
corr_module.set_data_dir(_prev_data)
|
||||||
sft_module.set_sft_config_dir(_prev_cfg)
|
corr_module.set_config_dir(_prev_cfg)
|
||||||
sft_module.set_default_bench_results_dir(_prev_default)
|
corr_module.set_default_bench_results_dir(_prev_default)
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
|
|
@ -63,7 +63,7 @@ def _write_config(tmp_path, bench_results_dir: Path) -> None:
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
# ── /api/sft/runs ──────────────────────────────────────────────────────────
|
# -- /api/sft/runs -------------------------------------------------------------
|
||||||
|
|
||||||
def test_runs_returns_empty_when_no_config(client):
|
def test_runs_returns_empty_when_no_config(client):
|
||||||
r = client.get("/api/sft/runs")
|
r = client.get("/api/sft/runs")
|
||||||
|
|
@ -86,7 +86,7 @@ def test_runs_returns_available_runs(client, tmp_path):
|
||||||
def test_runs_marks_already_imported(client, tmp_path):
|
def test_runs_marks_already_imported(client, tmp_path):
|
||||||
_write_run(tmp_path, "2026-04-07-143022", [_make_record("a")])
|
_write_run(tmp_path, "2026-04-07-143022", [_make_record("a")])
|
||||||
_write_config(tmp_path, tmp_path / "bench_results")
|
_write_config(tmp_path, tmp_path / "bench_results")
|
||||||
from app import sft as sft_module
|
from app.data import corrections as sft_module
|
||||||
candidates = sft_module._candidates_file()
|
candidates = sft_module._candidates_file()
|
||||||
candidates.parent.mkdir(parents=True, exist_ok=True)
|
candidates.parent.mkdir(parents=True, exist_ok=True)
|
||||||
candidates.write_text(
|
candidates.write_text(
|
||||||
|
|
@ -97,7 +97,7 @@ def test_runs_marks_already_imported(client, tmp_path):
|
||||||
assert r.json()[0]["already_imported"] is True
|
assert r.json()[0]["already_imported"] is True
|
||||||
|
|
||||||
|
|
||||||
# ── /api/sft/import ─────────────────────────────────────────────────────────
|
# -- /api/sft/import -----------------------------------------------------------
|
||||||
|
|
||||||
def test_import_adds_records(client, tmp_path):
|
def test_import_adds_records(client, tmp_path):
|
||||||
_write_run(tmp_path, "2026-04-07-143022", [_make_record("a"), _make_record("b")])
|
_write_run(tmp_path, "2026-04-07-143022", [_make_record("a"), _make_record("b")])
|
||||||
|
|
@ -121,10 +121,10 @@ def test_import_unknown_run_returns_404(client, tmp_path):
|
||||||
assert r.status_code == 404
|
assert r.status_code == 404
|
||||||
|
|
||||||
|
|
||||||
# ── /api/sft/queue ──────────────────────────────────────────────────────────
|
# -- /api/sft/queue ------------------------------------------------------------
|
||||||
|
|
||||||
def _populate_candidates(tmp_path, records: list[dict]) -> None:
|
def _populate_candidates(tmp_path, records: list[dict]) -> None:
|
||||||
from app import sft as sft_module
|
from app.data import corrections as sft_module
|
||||||
path = sft_module._candidates_file()
|
path = sft_module._candidates_file()
|
||||||
path.parent.mkdir(parents=True, exist_ok=True)
|
path.parent.mkdir(parents=True, exist_ok=True)
|
||||||
path.write_text(
|
path.write_text(
|
||||||
|
|
@ -164,7 +164,7 @@ def test_queue_empty_when_no_file(client):
|
||||||
assert r.json() == {"items": [], "total": 0, "page": 1, "per_page": 20}
|
assert r.json() == {"items": [], "total": 0, "page": 1, "per_page": 20}
|
||||||
|
|
||||||
|
|
||||||
# ── /api/sft/submit ─────────────────────────────────────────────────────────
|
# -- /api/sft/submit -----------------------------------------------------------
|
||||||
|
|
||||||
def test_submit_correct_sets_approved(client, tmp_path):
|
def test_submit_correct_sets_approved(client, tmp_path):
|
||||||
_populate_candidates(tmp_path, [_make_record("a")])
|
_populate_candidates(tmp_path, [_make_record("a")])
|
||||||
|
|
@ -173,7 +173,7 @@ def test_submit_correct_sets_approved(client, tmp_path):
|
||||||
"corrected_response": "def add(a, b): return a + b",
|
"corrected_response": "def add(a, b): return a + b",
|
||||||
})
|
})
|
||||||
assert r.status_code == 200
|
assert r.status_code == 200
|
||||||
from app import sft as sft_module
|
from app.data import corrections as sft_module
|
||||||
records = sft_module._read_candidates()
|
records = sft_module._read_candidates()
|
||||||
assert records[0]["status"] == "approved"
|
assert records[0]["status"] == "approved"
|
||||||
assert records[0]["corrected_response"] == "def add(a, b): return a + b"
|
assert records[0]["corrected_response"] == "def add(a, b): return a + b"
|
||||||
|
|
@ -185,7 +185,7 @@ def test_submit_correct_also_appends_to_approved_file(client, tmp_path):
|
||||||
"id": "a", "action": "correct",
|
"id": "a", "action": "correct",
|
||||||
"corrected_response": "def add(a, b): return a + b",
|
"corrected_response": "def add(a, b): return a + b",
|
||||||
})
|
})
|
||||||
from app import sft as sft_module
|
from app.data import corrections as sft_module
|
||||||
from app.utils import read_jsonl
|
from app.utils import read_jsonl
|
||||||
approved = read_jsonl(sft_module._approved_file())
|
approved = read_jsonl(sft_module._approved_file())
|
||||||
assert len(approved) == 1
|
assert len(approved) == 1
|
||||||
|
|
@ -196,7 +196,7 @@ def test_submit_discard_sets_discarded(client, tmp_path):
|
||||||
_populate_candidates(tmp_path, [_make_record("a")])
|
_populate_candidates(tmp_path, [_make_record("a")])
|
||||||
r = client.post("/api/sft/submit", json={"id": "a", "action": "discard"})
|
r = client.post("/api/sft/submit", json={"id": "a", "action": "discard"})
|
||||||
assert r.status_code == 200
|
assert r.status_code == 200
|
||||||
from app import sft as sft_module
|
from app.data import corrections as sft_module
|
||||||
assert sft_module._read_candidates()[0]["status"] == "discarded"
|
assert sft_module._read_candidates()[0]["status"] == "discarded"
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -204,7 +204,7 @@ def test_submit_flag_sets_model_rejected(client, tmp_path):
|
||||||
_populate_candidates(tmp_path, [_make_record("a")])
|
_populate_candidates(tmp_path, [_make_record("a")])
|
||||||
r = client.post("/api/sft/submit", json={"id": "a", "action": "flag"})
|
r = client.post("/api/sft/submit", json={"id": "a", "action": "flag"})
|
||||||
assert r.status_code == 200
|
assert r.status_code == 200
|
||||||
from app import sft as sft_module
|
from app.data import corrections as sft_module
|
||||||
assert sft_module._read_candidates()[0]["status"] == "model_rejected"
|
assert sft_module._read_candidates()[0]["status"] == "model_rejected"
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -243,7 +243,7 @@ def test_submit_correct_stores_failure_category(client, tmp_path):
|
||||||
"failure_category": "style_violation",
|
"failure_category": "style_violation",
|
||||||
})
|
})
|
||||||
assert r.status_code == 200
|
assert r.status_code == 200
|
||||||
from app import sft as sft_module
|
from app.data import corrections as sft_module
|
||||||
records = sft_module._read_candidates()
|
records = sft_module._read_candidates()
|
||||||
assert records[0]["failure_category"] == "style_violation"
|
assert records[0]["failure_category"] == "style_violation"
|
||||||
|
|
||||||
|
|
@ -255,7 +255,7 @@ def test_submit_correct_null_failure_category(client, tmp_path):
|
||||||
"corrected_response": "def add(a, b): return a + b",
|
"corrected_response": "def add(a, b): return a + b",
|
||||||
})
|
})
|
||||||
assert r.status_code == 200
|
assert r.status_code == 200
|
||||||
from app import sft as sft_module
|
from app.data import corrections as sft_module
|
||||||
records = sft_module._read_candidates()
|
records = sft_module._read_candidates()
|
||||||
assert records[0]["failure_category"] is None
|
assert records[0]["failure_category"] is None
|
||||||
|
|
||||||
|
|
@ -270,14 +270,14 @@ def test_submit_invalid_failure_category_returns_422(client, tmp_path):
|
||||||
assert r.status_code == 422
|
assert r.status_code == 422
|
||||||
|
|
||||||
|
|
||||||
# ── /api/sft/undo ────────────────────────────────────────────────────────────
|
# -- /api/sft/undo -------------------------------------------------------------
|
||||||
|
|
||||||
def test_undo_restores_discarded_to_needs_review(client, tmp_path):
|
def test_undo_restores_discarded_to_needs_review(client, tmp_path):
|
||||||
_populate_candidates(tmp_path, [_make_record("a")])
|
_populate_candidates(tmp_path, [_make_record("a")])
|
||||||
client.post("/api/sft/submit", json={"id": "a", "action": "discard"})
|
client.post("/api/sft/submit", json={"id": "a", "action": "discard"})
|
||||||
r = client.post("/api/sft/undo", json={"id": "a"})
|
r = client.post("/api/sft/undo", json={"id": "a"})
|
||||||
assert r.status_code == 200
|
assert r.status_code == 200
|
||||||
from app import sft as sft_module
|
from app.data import corrections as sft_module
|
||||||
assert sft_module._read_candidates()[0]["status"] == "needs_review"
|
assert sft_module._read_candidates()[0]["status"] == "needs_review"
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -288,7 +288,7 @@ def test_undo_removes_approved_from_approved_file(client, tmp_path):
|
||||||
"corrected_response": "def add(a, b): return a + b",
|
"corrected_response": "def add(a, b): return a + b",
|
||||||
})
|
})
|
||||||
client.post("/api/sft/undo", json={"id": "a"})
|
client.post("/api/sft/undo", json={"id": "a"})
|
||||||
from app import sft as sft_module
|
from app.data import corrections as sft_module
|
||||||
from app.utils import read_jsonl
|
from app.utils import read_jsonl
|
||||||
approved = read_jsonl(sft_module._approved_file())
|
approved = read_jsonl(sft_module._approved_file())
|
||||||
assert not any(r["id"] == "a" for r in approved)
|
assert not any(r["id"] == "a" for r in approved)
|
||||||
|
|
@ -300,10 +300,10 @@ def test_undo_already_needs_review_returns_409(client, tmp_path):
|
||||||
assert r.status_code == 409
|
assert r.status_code == 409
|
||||||
|
|
||||||
|
|
||||||
# ── /api/sft/export ──────────────────────────────────────────────────────────
|
# -- /api/sft/export -----------------------------------------------------------
|
||||||
|
|
||||||
def test_export_returns_approved_as_sft_jsonl(client, tmp_path):
|
def test_export_returns_approved_as_sft_jsonl(client, tmp_path):
|
||||||
from app import sft as sft_module
|
from app.data import corrections as sft_module
|
||||||
from app.utils import write_jsonl
|
from app.utils import write_jsonl
|
||||||
approved = {
|
approved = {
|
||||||
**_make_record("a"),
|
**_make_record("a"),
|
||||||
|
|
@ -331,7 +331,7 @@ def test_export_returns_approved_as_sft_jsonl(client, tmp_path):
|
||||||
|
|
||||||
|
|
||||||
def test_export_excludes_non_approved(client, tmp_path):
|
def test_export_excludes_non_approved(client, tmp_path):
|
||||||
from app import sft as sft_module
|
from app.data import corrections as sft_module
|
||||||
from app.utils import write_jsonl
|
from app.utils import write_jsonl
|
||||||
records = [
|
records = [
|
||||||
{**_make_record("a"), "status": "discarded", "corrected_response": None},
|
{**_make_record("a"), "status": "discarded", "corrected_response": None},
|
||||||
|
|
@ -348,10 +348,10 @@ def test_export_empty_when_no_approved_file(client):
|
||||||
assert r.text.strip() == ""
|
assert r.text.strip() == ""
|
||||||
|
|
||||||
|
|
||||||
# ── /api/sft/stats ───────────────────────────────────────────────────────────
|
# -- /api/sft/stats ------------------------------------------------------------
|
||||||
|
|
||||||
def test_stats_counts_by_status(client, tmp_path):
|
def test_stats_counts_by_status(client, tmp_path):
|
||||||
from app import sft as sft_module
|
from app.data import corrections as sft_module
|
||||||
from app.utils import write_jsonl
|
from app.utils import write_jsonl
|
||||||
records = [
|
records = [
|
||||||
_make_record("a"),
|
_make_record("a"),
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue