feat: move SFT corrections API into app/data/corrections.py

This commit is contained in:
pyr0ball 2026-05-01 22:02:22 -07:00
parent 2054866ff1
commit 99ea39fe38
3 changed files with 374 additions and 366 deletions

335
app/data/corrections.py Normal file
View file

@ -0,0 +1,335 @@
"""Avocet -- SFT candidate corrections API (moved from app/sft.py).
All endpoints are registered on `router` (a FastAPI APIRouter).
api.py includes this router with prefix="/api/sft".
Module-level globals (_DATA_DIR, _CONFIG_DIR) follow the same
testability pattern as api.py -- override them via set_data_dir() and
set_config_dir() in test fixtures.
"""
from __future__ import annotations
import json
import logging
from datetime import datetime, timezone
from pathlib import Path
from typing import Literal
import yaml
from fastapi import APIRouter, HTTPException
from fastapi.responses import StreamingResponse
from pydantic import BaseModel
from app.utils import append_jsonl, read_jsonl, write_jsonl
logger = logging.getLogger(__name__)
_ROOT = Path(__file__).parent.parent.parent
_DATA_DIR: Path = _ROOT / "data"
_CONFIG_DIR: Path | None = None
router = APIRouter()
# -- Testability seams ---------------------------------------------------------
def set_data_dir(path: Path) -> None:
global _DATA_DIR
_DATA_DIR = path
def set_config_dir(path: Path | None) -> None:
global _CONFIG_DIR
_CONFIG_DIR = path
# -- Internal helpers ----------------------------------------------------------
def _config_file() -> Path:
if _CONFIG_DIR is not None:
return _CONFIG_DIR / "label_tool.yaml"
return _ROOT / "config" / "label_tool.yaml"
_DEFAULT_BENCH_RESULTS_DIR = "/Library/Development/CircuitForge/circuitforge-orch/scripts/bench_results"
def set_default_bench_results_dir(path: str) -> None:
"""Override the default bench_results_dir -- used by tests to avoid real filesystem."""
global _DEFAULT_BENCH_RESULTS_DIR
_DEFAULT_BENCH_RESULTS_DIR = path
def _get_bench_results_dir() -> Path:
f = _config_file()
if f.exists():
try:
raw = yaml.safe_load(f.read_text(encoding="utf-8")) or {}
d = raw.get("sft", {}).get("bench_results_dir", "")
if d:
return Path(d)
except yaml.YAMLError as exc:
logger.warning("Failed to parse SFT config %s: %s", f, exc)
return Path(_DEFAULT_BENCH_RESULTS_DIR)
def _candidates_file() -> Path:
return _DATA_DIR / "sft_candidates.jsonl"
def _approved_file() -> Path:
return _DATA_DIR / "sft_approved.jsonl"
def _read_candidates() -> list[dict]:
return read_jsonl(_candidates_file())
def _write_candidates(records: list[dict]) -> None:
write_jsonl(_candidates_file(), records)
def _is_exportable(r: dict) -> bool:
"""Return True if an approved record is ready to include in SFT export."""
return (
r.get("status") == "approved"
and bool(r.get("corrected_response"))
and str(r["corrected_response"]).strip() != ""
)
# -- GET /runs -----------------------------------------------------------------
@router.get("/runs")
def get_runs():
"""List available benchmark runs in the configured bench_results_dir."""
from scripts.sft_import import discover_runs
bench_dir = _get_bench_results_dir()
existing = _read_candidates()
# benchmark_run_id in each record equals the run's directory name by cf-orch convention
imported_run_ids = {
r["benchmark_run_id"]
for r in existing
if r.get("benchmark_run_id") is not None
}
runs = discover_runs(bench_dir)
return [
{
"run_id": r["run_id"],
"timestamp": r["timestamp"],
"candidate_count": r["candidate_count"],
"already_imported": r["run_id"] in imported_run_ids,
}
for r in runs
]
# -- POST /import --------------------------------------------------------------
class ImportRequest(BaseModel):
run_id: str
@router.post("/import")
def post_import(req: ImportRequest):
"""Import one benchmark run's sft_candidates.jsonl into the local data dir."""
from scripts.sft_import import discover_runs, import_run
bench_dir = _get_bench_results_dir()
runs = discover_runs(bench_dir)
run = next((r for r in runs if r["run_id"] == req.run_id), None)
if run is None:
raise HTTPException(404, f"Run {req.run_id!r} not found in bench_results_dir")
return import_run(run["sft_path"], _DATA_DIR)
# -- GET /queue ----------------------------------------------------------------
@router.get("/queue")
def get_queue(page: int = 1, per_page: int = 20):
"""Return paginated needs_review candidates."""
records = _read_candidates()
pending = [r for r in records if r.get("status") == "needs_review"]
start = (page - 1) * per_page
return {
"items": pending[start:start + per_page],
"total": len(pending),
"page": page,
"per_page": per_page,
}
# -- POST /submit --------------------------------------------------------------
FailureCategory = Literal[
"scoring_artifact",
"style_violation",
"partial_answer",
"wrong_answer",
"format_error",
"hallucination",
]
class SubmitRequest(BaseModel):
id: str
action: Literal["correct", "discard", "flag"]
corrected_response: str | None = None
failure_category: FailureCategory | None = None
@router.post("/submit")
def post_submit(req: SubmitRequest):
"""Record a reviewer decision for one SFT candidate."""
if req.action == "correct":
if not req.corrected_response or not req.corrected_response.strip():
raise HTTPException(422, "corrected_response must be non-empty when action is 'correct'")
records = _read_candidates()
idx = next((i for i, r in enumerate(records) if r.get("id") == req.id), None)
if idx is None:
raise HTTPException(404, f"Record {req.id!r} not found")
record = records[idx]
if record.get("status") != "needs_review":
raise HTTPException(409, f"Record is not in needs_review state (current: {record.get('status')})")
if req.action == "correct":
records[idx] = {
**record,
"status": "approved",
"corrected_response": req.corrected_response,
"failure_category": req.failure_category,
}
_write_candidates(records)
append_jsonl(_approved_file(), records[idx])
elif req.action == "discard":
records[idx] = {**record, "status": "discarded"}
_write_candidates(records)
else: # flag
records[idx] = {**record, "status": "model_rejected"}
_write_candidates(records)
return {"ok": True}
# -- POST /undo ----------------------------------------------------------------
class UndoRequest(BaseModel):
id: str
@router.post("/undo")
def post_undo(req: UndoRequest):
"""Restore a previously actioned candidate back to needs_review."""
records = _read_candidates()
idx = next((i for i, r in enumerate(records) if r.get("id") == req.id), None)
if idx is None:
raise HTTPException(404, f"Record {req.id!r} not found")
record = records[idx]
old_status = record.get("status")
if old_status == "needs_review":
raise HTTPException(409, "Record is already in needs_review state")
records[idx] = {**record, "status": "needs_review", "corrected_response": None}
_write_candidates(records)
# If it was approved, remove from the approved file too
if old_status == "approved":
approved = read_jsonl(_approved_file())
write_jsonl(_approved_file(), [r for r in approved if r.get("id") != req.id])
return {"ok": True}
# -- GET /export ---------------------------------------------------------------
@router.get("/export")
def get_export() -> StreamingResponse:
"""Stream approved records as SFT-ready JSONL for download."""
exportable = [r for r in read_jsonl(_approved_file()) if _is_exportable(r)]
def generate():
for r in exportable:
record = {
"messages": r.get("prompt_messages", []) + [
{"role": "assistant", "content": r["corrected_response"]}
]
}
yield json.dumps(record) + "\n"
timestamp = datetime.now(timezone.utc).strftime("%Y%m%d-%H%M%S")
return StreamingResponse(
generate(),
media_type="application/x-ndjson",
headers={
"Content-Disposition": f'attachment; filename="sft_export_{timestamp}.jsonl"'
},
)
# -- GET /stats ----------------------------------------------------------------
@router.get("/stats")
def get_stats() -> dict[str, object]:
"""Return counts by status, model, and task type."""
records = _read_candidates()
by_status: dict[str, int] = {}
by_model: dict[str, int] = {}
by_task_type: dict[str, int] = {}
for r in records:
status = r.get("status", "unknown")
by_status[status] = by_status.get(status, 0) + 1
model = r.get("model_name", "unknown")
by_model[model] = by_model.get(model, 0) + 1
task_type = r.get("task_type", "unknown")
by_task_type[task_type] = by_task_type.get(task_type, 0) + 1
approved = read_jsonl(_approved_file())
export_ready = sum(1 for r in approved if _is_exportable(r))
return {
"total": len(records),
"by_status": by_status,
"by_model": by_model,
"by_task_type": by_task_type,
"export_ready": export_ready,
}
# -- GET /config ---------------------------------------------------------------
@router.get("/config")
def get_sft_config() -> dict:
"""Return the current SFT configuration (bench_results_dir)."""
f = _config_file()
if not f.exists():
return {"bench_results_dir": ""}
try:
raw = yaml.safe_load(f.read_text(encoding="utf-8")) or {}
except yaml.YAMLError:
return {"bench_results_dir": ""}
sft_section = raw.get("sft") or {}
return {"bench_results_dir": sft_section.get("bench_results_dir", "")}
class SftConfigPayload(BaseModel):
bench_results_dir: str
@router.post("/config")
def post_sft_config(payload: SftConfigPayload) -> dict:
"""Write the bench_results_dir setting to the config file."""
f = _config_file()
f.parent.mkdir(parents=True, exist_ok=True)
try:
raw = yaml.safe_load(f.read_text(encoding="utf-8")) if f.exists() else {}
raw = raw or {}
except yaml.YAMLError:
raw = {}
raw["sft"] = {"bench_results_dir": payload.bench_results_dir}
tmp = f.with_suffix(".tmp")
tmp.write_text(yaml.dump(raw, allow_unicode=True, sort_keys=False), encoding="utf-8")
tmp.rename(f)
return {"ok": True}

