feat: SFT failure_category — classify why a model response was wrong #17

Merged
pyr0ball merged 1 commit from feat/sft-failure-category into main 2026-04-08 22:19:20 -07:00
6 changed files with 323 additions and 21 deletions

View file

@ -151,10 +151,21 @@ def get_queue(page: int = 1, per_page: int = 20):
# ── POST /submit ─────────────────────────────────────────────────────────── # ── POST /submit ───────────────────────────────────────────────────────────
FailureCategory = Literal[
"scoring_artifact",
"style_violation",
"partial_answer",
"wrong_answer",
"format_error",
"hallucination",
]
class SubmitRequest(BaseModel): class SubmitRequest(BaseModel):
id: str id: str
action: Literal["correct", "discard", "flag"] action: Literal["correct", "discard", "flag"]
corrected_response: str | None = None corrected_response: str | None = None
failure_category: FailureCategory | None = None
@router.post("/submit") @router.post("/submit")
@ -174,7 +185,12 @@ def post_submit(req: SubmitRequest):
raise HTTPException(409, f"Record is not in needs_review state (current: {record.get('status')})") raise HTTPException(409, f"Record is not in needs_review state (current: {record.get('status')})")
if req.action == "correct": if req.action == "correct":
records[idx] = {**record, "status": "approved", "corrected_response": req.corrected_response} records[idx] = {
**record,
"status": "approved",
"corrected_response": req.corrected_response,
"failure_category": req.failure_category,
}
_write_candidates(records) _write_candidates(records)
append_jsonl(_approved_file(), records[idx]) append_jsonl(_approved_file(), records[idx])
elif req.action == "discard": elif req.action == "discard":

View file

@ -232,6 +232,41 @@ def test_submit_already_approved_returns_409(client, tmp_path):
assert r.status_code == 409 assert r.status_code == 409
def test_submit_correct_stores_failure_category(client, tmp_path):
_populate_candidates(tmp_path, [_make_record("a")])
r = client.post("/api/sft/submit", json={
"id": "a", "action": "correct",
"corrected_response": "def add(a, b): return a + b",
"failure_category": "style_violation",
})
assert r.status_code == 200
from app import sft as sft_module
records = sft_module._read_candidates()
assert records[0]["failure_category"] == "style_violation"
def test_submit_correct_null_failure_category(client, tmp_path):
_populate_candidates(tmp_path, [_make_record("a")])
r = client.post("/api/sft/submit", json={
"id": "a", "action": "correct",
"corrected_response": "def add(a, b): return a + b",
})
assert r.status_code == 200
from app import sft as sft_module
records = sft_module._read_candidates()
assert records[0]["failure_category"] is None
def test_submit_invalid_failure_category_returns_422(client, tmp_path):
_populate_candidates(tmp_path, [_make_record("a")])
r = client.post("/api/sft/submit", json={
"id": "a", "action": "correct",
"corrected_response": "def add(a, b): return a + b",
"failure_category": "nonsense",
})
assert r.status_code == 422
# ── /api/sft/undo ──────────────────────────────────────────────────────────── # ── /api/sft/undo ────────────────────────────────────────────────────────────
def test_undo_restores_discarded_to_needs_review(client, tmp_path): def test_undo_restores_discarded_to_needs_review(client, tmp_path):

View file

