feat: add make_corrections_router for LLM output correction collection
Shared router factory for storing user thumbs-up/down + correction text on LLM outputs. Used by Linnet initially; designed to wire into any CF product. JSONL export endpoint feeds Avocet SFT pipeline. Only opted_in=1 rows export (consent gate for correction text).
This commit is contained in:
parent
7623c3edaf
commit
fc52d32574
2 changed files with 201 additions and 1 deletions
|
|
@ -1,3 +1,4 @@
|
|||
from circuitforge_core.api.feedback import make_feedback_router
|
||||
from circuitforge_core.api.corrections import make_corrections_router, CORRECTIONS_MIGRATION_SQL
|
||||
|
||||
__all__ = ["make_feedback_router"]
|
||||
__all__ = ["make_feedback_router", "make_corrections_router", "CORRECTIONS_MIGRATION_SQL"]
|
||||
|
|
|
|||
199
circuitforge_core/api/corrections.py
Normal file
199
circuitforge_core/api/corrections.py
Normal file
|
|
@ -0,0 +1,199 @@
|
|||
"""
|
||||
Shared corrections router — stores user corrections to LLM output for SFT training.
|
||||
|
||||
Products include this with make_corrections_router(get_db=..., product=...).
|
||||
Corrections are stored locally in each product's SQLite DB and exported as JSONL
|
||||
for the Avocet SFT pipeline. Separate from the bug-feedback→Forgejo-issue path.
|
||||
|
||||
Required DB migration (add to product migrations dir):
|
||||
-- From circuitforge_core.api.corrections import CORRECTIONS_MIGRATION_SQL
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import sqlite3
|
||||
from collections.abc import Callable
|
||||
from datetime import datetime, timezone
|
||||
from typing import Iterator, Literal
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
from fastapi.responses import StreamingResponse
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
# Drop this SQL into a product's migrations directory (e.g. 020_corrections.sql).
|
||||
CORRECTIONS_MIGRATION_SQL = """\
|
||||
CREATE TABLE IF NOT EXISTS corrections (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
item_id TEXT NOT NULL DEFAULT '',
|
||||
product TEXT NOT NULL,
|
||||
correction_type TEXT NOT NULL,
|
||||
input_text TEXT NOT NULL,
|
||||
original_output TEXT NOT NULL,
|
||||
corrected_output TEXT NOT NULL DEFAULT '',
|
||||
rating TEXT NOT NULL DEFAULT 'down',
|
||||
context TEXT NOT NULL DEFAULT '{}',
|
||||
opted_in INTEGER NOT NULL DEFAULT 0,
|
||||
created_at TEXT NOT NULL DEFAULT (datetime('now'))
|
||||
);
|
||||
|
||||
CREATE INDEX IF NOT EXISTS idx_corrections_product
|
||||
ON corrections (product);
|
||||
|
||||
CREATE INDEX IF NOT EXISTS idx_corrections_opted_in
|
||||
ON corrections (opted_in);
|
||||
"""
|
||||
|
||||
|
||||
class CorrectionRequest(BaseModel):
|
||||
item_id: str = ""
|
||||
product: str
|
||||
correction_type: str
|
||||
input_text: str
|
||||
original_output: str
|
||||
corrected_output: str = ""
|
||||
rating: Literal["up", "down"] = "down"
|
||||
context: dict = Field(default_factory=dict)
|
||||
opted_in: bool = False
|
||||
|
||||
|
||||
class CorrectionResponse(BaseModel):
|
||||
id: int
|
||||
saved: bool
|
||||
|
||||
|
||||
class CorrectionRecord(BaseModel):
|
||||
id: int
|
||||
item_id: str
|
||||
product: str
|
||||
correction_type: str
|
||||
input_text: str
|
||||
original_output: str
|
||||
corrected_output: str
|
||||
rating: str
|
||||
context: dict
|
||||
opted_in: bool
|
||||
created_at: str
|
||||
|
||||
|
||||
def make_corrections_router(
|
||||
get_db: Callable[[], Iterator[sqlite3.Connection]],
|
||||
product: str,
|
||||
) -> APIRouter:
|
||||
"""Return a configured corrections APIRouter.
|
||||
|
||||
Args:
|
||||
get_db: FastAPI dependency that yields a sqlite3.Connection.
|
||||
product: Product slug injected into every correction row (e.g. "linnet").
|
||||
"""
|
||||
router = APIRouter()
|
||||
|
||||
@router.post("", response_model=CorrectionResponse)
|
||||
def submit_correction(
|
||||
payload: CorrectionRequest,
|
||||
conn: sqlite3.Connection = Depends(get_db),
|
||||
) -> CorrectionResponse:
|
||||
"""Store a user correction to an LLM output."""
|
||||
# Thumbs-up with no corrected text is a valid positive signal.
|
||||
if payload.rating == "down" and not payload.corrected_output.strip():
|
||||
raise HTTPException(
|
||||
status_code=422,
|
||||
detail="corrected_output is required when rating is 'down'.",
|
||||
)
|
||||
|
||||
row_id = conn.execute(
|
||||
"""
|
||||
INSERT INTO corrections
|
||||
(item_id, product, correction_type, input_text, original_output,
|
||||
corrected_output, rating, context, opted_in)
|
||||
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)
|
||||
""",
|
||||
(
|
||||
payload.item_id,
|
||||
product,
|
||||
payload.correction_type,
|
||||
payload.input_text,
|
||||
payload.original_output,
|
||||
payload.corrected_output,
|
||||
payload.rating,
|
||||
json.dumps(payload.context),
|
||||
int(payload.opted_in),
|
||||
),
|
||||
).lastrowid
|
||||
conn.commit()
|
||||
return CorrectionResponse(id=row_id, saved=True)
|
||||
|
||||
@router.get("", response_model=list[CorrectionRecord])
|
||||
def list_corrections(
|
||||
opted_in_only: bool = False,
|
||||
limit: int = 200,
|
||||
conn: sqlite3.Connection = Depends(get_db),
|
||||
) -> list[CorrectionRecord]:
|
||||
"""List stored corrections, optionally filtered to opted-in rows only."""
|
||||
conn.row_factory = sqlite3.Row
|
||||
query = "SELECT * FROM corrections"
|
||||
params: list = []
|
||||
if opted_in_only:
|
||||
query += " WHERE opted_in = 1"
|
||||
query += " ORDER BY created_at DESC LIMIT ?"
|
||||
params.append(max(1, min(limit, 1000)))
|
||||
rows = conn.execute(query, params).fetchall()
|
||||
return [
|
||||
CorrectionRecord(
|
||||
id=r["id"],
|
||||
item_id=r["item_id"],
|
||||
product=r["product"],
|
||||
correction_type=r["correction_type"],
|
||||
input_text=r["input_text"],
|
||||
original_output=r["original_output"],
|
||||
corrected_output=r["corrected_output"],
|
||||
rating=r["rating"],
|
||||
context=json.loads(r["context"] or "{}"),
|
||||
opted_in=bool(r["opted_in"]),
|
||||
created_at=r["created_at"],
|
||||
)
|
||||
for r in rows
|
||||
]
|
||||
|
||||
@router.get("/export")
|
||||
def export_corrections(
|
||||
opted_in_only: bool = True,
|
||||
conn: sqlite3.Connection = Depends(get_db),
|
||||
) -> StreamingResponse:
|
||||
"""Stream corrections as JSONL for the Avocet SFT pipeline.
|
||||
|
||||
Each line is a JSON object with the fields expected by avocet's
|
||||
SFT candidate importer. opted_in_only=True (default) — only rows
|
||||
where the user consented to share are exported.
|
||||
"""
|
||||
conn.row_factory = sqlite3.Row
|
||||
query = "SELECT * FROM corrections"
|
||||
if opted_in_only:
|
||||
query += " WHERE opted_in = 1"
|
||||
query += " ORDER BY created_at ASC"
|
||||
rows = conn.execute(query).fetchall()
|
||||
|
||||
timestamp = datetime.now(timezone.utc).strftime("%Y%m%dT%H%M%SZ")
|
||||
filename = f"corrections_{product}_{timestamp}.jsonl"
|
||||
|
||||
def generate() -> Iterator[str]:
|
||||
for r in rows:
|
||||
record = {
|
||||
"input": r["input_text"],
|
||||
"output": r["original_output"],
|
||||
"correction": r["corrected_output"],
|
||||
"rating": r["rating"],
|
||||
"correction_type": r["correction_type"],
|
||||
"product": r["product"],
|
||||
"item_id": r["item_id"],
|
||||
"context": json.loads(r["context"] or "{}"),
|
||||
"created_at": r["created_at"],
|
||||
}
|
||||
yield json.dumps(record, ensure_ascii=False) + "\n"
|
||||
|
||||
return StreamingResponse(
|
||||
generate(),
|
||||
media_type="application/x-ndjson",
|
||||
headers={"Content-Disposition": f'attachment; filename="{filename}"'},
|
||||
)
|
||||
|
||||
return router
|
||||
Loading…
Reference in a new issue