View file

@ -1,335 +1,8 @@
"""Avocet — SFT candidate import and correction API. """Backward-compat shim -- logic moved to app/data/corrections.py."""
from app.data.corrections import ( # noqa: F401
All endpoints are registered on `router` (a FastAPI APIRouter). router,
api.py includes this router with prefix="/api/sft". set_data_dir as set_sft_data_dir,
set_config_dir as set_sft_config_dir,
Module-level globals (_SFT_DATA_DIR, _SFT_CONFIG_DIR) follow the same set_default_bench_results_dir,
testability pattern as api.py override them via set_sft_data_dir() and _DEFAULT_BENCH_RESULTS_DIR,
set_sft_config_dir() in test fixtures. )
"""
from __future__ import annotations
import json
import logging
from datetime import datetime, timezone
from pathlib import Path
from typing import Literal
import yaml
from fastapi import APIRouter, HTTPException
from fastapi.responses import StreamingResponse
from pydantic import BaseModel
from app.utils import append_jsonl, read_jsonl, write_jsonl
logger = logging.getLogger(__name__)
_ROOT = Path(__file__).parent.parent
_SFT_DATA_DIR: Path = _ROOT / "data"
_SFT_CONFIG_DIR: Path | None = None
router = APIRouter()
# ── Testability seams ──────────────────────────────────────────────────────
def set_sft_data_dir(path: Path) -> None:
global _SFT_DATA_DIR
_SFT_DATA_DIR = path
def set_sft_config_dir(path: Path | None) -> None:
global _SFT_CONFIG_DIR
_SFT_CONFIG_DIR = path
# ── Internal helpers ───────────────────────────────────────────────────────
def _config_file() -> Path:
if _SFT_CONFIG_DIR is not None:
return _SFT_CONFIG_DIR / "label_tool.yaml"
return _ROOT / "config" / "label_tool.yaml"
_DEFAULT_BENCH_RESULTS_DIR = "/Library/Development/CircuitForge/circuitforge-orch/scripts/bench_results"
def set_default_bench_results_dir(path: str) -> None:
"""Override the default bench_results_dir — used by tests to avoid real filesystem."""
global _DEFAULT_BENCH_RESULTS_DIR
_DEFAULT_BENCH_RESULTS_DIR = path
def _get_bench_results_dir() -> Path:
f = _config_file()
if f.exists():
try:
raw = yaml.safe_load(f.read_text(encoding="utf-8")) or {}
d = raw.get("sft", {}).get("bench_results_dir", "")
if d:
return Path(d)
except yaml.YAMLError as exc:
logger.warning("Failed to parse SFT config %s: %s", f, exc)
return Path(_DEFAULT_BENCH_RESULTS_DIR)
def _candidates_file() -> Path:
return _SFT_DATA_DIR / "sft_candidates.jsonl"
def _approved_file() -> Path:
return _SFT_DATA_DIR / "sft_approved.jsonl"
def _read_candidates() -> list[dict]:
return read_jsonl(_candidates_file())
def _write_candidates(records: list[dict]) -> None:
write_jsonl(_candidates_file(), records)
def _is_exportable(r: dict) -> bool:
"""Return True if an approved record is ready to include in SFT export."""
return (
r.get("status") == "approved"
and bool(r.get("corrected_response"))
and str(r["corrected_response"]).strip() != ""
)
# ── GET /runs ──────────────────────────────────────────────────────────────
@router.get("/runs")
def get_runs():
"""List available benchmark runs in the configured bench_results_dir."""
from scripts.sft_import import discover_runs
bench_dir = _get_bench_results_dir()
existing = _read_candidates()
# benchmark_run_id in each record equals the run's directory name by cf-orch convention
imported_run_ids = {
r["benchmark_run_id"]
for r in existing
if r.get("benchmark_run_id") is not None
}
runs = discover_runs(bench_dir)
return [
{
"run_id": r["run_id"],
"timestamp": r["timestamp"],
"candidate_count": r["candidate_count"],
"already_imported": r["run_id"] in imported_run_ids,
}
for r in runs
]
# ── POST /import ───────────────────────────────────────────────────────────
class ImportRequest(BaseModel):
run_id: str
@router.post("/import")
def post_import(req: ImportRequest):
"""Import one benchmark run's sft_candidates.jsonl into the local data dir."""
from scripts.sft_import import discover_runs, import_run
bench_dir = _get_bench_results_dir()
runs = discover_runs(bench_dir)
run = next((r for r in runs if r["run_id"] == req.run_id), None)
if run is None:
raise HTTPException(404, f"Run {req.run_id!r} not found in bench_results_dir")
return import_run(run["sft_path"], _SFT_DATA_DIR)
# ── GET /queue ─────────────────────────────────────────────────────────────
@router.get("/queue")
def get_queue(page: int = 1, per_page: int = 20):
"""Return paginated needs_review candidates."""
records = _read_candidates()
pending = [r for r in records if r.get("status") == "needs_review"]
start = (page - 1) * per_page
return {
"items": pending[start:start + per_page],
"total": len(pending),
"page": page,
"per_page": per_page,
}
# ── POST /submit ───────────────────────────────────────────────────────────
FailureCategory = Literal[
"scoring_artifact",
"style_violation",
"partial_answer",
"wrong_answer",
"format_error",
"hallucination",
]
class SubmitRequest(BaseModel):
id: str
action: Literal["correct", "discard", "flag"]
corrected_response: str | None = None
failure_category: FailureCategory | None = None
@router.post("/submit")
def post_submit(req: SubmitRequest):
"""Record a reviewer decision for one SFT candidate."""
if req.action == "correct":
if not req.corrected_response or not req.corrected_response.strip():
raise HTTPException(422, "corrected_response must be non-empty when action is 'correct'")
records = _read_candidates()
idx = next((i for i, r in enumerate(records) if r.get("id") == req.id), None)
if idx is None:
raise HTTPException(404, f"Record {req.id!r} not found")
record = records[idx]
if record.get("status") != "needs_review":
raise HTTPException(409, f"Record is not in needs_review state (current: {record.get('status')})")
if req.action == "correct":
records[idx] = {
**record,
"status": "approved",
"corrected_response": req.corrected_response,
"failure_category": req.failure_category,
}
_write_candidates(records)
append_jsonl(_approved_file(), records[idx])
elif req.action == "discard":
records[idx] = {**record, "status": "discarded"}
_write_candidates(records)
else: # flag
records[idx] = {**record, "status": "model_rejected"}
_write_candidates(records)
return {"ok": True}
# ── POST /undo ─────────────────────────────────────────────────────────────
class UndoRequest(BaseModel):
id: str
@router.post("/undo")
def post_undo(req: UndoRequest):
"""Restore a previously actioned candidate back to needs_review."""
records = _read_candidates()
idx = next((i for i, r in enumerate(records) if r.get("id") == req.id), None)
if idx is None:
raise HTTPException(404, f"Record {req.id!r} not found")
record = records[idx]
old_status = record.get("status")
if old_status == "needs_review":
raise HTTPException(409, "Record is already in needs_review state")
records[idx] = {**record, "status": "needs_review", "corrected_response": None}
_write_candidates(records)
# If it was approved, remove from the approved file too
if old_status == "approved":
approved = read_jsonl(_approved_file())
write_jsonl(_approved_file(), [r for r in approved if r.get("id") != req.id])
return {"ok": True}
# ── GET /export ─────────────────────────────────────────────────────────────
@router.get("/export")
def get_export() -> StreamingResponse:
"""Stream approved records as SFT-ready JSONL for download."""
exportable = [r for r in read_jsonl(_approved_file()) if _is_exportable(r)]
def generate():
for r in exportable:
record = {
"messages": r.get("prompt_messages", []) + [
{"role": "assistant", "content": r["corrected_response"]}
]
}
yield json.dumps(record) + "\n"
timestamp = datetime.now(timezone.utc).strftime("%Y%m%d-%H%M%S")
return StreamingResponse(
generate(),
media_type="application/x-ndjson",
headers={
"Content-Disposition": f'attachment; filename="sft_export_{timestamp}.jsonl"'
},
)
# ── GET /stats ──────────────────────────────────────────────────────────────
@router.get("/stats")
def get_stats() -> dict[str, object]:
"""Return counts by status, model, and task type."""
records = _read_candidates()
by_status: dict[str, int] = {}
by_model: dict[str, int] = {}
by_task_type: dict[str, int] = {}
for r in records:
status = r.get("status", "unknown")
by_status[status] = by_status.get(status, 0) + 1
model = r.get("model_name", "unknown")
by_model[model] = by_model.get(model, 0) + 1
task_type = r.get("task_type", "unknown")
by_task_type[task_type] = by_task_type.get(task_type, 0) + 1
approved = read_jsonl(_approved_file())
export_ready = sum(1 for r in approved if _is_exportable(r))
return {
"total": len(records),
"by_status": by_status,
"by_model": by_model,
"by_task_type": by_task_type,
"export_ready": export_ready,
}
# ── GET /config ─────────────────────────────────────────────────────────────
@router.get("/config")
def get_sft_config() -> dict:
"""Return the current SFT configuration (bench_results_dir)."""
f = _config_file()
if not f.exists():
return {"bench_results_dir": ""}
try:
raw = yaml.safe_load(f.read_text(encoding="utf-8")) or {}
except yaml.YAMLError:
return {"bench_results_dir": ""}
sft_section = raw.get("sft") or {}
return {"bench_results_dir": sft_section.get("bench_results_dir", "")}
class SftConfigPayload(BaseModel):
bench_results_dir: str
@router.post("/config")
def post_sft_config(payload: SftConfigPayload) -> dict:
"""Write the bench_results_dir setting to the config file."""
f = _config_file()
f.parent.mkdir(parents=True, exist_ok=True)
try:
raw = yaml.safe_load(f.read_text(encoding="utf-8")) if f.exists() else {}
raw = raw or {}
except yaml.YAMLError:
raw = {}
raw["sft"] = {"bench_results_dir": payload.bench_results_dir}
tmp = f.with_suffix(".tmp")
tmp.write_text(yaml.dump(raw, allow_unicode=True, sort_keys=False), encoding="utf-8")
tmp.rename(f)
return {"ok": True}