@ -13,6 +13,7 @@ const LOW_QUALITY_ITEM: SftQueueItem = {
model_response: 'def add(a, b): return a - b', model_response: 'def add(a, b): return a - b',
corrected_response: null, quality_score: 0.2, corrected_response: null, quality_score: 0.2,
failure_reason: 'pattern_match: 0/2 matched', failure_reason: 'pattern_match: 0/2 matched',
failure_category: null,
task_id: 'code-fn', task_type: 'code', task_name: 'Code: Write a function', task_id: 'code-fn', task_type: 'code', task_name: 'Code: Write a function',
model_id: 'Qwen/Qwen2.5-3B', model_name: 'Qwen2.5-3B', model_id: 'Qwen/Qwen2.5-3B', model_name: 'Qwen2.5-3B',
node_id: 'heimdall', gpu_id: 0, tokens_per_sec: 38.4, node_id: 'heimdall', gpu_id: 0, tokens_per_sec: 38.4,
@ -68,15 +69,17 @@ describe('SftCard', () => {
expect(w.emitted('correct')).toBeTruthy() expect(w.emitted('correct')).toBeTruthy()
}) })
it('clicking Discard button emits discard', async () => { it('clicking Discard button then confirming emits discard', async () => {
const w = mount(SftCard, { props: { item: LOW_QUALITY_ITEM } }) const w = mount(SftCard, { props: { item: LOW_QUALITY_ITEM } })
await w.find('[data-testid="discard-btn"]').trigger('click') await w.find('[data-testid="discard-btn"]').trigger('click')
await w.find('[data-testid="confirm-pending-btn"]').trigger('click')
expect(w.emitted('discard')).toBeTruthy() expect(w.emitted('discard')).toBeTruthy()
}) })
it('clicking Flag Model button emits flag', async () => { it('clicking Flag Model button then confirming emits flag', async () => {
const w = mount(SftCard, { props: { item: LOW_QUALITY_ITEM } }) const w = mount(SftCard, { props: { item: LOW_QUALITY_ITEM } })
await w.find('[data-testid="flag-btn"]').trigger('click') await w.find('[data-testid="flag-btn"]').trigger('click')
await w.find('[data-testid="confirm-pending-btn"]').trigger('click')
expect(w.emitted('flag')).toBeTruthy() expect(w.emitted('flag')).toBeTruthy()
}) })
@ -95,4 +98,82 @@ describe('SftCard', () => {
const w = mount(SftCard, { props: { item } }) const w = mount(SftCard, { props: { item } })
expect(w.find('.failure-reason').exists()).toBe(false) expect(w.find('.failure-reason').exists()).toBe(false)
}) })
// ── Failure category chip-group ───────────────────────────────────
it('failure category section hidden when not correcting and no pending action', () => {
const w = mount(SftCard, { props: { item: LOW_QUALITY_ITEM } })
expect(w.find('[data-testid="failure-category-section"]').exists()).toBe(false)
})
it('failure category section shown when correcting prop is true', () => {
const w = mount(SftCard, { props: { item: LOW_QUALITY_ITEM, correcting: true } })
expect(w.find('[data-testid="failure-category-section"]').exists()).toBe(true)
})
it('renders all six category chips when correcting', () => {
const w = mount(SftCard, { props: { item: LOW_QUALITY_ITEM, correcting: true } })
const chips = w.findAll('.category-chip')
expect(chips).toHaveLength(6)
})
it('clicking a category chip selects it (adds active class)', async () => {
const w = mount(SftCard, { props: { item: LOW_QUALITY_ITEM, correcting: true } })
const chip = w.find('[data-testid="category-chip-wrong_answer"]')
await chip.trigger('click')
expect(chip.classes()).toContain('category-chip--active')
})
it('clicking the active chip again deselects it', async () => {
const w = mount(SftCard, { props: { item: LOW_QUALITY_ITEM, correcting: true } })
const chip = w.find('[data-testid="category-chip-hallucination"]')
await chip.trigger('click')
expect(chip.classes()).toContain('category-chip--active')
await chip.trigger('click')
expect(chip.classes()).not.toContain('category-chip--active')
})
it('only one chip can be active at a time', async () => {
const w = mount(SftCard, { props: { item: LOW_QUALITY_ITEM, correcting: true } })
await w.find('[data-testid="category-chip-wrong_answer"]').trigger('click')
await w.find('[data-testid="category-chip-hallucination"]').trigger('click')
const active = w.findAll('.category-chip--active')
expect(active).toHaveLength(1)
expect(active[0].attributes('data-testid')).toBe('category-chip-hallucination')
})
it('clicking Discard shows pending action row with category section', async () => {
const w = mount(SftCard, { props: { item: LOW_QUALITY_ITEM } })
await w.find('[data-testid="discard-btn"]').trigger('click')
expect(w.find('[data-testid="failure-category-section"]').exists()).toBe(true)
expect(w.find('[data-testid="pending-action-row"]').exists()).toBe(true)
})
it('clicking Flag shows pending action row', async () => {
const w = mount(SftCard, { props: { item: LOW_QUALITY_ITEM } })
await w.find('[data-testid="flag-btn"]').trigger('click')
expect(w.find('[data-testid="pending-action-row"]').exists()).toBe(true)
})
it('confirming discard emits discard with null when no category selected', async () => {
const w = mount(SftCard, { props: { item: LOW_QUALITY_ITEM } })
await w.find('[data-testid="discard-btn"]').trigger('click')
await w.find('[data-testid="confirm-pending-btn"]').trigger('click')
expect(w.emitted('discard')).toBeTruthy()
expect(w.emitted('discard')![0]).toEqual([null])
})
it('confirming discard emits discard with selected category', async () => {
const w = mount(SftCard, { props: { item: LOW_QUALITY_ITEM } })
await w.find('[data-testid="discard-btn"]').trigger('click')
await w.find('[data-testid="category-chip-scoring_artifact"]').trigger('click')
await w.find('[data-testid="confirm-pending-btn"]').trigger('click')
expect(w.emitted('discard')![0]).toEqual(['scoring_artifact'])
})
it('cancelling pending action hides the pending row', async () => {
const w = mount(SftCard, { props: { item: LOW_QUALITY_ITEM } })
await w.find('[data-testid="discard-btn"]').trigger('click')
await w.find('[data-testid="cancel-pending-btn"]').trigger('click')
expect(w.find('[data-testid="pending-action-row"]').exists()).toBe(false)
})
}) })

