feat: move SFT corrections API into app/data/corrections.py
This commit is contained in:
parent
2054866ff1
commit
99ea39fe38
3 changed files with 374 additions and 366 deletions
335
app/data/corrections.py
Normal file
335
app/data/corrections.py
Normal file
|
|
@ -0,0 +1,335 @@
|
|||
"""Avocet -- SFT candidate corrections API (moved from app/sft.py).
|
||||
|
||||
All endpoints are registered on `router` (a FastAPI APIRouter).
|
||||
api.py includes this router with prefix="/api/sft".
|
||||
|
||||
Module-level globals (_DATA_DIR, _CONFIG_DIR) follow the same
|
||||
testability pattern as api.py -- override them via set_data_dir() and
|
||||
set_config_dir() in test fixtures.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
from datetime import datetime, timezone
|
||||
from pathlib import Path
|
||||
from typing import Literal
|
||||
|
||||
import yaml
|
||||
from fastapi import APIRouter, HTTPException
|
||||
from fastapi.responses import StreamingResponse
|
||||
from pydantic import BaseModel
|
||||
|
||||
from app.utils import append_jsonl, read_jsonl, write_jsonl
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_ROOT = Path(__file__).parent.parent.parent
|
||||
_DATA_DIR: Path = _ROOT / "data"
|
||||
_CONFIG_DIR: Path | None = None
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
# -- Testability seams ---------------------------------------------------------
|
||||
|
||||
def set_data_dir(path: Path) -> None:
|
||||
global _DATA_DIR
|
||||
_DATA_DIR = path
|
||||
|
||||
|
||||
def set_config_dir(path: Path | None) -> None:
|
||||
global _CONFIG_DIR
|
||||
_CONFIG_DIR = path
|
||||
|
||||
|
||||
# -- Internal helpers ----------------------------------------------------------
|
||||
|
||||
def _config_file() -> Path:
|
||||
if _CONFIG_DIR is not None:
|
||||
return _CONFIG_DIR / "label_tool.yaml"
|
||||
return _ROOT / "config" / "label_tool.yaml"
|
||||
|
||||
|
||||
_DEFAULT_BENCH_RESULTS_DIR = "/Library/Development/CircuitForge/circuitforge-orch/scripts/bench_results"
|
||||
|
||||
|
||||
def set_default_bench_results_dir(path: str) -> None:
|
||||
"""Override the default bench_results_dir -- used by tests to avoid real filesystem."""
|
||||
global _DEFAULT_BENCH_RESULTS_DIR
|
||||
_DEFAULT_BENCH_RESULTS_DIR = path
|
||||
|
||||
|
||||
def _get_bench_results_dir() -> Path:
|
||||
f = _config_file()
|
||||
if f.exists():
|
||||
try:
|
||||
raw = yaml.safe_load(f.read_text(encoding="utf-8")) or {}
|
||||
d = raw.get("sft", {}).get("bench_results_dir", "")
|
||||
if d:
|
||||
return Path(d)
|
||||
except yaml.YAMLError as exc:
|
||||
logger.warning("Failed to parse SFT config %s: %s", f, exc)
|
||||
return Path(_DEFAULT_BENCH_RESULTS_DIR)
|
||||
|
||||
|
||||
def _candidates_file() -> Path:
|
||||
return _DATA_DIR / "sft_candidates.jsonl"
|
||||
|
||||
|
||||
def _approved_file() -> Path:
|
||||
return _DATA_DIR / "sft_approved.jsonl"
|
||||
|
||||
|
||||
def _read_candidates() -> list[dict]:
|
||||
return read_jsonl(_candidates_file())
|
||||
|
||||
|
||||
def _write_candidates(records: list[dict]) -> None:
|
||||
write_jsonl(_candidates_file(), records)
|
||||
|
||||
|
||||
def _is_exportable(r: dict) -> bool:
|
||||
"""Return True if an approved record is ready to include in SFT export."""
|
||||
return (
|
||||
r.get("status") == "approved"
|
||||
and bool(r.get("corrected_response"))
|
||||
and str(r["corrected_response"]).strip() != ""
|
||||
)
|
||||
|
||||
|
||||
# -- GET /runs -----------------------------------------------------------------
|
||||
|
||||
@router.get("/runs")
|
||||
def get_runs():
|
||||
"""List available benchmark runs in the configured bench_results_dir."""
|
||||
from scripts.sft_import import discover_runs
|
||||
bench_dir = _get_bench_results_dir()
|
||||
existing = _read_candidates()
|
||||
# benchmark_run_id in each record equals the run's directory name by cf-orch convention
|
||||
imported_run_ids = {
|
||||
r["benchmark_run_id"]
|
||||
for r in existing
|
||||
if r.get("benchmark_run_id") is not None
|
||||
}
|
||||
runs = discover_runs(bench_dir)
|
||||
return [
|
||||
{
|
||||
"run_id": r["run_id"],
|
||||
"timestamp": r["timestamp"],
|
||||
"candidate_count": r["candidate_count"],
|
||||
"already_imported": r["run_id"] in imported_run_ids,
|
||||
}
|
||||
for r in runs
|
||||
]
|
||||
|
||||
|
||||
# -- POST /import --------------------------------------------------------------
|
||||
|
||||
class ImportRequest(BaseModel):
|
||||
run_id: str
|
||||
|
||||
|
||||
@router.post("/import")
|
||||
def post_import(req: ImportRequest):
|
||||
"""Import one benchmark run's sft_candidates.jsonl into the local data dir."""
|
||||
from scripts.sft_import import discover_runs, import_run
|
||||
bench_dir = _get_bench_results_dir()
|
||||
runs = discover_runs(bench_dir)
|
||||
run = next((r for r in runs if r["run_id"] == req.run_id), None)
|
||||
if run is None:
|
||||
raise HTTPException(404, f"Run {req.run_id!r} not found in bench_results_dir")
|
||||
return import_run(run["sft_path"], _DATA_DIR)
|
||||
|
||||
|
||||
# -- GET /queue ----------------------------------------------------------------
|
||||
|
||||
@router.get("/queue")
|
||||
def get_queue(page: int = 1, per_page: int = 20):
|
||||
"""Return paginated needs_review candidates."""
|
||||
records = _read_candidates()
|
||||
pending = [r for r in records if r.get("status") == "needs_review"]
|
||||
start = (page - 1) * per_page
|
||||
return {
|
||||
"items": pending[start:start + per_page],
|
||||
"total": len(pending),
|
||||
"page": page,
|
||||
"per_page": per_page,
|
||||
}
|
||||
|
||||
|
||||
# -- POST /submit --------------------------------------------------------------
|
||||
|
||||
FailureCategory = Literal[
|
||||
"scoring_artifact",
|
||||
"style_violation",
|
||||
"partial_answer",
|
||||
"wrong_answer",
|
||||
"format_error",
|
||||
"hallucination",
|
||||
]
|
||||
|
||||
|
||||
class SubmitRequest(BaseModel):
|
||||
id: str
|
||||
action: Literal["correct", "discard", "flag"]
|
||||
corrected_response: str | None = None
|
||||
failure_category: FailureCategory | None = None
|
||||
|
||||
|
||||
@router.post("/submit")
|
||||
def post_submit(req: SubmitRequest):
|
||||
"""Record a reviewer decision for one SFT candidate."""
|
||||
if req.action == "correct":
|
||||
if not req.corrected_response or not req.corrected_response.strip():
|
||||
raise HTTPException(422, "corrected_response must be non-empty when action is 'correct'")
|
||||
|
||||
records = _read_candidates()
|
||||
idx = next((i for i, r in enumerate(records) if r.get("id") == req.id), None)
|
||||
if idx is None:
|
||||
raise HTTPException(404, f"Record {req.id!r} not found")
|
||||
|
||||
record = records[idx]
|
||||
if record.get("status") != "needs_review":
|
||||
raise HTTPException(409, f"Record is not in needs_review state (current: {record.get('status')})")
|
||||
|
||||
if req.action == "correct":
|
||||
records[idx] = {
|
||||
**record,
|
||||
"status": "approved",
|
||||
"corrected_response": req.corrected_response,
|
||||
"failure_category": req.failure_category,
|
||||
}
|
||||
_write_candidates(records)
|
||||
append_jsonl(_approved_file(), records[idx])
|
||||
elif req.action == "discard":
|
||||
records[idx] = {**record, "status": "discarded"}
|
||||
_write_candidates(records)
|
||||
else: # flag
|
||||
records[idx] = {**record, "status": "model_rejected"}
|
||||
_write_candidates(records)
|
||||
|
||||
return {"ok": True}
|
||||
|
||||
|
||||
# -- POST /undo ----------------------------------------------------------------
|
||||
|
||||
class UndoRequest(BaseModel):
|
||||
id: str
|
||||
|
||||
|
||||
@router.post("/undo")
|
||||
def post_undo(req: UndoRequest):
|
||||
"""Restore a previously actioned candidate back to needs_review."""
|
||||
records = _read_candidates()
|
||||
idx = next((i for i, r in enumerate(records) if r.get("id") == req.id), None)
|
||||
if idx is None:
|
||||
raise HTTPException(404, f"Record {req.id!r} not found")
|
||||
|
||||
record = records[idx]
|
||||
old_status = record.get("status")
|
||||
if old_status == "needs_review":
|
||||
raise HTTPException(409, "Record is already in needs_review state")
|
||||
|
||||
records[idx] = {**record, "status": "needs_review", "corrected_response": None}
|
||||
_write_candidates(records)
|
||||
|
||||
# If it was approved, remove from the approved file too
|
||||
if old_status == "approved":
|
||||
approved = read_jsonl(_approved_file())
|
||||
write_jsonl(_approved_file(), [r for r in approved if r.get("id") != req.id])
|
||||
|
||||
return {"ok": True}
|
||||
|
||||
|
||||
# -- GET /export ---------------------------------------------------------------
|
||||
|
||||
@router.get("/export")
|
||||
def get_export() -> StreamingResponse:
|
||||
"""Stream approved records as SFT-ready JSONL for download."""
|
||||
exportable = [r for r in read_jsonl(_approved_file()) if _is_exportable(r)]
|
||||
|
||||
def generate():
|
||||
for r in exportable:
|
||||
record = {
|
||||
"messages": r.get("prompt_messages", []) + [
|
||||
{"role": "assistant", "content": r["corrected_response"]}
|
||||
]
|
||||
}
|
||||
yield json.dumps(record) + "\n"
|
||||
|
||||
timestamp = datetime.now(timezone.utc).strftime("%Y%m%d-%H%M%S")
|
||||
return StreamingResponse(
|
||||
generate(),
|
||||
media_type="application/x-ndjson",
|
||||
headers={
|
||||
"Content-Disposition": f'attachment; filename="sft_export_{timestamp}.jsonl"'
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
# -- GET /stats ----------------------------------------------------------------
|
||||
|
||||
@router.get("/stats")
|
||||
def get_stats() -> dict[str, object]:
|
||||
"""Return counts by status, model, and task type."""
|
||||
records = _read_candidates()
|
||||
by_status: dict[str, int] = {}
|
||||
by_model: dict[str, int] = {}
|
||||
by_task_type: dict[str, int] = {}
|
||||
|
||||
for r in records:
|
||||
status = r.get("status", "unknown")
|
||||
by_status[status] = by_status.get(status, 0) + 1
|
||||
model = r.get("model_name", "unknown")
|
||||
by_model[model] = by_model.get(model, 0) + 1
|
||||
task_type = r.get("task_type", "unknown")
|
||||
by_task_type[task_type] = by_task_type.get(task_type, 0) + 1
|
||||
|
||||
approved = read_jsonl(_approved_file())
|
||||
export_ready = sum(1 for r in approved if _is_exportable(r))
|
||||
|
||||
return {
|
||||
"total": len(records),
|
||||
"by_status": by_status,
|
||||
"by_model": by_model,
|
||||
"by_task_type": by_task_type,
|
||||
"export_ready": export_ready,
|
||||
}
|
||||
|
||||
|
||||
# -- GET /config ---------------------------------------------------------------
|
||||
|
||||
@router.get("/config")
|
||||
def get_sft_config() -> dict:
|
||||
"""Return the current SFT configuration (bench_results_dir)."""
|
||||
f = _config_file()
|
||||
if not f.exists():
|
||||
return {"bench_results_dir": ""}
|
||||
try:
|
||||
raw = yaml.safe_load(f.read_text(encoding="utf-8")) or {}
|
||||
except yaml.YAMLError:
|
||||
return {"bench_results_dir": ""}
|
||||
sft_section = raw.get("sft") or {}
|
||||
return {"bench_results_dir": sft_section.get("bench_results_dir", "")}
|
||||
|
||||
|
||||
class SftConfigPayload(BaseModel):
|
||||
bench_results_dir: str
|
||||
|
||||
|
||||
@router.post("/config")
|
||||
def post_sft_config(payload: SftConfigPayload) -> dict:
|
||||
"""Write the bench_results_dir setting to the config file."""
|
||||
f = _config_file()
|
||||
f.parent.mkdir(parents=True, exist_ok=True)
|
||||
try:
|
||||
raw = yaml.safe_load(f.read_text(encoding="utf-8")) if f.exists() else {}
|
||||
raw = raw or {}
|
||||
except yaml.YAMLError:
|
||||
raw = {}
|
||||
raw["sft"] = {"bench_results_dir": payload.bench_results_dir}
|
||||
tmp = f.with_suffix(".tmp")
|
||||
tmp.write_text(yaml.dump(raw, allow_unicode=True, sort_keys=False), encoding="utf-8")
|
||||
tmp.rename(f)
|
||||
return {"ok": True}
|
||||
341
app/sft.py
341
app/sft.py
|
|
@ -1,335 +1,8 @@
|
|||
"""Avocet — SFT candidate import and correction API.
|
||||
|
||||
All endpoints are registered on `router` (a FastAPI APIRouter).
|
||||
api.py includes this router with prefix="/api/sft".
|
||||
|
||||
Module-level globals (_SFT_DATA_DIR, _SFT_CONFIG_DIR) follow the same
|
||||
testability pattern as api.py — override them via set_sft_data_dir() and
|
||||
set_sft_config_dir() in test fixtures.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
from datetime import datetime, timezone
|
||||
from pathlib import Path
|
||||
from typing import Literal
|
||||
|
||||
import yaml
|
||||
from fastapi import APIRouter, HTTPException
|
||||
from fastapi.responses import StreamingResponse
|
||||
from pydantic import BaseModel
|
||||
|
||||
from app.utils import append_jsonl, read_jsonl, write_jsonl
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_ROOT = Path(__file__).parent.parent
|
||||
_SFT_DATA_DIR: Path = _ROOT / "data"
|
||||
_SFT_CONFIG_DIR: Path | None = None
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
# ── Testability seams ──────────────────────────────────────────────────────
|
||||
|
||||
def set_sft_data_dir(path: Path) -> None:
|
||||
global _SFT_DATA_DIR
|
||||
_SFT_DATA_DIR = path
|
||||
|
||||
|
||||
def set_sft_config_dir(path: Path | None) -> None:
|
||||
global _SFT_CONFIG_DIR
|
||||
_SFT_CONFIG_DIR = path
|
||||
|
||||
|
||||
# ── Internal helpers ───────────────────────────────────────────────────────
|
||||
|
||||
def _config_file() -> Path:
|
||||
if _SFT_CONFIG_DIR is not None:
|
||||
return _SFT_CONFIG_DIR / "label_tool.yaml"
|
||||
return _ROOT / "config" / "label_tool.yaml"
|
||||
|
||||
|
||||
_DEFAULT_BENCH_RESULTS_DIR = "/Library/Development/CircuitForge/circuitforge-orch/scripts/bench_results"
|
||||
|
||||
|
||||
def set_default_bench_results_dir(path: str) -> None:
|
||||
"""Override the default bench_results_dir — used by tests to avoid real filesystem."""
|
||||
global _DEFAULT_BENCH_RESULTS_DIR
|
||||
_DEFAULT_BENCH_RESULTS_DIR = path
|
||||
|
||||
|
||||
def _get_bench_results_dir() -> Path:
|
||||
f = _config_file()
|
||||
if f.exists():
|
||||
try:
|
||||
raw = yaml.safe_load(f.read_text(encoding="utf-8")) or {}
|
||||
d = raw.get("sft", {}).get("bench_results_dir", "")
|
||||
if d:
|
||||
return Path(d)
|
||||
except yaml.YAMLError as exc:
|
||||
logger.warning("Failed to parse SFT config %s: %s", f, exc)
|
||||
return Path(_DEFAULT_BENCH_RESULTS_DIR)
|
||||
|
||||
|
||||
def _candidates_file() -> Path:
|
||||
return _SFT_DATA_DIR / "sft_candidates.jsonl"
|
||||
|
||||
|
||||
def _approved_file() -> Path:
|
||||
return _SFT_DATA_DIR / "sft_approved.jsonl"
|
||||
|
||||
|
||||
def _read_candidates() -> list[dict]:
|
||||
return read_jsonl(_candidates_file())
|
||||
|
||||
|
||||
def _write_candidates(records: list[dict]) -> None:
|
||||
write_jsonl(_candidates_file(), records)
|
||||
|
||||
|
||||
def _is_exportable(r: dict) -> bool:
|
||||
"""Return True if an approved record is ready to include in SFT export."""
|
||||
return (
|
||||
r.get("status") == "approved"
|
||||
and bool(r.get("corrected_response"))
|
||||
and str(r["corrected_response"]).strip() != ""
|
||||
"""Backward-compat shim -- logic moved to app/data/corrections.py."""
|
||||
from app.data.corrections import ( # noqa: F401
|
||||
router,
|
||||
set_data_dir as set_sft_data_dir,
|
||||
set_config_dir as set_sft_config_dir,
|
||||
set_default_bench_results_dir,
|
||||
_DEFAULT_BENCH_RESULTS_DIR,
|
||||
)
|
||||
|
||||
|
||||
# ── GET /runs ──────────────────────────────────────────────────────────────
|
||||
|
||||
@router.get("/runs")
|
||||
def get_runs():
|
||||
"""List available benchmark runs in the configured bench_results_dir."""
|
||||
from scripts.sft_import import discover_runs
|
||||
bench_dir = _get_bench_results_dir()
|
||||
existing = _read_candidates()
|
||||
# benchmark_run_id in each record equals the run's directory name by cf-orch convention
|
||||
imported_run_ids = {
|
||||
r["benchmark_run_id"]
|
||||
for r in existing
|
||||
if r.get("benchmark_run_id") is not None
|
||||
}
|
||||
runs = discover_runs(bench_dir)
|
||||
return [
|
||||
{
|
||||
"run_id": r["run_id"],
|
||||
"timestamp": r["timestamp"],
|
||||
"candidate_count": r["candidate_count"],
|
||||
"already_imported": r["run_id"] in imported_run_ids,
|
||||
}
|
||||
for r in runs
|
||||
]
|
||||
|
||||
|
||||
# ── POST /import ───────────────────────────────────────────────────────────
|
||||
|
||||
class ImportRequest(BaseModel):
|
||||
run_id: str
|
||||
|
||||
|
||||
@router.post("/import")
|
||||
def post_import(req: ImportRequest):
|
||||
"""Import one benchmark run's sft_candidates.jsonl into the local data dir."""
|
||||
from scripts.sft_import import discover_runs, import_run
|
||||
bench_dir = _get_bench_results_dir()
|
||||
runs = discover_runs(bench_dir)
|
||||
run = next((r for r in runs if r["run_id"] == req.run_id), None)
|
||||
if run is None:
|
||||
raise HTTPException(404, f"Run {req.run_id!r} not found in bench_results_dir")
|
||||
return import_run(run["sft_path"], _SFT_DATA_DIR)
|
||||
|
||||
|
||||
# ── GET /queue ─────────────────────────────────────────────────────────────
|
||||
|
||||
@router.get("/queue")
|
||||
def get_queue(page: int = 1, per_page: int = 20):
|
||||
"""Return paginated needs_review candidates."""
|
||||
records = _read_candidates()
|
||||
pending = [r for r in records if r.get("status") == "needs_review"]
|
||||
start = (page - 1) * per_page
|
||||
return {
|
||||
"items": pending[start:start + per_page],
|
||||
"total": len(pending),
|
||||
"page": page,
|
||||
"per_page": per_page,
|
||||
}
|
||||
|
||||
|
||||
# ── POST /submit ───────────────────────────────────────────────────────────
|
||||
|
||||
FailureCategory = Literal[
|
||||
"scoring_artifact",
|
||||
"style_violation",
|
||||
"partial_answer",
|
||||
"wrong_answer",
|
||||
"format_error",
|
||||
"hallucination",
|
||||
]
|
||||
|
||||
|
||||
class SubmitRequest(BaseModel):
|
||||
id: str
|
||||
action: Literal["correct", "discard", "flag"]
|
||||
corrected_response: str | None = None
|
||||
failure_category: FailureCategory | None = None
|
||||
|
||||
|
||||
@router.post("/submit")
|
||||
def post_submit(req: SubmitRequest):
|
||||
"""Record a reviewer decision for one SFT candidate."""
|
||||
if req.action == "correct":
|
||||
if not req.corrected_response or not req.corrected_response.strip():
|
||||
raise HTTPException(422, "corrected_response must be non-empty when action is 'correct'")
|
||||
|
||||
records = _read_candidates()
|
||||
idx = next((i for i, r in enumerate(records) if r.get("id") == req.id), None)
|
||||
if idx is None:
|
||||
raise HTTPException(404, f"Record {req.id!r} not found")
|
||||
|
||||
record = records[idx]
|
||||
if record.get("status") != "needs_review":
|
||||
raise HTTPException(409, f"Record is not in needs_review state (current: {record.get('status')})")
|
||||
|
||||
if req.action == "correct":
|
||||
records[idx] = {
|
||||
**record,
|
||||
"status": "approved",
|
||||
"corrected_response": req.corrected_response,
|
||||
"failure_category": req.failure_category,
|
||||
}
|
||||
_write_candidates(records)
|
||||
append_jsonl(_approved_file(), records[idx])
|
||||
elif req.action == "discard":
|
||||
records[idx] = {**record, "status": "discarded"}
|
||||
_write_candidates(records)
|
||||
else: # flag
|
||||
records[idx] = {**record, "status": "model_rejected"}
|
||||
_write_candidates(records)
|
||||
|
||||
return {"ok": True}
|
||||
|
||||
|
||||
# ── POST /undo ─────────────────────────────────────────────────────────────
|
||||
|
||||
class UndoRequest(BaseModel):
|
||||
id: str
|
||||
|
||||
|
||||
@router.post("/undo")
|
||||
def post_undo(req: UndoRequest):
|
||||
"""Restore a previously actioned candidate back to needs_review."""
|
||||
records = _read_candidates()
|
||||
idx = next((i for i, r in enumerate(records) if r.get("id") == req.id), None)
|
||||
if idx is None:
|
||||
raise HTTPException(404, f"Record {req.id!r} not found")
|
||||
|
||||
record = records[idx]
|
||||
old_status = record.get("status")
|
||||
if old_status == "needs_review":
|
||||
raise HTTPException(409, "Record is already in needs_review state")
|
||||
|
||||
records[idx] = {**record, "status": "needs_review", "corrected_response": None}
|
||||
_write_candidates(records)
|
||||
|
||||
# If it was approved, remove from the approved file too
|
||||
if old_status == "approved":
|
||||
approved = read_jsonl(_approved_file())
|
||||
write_jsonl(_approved_file(), [r for r in approved if r.get("id") != req.id])
|
||||
|
||||
return {"ok": True}
|
||||
|
||||
|
||||
# ── GET /export ─────────────────────────────────────────────────────────────
|
||||
|
||||
@router.get("/export")
|
||||
def get_export() -> StreamingResponse:
|
||||
"""Stream approved records as SFT-ready JSONL for download."""
|
||||
exportable = [r for r in read_jsonl(_approved_file()) if _is_exportable(r)]
|
||||
|
||||
def generate():
|
||||
for r in exportable:
|
||||
record = {
|
||||
"messages": r.get("prompt_messages", []) + [
|
||||
{"role": "assistant", "content": r["corrected_response"]}
|
||||
]
|
||||
}
|
||||
yield json.dumps(record) + "\n"
|
||||
|
||||
timestamp = datetime.now(timezone.utc).strftime("%Y%m%d-%H%M%S")
|
||||
return StreamingResponse(
|
||||
generate(),
|
||||
media_type="application/x-ndjson",
|
||||
headers={
|
||||
"Content-Disposition": f'attachment; filename="sft_export_{timestamp}.jsonl"'
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
# ── GET /stats ──────────────────────────────────────────────────────────────
|
||||
|
||||
@router.get("/stats")
|
||||
def get_stats() -> dict[str, object]:
|
||||
"""Return counts by status, model, and task type."""
|
||||
records = _read_candidates()
|
||||
by_status: dict[str, int] = {}
|
||||
by_model: dict[str, int] = {}
|
||||
by_task_type: dict[str, int] = {}
|
||||
|
||||
for r in records:
|
||||
status = r.get("status", "unknown")
|
||||
by_status[status] = by_status.get(status, 0) + 1
|
||||
model = r.get("model_name", "unknown")
|
||||
by_model[model] = by_model.get(model, 0) + 1
|
||||
task_type = r.get("task_type", "unknown")
|
||||
by_task_type[task_type] = by_task_type.get(task_type, 0) + 1
|
||||
|
||||
approved = read_jsonl(_approved_file())
|
||||
export_ready = sum(1 for r in approved if _is_exportable(r))
|
||||
|
||||
return {
|
||||
"total": len(records),
|
||||
"by_status": by_status,
|
||||
"by_model": by_model,
|
||||
"by_task_type": by_task_type,
|
||||
"export_ready": export_ready,
|
||||
}
|
||||
|
||||
|
||||
# ── GET /config ─────────────────────────────────────────────────────────────
|
||||
|
||||
@router.get("/config")
|
||||
def get_sft_config() -> dict:
|
||||
"""Return the current SFT configuration (bench_results_dir)."""
|
||||
f = _config_file()
|
||||
if not f.exists():
|
||||
return {"bench_results_dir": ""}
|
||||
try:
|
||||
raw = yaml.safe_load(f.read_text(encoding="utf-8")) or {}
|
||||
except yaml.YAMLError:
|
||||
return {"bench_results_dir": ""}
|
||||
sft_section = raw.get("sft") or {}
|
||||
return {"bench_results_dir": sft_section.get("bench_results_dir", "")}
|
||||
|
||||
|
||||
class SftConfigPayload(BaseModel):
|
||||
bench_results_dir: str
|
||||
|
||||
|
||||
@router.post("/config")
|
||||
def post_sft_config(payload: SftConfigPayload) -> dict:
|
||||
"""Write the bench_results_dir setting to the config file."""
|
||||
f = _config_file()
|
||||
f.parent.mkdir(parents=True, exist_ok=True)
|
||||
try:
|
||||
raw = yaml.safe_load(f.read_text(encoding="utf-8")) if f.exists() else {}
|
||||
raw = raw or {}
|
||||
except yaml.YAMLError:
|
||||
raw = {}
|
||||
raw["sft"] = {"bench_results_dir": payload.bench_results_dir}
|
||||
tmp = f.with_suffix(".tmp")
|
||||
tmp.write_text(yaml.dump(raw, allow_unicode=True, sort_keys=False), encoding="utf-8")
|
||||
tmp.rename(f)
|
||||
return {"ok": True}
|
||||
|
|
|
|||
|
|
@ -1,4 +1,4 @@
|
|||
"""API integration tests for app/sft.py — /api/sft/* endpoints."""
|
||||
"""API integration tests for app/sft.py -- /api/sft/* endpoints."""
|
||||
import json
|
||||
import pytest
|
||||
from fastapi.testclient import TestClient
|
||||
|
|
@ -7,17 +7,17 @@ from pathlib import Path
|
|||
|
||||
@pytest.fixture(autouse=True)
|
||||
def reset_sft_globals(tmp_path):
|
||||
from app import sft as sft_module
|
||||
_prev_data = sft_module._SFT_DATA_DIR
|
||||
_prev_cfg = sft_module._SFT_CONFIG_DIR
|
||||
_prev_default = sft_module._DEFAULT_BENCH_RESULTS_DIR
|
||||
sft_module.set_sft_data_dir(tmp_path)
|
||||
sft_module.set_sft_config_dir(tmp_path)
|
||||
sft_module.set_default_bench_results_dir(str(tmp_path / "bench_results"))
|
||||
from app.data import corrections as corr_module
|
||||
_prev_data = corr_module._DATA_DIR
|
||||
_prev_cfg = corr_module._CONFIG_DIR
|
||||
_prev_default = corr_module._DEFAULT_BENCH_RESULTS_DIR
|
||||
corr_module.set_data_dir(tmp_path)
|
||||
corr_module.set_config_dir(tmp_path)
|
||||
corr_module.set_default_bench_results_dir(str(tmp_path / "bench_results"))
|
||||
yield
|
||||
sft_module.set_sft_data_dir(_prev_data)
|
||||
sft_module.set_sft_config_dir(_prev_cfg)
|
||||
sft_module.set_default_bench_results_dir(_prev_default)
|
||||
corr_module.set_data_dir(_prev_data)
|
||||
corr_module.set_config_dir(_prev_cfg)
|
||||
corr_module.set_default_bench_results_dir(_prev_default)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
|
|
@ -63,7 +63,7 @@ def _write_config(tmp_path, bench_results_dir: Path) -> None:
|
|||
)
|
||||
|
||||
|
||||
# ── /api/sft/runs ──────────────────────────────────────────────────────────
|
||||
# -- /api/sft/runs -------------------------------------------------------------
|
||||
|
||||
def test_runs_returns_empty_when_no_config(client):
|
||||
r = client.get("/api/sft/runs")
|
||||
|
|
@ -86,7 +86,7 @@ def test_runs_returns_available_runs(client, tmp_path):
|
|||
def test_runs_marks_already_imported(client, tmp_path):
|
||||
_write_run(tmp_path, "2026-04-07-143022", [_make_record("a")])
|
||||
_write_config(tmp_path, tmp_path / "bench_results")
|
||||
from app import sft as sft_module
|
||||
from app.data import corrections as sft_module
|
||||
candidates = sft_module._candidates_file()
|
||||
candidates.parent.mkdir(parents=True, exist_ok=True)
|
||||
candidates.write_text(
|
||||
|
|
@ -97,7 +97,7 @@ def test_runs_marks_already_imported(client, tmp_path):
|
|||
assert r.json()[0]["already_imported"] is True
|
||||
|
||||
|
||||
# ── /api/sft/import ─────────────────────────────────────────────────────────
|
||||
# -- /api/sft/import -----------------------------------------------------------
|
||||
|
||||
def test_import_adds_records(client, tmp_path):
|
||||
_write_run(tmp_path, "2026-04-07-143022", [_make_record("a"), _make_record("b")])
|
||||
|
|
@ -121,10 +121,10 @@ def test_import_unknown_run_returns_404(client, tmp_path):
|
|||
assert r.status_code == 404
|
||||
|
||||
|
||||
# ── /api/sft/queue ──────────────────────────────────────────────────────────
|
||||
# -- /api/sft/queue ------------------------------------------------------------
|
||||
|
||||
def _populate_candidates(tmp_path, records: list[dict]) -> None:
|
||||
from app import sft as sft_module
|
||||
from app.data import corrections as sft_module
|
||||
path = sft_module._candidates_file()
|
||||
path.parent.mkdir(parents=True, exist_ok=True)
|
||||
path.write_text(
|
||||
|
|
@ -164,7 +164,7 @@ def test_queue_empty_when_no_file(client):
|
|||
assert r.json() == {"items": [], "total": 0, "page": 1, "per_page": 20}
|
||||
|
||||
|
||||
# ── /api/sft/submit ─────────────────────────────────────────────────────────
|
||||
# -- /api/sft/submit -----------------------------------------------------------
|
||||
|
||||
def test_submit_correct_sets_approved(client, tmp_path):
|
||||
_populate_candidates(tmp_path, [_make_record("a")])
|
||||
|
|
@ -173,7 +173,7 @@ def test_submit_correct_sets_approved(client, tmp_path):
|
|||
"corrected_response": "def add(a, b): return a + b",
|
||||
})
|
||||
assert r.status_code == 200
|
||||
from app import sft as sft_module
|
||||
from app.data import corrections as sft_module
|
||||
records = sft_module._read_candidates()
|
||||
assert records[0]["status"] == "approved"
|
||||
assert records[0]["corrected_response"] == "def add(a, b): return a + b"
|
||||
|
|
@ -185,7 +185,7 @@ def test_submit_correct_also_appends_to_approved_file(client, tmp_path):
|
|||
"id": "a", "action": "correct",
|
||||
"corrected_response": "def add(a, b): return a + b",
|
||||
})
|
||||
from app import sft as sft_module
|
||||
from app.data import corrections as sft_module
|
||||
from app.utils import read_jsonl
|
||||
approved = read_jsonl(sft_module._approved_file())
|
||||
assert len(approved) == 1
|
||||
|
|
@ -196,7 +196,7 @@ def test_submit_discard_sets_discarded(client, tmp_path):
|
|||
_populate_candidates(tmp_path, [_make_record("a")])
|
||||
r = client.post("/api/sft/submit", json={"id": "a", "action": "discard"})
|
||||
assert r.status_code == 200
|
||||
from app import sft as sft_module
|
||||
from app.data import corrections as sft_module
|
||||
assert sft_module._read_candidates()[0]["status"] == "discarded"
|
||||
|
||||
|
||||
|
|
@ -204,7 +204,7 @@ def test_submit_flag_sets_model_rejected(client, tmp_path):
|
|||
_populate_candidates(tmp_path, [_make_record("a")])
|
||||
r = client.post("/api/sft/submit", json={"id": "a", "action": "flag"})
|
||||
assert r.status_code == 200
|
||||
from app import sft as sft_module
|
||||
from app.data import corrections as sft_module
|
||||
assert sft_module._read_candidates()[0]["status"] == "model_rejected"
|
||||
|
||||
|
||||
|
|
@ -243,7 +243,7 @@ def test_submit_correct_stores_failure_category(client, tmp_path):
|
|||
"failure_category": "style_violation",
|
||||
})
|
||||
assert r.status_code == 200
|
||||
from app import sft as sft_module
|
||||
from app.data import corrections as sft_module
|
||||
records = sft_module._read_candidates()
|
||||
assert records[0]["failure_category"] == "style_violation"
|
||||
|
||||
|
|
@ -255,7 +255,7 @@ def test_submit_correct_null_failure_category(client, tmp_path):
|
|||
"corrected_response": "def add(a, b): return a + b",
|
||||
})
|
||||
assert r.status_code == 200
|
||||
from app import sft as sft_module
|
||||
from app.data import corrections as sft_module
|
||||
records = sft_module._read_candidates()
|
||||
assert records[0]["failure_category"] is None
|
||||
|
||||
|
|
@ -270,14 +270,14 @@ def test_submit_invalid_failure_category_returns_422(client, tmp_path):
|
|||
assert r.status_code == 422
|
||||
|
||||
|
||||
# ── /api/sft/undo ────────────────────────────────────────────────────────────
|
||||
# -- /api/sft/undo -------------------------------------------------------------
|
||||
|
||||
def test_undo_restores_discarded_to_needs_review(client, tmp_path):
|
||||
_populate_candidates(tmp_path, [_make_record("a")])
|
||||
client.post("/api/sft/submit", json={"id": "a", "action": "discard"})
|
||||
r = client.post("/api/sft/undo", json={"id": "a"})
|
||||
assert r.status_code == 200
|
||||
from app import sft as sft_module
|
||||
from app.data import corrections as sft_module
|
||||
assert sft_module._read_candidates()[0]["status"] == "needs_review"
|
||||
|
||||
|
||||
|
|
@ -288,7 +288,7 @@ def test_undo_removes_approved_from_approved_file(client, tmp_path):
|
|||
"corrected_response": "def add(a, b): return a + b",
|
||||
})
|
||||
client.post("/api/sft/undo", json={"id": "a"})
|
||||
from app import sft as sft_module
|
||||
from app.data import corrections as sft_module
|
||||
from app.utils import read_jsonl
|
||||
approved = read_jsonl(sft_module._approved_file())
|
||||
assert not any(r["id"] == "a" for r in approved)
|
||||
|
|
@ -300,10 +300,10 @@ def test_undo_already_needs_review_returns_409(client, tmp_path):
|
|||
assert r.status_code == 409
|
||||
|
||||
|
||||
# ── /api/sft/export ──────────────────────────────────────────────────────────
|
||||
# -- /api/sft/export -----------------------------------------------------------
|
||||
|
||||
def test_export_returns_approved_as_sft_jsonl(client, tmp_path):
|
||||
from app import sft as sft_module
|
||||
from app.data import corrections as sft_module
|
||||
from app.utils import write_jsonl
|
||||
approved = {
|
||||
**_make_record("a"),
|
||||
|
|
@ -331,7 +331,7 @@ def test_export_returns_approved_as_sft_jsonl(client, tmp_path):
|
|||
|
||||
|
||||
def test_export_excludes_non_approved(client, tmp_path):
|
||||
from app import sft as sft_module
|
||||
from app.data import corrections as sft_module
|
||||
from app.utils import write_jsonl
|
||||
records = [
|
||||
{**_make_record("a"), "status": "discarded", "corrected_response": None},
|
||||
|
|
@ -348,10 +348,10 @@ def test_export_empty_when_no_approved_file(client):
|
|||
assert r.text.strip() == ""
|
||||
|
||||
|
||||
# ── /api/sft/stats ───────────────────────────────────────────────────────────
|
||||
# -- /api/sft/stats ------------------------------------------------------------
|
||||
|
||||
def test_stats_counts_by_status(client, tmp_path):
|
||||
from app import sft as sft_module
|
||||
from app.data import corrections as sft_module
|
||||
from app.utils import write_jsonl
|
||||
records = [
|
||||
_make_record("a"),
|
||||
|
|
|
|||
Loading…
Reference in a new issue