View file

@ -1,4 +1,4 @@
"""API integration tests for app/sft.py /api/sft/* endpoints.""" """API integration tests for app/sft.py -- /api/sft/* endpoints."""
import json import json
import pytest import pytest
from fastapi.testclient import TestClient from fastapi.testclient import TestClient
@ -7,17 +7,17 @@ from pathlib import Path
@pytest.fixture(autouse=True) @pytest.fixture(autouse=True)
def reset_sft_globals(tmp_path): def reset_sft_globals(tmp_path):
from app import sft as sft_module from app.data import corrections as corr_module
_prev_data = sft_module._SFT_DATA_DIR _prev_data = corr_module._DATA_DIR
_prev_cfg = sft_module._SFT_CONFIG_DIR _prev_cfg = corr_module._CONFIG_DIR
_prev_default = sft_module._DEFAULT_BENCH_RESULTS_DIR _prev_default = corr_module._DEFAULT_BENCH_RESULTS_DIR
sft_module.set_sft_data_dir(tmp_path) corr_module.set_data_dir(tmp_path)
sft_module.set_sft_config_dir(tmp_path) corr_module.set_config_dir(tmp_path)
sft_module.set_default_bench_results_dir(str(tmp_path / "bench_results")) corr_module.set_default_bench_results_dir(str(tmp_path / "bench_results"))
yield yield
sft_module.set_sft_data_dir(_prev_data) corr_module.set_data_dir(_prev_data)
sft_module.set_sft_config_dir(_prev_cfg) corr_module.set_config_dir(_prev_cfg)
sft_module.set_default_bench_results_dir(_prev_default) corr_module.set_default_bench_results_dir(_prev_default)
@pytest.fixture @pytest.fixture
@ -63,7 +63,7 @@ def _write_config(tmp_path, bench_results_dir: Path) -> None:
) )
# ── /api/sft/runs ────────────────────────────────────────────────────────── # -- /api/sft/runs -------------------------------------------------------------
def test_runs_returns_empty_when_no_config(client): def test_runs_returns_empty_when_no_config(client):
r = client.get("/api/sft/runs") r = client.get("/api/sft/runs")
@ -86,7 +86,7 @@ def test_runs_returns_available_runs(client, tmp_path):
def test_runs_marks_already_imported(client, tmp_path): def test_runs_marks_already_imported(client, tmp_path):
_write_run(tmp_path, "2026-04-07-143022", [_make_record("a")]) _write_run(tmp_path, "2026-04-07-143022", [_make_record("a")])
_write_config(tmp_path, tmp_path / "bench_results") _write_config(tmp_path, tmp_path / "bench_results")
from app import sft as sft_module from app.data import corrections as sft_module
candidates = sft_module._candidates_file() candidates = sft_module._candidates_file()
candidates.parent.mkdir(parents=True, exist_ok=True) candidates.parent.mkdir(parents=True, exist_ok=True)
candidates.write_text( candidates.write_text(
@ -97,7 +97,7 @@ def test_runs_marks_already_imported(client, tmp_path):
assert r.json()[0]["already_imported"] is True assert r.json()[0]["already_imported"] is True
# ── /api/sft/import ───────────────────────────────────────────────────────── # -- /api/sft/import -----------------------------------------------------------
def test_import_adds_records(client, tmp_path): def test_import_adds_records(client, tmp_path):
_write_run(tmp_path, "2026-04-07-143022", [_make_record("a"), _make_record("b")]) _write_run(tmp_path, "2026-04-07-143022", [_make_record("a"), _make_record("b")])
@ -121,10 +121,10 @@ def test_import_unknown_run_returns_404(client, tmp_path):
assert r.status_code == 404 assert r.status_code == 404
# ── /api/sft/queue ────────────────────────────────────────────────────────── # -- /api/sft/queue ------------------------------------------------------------
def _populate_candidates(tmp_path, records: list[dict]) -> None: def _populate_candidates(tmp_path, records: list[dict]) -> None:
from app import sft as sft_module from app.data import corrections as sft_module
path = sft_module._candidates_file() path = sft_module._candidates_file()
path.parent.mkdir(parents=True, exist_ok=True) path.parent.mkdir(parents=True, exist_ok=True)
path.write_text( path.write_text(
@ -164,7 +164,7 @@ def test_queue_empty_when_no_file(client):
assert r.json() == {"items": [], "total": 0, "page": 1, "per_page": 20} assert r.json() == {"items": [], "total": 0, "page": 1, "per_page": 20}
# ── /api/sft/submit ───────────────────────────────────────────────────────── # -- /api/sft/submit -----------------------------------------------------------
def test_submit_correct_sets_approved(client, tmp_path): def test_submit_correct_sets_approved(client, tmp_path):
_populate_candidates(tmp_path, [_make_record("a")]) _populate_candidates(tmp_path, [_make_record("a")])
@ -173,7 +173,7 @@ def test_submit_correct_sets_approved(client, tmp_path):
"corrected_response": "def add(a, b): return a + b", "corrected_response": "def add(a, b): return a + b",
}) })
assert r.status_code == 200 assert r.status_code == 200
from app import sft as sft_module from app.data import corrections as sft_module
records = sft_module._read_candidates() records = sft_module._read_candidates()
assert records[0]["status"] == "approved" assert records[0]["status"] == "approved"
assert records[0]["corrected_response"] == "def add(a, b): return a + b" assert records[0]["corrected_response"] == "def add(a, b): return a + b"
@ -185,7 +185,7 @@ def test_submit_correct_also_appends_to_approved_file(client, tmp_path):
"id": "a", "action": "correct", "id": "a", "action": "correct",
"corrected_response": "def add(a, b): return a + b", "corrected_response": "def add(a, b): return a + b",
}) })
from app import sft as sft_module from app.data import corrections as sft_module
from app.utils import read_jsonl from app.utils import read_jsonl
approved = read_jsonl(sft_module._approved_file()) approved = read_jsonl(sft_module._approved_file())
assert len(approved) == 1 assert len(approved) == 1
@ -196,7 +196,7 @@ def test_submit_discard_sets_discarded(client, tmp_path):
_populate_candidates(tmp_path, [_make_record("a")]) _populate_candidates(tmp_path, [_make_record("a")])
r = client.post("/api/sft/submit", json={"id": "a", "action": "discard"}) r = client.post("/api/sft/submit", json={"id": "a", "action": "discard"})
assert r.status_code == 200 assert r.status_code == 200
from app import sft as sft_module from app.data import corrections as sft_module
assert sft_module._read_candidates()[0]["status"] == "discarded" assert sft_module._read_candidates()[0]["status"] == "discarded"
@ -204,7 +204,7 @@ def test_submit_flag_sets_model_rejected(client, tmp_path):
_populate_candidates(tmp_path, [_make_record("a")]) _populate_candidates(tmp_path, [_make_record("a")])
r = client.post("/api/sft/submit", json={"id": "a", "action": "flag"}) r = client.post("/api/sft/submit", json={"id": "a", "action": "flag"})
assert r.status_code == 200 assert r.status_code == 200
from app import sft as sft_module from app.data import corrections as sft_module
assert sft_module._read_candidates()[0]["status"] == "model_rejected" assert sft_module._read_candidates()[0]["status"] == "model_rejected"
@ -243,7 +243,7 @@ def test_submit_correct_stores_failure_category(client, tmp_path):
"failure_category": "style_violation", "failure_category": "style_violation",
}) })
assert r.status_code == 200 assert r.status_code == 200
from app import sft as sft_module from app.data import corrections as sft_module
records = sft_module._read_candidates() records = sft_module._read_candidates()
assert records[0]["failure_category"] == "style_violation" assert records[0]["failure_category"] == "style_violation"
@ -255,7 +255,7 @@ def test_submit_correct_null_failure_category(client, tmp_path):
"corrected_response": "def add(a, b): return a + b", "corrected_response": "def add(a, b): return a + b",
}) })
assert r.status_code == 200 assert r.status_code == 200
from app import sft as sft_module from app.data import corrections as sft_module
records = sft_module._read_candidates() records = sft_module._read_candidates()
assert records[0]["failure_category"] is None assert records[0]["failure_category"] is None
@ -270,14 +270,14 @@ def test_submit_invalid_failure_category_returns_422(client, tmp_path):
assert r.status_code == 422 assert r.status_code == 422
# ── /api/sft/undo ──────────────────────────────────────────────────────────── # -- /api/sft/undo -------------------------------------------------------------
def test_undo_restores_discarded_to_needs_review(client, tmp_path): def test_undo_restores_discarded_to_needs_review(client, tmp_path):
_populate_candidates(tmp_path, [_make_record("a")]) _populate_candidates(tmp_path, [_make_record("a")])
client.post("/api/sft/submit", json={"id": "a", "action": "discard"}) client.post("/api/sft/submit", json={"id": "a", "action": "discard"})
r = client.post("/api/sft/undo", json={"id": "a"}) r = client.post("/api/sft/undo", json={"id": "a"})
assert r.status_code == 200 assert r.status_code == 200
from app import sft as sft_module from app.data import corrections as sft_module
assert sft_module._read_candidates()[0]["status"] == "needs_review" assert sft_module._read_candidates()[0]["status"] == "needs_review"
@ -288,7 +288,7 @@ def test_undo_removes_approved_from_approved_file(client, tmp_path):
"corrected_response": "def add(a, b): return a + b", "corrected_response": "def add(a, b): return a + b",
}) })
client.post("/api/sft/undo", json={"id": "a"}) client.post("/api/sft/undo", json={"id": "a"})
from app import sft as sft_module from app.data import corrections as sft_module
from app.utils import read_jsonl from app.utils import read_jsonl
approved = read_jsonl(sft_module._approved_file()) approved = read_jsonl(sft_module._approved_file())
assert not any(r["id"] == "a" for r in approved) assert not any(r["id"] == "a" for r in approved)
@ -300,10 +300,10 @@ def test_undo_already_needs_review_returns_409(client, tmp_path):
assert r.status_code == 409 assert r.status_code == 409
# ── /api/sft/export ────────────────────────────────────────────────────────── # -- /api/sft/export -----------------------------------------------------------
def test_export_returns_approved_as_sft_jsonl(client, tmp_path): def test_export_returns_approved_as_sft_jsonl(client, tmp_path):
from app import sft as sft_module from app.data import corrections as sft_module
from app.utils import write_jsonl from app.utils import write_jsonl
approved = { approved = {
**_make_record("a"), **_make_record("a"),
@ -331,7 +331,7 @@ def test_export_returns_approved_as_sft_jsonl(client, tmp_path):
def test_export_excludes_non_approved(client, tmp_path): def test_export_excludes_non_approved(client, tmp_path):
from app import sft as sft_module from app.data import corrections as sft_module
from app.utils import write_jsonl from app.utils import write_jsonl
records = [ records = [
{**_make_record("a"), "status": "discarded", "corrected_response": None}, {**_make_record("a"), "status": "discarded", "corrected_response": None},
@ -348,10 +348,10 @@ def test_export_empty_when_no_approved_file(client):
assert r.text.strip() == "" assert r.text.strip() == ""
# ── /api/sft/stats ─────────────────────────────────────────────────────────── # -- /api/sft/stats ------------------------------------------------------------
def test_stats_counts_by_status(client, tmp_path): def test_stats_counts_by_status(client, tmp_path):
from app import sft as sft_module from app.data import corrections as sft_module
from app.utils import write_jsonl from app.utils import write_jsonl
records = [ records = [
_make_record("a"), _make_record("a"),