Adds IngestRequest model and POST /api/sft/ingest route to app/data/corrections.py. Sibling CF products (Peregrine, Kiwi, etc.) can push pre-approved corrections via Bearer token auth (AVOCET_INGESTION_SECRET). Records land as status=approved in both sft_candidates.jsonl and sft_approved.jsonl immediately. 7 tests in tests/test_data_corrections.py cover 503 (secret unset), 401 (missing/malformed header), 403 (wrong secret), happy-path writes to both files, and optional label field.
393 lines
12 KiB
Python
393 lines
12 KiB
Python
"""Avocet -- SFT candidate corrections API (moved from app/sft.py).
|
|
|
|
All endpoints are registered on `router` (a FastAPI APIRouter).
|
|
api.py includes this router with prefix="/api/sft".
|
|
|
|
Module-level globals (_DATA_DIR, _CONFIG_DIR) follow the same
|
|
testability pattern as api.py -- override them via set_data_dir() and
|
|
set_config_dir() in test fixtures.
|
|
"""
|
|
from __future__ import annotations
|
|
|
|
import json
|
|
import logging
|
|
import os
|
|
import uuid
|
|
from datetime import datetime, timezone
|
|
from pathlib import Path
|
|
from typing import Literal
|
|
|
|
import yaml
|
|
from fastapi import APIRouter, Header, HTTPException
|
|
from fastapi.responses import StreamingResponse
|
|
from pydantic import BaseModel
|
|
|
|
from app.utils import append_jsonl, read_jsonl, write_jsonl
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
_ROOT = Path(__file__).parent.parent.parent
|
|
_DATA_DIR: Path = _ROOT / "data"
|
|
_CONFIG_DIR: Path | None = None
|
|
|
|
router = APIRouter()
|
|
|
|
|
|
# -- Testability seams ---------------------------------------------------------
|
|
|
|
def set_data_dir(path: Path) -> None:
|
|
global _DATA_DIR
|
|
_DATA_DIR = path
|
|
|
|
|
|
def set_config_dir(path: Path | None) -> None:
|
|
global _CONFIG_DIR
|
|
_CONFIG_DIR = path
|
|
|
|
|
|
# -- Internal helpers ----------------------------------------------------------
|
|
|
|
def _config_file() -> Path:
|
|
if _CONFIG_DIR is not None:
|
|
return _CONFIG_DIR / "label_tool.yaml"
|
|
return _ROOT / "config" / "label_tool.yaml"
|
|
|
|
|
|
_DEFAULT_BENCH_RESULTS_DIR = "/Library/Development/CircuitForge/circuitforge-orch/scripts/bench_results"
|
|
|
|
|
|
def set_default_bench_results_dir(path: str) -> None:
|
|
"""Override the default bench_results_dir -- used by tests to avoid real filesystem."""
|
|
global _DEFAULT_BENCH_RESULTS_DIR
|
|
_DEFAULT_BENCH_RESULTS_DIR = path
|
|
|
|
|
|
def _get_bench_results_dir() -> Path:
|
|
f = _config_file()
|
|
if f.exists():
|
|
try:
|
|
raw = yaml.safe_load(f.read_text(encoding="utf-8")) or {}
|
|
d = raw.get("sft", {}).get("bench_results_dir", "")
|
|
if d:
|
|
return Path(d)
|
|
except yaml.YAMLError as exc:
|
|
logger.warning("Failed to parse SFT config %s: %s", f, exc)
|
|
return Path(_DEFAULT_BENCH_RESULTS_DIR)
|
|
|
|
|
|
def _candidates_file() -> Path:
|
|
return _DATA_DIR / "sft_candidates.jsonl"
|
|
|
|
|
|
def _approved_file() -> Path:
|
|
return _DATA_DIR / "sft_approved.jsonl"
|
|
|
|
|
|
def _read_candidates() -> list[dict]:
|
|
return read_jsonl(_candidates_file())
|
|
|
|
|
|
def _write_candidates(records: list[dict]) -> None:
|
|
write_jsonl(_candidates_file(), records)
|
|
|
|
|
|
def _is_exportable(r: dict) -> bool:
|
|
"""Return True if an approved record is ready to include in SFT export."""
|
|
return (
|
|
r.get("status") == "approved"
|
|
and bool(r.get("corrected_response"))
|
|
and str(r["corrected_response"]).strip() != ""
|
|
)
|
|
|
|
|
|
# -- GET /runs -----------------------------------------------------------------
|
|
|
|
@router.get("/runs")
|
|
def get_runs():
|
|
"""List available benchmark runs in the configured bench_results_dir."""
|
|
from scripts.sft_import import discover_runs
|
|
bench_dir = _get_bench_results_dir()
|
|
existing = _read_candidates()
|
|
# benchmark_run_id in each record equals the run's directory name by cf-orch convention
|
|
imported_run_ids = {
|
|
r["benchmark_run_id"]
|
|
for r in existing
|
|
if r.get("benchmark_run_id") is not None
|
|
}
|
|
runs = discover_runs(bench_dir)
|
|
return [
|
|
{
|
|
"run_id": r["run_id"],
|
|
"timestamp": r["timestamp"],
|
|
"candidate_count": r["candidate_count"],
|
|
"already_imported": r["run_id"] in imported_run_ids,
|
|
}
|
|
for r in runs
|
|
]
|
|
|
|
|
|
# -- POST /import --------------------------------------------------------------
|
|
|
|
class ImportRequest(BaseModel):
|
|
run_id: str
|
|
|
|
|
|
@router.post("/import")
|
|
def post_import(req: ImportRequest):
|
|
"""Import one benchmark run's sft_candidates.jsonl into the local data dir."""
|
|
from scripts.sft_import import discover_runs, import_run
|
|
bench_dir = _get_bench_results_dir()
|
|
runs = discover_runs(bench_dir)
|
|
run = next((r for r in runs if r["run_id"] == req.run_id), None)
|
|
if run is None:
|
|
raise HTTPException(404, f"Run {req.run_id!r} not found in bench_results_dir")
|
|
return import_run(run["sft_path"], _DATA_DIR)
|
|
|
|
|
|
# -- GET /queue ----------------------------------------------------------------
|
|
|
|
@router.get("/queue")
|
|
def get_queue(page: int = 1, per_page: int = 20):
|
|
"""Return paginated needs_review candidates."""
|
|
records = _read_candidates()
|
|
pending = [r for r in records if r.get("status") == "needs_review"]
|
|
start = (page - 1) * per_page
|
|
return {
|
|
"items": pending[start:start + per_page],
|
|
"total": len(pending),
|
|
"page": page,
|
|
"per_page": per_page,
|
|
}
|
|
|
|
|
|
# -- POST /submit --------------------------------------------------------------
|
|
|
|
FailureCategory = Literal[
|
|
"scoring_artifact",
|
|
"style_violation",
|
|
"partial_answer",
|
|
"wrong_answer",
|
|
"format_error",
|
|
"hallucination",
|
|
]
|
|
|
|
|
|
class SubmitRequest(BaseModel):
|
|
id: str
|
|
action: Literal["correct", "discard", "flag"]
|
|
corrected_response: str | None = None
|
|
failure_category: FailureCategory | None = None
|
|
|
|
|
|
@router.post("/submit")
|
|
def post_submit(req: SubmitRequest):
|
|
"""Record a reviewer decision for one SFT candidate."""
|
|
if req.action == "correct":
|
|
if not req.corrected_response or not req.corrected_response.strip():
|
|
raise HTTPException(422, "corrected_response must be non-empty when action is 'correct'")
|
|
|
|
records = _read_candidates()
|
|
idx = next((i for i, r in enumerate(records) if r.get("id") == req.id), None)
|
|
if idx is None:
|
|
raise HTTPException(404, f"Record {req.id!r} not found")
|
|
|
|
record = records[idx]
|
|
if record.get("status") != "needs_review":
|
|
raise HTTPException(409, f"Record is not in needs_review state (current: {record.get('status')})")
|
|
|
|
if req.action == "correct":
|
|
records[idx] = {
|
|
**record,
|
|
"status": "approved",
|
|
"corrected_response": req.corrected_response,
|
|
"failure_category": req.failure_category,
|
|
}
|
|
_write_candidates(records)
|
|
append_jsonl(_approved_file(), records[idx])
|
|
elif req.action == "discard":
|
|
records[idx] = {**record, "status": "discarded"}
|
|
_write_candidates(records)
|
|
else: # flag
|
|
records[idx] = {**record, "status": "model_rejected"}
|
|
_write_candidates(records)
|
|
|
|
return {"ok": True}
|
|
|
|
|
|
# -- POST /undo ----------------------------------------------------------------
|
|
|
|
class UndoRequest(BaseModel):
|
|
id: str
|
|
|
|
|
|
@router.post("/undo")
|
|
def post_undo(req: UndoRequest):
|
|
"""Restore a previously actioned candidate back to needs_review."""
|
|
records = _read_candidates()
|
|
idx = next((i for i, r in enumerate(records) if r.get("id") == req.id), None)
|
|
if idx is None:
|
|
raise HTTPException(404, f"Record {req.id!r} not found")
|
|
|
|
record = records[idx]
|
|
old_status = record.get("status")
|
|
if old_status == "needs_review":
|
|
raise HTTPException(409, "Record is already in needs_review state")
|
|
|
|
records[idx] = {**record, "status": "needs_review", "corrected_response": None}
|
|
_write_candidates(records)
|
|
|
|
# If it was approved, remove from the approved file too
|
|
if old_status == "approved":
|
|
approved = read_jsonl(_approved_file())
|
|
write_jsonl(_approved_file(), [r for r in approved if r.get("id") != req.id])
|
|
|
|
return {"ok": True}
|
|
|
|
|
|
# -- GET /export ---------------------------------------------------------------
|
|
|
|
@router.get("/export")
|
|
def get_export() -> StreamingResponse:
|
|
"""Stream approved records as SFT-ready JSONL for download."""
|
|
exportable = [r for r in read_jsonl(_approved_file()) if _is_exportable(r)]
|
|
|
|
def generate():
|
|
for r in exportable:
|
|
record = {
|
|
"messages": r.get("prompt_messages", []) + [
|
|
{"role": "assistant", "content": r["corrected_response"]}
|
|
]
|
|
}
|
|
yield json.dumps(record) + "\n"
|
|
|
|
timestamp = datetime.now(timezone.utc).strftime("%Y%m%d-%H%M%S")
|
|
return StreamingResponse(
|
|
generate(),
|
|
media_type="application/x-ndjson",
|
|
headers={
|
|
"Content-Disposition": f'attachment; filename="sft_export_{timestamp}.jsonl"'
|
|
},
|
|
)
|
|
|
|
|
|
# -- GET /stats ----------------------------------------------------------------
|
|
|
|
@router.get("/stats")
|
|
def get_stats() -> dict[str, object]:
|
|
"""Return counts by status, model, and task type."""
|
|
records = _read_candidates()
|
|
by_status: dict[str, int] = {}
|
|
by_model: dict[str, int] = {}
|
|
by_task_type: dict[str, int] = {}
|
|
|
|
for r in records:
|
|
status = r.get("status", "unknown")
|
|
by_status[status] = by_status.get(status, 0) + 1
|
|
model = r.get("model_name", "unknown")
|
|
by_model[model] = by_model.get(model, 0) + 1
|
|
task_type = r.get("task_type", "unknown")
|
|
by_task_type[task_type] = by_task_type.get(task_type, 0) + 1
|
|
|
|
approved = read_jsonl(_approved_file())
|
|
export_ready = sum(1 for r in approved if _is_exportable(r))
|
|
|
|
return {
|
|
"total": len(records),
|
|
"by_status": by_status,
|
|
"by_model": by_model,
|
|
"by_task_type": by_task_type,
|
|
"export_ready": export_ready,
|
|
}
|
|
|
|
|
|
# -- GET /config ---------------------------------------------------------------
|
|
|
|
@router.get("/config")
|
|
def get_sft_config() -> dict:
|
|
"""Return the current SFT configuration (bench_results_dir)."""
|
|
f = _config_file()
|
|
if not f.exists():
|
|
return {"bench_results_dir": ""}
|
|
try:
|
|
raw = yaml.safe_load(f.read_text(encoding="utf-8")) or {}
|
|
except yaml.YAMLError:
|
|
return {"bench_results_dir": ""}
|
|
sft_section = raw.get("sft") or {}
|
|
return {"bench_results_dir": sft_section.get("bench_results_dir", "")}
|
|
|
|
|
|
class SftConfigPayload(BaseModel):
|
|
bench_results_dir: str
|
|
|
|
|
|
@router.post("/config")
|
|
def post_sft_config(payload: SftConfigPayload) -> dict:
|
|
"""Write the bench_results_dir setting to the config file."""
|
|
f = _config_file()
|
|
f.parent.mkdir(parents=True, exist_ok=True)
|
|
try:
|
|
raw = yaml.safe_load(f.read_text(encoding="utf-8")) if f.exists() else {}
|
|
raw = raw or {}
|
|
except yaml.YAMLError:
|
|
raw = {}
|
|
raw["sft"] = {"bench_results_dir": payload.bench_results_dir}
|
|
tmp = f.with_suffix(".tmp")
|
|
tmp.write_text(yaml.dump(raw, allow_unicode=True, sort_keys=False), encoding="utf-8")
|
|
tmp.rename(f)
|
|
return {"ok": True}
|
|
|
|
|
|
# -- POST /ingest --------------------------------------------------------------
|
|
|
|
class IngestRequest(BaseModel):
|
|
source: str # e.g. "peregrine", "kiwi"
|
|
task_type: str # e.g. "email_classification", "recipe_suggestion"
|
|
prompt: str # the prompt that was sent to the LLM
|
|
response: str # the LLM's original response
|
|
correction: str # the human-corrected response
|
|
label: str | None = None # optional label/category
|
|
|
|
|
|
@router.post("/ingest")
|
|
def post_ingest(
|
|
req: IngestRequest,
|
|
authorization: str | None = Header(default=None),
|
|
) -> dict:
|
|
"""Ingest a correction from a sibling CF product.
|
|
|
|
Authentication: Authorization: Bearer <AVOCET_INGESTION_SECRET>
|
|
|
|
Creates a sft_candidates record with status='approved' (pre-approved by
|
|
the calling product -- human review already happened upstream). Also writes
|
|
to sft_approved.jsonl so it is immediately included in export counts.
|
|
|
|
Returns {"ok": True, "id": "<uuid>"}.
|
|
"""
|
|
expected_secret = os.environ.get("AVOCET_INGESTION_SECRET", "")
|
|
if not expected_secret:
|
|
raise HTTPException(503, "Ingestion not configured -- AVOCET_INGESTION_SECRET not set")
|
|
|
|
if not authorization or not authorization.startswith("Bearer "):
|
|
raise HTTPException(401, "Missing or malformed Authorization header")
|
|
|
|
token = authorization.removeprefix("Bearer ").strip()
|
|
if token != expected_secret:
|
|
raise HTTPException(403, "Invalid ingestion secret")
|
|
|
|
record_id = str(uuid.uuid4())
|
|
now = datetime.now(timezone.utc).isoformat()
|
|
record = {
|
|
"id": record_id,
|
|
"source": req.source,
|
|
"task_type": req.task_type,
|
|
"status": "approved",
|
|
"prompt_messages": [{"role": "user", "content": req.prompt}],
|
|
"model_response": req.response,
|
|
"corrected_response": req.correction,
|
|
"label": req.label,
|
|
"timestamp": now,
|
|
"benchmark_run_id": None,
|
|
}
|
|
append_jsonl(_candidates_file(), record)
|
|
append_jsonl(_approved_file(), record)
|
|
return {"ok": True, "id": record_id}
|