View file

@ -57,21 +57,52 @@
<button <button
data-testid="discard-btn" data-testid="discard-btn"
class="btn-discard" class="btn-discard"
@click="$emit('discard')" @click="emitWithCategory('discard')"
> Discard</button> > Discard</button>
<button <button
data-testid="flag-btn" data-testid="flag-btn"
class="btn-flag" class="btn-flag"
@click="$emit('flag')" @click="emitWithCategory('flag')"
> Flag Model</button> > Flag Model</button>
</div> </div>
<!-- Failure category selector (shown when correcting or acting) -->
<div
v-if="correcting || pendingAction"
class="failure-category-section"
data-testid="failure-category-section"
>
<p class="section-label">Failure category <span class="optional-label">(optional)</span></p>
<div class="category-chips" role="group" aria-label="Failure category">
<button
v-for="cat in FAILURE_CATEGORIES"
:key="cat.value"
type="button"
class="category-chip"
:class="{ 'category-chip--active': selectedCategory === cat.value }"
:aria-pressed="selectedCategory === cat.value || undefined"
:data-testid="'category-chip-' + cat.value"
@click="toggleCategory(cat.value)"
>{{ cat.label }}</button>
</div>
<!-- Pending discard/flag confirm row -->
<div v-if="pendingAction" class="pending-action-row" data-testid="pending-action-row">
<button class="btn-confirm" @click="confirmPendingAction" data-testid="confirm-pending-btn">
Confirm {{ pendingAction }}
</button>
<button class="btn-cancel-pending" @click="cancelPendingAction" data-testid="cancel-pending-btn">
Cancel
</button>
</div>
</div>
<!-- Correction area (shown when correcting = true) --> <!-- Correction area (shown when correcting = true) -->
<div v-if="correcting" data-testid="correction-area"> <div v-if="correcting" data-testid="correction-area">
<SftCorrectionArea <SftCorrectionArea
ref="correctionAreaEl" ref="correctionAreaEl"
:described-by="'sft-failure-' + item.id" :described-by="'sft-failure-' + item.id"
@submit="$emit('submit-correction', $event)" @submit="handleSubmitCorrection"
@cancel="$emit('cancel-correction')" @cancel="$emit('cancel-correction')"
/> />
</div> </div>
@ -80,21 +111,32 @@
<script setup lang="ts"> <script setup lang="ts">
import { ref, computed } from 'vue' import { ref, computed } from 'vue'
import type { SftQueueItem } from '../stores/sft' import type { SftQueueItem, SftFailureCategory } from '../stores/sft'
import SftCorrectionArea from './SftCorrectionArea.vue' import SftCorrectionArea from './SftCorrectionArea.vue'
const props = defineProps<{ item: SftQueueItem; correcting?: boolean }>() const props = defineProps<{ item: SftQueueItem; correcting?: boolean }>()
const emit = defineEmits<{ const emit = defineEmits<{
correct: [] correct: []
discard: [] discard: [category: SftFailureCategory | null]
flag: [] flag: [category: SftFailureCategory | null]
'submit-correction': [text: string] 'submit-correction': [text: string, category: SftFailureCategory | null]
'cancel-correction': [] 'cancel-correction': []
}>() }>()
const FAILURE_CATEGORIES: { value: SftFailureCategory; label: string }[] = [
{ value: 'scoring_artifact', label: 'Scoring artifact' },
{ value: 'style_violation', label: 'Style violation' },
{ value: 'partial_answer', label: 'Partial answer' },
{ value: 'wrong_answer', label: 'Wrong answer' },
{ value: 'format_error', label: 'Format error' },
{ value: 'hallucination', label: 'Hallucination' },
]
const promptExpanded = ref(false) const promptExpanded = ref(false)
const correctionAreaEl = ref<InstanceType<typeof SftCorrectionArea> | null>(null) const correctionAreaEl = ref<InstanceType<typeof SftCorrectionArea> | null>(null)
const selectedCategory = ref<SftFailureCategory | null>(null)
const pendingAction = ref<'discard' | 'flag' | null>(null)
const qualityClass = computed(() => { const qualityClass = computed(() => {
const s = props.item.quality_score const s = props.item.quality_score
@ -110,8 +152,34 @@ const qualityLabel = computed(() => {
return 'acceptable' return 'acceptable'
}) })
function toggleCategory(cat: SftFailureCategory) {
selectedCategory.value = selectedCategory.value === cat ? null : cat
}
function emitWithCategory(action: 'discard' | 'flag') {
pendingAction.value = action
}
function confirmPendingAction() {
if (!pendingAction.value) return
emit(pendingAction.value, selectedCategory.value)
pendingAction.value = null
selectedCategory.value = null
}
function cancelPendingAction() {
pendingAction.value = null
}
function handleSubmitCorrection(text: string) {
emit('submit-correction', text, selectedCategory.value)
selectedCategory.value = null
}
function resetCorrection() { function resetCorrection() {
correctionAreaEl.value?.reset() correctionAreaEl.value?.reset()
selectedCategory.value = null
pendingAction.value = null
} }
defineExpose({ resetCorrection }) defineExpose({ resetCorrection })
@ -243,4 +311,83 @@ defineExpose({ resetCorrection })
.btn-flag { border-color: var(--color-warning); color: var(--color-warning); } .btn-flag { border-color: var(--color-warning); color: var(--color-warning); }
.btn-flag:hover { background: color-mix(in srgb, var(--color-warning) 10%, transparent); } .btn-flag:hover { background: color-mix(in srgb, var(--color-warning) 10%, transparent); }
/* ── Failure category selector ─────────────────── */
.failure-category-section {
display: flex;
flex-direction: column;
gap: var(--space-2);
}
.optional-label {
font-size: 0.75rem;
font-weight: 400;
color: var(--color-text-muted);
}
.category-chips {
display: flex;
flex-wrap: wrap;
gap: var(--space-2);
}
.category-chip {
padding: var(--space-1) var(--space-3);
border-radius: var(--radius-full);
border: 1px solid var(--color-border);
background: var(--color-surface-alt);
color: var(--color-text-muted);
font-size: 0.78rem;
font-weight: 500;
cursor: pointer;
transition: background var(--transition), color var(--transition), border-color var(--transition);
}
.category-chip:hover {
border-color: var(--color-accent);
color: var(--color-accent);
background: var(--color-accent-light);
}
.category-chip--active {
background: var(--color-accent-light);
border-color: var(--color-accent);
color: var(--color-accent);
font-weight: 700;
}
.pending-action-row {
display: flex;
gap: var(--space-2);
margin-top: var(--space-1);
}
.btn-confirm {
padding: var(--space-1) var(--space-3);
border-radius: var(--radius-md);
border: 1px solid var(--color-accent);
background: var(--color-accent-light);
color: var(--color-accent);
font-size: 0.85rem;
font-weight: 600;
cursor: pointer;
}
.btn-confirm:hover {
background: color-mix(in srgb, var(--color-accent) 15%, transparent);
}
.btn-cancel-pending {
padding: var(--space-1) var(--space-3);
border-radius: var(--radius-md);
border: 1px solid var(--color-border);
background: none;
color: var(--color-text-muted);
font-size: 0.85rem;
cursor: pointer;
}
.btn-cancel-pending:hover {
background: var(--color-surface-alt);
}
</style> </style>

View file

@ -2,6 +2,14 @@
import { defineStore } from 'pinia' import { defineStore } from 'pinia'
import { computed, ref } from 'vue' import { computed, ref } from 'vue'
export type SftFailureCategory =
| 'scoring_artifact'
| 'style_violation'
| 'partial_answer'
| 'wrong_answer'
| 'format_error'
| 'hallucination'
export interface SftQueueItem { export interface SftQueueItem {
id: string id: string
source: 'cf-orch-benchmark' source: 'cf-orch-benchmark'
@ -13,6 +21,7 @@ export interface SftQueueItem {
corrected_response: string | null corrected_response: string | null
quality_score: number // 0.0 to 1.0 quality_score: number // 0.0 to 1.0
failure_reason: string | null failure_reason: string | null
failure_category: SftFailureCategory | null
task_id: string task_id: string
task_type: string task_type: string
task_name: string task_name: string
@ -26,6 +35,7 @@ export interface SftQueueItem {
export interface SftLastAction { export interface SftLastAction {
type: 'correct' | 'discard' | 'flag' type: 'correct' | 'discard' | 'flag'
item: SftQueueItem item: SftQueueItem
failure_category?: SftFailureCategory | null
} }
export const useSftStore = defineStore('sft', () => { export const useSftStore = defineStore('sft', () => {
@ -39,8 +49,12 @@ export const useSftStore = defineStore('sft', () => {
queue.value.shift() queue.value.shift()
} }
function setLastAction(type: SftLastAction['type'], item: SftQueueItem) { function setLastAction(
lastAction.value = { type, item } type: SftLastAction['type'],
item: SftQueueItem,
failure_category?: SftFailureCategory | null,
) {
lastAction.value = { type, item, failure_category }
} }
function clearLastAction() { function clearLastAction() {

View file

@ -36,6 +36,7 @@
@flag="handleFlag" @flag="handleFlag"
@submit-correction="handleCorrect" @submit-correction="handleCorrect"
@cancel-correction="correcting = false" @cancel-correction="correcting = false"
ref="sftCardEl"
/> />
</div> </div>
</template> </template>
@ -67,6 +68,7 @@
<script setup lang="ts"> <script setup lang="ts">
import { ref, onMounted } from 'vue' import { ref, onMounted } from 'vue'
import { useSftStore } from '../stores/sft' import { useSftStore } from '../stores/sft'
import type { SftFailureCategory } from '../stores/sft'
import { useSftKeyboard } from '../composables/useSftKeyboard' import { useSftKeyboard } from '../composables/useSftKeyboard'
import SftCard from '../components/SftCard.vue' import SftCard from '../components/SftCard.vue'
@ -76,6 +78,7 @@ const apiError = ref(false)
const correcting = ref(false) const correcting = ref(false)
const stats = ref<Record<string, any> | null>(null) const stats = ref<Record<string, any> | null>(null)
const exportUrl = '/api/sft/export' const exportUrl = '/api/sft/export'
const sftCardEl = ref<InstanceType<typeof SftCard> | null>(null)
useSftKeyboard({ useSftKeyboard({
onCorrect: () => { if (store.current && !correcting.value) correcting.value = true }, onCorrect: () => { if (store.current && !correcting.value) correcting.value = true },
@ -113,19 +116,21 @@ function startCorrection() {
correcting.value = true correcting.value = true
} }
async function handleCorrect(text: string) { async function handleCorrect(text: string, category: SftFailureCategory | null = null) {
if (!store.current) return if (!store.current) return
const item = store.current const item = store.current
correcting.value = false correcting.value = false
try { try {
const body: Record<string, unknown> = { id: item.id, action: 'correct', corrected_response: text }
if (category != null) body.failure_category = category
const res = await fetch('/api/sft/submit', { const res = await fetch('/api/sft/submit', {
method: 'POST', method: 'POST',
headers: { 'Content-Type': 'application/json' }, headers: { 'Content-Type': 'application/json' },
body: JSON.stringify({ id: item.id, action: 'correct', corrected_response: text }), body: JSON.stringify(body),
}) })
if (!res.ok) throw new Error(`HTTP ${res.status}`) if (!res.ok) throw new Error(`HTTP ${res.status}`)
store.removeCurrentFromQueue() store.removeCurrentFromQueue()
store.setLastAction('correct', item) store.setLastAction('correct', item, category)
store.totalRemaining = Math.max(0, store.totalRemaining - 1) store.totalRemaining = Math.max(0, store.totalRemaining - 1)
fetchStats() fetchStats()
if (store.queue.length < 5) fetchBatch() if (store.queue.length < 5) fetchBatch()
@ -134,18 +139,20 @@ async function handleCorrect(text: string) {
} }
} }
async function handleDiscard() { async function handleDiscard(category: SftFailureCategory | null = null) {
if (!store.current) return if (!store.current) return
const item = store.current const item = store.current
try { try {
const body: Record<string, unknown> = { id: item.id, action: 'discard' }
if (category != null) body.failure_category = category
const res = await fetch('/api/sft/submit', { const res = await fetch('/api/sft/submit', {
method: 'POST', method: 'POST',
headers: { 'Content-Type': 'application/json' }, headers: { 'Content-Type': 'application/json' },
body: JSON.stringify({ id: item.id, action: 'discard' }), body: JSON.stringify(body),
}) })
if (!res.ok) throw new Error(`HTTP ${res.status}`) if (!res.ok) throw new Error(`HTTP ${res.status}`)
store.removeCurrentFromQueue() store.removeCurrentFromQueue()
store.setLastAction('discard', item) store.setLastAction('discard', item, category)
store.totalRemaining = Math.max(0, store.totalRemaining - 1) store.totalRemaining = Math.max(0, store.totalRemaining - 1)
fetchStats() fetchStats()
if (store.queue.length < 5) fetchBatch() if (store.queue.length < 5) fetchBatch()
@ -154,18 +161,20 @@ async function handleDiscard() {
} }
} }
async function handleFlag() { async function handleFlag(category: SftFailureCategory | null = null) {
if (!store.current) return if (!store.current) return
const item = store.current const item = store.current
try { try {
const body: Record<string, unknown> = { id: item.id, action: 'flag' }
if (category != null) body.failure_category = category
const res = await fetch('/api/sft/submit', { const res = await fetch('/api/sft/submit', {
method: 'POST', method: 'POST',
headers: { 'Content-Type': 'application/json' }, headers: { 'Content-Type': 'application/json' },
body: JSON.stringify({ id: item.id, action: 'flag' }), body: JSON.stringify(body),
}) })
if (!res.ok) throw new Error(`HTTP ${res.status}`) if (!res.ok) throw new Error(`HTTP ${res.status}`)
store.removeCurrentFromQueue() store.removeCurrentFromQueue()
store.setLastAction('flag', item) store.setLastAction('flag', item, category)
store.totalRemaining = Math.max(0, store.totalRemaining - 1) store.totalRemaining = Math.max(0, store.totalRemaining - 1)
fetchStats() fetchStats()
if (store.queue.length < 5) fetchBatch() if (store.queue.length < 5) fetchBatch()