Compare commits
34 commits
e03d91ece9
...
1fa5b9e2b0
| Author | SHA1 | Date | |
|---|---|---|---|
| 1fa5b9e2b0 | |||
| 8e016d7fe6 | |||
| 30f19711ec | |||
| 753f8f5def | |||
| 5dee23f53c | |||
| 606917f90f | |||
| 74ba9103cd | |||
| 95eb238add | |||
| 60fe1231ce | |||
| ef8adfb035 | |||
| 64fd19a7b6 | |||
| 8ba34bb2d1 | |||
| f262b23cf5 | |||
| 5eb593569d | |||
| 2d795b9573 | |||
| 36117b35c4 | |||
| da8478082e | |||
| 7a4ca422ca | |||
| 71d0bfafe6 | |||
| 8c22dd62de | |||
| d3ae5b576a | |||
| 8c26cbe597 | |||
| fb1ce89244 | |||
| e01f743c39 | |||
| 6829bebdd4 | |||
| ddb6025c89 | |||
| d410fa5c80 | |||
| 6c98ee6d69 | |||
| 3197252c31 | |||
| 4bea1b6812 | |||
| d418a719f0 | |||
| 144a9b29b3 | |||
| 8af63d959b | |||
| cfa5ed2194 |
31 changed files with 5595 additions and 145 deletions
7
PRIVACY.md
Normal file
7
PRIVACY.md
Normal file
|
|
@ -0,0 +1,7 @@
|
|||
# Privacy Policy
|
||||
|
||||
CircuitForge LLC's privacy policy applies to this product and is published at:
|
||||
|
||||
**<https://circuitforge.tech/privacy>**
|
||||
|
||||
Last reviewed: March 2026.
|
||||
226
app/api.py
226
app/api.py
|
|
@ -7,6 +7,8 @@ from __future__ import annotations
|
|||
|
||||
import hashlib
|
||||
import json
|
||||
import os
|
||||
import subprocess as _subprocess
|
||||
import yaml
|
||||
from pathlib import Path
|
||||
|
||||
|
|
@ -16,8 +18,14 @@ from fastapi import FastAPI, HTTPException, Query
|
|||
from pydantic import BaseModel
|
||||
|
||||
_ROOT = Path(__file__).parent.parent
|
||||
_DATA_DIR: Path = _ROOT / "data" # overridable in tests via set_data_dir()
|
||||
_CONFIG_DIR: Path | None = None # None = use real path
|
||||
_DATA_DIR: Path = _ROOT / "data" # overridable in tests via set_data_dir()
|
||||
_MODELS_DIR: Path = _ROOT / "models" # overridable in tests via set_models_dir()
|
||||
_CONFIG_DIR: Path | None = None # None = use real path
|
||||
|
||||
# Process registry for running jobs — used by cancel endpoints.
|
||||
# Keys: "benchmark" | "finetune". Values: the live Popen object.
|
||||
_running_procs: dict = {}
|
||||
_cancelled_jobs: set = set()
|
||||
|
||||
|
||||
def set_data_dir(path: Path) -> None:
|
||||
|
|
@ -26,6 +34,40 @@ def set_data_dir(path: Path) -> None:
|
|||
_DATA_DIR = path
|
||||
|
||||
|
||||
def _best_cuda_device() -> str:
|
||||
"""Return the index of the GPU with the most free VRAM as a string.
|
||||
|
||||
Uses nvidia-smi so it works in the job-seeker env (no torch). Returns ""
|
||||
if nvidia-smi is unavailable or no GPUs are found. Restricting the
|
||||
training subprocess to a single GPU via CUDA_VISIBLE_DEVICES prevents
|
||||
PyTorch DataParallel from replicating the model across all GPUs, which
|
||||
would OOM the GPU with less headroom.
|
||||
"""
|
||||
try:
|
||||
out = _subprocess.check_output(
|
||||
["nvidia-smi", "--query-gpu=index,memory.free",
|
||||
"--format=csv,noheader,nounits"],
|
||||
text=True,
|
||||
timeout=5,
|
||||
)
|
||||
best_idx, best_free = "", 0
|
||||
for line in out.strip().splitlines():
|
||||
parts = line.strip().split(", ")
|
||||
if len(parts) == 2:
|
||||
idx, free = parts[0].strip(), int(parts[1].strip())
|
||||
if free > best_free:
|
||||
best_free, best_idx = free, idx
|
||||
return best_idx
|
||||
except Exception:
|
||||
return ""
|
||||
|
||||
|
||||
def set_models_dir(path: Path) -> None:
|
||||
"""Override models directory — used by tests."""
|
||||
global _MODELS_DIR
|
||||
_MODELS_DIR = path
|
||||
|
||||
|
||||
def set_config_dir(path: Path | None) -> None:
|
||||
"""Override config directory — used by tests."""
|
||||
global _CONFIG_DIR
|
||||
|
|
@ -287,6 +329,186 @@ def test_account(req: AccountTestRequest):
|
|||
from fastapi.responses import StreamingResponse
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Benchmark endpoints
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@app.get("/api/benchmark/results")
|
||||
def get_benchmark_results():
|
||||
"""Return the most recently saved benchmark results, or an empty envelope."""
|
||||
path = _DATA_DIR / "benchmark_results.json"
|
||||
if not path.exists():
|
||||
return {"models": {}, "sample_count": 0, "timestamp": None}
|
||||
return json.loads(path.read_text())
|
||||
|
||||
|
||||
@app.get("/api/benchmark/run")
|
||||
def run_benchmark(include_slow: bool = False):
|
||||
"""Spawn the benchmark script and stream stdout as SSE progress events."""
|
||||
python_bin = "/devl/miniconda3/envs/job-seeker-classifiers/bin/python"
|
||||
script = str(_ROOT / "scripts" / "benchmark_classifier.py")
|
||||
cmd = [python_bin, script, "--score", "--save"]
|
||||
if include_slow:
|
||||
cmd.append("--include-slow")
|
||||
|
||||
def generate():
|
||||
try:
|
||||
proc = _subprocess.Popen(
|
||||
cmd,
|
||||
stdout=_subprocess.PIPE,
|
||||
stderr=_subprocess.STDOUT,
|
||||
text=True,
|
||||
bufsize=1,
|
||||
cwd=str(_ROOT),
|
||||
)
|
||||
_running_procs["benchmark"] = proc
|
||||
_cancelled_jobs.discard("benchmark") # clear any stale flag from a prior run
|
||||
try:
|
||||
for line in proc.stdout:
|
||||
line = line.rstrip()
|
||||
if line:
|
||||
yield f"data: {json.dumps({'type': 'progress', 'message': line})}\n\n"
|
||||
proc.wait()
|
||||
if proc.returncode == 0:
|
||||
yield f"data: {json.dumps({'type': 'complete'})}\n\n"
|
||||
elif "benchmark" in _cancelled_jobs:
|
||||
_cancelled_jobs.discard("benchmark")
|
||||
yield f"data: {json.dumps({'type': 'cancelled'})}\n\n"
|
||||
else:
|
||||
yield f"data: {json.dumps({'type': 'error', 'message': f'Process exited with code {proc.returncode}'})}\n\n"
|
||||
finally:
|
||||
_running_procs.pop("benchmark", None)
|
||||
except Exception as exc:
|
||||
yield f"data: {json.dumps({'type': 'error', 'message': str(exc)})}\n\n"
|
||||
|
||||
return StreamingResponse(
|
||||
generate(),
|
||||
media_type="text/event-stream",
|
||||
headers={"Cache-Control": "no-cache", "X-Accel-Buffering": "no"},
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Finetune endpoints
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@app.get("/api/finetune/status")
|
||||
def get_finetune_status():
|
||||
"""Scan models/ for training_info.json files. Returns [] if none exist."""
|
||||
models_dir = _MODELS_DIR
|
||||
if not models_dir.exists():
|
||||
return []
|
||||
results = []
|
||||
for sub in models_dir.iterdir():
|
||||
if not sub.is_dir():
|
||||
continue
|
||||
info_path = sub / "training_info.json"
|
||||
if not info_path.exists():
|
||||
continue
|
||||
try:
|
||||
info = json.loads(info_path.read_text(encoding="utf-8"))
|
||||
results.append(info)
|
||||
except Exception:
|
||||
pass
|
||||
return results
|
||||
|
||||
|
||||
@app.get("/api/finetune/run")
|
||||
def run_finetune_endpoint(
|
||||
model: str = "deberta-small",
|
||||
epochs: int = 5,
|
||||
score: list[str] = Query(default=[]),
|
||||
):
|
||||
"""Spawn finetune_classifier.py and stream stdout as SSE progress events."""
|
||||
python_bin = "/devl/miniconda3/envs/job-seeker-classifiers/bin/python"
|
||||
script = str(_ROOT / "scripts" / "finetune_classifier.py")
|
||||
cmd = [python_bin, script, "--model", model, "--epochs", str(epochs)]
|
||||
data_root = _DATA_DIR.resolve()
|
||||
for score_file in score:
|
||||
resolved = (_DATA_DIR / score_file).resolve()
|
||||
if not str(resolved).startswith(str(data_root)):
|
||||
raise HTTPException(400, f"Invalid score path: {score_file!r}")
|
||||
cmd.extend(["--score", str(resolved)])
|
||||
|
||||
# Pick the GPU with the most free VRAM. Setting CUDA_VISIBLE_DEVICES to a
|
||||
# single device prevents DataParallel from replicating the model across all
|
||||
# GPUs, which would force a full copy onto the more memory-constrained device.
|
||||
proc_env = {**os.environ, "PYTORCH_ALLOC_CONF": "expandable_segments:True"}
|
||||
best_gpu = _best_cuda_device()
|
||||
if best_gpu:
|
||||
proc_env["CUDA_VISIBLE_DEVICES"] = best_gpu
|
||||
|
||||
gpu_note = f"GPU {best_gpu}" if best_gpu else "CPU (no GPU found)"
|
||||
|
||||
def generate():
|
||||
yield f"data: {json.dumps({'type': 'progress', 'message': f'[api] Using {gpu_note} (most free VRAM)'})}\n\n"
|
||||
try:
|
||||
proc = _subprocess.Popen(
|
||||
cmd,
|
||||
stdout=_subprocess.PIPE,
|
||||
stderr=_subprocess.STDOUT,
|
||||
text=True,
|
||||
bufsize=1,
|
||||
cwd=str(_ROOT),
|
||||
env=proc_env,
|
||||
)
|
||||
_running_procs["finetune"] = proc
|
||||
_cancelled_jobs.discard("finetune") # clear any stale flag from a prior run
|
||||
try:
|
||||
for line in proc.stdout:
|
||||
line = line.rstrip()
|
||||
if line:
|
||||
yield f"data: {json.dumps({'type': 'progress', 'message': line})}\n\n"
|
||||
proc.wait()
|
||||
if proc.returncode == 0:
|
||||
yield f"data: {json.dumps({'type': 'complete'})}\n\n"
|
||||
elif "finetune" in _cancelled_jobs:
|
||||
_cancelled_jobs.discard("finetune")
|
||||
yield f"data: {json.dumps({'type': 'cancelled'})}\n\n"
|
||||
else:
|
||||
yield f"data: {json.dumps({'type': 'error', 'message': f'Process exited with code {proc.returncode}'})}\n\n"
|
||||
finally:
|
||||
_running_procs.pop("finetune", None)
|
||||
except Exception as exc:
|
||||
yield f"data: {json.dumps({'type': 'error', 'message': str(exc)})}\n\n"
|
||||
|
||||
return StreamingResponse(
|
||||
generate(),
|
||||
media_type="text/event-stream",
|
||||
headers={"Cache-Control": "no-cache", "X-Accel-Buffering": "no"},
|
||||
)
|
||||
|
||||
|
||||
@app.post("/api/benchmark/cancel")
|
||||
def cancel_benchmark():
|
||||
"""Kill the running benchmark subprocess. 404 if none is running."""
|
||||
proc = _running_procs.get("benchmark")
|
||||
if proc is None:
|
||||
raise HTTPException(404, "No benchmark is running")
|
||||
_cancelled_jobs.add("benchmark")
|
||||
proc.terminate()
|
||||
try:
|
||||
proc.wait(timeout=3)
|
||||
except _subprocess.TimeoutExpired:
|
||||
proc.kill()
|
||||
return {"status": "cancelled"}
|
||||
|
||||
|
||||
@app.post("/api/finetune/cancel")
|
||||
def cancel_finetune():
|
||||
"""Kill the running fine-tune subprocess. 404 if none is running."""
|
||||
proc = _running_procs.get("finetune")
|
||||
if proc is None:
|
||||
raise HTTPException(404, "No finetune is running")
|
||||
_cancelled_jobs.add("finetune")
|
||||
proc.terminate()
|
||||
try:
|
||||
proc.wait(timeout=3)
|
||||
except _subprocess.TimeoutExpired:
|
||||
proc.kill()
|
||||
return {"status": "cancelled"}
|
||||
|
||||
|
||||
@app.get("/api/fetch/stream")
|
||||
def fetch_stream(
|
||||
accounts: str = Query(default=""),
|
||||
|
|
|
|||
95
docs/plans/2026-03-08-anime-animation-design.md
Normal file
95
docs/plans/2026-03-08-anime-animation-design.md
Normal file
|
|
@ -0,0 +1,95 @@
|
|||
# Anime.js Animation Integration — Design
|
||||
|
||||
**Date:** 2026-03-08
|
||||
**Status:** Approved
|
||||
**Branch:** feat/vue-label-tab
|
||||
|
||||
## Problem
|
||||
|
||||
The current animation system mixes CSS keyframes, CSS transitions, and imperative inline-style bindings across three files. The seams between systems produce:
|
||||
|
||||
- Abrupt ball pickup (instant scale/borderRadius jump)
|
||||
- No spring snap-back on release to no target
|
||||
- Rigid CSS dismissals with no timing control
|
||||
- Bucket grid and badge pop on basic `@keyframes`
|
||||
|
||||
## Decision
|
||||
|
||||
Integrate **Anime.js v4** as a single animation layer. Vue reactive state is unchanged; Anime.js owns all DOM motion imperatively.
|
||||
|
||||
## Architecture
|
||||
|
||||
One new composable, minimal changes to two existing files, CSS cleanup in two files.
|
||||
|
||||
```
|
||||
web/src/composables/useCardAnimation.ts ← NEW
|
||||
web/src/components/EmailCardStack.vue ← modify
|
||||
web/src/views/LabelView.vue ← modify
|
||||
```
|
||||
|
||||
**Data flow:**
|
||||
```
|
||||
pointer events → Vue refs (isHeld, deltaX, deltaY, dismissType)
|
||||
↓ watched by
|
||||
useCardAnimation(cardEl, stackEl, isHeld, ...)
|
||||
↓ imperatively drives
|
||||
Anime.js → DOM transforms
|
||||
```
|
||||
|
||||
`useCardAnimation` is a pure side-effect composable — returns nothing to the template. The `cardStyle` computed in `EmailCardStack.vue` is removed; Anime.js owns the element's transform directly.
|
||||
|
||||
## Animation Surfaces
|
||||
|
||||
### Pickup morph
|
||||
```
|
||||
animate(cardEl, { scale: 0.55, borderRadius: '50%', y: -80 }, { duration: 200, ease: spring(1, 80, 10) })
|
||||
```
|
||||
Replaces the instant CSS transform jump on `onPointerDown`.
|
||||
|
||||
### Drag tracking
|
||||
Raw `cardEl.style.translate` update on `onPointerMove` — no animation, just position. Easing only at boundaries (pickup / release), not during active drag.
|
||||
|
||||
### Snap-back
|
||||
```
|
||||
animate(cardEl, { x: 0, y: 0, scale: 1, borderRadius: '1rem' }, { ease: spring(1, 80, 10) })
|
||||
```
|
||||
Fires on `onPointerUp` when no zone/bucket target was hit.
|
||||
|
||||
### Dismissals (replace CSS `@keyframes`)
|
||||
- **fileAway** — `animate(cardEl, { y: '-120%', scale: 0.85, opacity: 0 }, { duration: 280, ease: 'out(3)' })`
|
||||
- **crumple** — 2-step timeline: shrink + redden → `scale(0)` + rotate
|
||||
- **slideUnder** — `animate(cardEl, { x: '110%', rotate: 5, opacity: 0 }, { duration: 260 })`
|
||||
|
||||
### Bucket grid rise
|
||||
`animate(gridEl, { y: -8, opacity: 0.45 })` on `isHeld` → true; reversed on false. Spring easing.
|
||||
|
||||
### Badge pop
|
||||
`animate(badgeEl, { scale: [0.6, 1], opacity: [0, 1] }, { ease: spring(1.5, 80, 8), duration: 300 })` triggered on badge mount via Vue's `onMounted` lifecycle hook in a `BadgePop` wrapper component or `v-enter-active` transition hook.
|
||||
|
||||
## Constraints
|
||||
|
||||
### Reduced motion
|
||||
`useCardAnimation` checks `motion.rich.value` before firing any Anime.js call. If false, all animations are skipped — instant state changes only. Consistent with existing `useMotion` pattern.
|
||||
|
||||
### Bundle size
|
||||
Anime.js v4 core ~17KB gzipped. Only `animate`, `spring`, and `createTimeline` are imported — Vite ESM tree-shaking keeps footprint minimal. The `draggable` module is not used.
|
||||
|
||||
### Tests
|
||||
Existing `EmailCardStack.test.ts` tests emit behavior, not animation — they remain passing. Anime.js mocked at module level in Vitest via `vi.mock('animejs')` where needed.
|
||||
|
||||
### CSS cleanup
|
||||
Remove from `EmailCardStack.vue` and `LabelView.vue`:
|
||||
- `@keyframes fileAway`, `crumple`, `slideUnder`
|
||||
- `@keyframes badge-pop`
|
||||
- `.dismiss-label`, `.dismiss-skip`, `.dismiss-discard` classes (Anime.js fires on element refs directly)
|
||||
- The `dismissClass` computed in `EmailCardStack.vue`
|
||||
|
||||
## Files Changed
|
||||
|
||||
| File | Change |
|
||||
|------|--------|
|
||||
| `web/package.json` | Add `animejs` dependency |
|
||||
| `web/src/composables/useCardAnimation.ts` | New — all Anime.js animation logic |
|
||||
| `web/src/components/EmailCardStack.vue` | Remove `cardStyle` computed + dismiss classes; call `useCardAnimation` |
|
||||
| `web/src/views/LabelView.vue` | Badge pop + bucket grid rise via Anime.js |
|
||||
| `web/src/assets/avocet.css` | Remove any global animation keyframes if present |
|
||||
573
docs/plans/2026-03-08-anime-animation-plan.md
Normal file
573
docs/plans/2026-03-08-anime-animation-plan.md
Normal file
|
|
@ -0,0 +1,573 @@
|
|||
# Anime.js Animation Integration — Implementation Plan
|
||||
|
||||
> **For Claude:** REQUIRED SUB-SKILL: Use superpowers:executing-plans to implement this plan task-by-task.
|
||||
|
||||
**Goal:** Replace the current mixed CSS keyframes / inline-style animation system with Anime.js v4 for all card motion — pickup morph, drag tracking, spring snap-back, dismissals, bucket grid rise, and badge pop.
|
||||
|
||||
**Architecture:** A new `useCardAnimation` composable owns all Anime.js calls imperatively against DOM refs. Vue reactive state (`isHeld`, `deltaX`, `deltaY`, `dismissType`) is unchanged. `cardStyle` computed and `dismissClass` computed are deleted; Anime.js writes to the element directly.
|
||||
|
||||
**Tech Stack:** Anime.js v4 (`animejs`), Vue 3 Composition API, `@vue/test-utils` + Vitest for tests.
|
||||
|
||||
---
|
||||
|
||||
## Task 1: Install Anime.js
|
||||
|
||||
**Files:**
|
||||
- Modify: `web/package.json`
|
||||
|
||||
**Step 1: Install the package**
|
||||
|
||||
```bash
|
||||
cd /Library/Development/CircuitForge/avocet/web
|
||||
npm install animejs
|
||||
```
|
||||
|
||||
**Step 2: Verify the import resolves**
|
||||
|
||||
Create a throwaway check — open `web/src/main.ts` briefly and confirm:
|
||||
```ts
|
||||
import { animate, spring } from 'animejs'
|
||||
```
|
||||
resolves without error in the editor (TypeScript types ship with animejs v4).
|
||||
Remove the import immediately after verifying — do not commit it.
|
||||
|
||||
**Step 3: Commit**
|
||||
|
||||
```bash
|
||||
cd /Library/Development/CircuitForge/avocet/web
|
||||
git add package.json package-lock.json
|
||||
git commit -m "feat(avocet): add animejs v4 dependency"
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Task 2: Create `useCardAnimation` composable
|
||||
|
||||
**Files:**
|
||||
- Create: `web/src/composables/useCardAnimation.ts`
|
||||
- Create: `web/src/composables/useCardAnimation.test.ts`
|
||||
|
||||
**Background — Anime.js v4 transform model:**
|
||||
Anime.js v4 tracks `x`, `y`, `scale`, `rotate`, etc. as separate transform components internally.
|
||||
Use `utils.set(el, props)` for instant (no-animation) property updates — this keeps the internal cache consistent.
|
||||
Never mix direct `el.style.transform = "..."` with Anime.js on the same element, or the cache desyncs.
|
||||
|
||||
**Step 1: Write the failing tests**
|
||||
|
||||
`web/src/composables/useCardAnimation.test.ts`:
|
||||
```ts
|
||||
import { ref, nextTick } from 'vue'
|
||||
import { describe, it, expect, vi, beforeEach } from 'vitest'
|
||||
|
||||
// Mock animejs before importing the composable
|
||||
vi.mock('animejs', () => ({
|
||||
animate: vi.fn(),
|
||||
spring: vi.fn(() => 'mock-spring'),
|
||||
utils: { set: vi.fn() },
|
||||
}))
|
||||
|
||||
import { useCardAnimation } from './useCardAnimation'
|
||||
import { animate, utils } from 'animejs'
|
||||
|
||||
const mockAnimate = animate as ReturnType<typeof vi.fn>
|
||||
const mockSet = utils.set as ReturnType<typeof vi.fn>
|
||||
|
||||
function makeEl() {
|
||||
return document.createElement('div')
|
||||
}
|
||||
|
||||
describe('useCardAnimation', () => {
|
||||
beforeEach(() => {
|
||||
vi.clearAllMocks()
|
||||
})
|
||||
|
||||
it('pickup() calls animate with ball shape', () => {
|
||||
const el = makeEl()
|
||||
const cardEl = ref<HTMLElement | null>(el)
|
||||
const motion = { rich: ref(true) }
|
||||
const { pickup } = useCardAnimation(cardEl, motion)
|
||||
pickup()
|
||||
expect(mockAnimate).toHaveBeenCalledWith(
|
||||
el,
|
||||
expect.objectContaining({ scale: 0.55, borderRadius: '50%' }),
|
||||
expect.anything(),
|
||||
)
|
||||
})
|
||||
|
||||
it('pickup() is a no-op when motion.rich is false', () => {
|
||||
const el = makeEl()
|
||||
const cardEl = ref<HTMLElement | null>(el)
|
||||
const motion = { rich: ref(false) }
|
||||
const { pickup } = useCardAnimation(cardEl, motion)
|
||||
pickup()
|
||||
expect(mockAnimate).not.toHaveBeenCalled()
|
||||
})
|
||||
|
||||
it('setDragPosition() calls utils.set with translated coords', () => {
|
||||
const el = makeEl()
|
||||
const cardEl = ref<HTMLElement | null>(el)
|
||||
const motion = { rich: ref(true) }
|
||||
const { setDragPosition } = useCardAnimation(cardEl, motion)
|
||||
setDragPosition(50, 30)
|
||||
expect(mockSet).toHaveBeenCalledWith(el, expect.objectContaining({ x: 50, y: -50 }))
|
||||
// y = deltaY - 80 = 30 - 80 = -50
|
||||
})
|
||||
|
||||
it('snapBack() calls animate returning to card shape', () => {
|
||||
const el = makeEl()
|
||||
const cardEl = ref<HTMLElement | null>(el)
|
||||
const motion = { rich: ref(true) }
|
||||
const { snapBack } = useCardAnimation(cardEl, motion)
|
||||
snapBack()
|
||||
expect(mockAnimate).toHaveBeenCalledWith(
|
||||
el,
|
||||
expect.objectContaining({ x: 0, y: 0, scale: 1 }),
|
||||
expect.anything(),
|
||||
)
|
||||
})
|
||||
|
||||
it('animateDismiss("label") calls animate', () => {
|
||||
const el = makeEl()
|
||||
const cardEl = ref<HTMLElement | null>(el)
|
||||
const motion = { rich: ref(true) }
|
||||
const { animateDismiss } = useCardAnimation(cardEl, motion)
|
||||
animateDismiss('label')
|
||||
expect(mockAnimate).toHaveBeenCalled()
|
||||
})
|
||||
|
||||
it('animateDismiss("discard") calls animate', () => {
|
||||
const el = makeEl()
|
||||
const cardEl = ref<HTMLElement | null>(el)
|
||||
const motion = { rich: ref(true) }
|
||||
const { animateDismiss } = useCardAnimation(cardEl, motion)
|
||||
animateDismiss('discard')
|
||||
expect(mockAnimate).toHaveBeenCalled()
|
||||
})
|
||||
|
||||
it('animateDismiss("skip") calls animate', () => {
|
||||
const el = makeEl()
|
||||
const cardEl = ref<HTMLElement | null>(el)
|
||||
const motion = { rich: ref(true) }
|
||||
const { animateDismiss } = useCardAnimation(cardEl, motion)
|
||||
animateDismiss('skip')
|
||||
expect(mockAnimate).toHaveBeenCalled()
|
||||
})
|
||||
|
||||
it('animateDismiss is a no-op when motion.rich is false', () => {
|
||||
const el = makeEl()
|
||||
const cardEl = ref<HTMLElement | null>(el)
|
||||
const motion = { rich: ref(false) }
|
||||
const { animateDismiss } = useCardAnimation(cardEl, motion)
|
||||
animateDismiss('label')
|
||||
expect(mockAnimate).not.toHaveBeenCalled()
|
||||
})
|
||||
})
|
||||
```
|
||||
|
||||
**Step 2: Run tests to confirm they fail**
|
||||
|
||||
```bash
|
||||
cd /Library/Development/CircuitForge/avocet/web
|
||||
npm test -- useCardAnimation
|
||||
```
|
||||
|
||||
Expected: FAIL — "Cannot find module './useCardAnimation'"
|
||||
|
||||
**Step 3: Implement the composable**
|
||||
|
||||
`web/src/composables/useCardAnimation.ts`:
|
||||
```ts
|
||||
import { type Ref } from 'vue'
|
||||
import { animate, spring, utils } from 'animejs'
|
||||
|
||||
const BALL_SCALE = 0.55
|
||||
const BALL_RADIUS = '50%'
|
||||
const CARD_RADIUS = '1rem'
|
||||
const PICKUP_Y_OFFSET = 80 // px above finger
|
||||
const PICKUP_DURATION = 200
|
||||
// NOTE: animejs v4 — spring() takes an object, not positional args
|
||||
const SNAP_SPRING = spring({ mass: 1, stiffness: 80, damping: 10 })
|
||||
|
||||
interface Motion { rich: Ref<boolean> }
|
||||
|
||||
export function useCardAnimation(
|
||||
cardEl: Ref<HTMLElement | null>,
|
||||
motion: Motion,
|
||||
) {
|
||||
function pickup() {
|
||||
if (!motion.rich.value || !cardEl.value) return
|
||||
// NOTE: animejs v4 — animate() is 2-arg; timing options merge into the params object
|
||||
animate(cardEl.value, {
|
||||
scale: BALL_SCALE,
|
||||
borderRadius: BALL_RADIUS,
|
||||
y: -PICKUP_Y_OFFSET,
|
||||
duration: PICKUP_DURATION,
|
||||
ease: SNAP_SPRING,
|
||||
})
|
||||
}
|
||||
|
||||
function setDragPosition(dx: number, dy: number) {
|
||||
if (!cardEl.value) return
|
||||
utils.set(cardEl.value, { x: dx, y: dy - PICKUP_Y_OFFSET })
|
||||
}
|
||||
|
||||
function snapBack() {
|
||||
if (!motion.rich.value || !cardEl.value) return
|
||||
// No duration — spring physics determines settling time
|
||||
animate(cardEl.value, {
|
||||
x: 0,
|
||||
y: 0,
|
||||
scale: 1,
|
||||
borderRadius: CARD_RADIUS,
|
||||
ease: SNAP_SPRING,
|
||||
})
|
||||
}
|
||||
|
||||
function animateDismiss(type: 'label' | 'skip' | 'discard') {
|
||||
if (!motion.rich.value || !cardEl.value) return
|
||||
const el = cardEl.value
|
||||
if (type === 'label') {
|
||||
animate(el, { y: '-120%', scale: 0.85, opacity: 0, duration: 280, ease: 'out(3)' })
|
||||
} else if (type === 'discard') {
|
||||
// Two-step: crumple then shrink (keyframes array in params object)
|
||||
animate(el, { keyframes: [
|
||||
{ scale: 0.95, rotate: 2, filter: 'brightness(0.6) sepia(1) hue-rotate(-20deg)', duration: 140 },
|
||||
{ scale: 0, rotate: 8, opacity: 0, duration: 210 },
|
||||
])
|
||||
} else if (type === 'skip') {
|
||||
animate(el, { x: '110%', rotate: 5, opacity: 0 }, { duration: 260, ease: 'out(2)' })
|
||||
}
|
||||
}
|
||||
|
||||
return { pickup, setDragPosition, snapBack, animateDismiss }
|
||||
}
|
||||
```
|
||||
|
||||
**Step 4: Run tests — expect pass**
|
||||
|
||||
```bash
|
||||
npm test -- useCardAnimation
|
||||
```
|
||||
|
||||
Expected: All 8 tests PASS.
|
||||
|
||||
**Step 5: Commit**
|
||||
|
||||
```bash
|
||||
git add web/src/composables/useCardAnimation.ts web/src/composables/useCardAnimation.test.ts
|
||||
git commit -m "feat(avocet): add useCardAnimation composable with Anime.js"
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Task 3: Wire `useCardAnimation` into `EmailCardStack.vue`
|
||||
|
||||
**Files:**
|
||||
- Modify: `web/src/components/EmailCardStack.vue`
|
||||
- Modify: `web/src/components/EmailCardStack.test.ts`
|
||||
|
||||
**What changes:**
|
||||
- Remove `cardStyle` computed and `:style="cardStyle"` binding
|
||||
- Remove `dismissClass` computed and `:class="[dismissClass, ...]"` binding (keep `is-held`)
|
||||
- Remove `deltaX`, `deltaY` reactive refs (position now owned by Anime.js)
|
||||
- Call `pickup()` in `onPointerDown`, `setDragPosition()` in `onPointerMove`, `snapBack()` in `onPointerUp` (no-target path)
|
||||
- Watch `props.dismissType` and call `animateDismiss()`
|
||||
- Remove CSS `@keyframes fileAway`, `crumple`, `slideUnder` and their `.dismiss-*` rule blocks from `<style>`
|
||||
|
||||
**Step 1: Update the tests that check dismiss classes**
|
||||
|
||||
In `EmailCardStack.test.ts`, the 5 tests checking `.dismiss-label`, `.dismiss-discard`, `.dismiss-skip` classes are testing implementation (CSS class name), not behavior. Replace them with a single test that verifies `animateDismiss` is called:
|
||||
|
||||
```ts
|
||||
// Add at the top of the file (after existing imports):
|
||||
vi.mock('../composables/useCardAnimation', () => ({
|
||||
useCardAnimation: vi.fn(() => ({
|
||||
pickup: vi.fn(),
|
||||
setDragPosition: vi.fn(),
|
||||
snapBack: vi.fn(),
|
||||
animateDismiss: vi.fn(),
|
||||
})),
|
||||
}))
|
||||
|
||||
import { useCardAnimation } from '../composables/useCardAnimation'
|
||||
```
|
||||
|
||||
Replace the five `dismissType` class tests (lines 25–46) with:
|
||||
|
||||
```ts
|
||||
it('calls animateDismiss with type when dismissType prop changes', async () => {
|
||||
const w = mount(EmailCardStack, { props: { item, isBucketMode: false, dismissType: null } })
|
||||
const { animateDismiss } = (useCardAnimation as ReturnType<typeof vi.fn>).mock.results[0].value
|
||||
await w.setProps({ dismissType: 'label' })
|
||||
await nextTick()
|
||||
expect(animateDismiss).toHaveBeenCalledWith('label')
|
||||
})
|
||||
```
|
||||
|
||||
Add `nextTick` import to the test file header if not already present:
|
||||
```ts
|
||||
import { nextTick } from 'vue'
|
||||
```
|
||||
|
||||
**Step 2: Run tests to confirm the replaced tests fail**
|
||||
|
||||
```bash
|
||||
npm test -- EmailCardStack
|
||||
```
|
||||
|
||||
Expected: FAIL — `animateDismiss` not called (not yet wired in component)
|
||||
|
||||
**Step 3: Modify `EmailCardStack.vue`**
|
||||
|
||||
Script section changes:
|
||||
|
||||
```ts
|
||||
// Remove:
|
||||
// import { ref, computed } from 'vue' → change to:
|
||||
import { ref, watch } from 'vue'
|
||||
|
||||
// Add import:
|
||||
import { useCardAnimation } from '../composables/useCardAnimation'
|
||||
|
||||
// Remove these refs:
|
||||
// const deltaX = ref(0)
|
||||
// const deltaY = ref(0)
|
||||
|
||||
// Add after const motion = useMotion():
|
||||
const { pickup, setDragPosition, snapBack, animateDismiss } = useCardAnimation(cardEl, motion)
|
||||
|
||||
// Add watcher:
|
||||
watch(() => props.dismissType, (type) => {
|
||||
if (type) animateDismiss(type)
|
||||
})
|
||||
|
||||
// Remove dismissClass computed entirely.
|
||||
|
||||
// In onPointerDown — add after isHeld.value = true:
|
||||
pickup()
|
||||
|
||||
// In onPointerMove — replace deltaX/deltaY assignments with:
|
||||
const dx = e.clientX - pickupX.value
|
||||
const dy = e.clientY - pickupY.value
|
||||
setDragPosition(dx, dy)
|
||||
// (keep the zone/bucket detection that uses e.clientX/e.clientY — those stay the same)
|
||||
|
||||
// In onPointerUp — in the snap-back else branch, replace:
|
||||
// deltaX.value = 0
|
||||
// deltaY.value = 0
|
||||
// with:
|
||||
snapBack()
|
||||
```
|
||||
|
||||
Template changes — on the `.card-wrapper` div:
|
||||
```html
|
||||
<!-- Remove: :class="[dismissClass, { 'is-held': isHeld }]" -->
|
||||
<!-- Replace with: -->
|
||||
:class="{ 'is-held': isHeld }"
|
||||
<!-- Remove: :style="cardStyle" -->
|
||||
```
|
||||
|
||||
CSS changes in `<style scoped>` — delete these entire blocks:
|
||||
```
|
||||
@keyframes fileAway { ... }
|
||||
@keyframes crumple { ... }
|
||||
@keyframes slideUnder { ... }
|
||||
.card-wrapper.dismiss-label { ... }
|
||||
.card-wrapper.dismiss-discard { ... }
|
||||
.card-wrapper.dismiss-skip { ... }
|
||||
```
|
||||
|
||||
Also delete `--card-dismiss` and `--card-skip` CSS var usages if present.
|
||||
|
||||
**Step 4: Run all tests**
|
||||
|
||||
```bash
|
||||
npm test
|
||||
```
|
||||
|
||||
Expected: All pass (both `useCardAnimation.test.ts` and `EmailCardStack.test.ts`).
|
||||
|
||||
**Step 5: Commit**
|
||||
|
||||
```bash
|
||||
git add web/src/components/EmailCardStack.vue web/src/components/EmailCardStack.test.ts
|
||||
git commit -m "feat(avocet): wire Anime.js card animation into EmailCardStack"
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Task 4: Bucket grid rise animation
|
||||
|
||||
**Files:**
|
||||
- Modify: `web/src/views/LabelView.vue`
|
||||
|
||||
**What changes:**
|
||||
Replace the CSS class-toggle animation on `.bucket-grid-footer.grid-active` with an Anime.js watch in `LabelView.vue`. The `position: sticky → fixed` switch stays as a CSS class (can't animate position), but `translateY` and `opacity` move to Anime.js.
|
||||
|
||||
**Step 1: Add gridEl ref and import animate**
|
||||
|
||||
In `LabelView.vue` `<script setup>`:
|
||||
```ts
|
||||
// Add to imports:
|
||||
import { ref, onMounted, onUnmounted, watch } from 'vue'
|
||||
import { animate, spring } from 'animejs'
|
||||
|
||||
// Add ref:
|
||||
const gridEl = ref<HTMLElement | null>(null)
|
||||
```
|
||||
|
||||
**Step 2: Add watcher for isHeld**
|
||||
|
||||
```ts
|
||||
watch(isHeld, (held) => {
|
||||
if (!motion.rich.value || !gridEl.value) return
|
||||
// animejs v4: 2-arg animate, spring() takes object
|
||||
animate(gridEl.value,
|
||||
held
|
||||
? { y: -8, opacity: 0.45, ease: spring({ mass: 1, stiffness: 80, damping: 10 }), duration: 250 }
|
||||
: { y: 0, opacity: 1, ease: spring({ mass: 1, stiffness: 80, damping: 10 }), duration: 250 }
|
||||
)
|
||||
})
|
||||
```
|
||||
|
||||
**Step 3: Wire ref in template**
|
||||
|
||||
On the `.bucket-grid-footer` div:
|
||||
```html
|
||||
<div ref="gridEl" class="bucket-grid-footer" :class="{ 'grid-active': isHeld }">
|
||||
```
|
||||
|
||||
**Step 4: Remove CSS transition from `.bucket-grid-footer`**
|
||||
|
||||
In `LabelView.vue <style scoped>`, delete the `transition:` line from `.bucket-grid-footer`:
|
||||
```css
|
||||
/* DELETE this line: */
|
||||
transition: transform 250ms cubic-bezier(0.34, 1.56, 0.64, 1),
|
||||
opacity 200ms ease,
|
||||
background 200ms ease;
|
||||
```
|
||||
Keep the `transform: translateY(-8px)` and `opacity: 0.45` on `.bucket-grid-footer.grid-active` as fallback for reduced-motion users (no-JS fallback too).
|
||||
|
||||
Actually — keep `.grid-active` rules as-is for the no-motion path. The Anime.js `watch` guard (`if (!motion.rich.value)`) means reduced-motion users never hit Anime.js; the CSS class handles them.
|
||||
|
||||
**Step 5: Run tests**
|
||||
|
||||
```bash
|
||||
npm test
|
||||
```
|
||||
|
||||
Expected: All pass (LabelView has no dedicated tests, but full suite should be green).
|
||||
|
||||
**Step 6: Commit**
|
||||
|
||||
```bash
|
||||
git add web/src/views/LabelView.vue
|
||||
git commit -m "feat(avocet): animate bucket grid rise with Anime.js spring"
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Task 5: Badge pop animation
|
||||
|
||||
**Files:**
|
||||
- Modify: `web/src/views/LabelView.vue`
|
||||
|
||||
**What changes:**
|
||||
Replace `@keyframes badge-pop` (scale + opacity keyframe) with a Vue `<Transition>` `@enter` hook that calls `animate()`. Badges already appear/disappear via `v-if`, so they have natural mount/unmount lifecycle.
|
||||
|
||||
**Step 1: Wrap each badge in a `<Transition>`**
|
||||
|
||||
In `LabelView.vue` template, each badge `<span v-if="...">` gets wrapped:
|
||||
|
||||
```html
|
||||
<Transition @enter="onBadgeEnter" :css="false">
|
||||
<span v-if="onRoll" class="badge badge-roll">🔥 On a roll!</span>
|
||||
</Transition>
|
||||
<Transition @enter="onBadgeEnter" :css="false">
|
||||
<span v-if="speedRound" class="badge badge-speed">⚡ Speed round!</span>
|
||||
</Transition>
|
||||
<!-- repeat for all 6 badges -->
|
||||
```
|
||||
|
||||
`:css="false"` tells Vue not to apply any CSS transition classes — Anime.js owns the enter animation entirely.
|
||||
|
||||
**Step 2: Add `onBadgeEnter` hook**
|
||||
|
||||
```ts
|
||||
function onBadgeEnter(el: Element, done: () => void) {
|
||||
if (!motion.rich.value) { done(); return }
|
||||
animate(el as HTMLElement,
|
||||
{ scale: [0.6, 1], opacity: [0, 1] },
|
||||
{ ease: spring(1.5, 80, 8), duration: 300, onComplete: done }
|
||||
)
|
||||
}
|
||||
```
|
||||
|
||||
**Step 3: Remove `@keyframes badge-pop` from CSS**
|
||||
|
||||
In `LabelView.vue <style scoped>`:
|
||||
```css
|
||||
/* DELETE: */
|
||||
@keyframes badge-pop {
|
||||
from { transform: scale(0.6); opacity: 0; }
|
||||
to { transform: scale(1); opacity: 1; }
|
||||
}
|
||||
|
||||
/* DELETE animation line from .badge: */
|
||||
animation: badge-pop 0.3s cubic-bezier(0.34, 1.56, 0.64, 1);
|
||||
```
|
||||
|
||||
**Step 4: Run tests**
|
||||
|
||||
```bash
|
||||
npm test
|
||||
```
|
||||
|
||||
Expected: All pass.
|
||||
|
||||
**Step 5: Commit**
|
||||
|
||||
```bash
|
||||
git add web/src/views/LabelView.vue
|
||||
git commit -m "feat(avocet): badge pop via Anime.js spring transition hook"
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Task 6: Build and smoke test
|
||||
|
||||
**Step 1: Build the SPA**
|
||||
|
||||
```bash
|
||||
cd /Library/Development/CircuitForge/avocet
|
||||
./manage.sh start-api
|
||||
```
|
||||
|
||||
(This builds Vue + starts FastAPI on port 8503.)
|
||||
|
||||
**Step 2: Open the app**
|
||||
|
||||
```bash
|
||||
./manage.sh open-api
|
||||
```
|
||||
|
||||
**Step 3: Manual smoke test checklist**
|
||||
|
||||
- [ ] Pick up a card — ball morph is smooth (not instant jump)
|
||||
- [ ] Drag ball around — follows finger with no lag
|
||||
- [ ] Release in center — springs back to card with bounce
|
||||
- [ ] Release in left zone — discard fires (card crumples)
|
||||
- [ ] Release in right zone — skip fires (card slides right)
|
||||
- [ ] Release on a bucket — label fires (card files up)
|
||||
- [ ] Fling left fast — discard fires
|
||||
- [ ] Bucket grid rises smoothly on pickup, falls on release
|
||||
- [ ] Badge (label 10 in a row for 🔥) pops in with spring
|
||||
- [ ] Reduced motion: toggle in system settings → no animations, instant behavior
|
||||
- [ ] Keyboard labels (1–9) still work (pointer events unchanged)
|
||||
|
||||
**Step 4: Final commit if all green**
|
||||
|
||||
```bash
|
||||
git add -A
|
||||
git commit -m "feat(avocet): complete Anime.js animation integration"
|
||||
```
|
||||
1861
docs/superpowers/plans/2026-03-15-finetune-classifier.md
Normal file
1861
docs/superpowers/plans/2026-03-15-finetune-classifier.md
Normal file
File diff suppressed because it is too large
Load diff
254
docs/superpowers/specs/2026-03-15-finetune-classifier-design.md
Normal file
254
docs/superpowers/specs/2026-03-15-finetune-classifier-design.md
Normal file
|
|
@ -0,0 +1,254 @@
|
|||
# Fine-tune Email Classifier — Design Spec
|
||||
|
||||
**Date:** 2026-03-15
|
||||
**Status:** Approved
|
||||
**Scope:** Avocet — `scripts/`, `app/api.py`, `web/src/views/BenchmarkView.vue`, `environment.yml`
|
||||
|
||||
---
|
||||
|
||||
## Problem
|
||||
|
||||
The benchmark baseline shows zero-shot macro-F1 of 0.366 for the best models (`deberta-zeroshot`, `deberta-base-anli`). Zero-shot inference cannot improve with more labeled data. Fine-tuning the fastest models (`deberta-small` at 111ms, `bge-m3` at 123ms) on the growing labeled dataset is the path to meaningful accuracy gains.
|
||||
|
||||
---
|
||||
|
||||
## Constraints
|
||||
|
||||
- 501 labeled samples after dropping 2 non-canonical `profile_alert` rows
|
||||
- Heavy class imbalance: `digest` 29%, `neutral` 26%, `new_lead` 2.6%, `survey_received` 3%
|
||||
- 8.2 GB VRAM (shared with Peregrine vLLM during dev)
|
||||
- Target models: `cross-encoder/nli-deberta-v3-small` (100M params), `MoritzLaurer/bge-m3-zeroshot-v2.0` (600M params)
|
||||
- Output: local `models/avocet-{name}/` directory
|
||||
- UI-triggerable via web interface (SSE streaming log)
|
||||
- Stack: transformers 4.57.3, torch 2.10.0, accelerate 1.12.0, sklearn, CUDA 8.2GB
|
||||
|
||||
---
|
||||
|
||||
## Environment changes
|
||||
|
||||
`environment.yml` must add:
|
||||
- `scikit-learn` — required for `train_test_split(stratify=...)` and `f1_score`
|
||||
- `peft` is NOT used by this spec; it is available in the env but not required here
|
||||
|
||||
---
|
||||
|
||||
## Architecture
|
||||
|
||||
### New file: `scripts/finetune_classifier.py`
|
||||
|
||||
CLI entry point for fine-tuning. All prints use `flush=True` so stdout is SSE-streamable.
|
||||
|
||||
```
|
||||
python scripts/finetune_classifier.py --model deberta-small [--epochs 5]
|
||||
```
|
||||
|
||||
Supported `--model` values: `deberta-small`, `bge-m3`
|
||||
|
||||
**Model registry** (internal to this script):
|
||||
|
||||
| Key | Base model ID | Max tokens | fp16 | Batch size | Grad accum steps | Gradient checkpointing |
|
||||
|-----|--------------|------------|------|------------|-----------------|----------------------|
|
||||
| `deberta-small` | `cross-encoder/nli-deberta-v3-small` | 512 | No | 16 | 1 | No |
|
||||
| `bge-m3` | `MoritzLaurer/bge-m3-zeroshot-v2.0` | 512 | Yes | 4 | 4 | Yes |
|
||||
|
||||
`bge-m3` uses `fp16=True` (halves optimizer state from ~4.8GB to ~2.4GB) with batch size 4 + gradient accumulation 4 = effective batch 16, matching `deberta-small`. These settings are required to fit within 8.2GB VRAM. Still stop Peregrine vLLM before running bge-m3 fine-tuning.
|
||||
|
||||
### Modified: `scripts/classifier_adapters.py`
|
||||
|
||||
Add `FineTunedAdapter(ClassifierAdapter)`:
|
||||
- Takes `model_dir: str` (path to a `models/avocet-*/` checkpoint)
|
||||
- Loads via `pipeline("text-classification", model=model_dir)`
|
||||
- `classify()` input format: **`f"{subject} [SEP] {body[:400]}"`** — must match the training format exactly. Do NOT use the zero-shot adapters' `f"Subject: {subject}\n\n{body[:600]}"` format; distribution shift will degrade accuracy.
|
||||
- Returns the top predicted label directly (single forward pass — no per-label NLI scoring loop)
|
||||
- Expected inference speed: ~10–20ms/email vs 111–338ms for zero-shot
|
||||
|
||||
### Modified: `scripts/benchmark_classifier.py`
|
||||
|
||||
At startup, scan `models/` for subdirectories containing `training_info.json`. Register each as a dynamic entry in the model registry using `FineTunedAdapter`. Silently skips if `models/` does not exist. Existing CLI behaviour unchanged.
|
||||
|
||||
### Modified: `app/api.py`
|
||||
|
||||
Two new GET endpoints (GET required for `EventSource` compatibility):
|
||||
|
||||
**`GET /api/finetune/status`**
|
||||
Scans `models/` for `training_info.json` files. Returns:
|
||||
```json
|
||||
[
|
||||
{
|
||||
"name": "avocet-deberta-small",
|
||||
"base_model": "cross-encoder/nli-deberta-v3-small",
|
||||
"val_macro_f1": 0.712,
|
||||
"timestamp": "2026-03-15T12:00:00Z",
|
||||
"sample_count": 401
|
||||
}
|
||||
]
|
||||
```
|
||||
Returns `[]` if no fine-tuned models exist.
|
||||
|
||||
**`GET /api/finetune/run?model=deberta-small&epochs=5`**
|
||||
Spawns `finetune_classifier.py` via the `job-seeker-classifiers` Python binary. Streams stdout as SSE `{"type":"progress","message":"..."}` events. Emits `{"type":"complete"}` on clean exit, `{"type":"error","message":"..."}` on non-zero exit. Same implementation pattern as `/api/benchmark/run`.
|
||||
|
||||
### Modified: `web/src/views/BenchmarkView.vue`
|
||||
|
||||
**Trained models badge row** (top of view, conditional on fine-tuned models existing):
|
||||
Shows each fine-tuned model name + val macro-F1 chip. Fetches from `/api/finetune/status` on mount.
|
||||
|
||||
**Fine-tune section** (collapsible, below benchmark charts):
|
||||
- Dropdown: `deberta-small` | `bge-m3`
|
||||
- Number input: epochs (default 5, range 1–20)
|
||||
- Run button → streams into existing log component
|
||||
- On `complete`: auto-triggers `/api/benchmark/run` (with `--save`) so charts update immediately
|
||||
|
||||
---
|
||||
|
||||
## Training Pipeline
|
||||
|
||||
### Data preparation
|
||||
|
||||
1. Load `data/email_score.jsonl`
|
||||
2. Drop rows where `label` not in canonical `LABELS` (removes `profile_alert` etc.)
|
||||
3. Check for classes with < 2 **total** samples (before any split). Drop those classes and warn. Additionally warn — but do not skip — classes with < 5 training samples, noting eval F1 for those classes will be unreliable.
|
||||
4. Input text: `f"{subject} [SEP] {body[:400]}"` — fits within 512 tokens for both target models
|
||||
5. Stratified 80/20 train/val split via `sklearn.model_selection.train_test_split(stratify=labels)`
|
||||
|
||||
### Class weighting
|
||||
|
||||
Compute per-class weights: `total_samples / (n_classes × class_count)`. Pass to a `WeightedTrainer` subclass:
|
||||
|
||||
```python
|
||||
class WeightedTrainer(Trainer):
|
||||
def compute_loss(self, model, inputs, return_outputs=False, **kwargs):
|
||||
# **kwargs is required — absorbs num_items_in_batch added in Transformers 4.38.
|
||||
# Do not remove it; removing it causes TypeError on the first training step.
|
||||
labels = inputs.pop("labels")
|
||||
outputs = model(**inputs)
|
||||
# Move class_weights to the same device as logits — required for GPU training.
|
||||
# class_weights is created on CPU; logits are on cuda:0 during training.
|
||||
weight = self.class_weights.to(outputs.logits.device)
|
||||
loss = F.cross_entropy(outputs.logits, labels, weight=weight)
|
||||
return (loss, outputs) if return_outputs else loss
|
||||
```
|
||||
|
||||
### Model setup
|
||||
|
||||
```python
|
||||
AutoModelForSequenceClassification.from_pretrained(
|
||||
base_model_id,
|
||||
num_labels=10,
|
||||
ignore_mismatched_sizes=True, # see note below
|
||||
id2label=id2label,
|
||||
label2id=label2id,
|
||||
)
|
||||
```
|
||||
|
||||
**Note on `ignore_mismatched_sizes=True`:** The pretrained NLI head is a 3-class linear projection. It mismatches the 10-class head constructed by `num_labels=10`, so its weights are skipped during loading. PyTorch initializes the new head from scratch using the model's default init scheme. The backbone weights load normally. Do not set this to `False` — it will raise a shape error.
|
||||
|
||||
### Training config and `compute_metrics`
|
||||
|
||||
The Trainer requires a `compute_metrics` callback that takes an `EvalPrediction` (logits + label_ids) and returns a dict with a `macro_f1` key. This is distinct from the existing `compute_metrics` in `classifier_adapters.py` (which operates on string predictions):
|
||||
|
||||
```python
|
||||
def compute_metrics_for_trainer(eval_pred: EvalPrediction) -> dict:
|
||||
logits, labels = eval_pred
|
||||
preds = logits.argmax(axis=-1)
|
||||
return {
|
||||
"macro_f1": f1_score(labels, preds, average="macro", zero_division=0),
|
||||
"accuracy": accuracy_score(labels, preds),
|
||||
}
|
||||
```
|
||||
|
||||
`TrainingArguments` must include:
|
||||
- `load_best_model_at_end=True`
|
||||
- `metric_for_best_model="macro_f1"`
|
||||
- `greater_is_better=True`
|
||||
|
||||
These are required for `EarlyStoppingCallback` to work correctly. Without `load_best_model_at_end=True`, `EarlyStoppingCallback` raises `AssertionError` on init.
|
||||
|
||||
| Hyperparameter | deberta-small | bge-m3 |
|
||||
|---------------|--------------|--------|
|
||||
| Epochs | 5 (default, CLI-overridable) | 5 |
|
||||
| Batch size | 16 | 4 |
|
||||
| Gradient accumulation | 1 | 4 (effective batch = 16) |
|
||||
| Learning rate | 2e-5 | 2e-5 |
|
||||
| LR schedule | Linear with 10% warmup | same |
|
||||
| Optimizer | AdamW | AdamW |
|
||||
| fp16 | No | Yes |
|
||||
| Gradient checkpointing | No | Yes |
|
||||
| Eval strategy | Every epoch | Every epoch |
|
||||
| Best checkpoint | By `macro_f1` | same |
|
||||
| Early stopping patience | 3 epochs | 3 epochs |
|
||||
|
||||
### Output
|
||||
|
||||
Saved to `models/avocet-{name}/`:
|
||||
- Model weights + tokenizer (standard HuggingFace format)
|
||||
- `training_info.json`:
|
||||
```json
|
||||
{
|
||||
"name": "avocet-deberta-small",
|
||||
"base_model_id": "cross-encoder/nli-deberta-v3-small",
|
||||
"timestamp": "2026-03-15T12:00:00Z",
|
||||
"epochs_run": 5,
|
||||
"val_macro_f1": 0.712,
|
||||
"val_accuracy": 0.798,
|
||||
"sample_count": 401,
|
||||
"label_counts": { "digest": 116, "neutral": 104, ... }
|
||||
}
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Data Flow
|
||||
|
||||
```
|
||||
email_score.jsonl
|
||||
│
|
||||
▼
|
||||
finetune_classifier.py
|
||||
├── drop non-canonical labels
|
||||
├── check for < 2 total samples per class (drop + warn)
|
||||
├── stratified 80/20 split
|
||||
├── tokenize (subject [SEP] body[:400])
|
||||
├── compute class weights
|
||||
├── WeightedTrainer + EarlyStoppingCallback
|
||||
└── save → models/avocet-{name}/
|
||||
│
|
||||
├── FineTunedAdapter (classifier_adapters.py)
|
||||
│ ├── pipeline("text-classification")
|
||||
│ ├── input: subject [SEP] body[:400] ← must match training format
|
||||
│ └── ~10–20ms/email inference
|
||||
│
|
||||
└── training_info.json
|
||||
└── /api/finetune/status
|
||||
└── BenchmarkView badge row
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Error Handling
|
||||
|
||||
- **Insufficient data (< 2 total samples in a class):** Drop class before split, print warning with class name and count.
|
||||
- **Low data warning (< 5 training samples in a class):** Warn but continue; note eval F1 for that class will be unreliable.
|
||||
- **VRAM OOM on bge-m3:** Surface as clear SSE error message. Suggest stopping Peregrine vLLM first (it holds ~5.7GB).
|
||||
- **Missing score file:** Raise `FileNotFoundError` with actionable message (same pattern as `load_scoring_jsonl`).
|
||||
- **Model dir already exists:** Overwrite with a warning log line. Re-running always produces a fresh checkpoint.
|
||||
|
||||
---
|
||||
|
||||
## Testing
|
||||
|
||||
- Unit test `WeightedTrainer.compute_loss` with a mock model and known label distribution — verify weighted loss differs from unweighted; verify `**kwargs` does not raise `TypeError`
|
||||
- Unit test `compute_metrics_for_trainer` — verify `macro_f1` key in output, correct value on known inputs
|
||||
- Unit test `FineTunedAdapter.classify` with a mock pipeline — verify it returns a string from `LABELS` using `subject [SEP] body[:400]` format
|
||||
- Unit test auto-discovery in `benchmark_classifier.py` — mock `models/` dir with two `training_info.json` files, verify both appear in the active registry
|
||||
- Integration test: fine-tune on `data/email_score.jsonl.example` (8 samples, 5 of 10 labels represented, 1 epoch, `--model deberta-small`). The 5 missing labels trigger the `< 2 total samples` drop path — the test must verify the drop warning is emitted for each missing label rather than treating it as a failure. Verify `models/avocet-deberta-small/training_info.json` is written with correct keys.
|
||||
|
||||
---
|
||||
|
||||
## Out of Scope
|
||||
|
||||
- Pushing fine-tuned weights to HuggingFace Hub (future)
|
||||
- Cross-validation or k-fold evaluation (future — dataset too small to be meaningful now)
|
||||
- Hyperparameter search (future)
|
||||
- LoRA/PEFT adapter fine-tuning (future — relevant if model sizes grow beyond available VRAM)
|
||||
- Fine-tuning models other than `deberta-small` and `bge-m3`
|
||||
|
|
@ -14,6 +14,7 @@ dependencies:
|
|||
- transformers>=4.40
|
||||
- torch>=2.2
|
||||
- accelerate>=0.27
|
||||
- scikit-learn>=1.4
|
||||
|
||||
# Optional: GLiClass adapter
|
||||
# - gliclass
|
||||
|
|
|
|||
|
|
@ -96,6 +96,7 @@ usage() {
|
|||
echo " Vue API:"
|
||||
echo -e " ${GREEN}start-api${NC} Build Vue SPA + start FastAPI on port 8503"
|
||||
echo -e " ${GREEN}stop-api${NC} Stop FastAPI server"
|
||||
echo -e " ${GREEN}restart-api${NC} Stop + rebuild + restart FastAPI server"
|
||||
echo -e " ${GREEN}open-api${NC} Open Vue UI in browser (http://localhost:8503)"
|
||||
echo ""
|
||||
echo " Dev:"
|
||||
|
|
@ -305,6 +306,11 @@ case "$CMD" in
|
|||
fi
|
||||
;;
|
||||
|
||||
restart-api)
|
||||
bash "$0" stop-api
|
||||
exec bash "$0" start-api
|
||||
;;
|
||||
|
||||
open-api)
|
||||
URL="http://localhost:8503"
|
||||
info "Opening ${URL}"
|
||||
|
|
|
|||
|
|
@ -32,10 +32,14 @@ from typing import Any
|
|||
|
||||
sys.path.insert(0, str(Path(__file__).parent.parent))
|
||||
|
||||
_ROOT = Path(__file__).parent.parent
|
||||
_MODELS_DIR = _ROOT / "models"
|
||||
|
||||
from scripts.classifier_adapters import (
|
||||
LABELS,
|
||||
LABEL_DESCRIPTIONS,
|
||||
ClassifierAdapter,
|
||||
FineTunedAdapter,
|
||||
GLiClassAdapter,
|
||||
RerankerAdapter,
|
||||
ZeroShotAdapter,
|
||||
|
|
@ -150,8 +154,55 @@ def load_scoring_jsonl(path: str) -> list[dict[str, str]]:
|
|||
return rows
|
||||
|
||||
|
||||
def _active_models(include_slow: bool) -> dict[str, dict[str, Any]]:
|
||||
return {k: v for k, v in MODEL_REGISTRY.items() if v["default"] or include_slow}
|
||||
def discover_finetuned_models(models_dir: Path | None = None) -> list[dict]:
|
||||
"""Scan models/ for subdirs containing training_info.json.
|
||||
|
||||
Returns a list of training_info dicts, each with an added 'model_dir' key.
|
||||
Returns [] silently if models_dir does not exist.
|
||||
"""
|
||||
if models_dir is None:
|
||||
models_dir = _MODELS_DIR
|
||||
if not models_dir.exists():
|
||||
return []
|
||||
found = []
|
||||
for sub in models_dir.iterdir():
|
||||
if not sub.is_dir():
|
||||
continue
|
||||
info_path = sub / "training_info.json"
|
||||
if not info_path.exists():
|
||||
continue
|
||||
try:
|
||||
info = json.loads(info_path.read_text(encoding="utf-8"))
|
||||
except Exception as exc:
|
||||
print(f"[discover] WARN: skipping {info_path}: {exc}", flush=True)
|
||||
continue
|
||||
if "name" not in info:
|
||||
print(f"[discover] WARN: skipping {info_path}: missing 'name' key", flush=True)
|
||||
continue
|
||||
info["model_dir"] = str(sub)
|
||||
found.append(info)
|
||||
return found
|
||||
|
||||
|
||||
def _active_models(include_slow: bool = False) -> dict[str, dict[str, Any]]:
|
||||
"""Return the active model registry, merged with any discovered fine-tuned models."""
|
||||
active: dict[str, dict[str, Any]] = {
|
||||
key: {**entry, "adapter_instance": entry["adapter"](
|
||||
key,
|
||||
entry["model_id"],
|
||||
**entry.get("kwargs", {}),
|
||||
)}
|
||||
for key, entry in MODEL_REGISTRY.items()
|
||||
if include_slow or entry.get("default", False)
|
||||
}
|
||||
for info in discover_finetuned_models():
|
||||
name = info["name"]
|
||||
active[name] = {
|
||||
"adapter_instance": FineTunedAdapter(name, info["model_dir"]),
|
||||
"params": "fine-tuned",
|
||||
"default": True,
|
||||
}
|
||||
return active
|
||||
|
||||
|
||||
def run_scoring(
|
||||
|
|
@ -163,7 +214,8 @@ def run_scoring(
|
|||
gold = [r["label"] for r in rows]
|
||||
results: dict[str, Any] = {}
|
||||
|
||||
for adapter in adapters:
|
||||
for i, adapter in enumerate(adapters, 1):
|
||||
print(f"[{i}/{len(adapters)}] Running {adapter.name} ({len(rows)} samples) …", flush=True)
|
||||
preds: list[str] = []
|
||||
t0 = time.monotonic()
|
||||
for row in rows:
|
||||
|
|
@ -177,6 +229,7 @@ def run_scoring(
|
|||
metrics = compute_metrics(preds, gold, LABELS)
|
||||
metrics["latency_ms"] = round(elapsed_ms / len(rows), 1)
|
||||
results[adapter.name] = metrics
|
||||
print(f" → macro-F1 {metrics['__macro_f1__']:.3f} accuracy {metrics['__accuracy__']:.3f} {metrics['latency_ms']:.1f} ms/email", flush=True)
|
||||
adapter.unload()
|
||||
|
||||
return results
|
||||
|
|
@ -345,10 +398,7 @@ def cmd_score(args: argparse.Namespace) -> None:
|
|||
if args.models:
|
||||
active = {k: v for k, v in active.items() if k in args.models}
|
||||
|
||||
adapters = [
|
||||
entry["adapter"](name, entry["model_id"], **entry.get("kwargs", {}))
|
||||
for name, entry in active.items()
|
||||
]
|
||||
adapters = [entry["adapter_instance"] for entry in active.values()]
|
||||
|
||||
print(f"\nScoring {len(adapters)} model(s) against {args.score_file} …\n")
|
||||
results = run_scoring(adapters, args.score_file)
|
||||
|
|
@ -375,6 +425,31 @@ def cmd_score(args: argparse.Namespace) -> None:
|
|||
print(row_str)
|
||||
print()
|
||||
|
||||
if args.save:
|
||||
import datetime
|
||||
rows = load_scoring_jsonl(args.score_file)
|
||||
save_data = {
|
||||
"timestamp": datetime.datetime.utcnow().isoformat() + "Z",
|
||||
"sample_count": len(rows),
|
||||
"models": {
|
||||
name: {
|
||||
"macro_f1": round(m["__macro_f1__"], 4),
|
||||
"accuracy": round(m["__accuracy__"], 4),
|
||||
"latency_ms": m["latency_ms"],
|
||||
"per_label": {
|
||||
label: {k: round(v, 4) for k, v in m[label].items()}
|
||||
for label in LABELS
|
||||
if label in m
|
||||
},
|
||||
}
|
||||
for name, m in results.items()
|
||||
},
|
||||
}
|
||||
save_path = Path(args.score_file).parent / "benchmark_results.json"
|
||||
with open(save_path, "w") as f:
|
||||
json.dump(save_data, f, indent=2)
|
||||
print(f"Results saved → {save_path}", flush=True)
|
||||
|
||||
|
||||
def cmd_compare(args: argparse.Namespace) -> None:
|
||||
active = _active_models(args.include_slow)
|
||||
|
|
@ -385,10 +460,7 @@ def cmd_compare(args: argparse.Namespace) -> None:
|
|||
emails = _fetch_imap_sample(args.limit, args.days)
|
||||
print(f"Fetched {len(emails)} emails. Loading {len(active)} model(s) …\n")
|
||||
|
||||
adapters = [
|
||||
entry["adapter"](name, entry["model_id"], **entry.get("kwargs", {}))
|
||||
for name, entry in active.items()
|
||||
]
|
||||
adapters = [entry["adapter_instance"] for entry in active.values()]
|
||||
model_names = [a.name for a in adapters]
|
||||
|
||||
col = 22
|
||||
|
|
@ -431,6 +503,8 @@ def main() -> None:
|
|||
parser.add_argument("--days", type=int, default=90, help="Days back for IMAP search")
|
||||
parser.add_argument("--include-slow", action="store_true", help="Include non-default heavy models")
|
||||
parser.add_argument("--models", nargs="+", help="Override: run only these model names")
|
||||
parser.add_argument("--save", action="store_true",
|
||||
help="Save results to data/benchmark_results.json (for the web UI)")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
|
|
|
|||
|
|
@ -17,6 +17,7 @@ __all__ = [
|
|||
"ZeroShotAdapter",
|
||||
"GLiClassAdapter",
|
||||
"RerankerAdapter",
|
||||
"FineTunedAdapter",
|
||||
]
|
||||
|
||||
LABELS: list[str] = [
|
||||
|
|
@ -263,3 +264,43 @@ class RerankerAdapter(ClassifierAdapter):
|
|||
pairs = [[text, LABEL_DESCRIPTIONS.get(label, label.replace("_", " "))] for label in LABELS]
|
||||
scores: list[float] = self._reranker.compute_score(pairs, normalize=True)
|
||||
return LABELS[scores.index(max(scores))]
|
||||
|
||||
|
||||
class FineTunedAdapter(ClassifierAdapter):
|
||||
"""Loads a fine-tuned checkpoint from a local models/ directory.
|
||||
|
||||
Uses pipeline("text-classification") for a single forward pass.
|
||||
Input format: 'subject [SEP] body[:400]' — must match training format exactly.
|
||||
Expected inference speed: ~10–20ms/email vs 111–338ms for zero-shot.
|
||||
"""
|
||||
|
||||
def __init__(self, name: str, model_dir: str) -> None:
|
||||
self._name = name
|
||||
self._model_dir = model_dir
|
||||
self._pipeline: Any = None
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return self._name
|
||||
|
||||
@property
|
||||
def model_id(self) -> str:
|
||||
return self._model_dir
|
||||
|
||||
def load(self) -> None:
|
||||
import scripts.classifier_adapters as _mod # noqa: PLC0415
|
||||
_pipe_fn = _mod.pipeline
|
||||
if _pipe_fn is None:
|
||||
raise ImportError("transformers not installed — run: pip install transformers")
|
||||
device = 0 if _cuda_available() else -1
|
||||
self._pipeline = _pipe_fn("text-classification", model=self._model_dir, device=device)
|
||||
|
||||
def unload(self) -> None:
|
||||
self._pipeline = None
|
||||
|
||||
def classify(self, subject: str, body: str) -> str:
|
||||
if self._pipeline is None:
|
||||
self.load()
|
||||
text = f"{subject} [SEP] {body[:400]}"
|
||||
result = self._pipeline(text)
|
||||
return result[0]["label"]
|
||||
|
|
|
|||
416
scripts/finetune_classifier.py
Normal file
416
scripts/finetune_classifier.py
Normal file
|
|
@ -0,0 +1,416 @@
|
|||
"""Fine-tune email classifiers on the labeled dataset.
|
||||
|
||||
CLI entry point. All prints use flush=True so stdout is SSE-streamable.
|
||||
|
||||
Usage:
|
||||
python scripts/finetune_classifier.py --model deberta-small [--epochs 5]
|
||||
|
||||
Supported --model values: deberta-small, bge-m3
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import hashlib
|
||||
import json
|
||||
import sys
|
||||
from collections import Counter
|
||||
from datetime import datetime, timezone
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch.utils.data import Dataset as TorchDataset
|
||||
from sklearn.model_selection import train_test_split
|
||||
from sklearn.metrics import f1_score, accuracy_score
|
||||
from transformers import (
|
||||
AutoTokenizer,
|
||||
AutoModelForSequenceClassification,
|
||||
EvalPrediction,
|
||||
Trainer,
|
||||
TrainingArguments,
|
||||
EarlyStoppingCallback,
|
||||
)
|
||||
|
||||
sys.path.insert(0, str(Path(__file__).parent.parent))
|
||||
|
||||
from scripts.classifier_adapters import LABELS
|
||||
|
||||
_ROOT = Path(__file__).parent.parent
|
||||
|
||||
_MODEL_CONFIG: dict[str, dict[str, Any]] = {
|
||||
"deberta-small": {
|
||||
"base_model_id": "cross-encoder/nli-deberta-v3-small",
|
||||
"max_tokens": 512,
|
||||
# fp16 must stay OFF — DeBERTa-v3 disentangled attention overflows fp16.
|
||||
"fp16": False,
|
||||
# batch_size=8 + grad_accum=2 keeps effective batch of 16 while halving
|
||||
# per-step activation memory. gradient_checkpointing recomputes activations
|
||||
# on backward instead of storing them — ~60% less activation VRAM.
|
||||
"batch_size": 8,
|
||||
"grad_accum": 2,
|
||||
"gradient_checkpointing": True,
|
||||
},
|
||||
"bge-m3": {
|
||||
"base_model_id": "MoritzLaurer/bge-m3-zeroshot-v2.0",
|
||||
"max_tokens": 512,
|
||||
"fp16": True,
|
||||
"batch_size": 4,
|
||||
"grad_accum": 4,
|
||||
"gradient_checkpointing": True,
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
def load_and_prepare_data(score_files: Path | list[Path]) -> tuple[list[str], list[str]]:
|
||||
"""Load labeled JSONL and return (texts, labels) filtered to canonical LABELS.
|
||||
|
||||
score_files: a single Path or a list of Paths. When multiple files are given,
|
||||
rows are merged with last-write-wins deduplication keyed by content hash
|
||||
(MD5 of subject + body[:100]).
|
||||
|
||||
Drops rows with non-canonical labels (with warning), and drops entire classes
|
||||
that have fewer than 2 total samples (required for stratified split).
|
||||
Warns (but continues) for classes with fewer than 5 samples.
|
||||
"""
|
||||
# Normalise to list — backwards compatible with single-Path callers.
|
||||
if isinstance(score_files, Path):
|
||||
score_files = [score_files]
|
||||
|
||||
for score_file in score_files:
|
||||
if not score_file.exists():
|
||||
raise FileNotFoundError(
|
||||
f"Labeled data not found: {score_file}\n"
|
||||
"Run the label tool first to generate email_score.jsonl."
|
||||
)
|
||||
|
||||
label_set = set(LABELS)
|
||||
# Use a plain dict keyed by content hash; later entries overwrite earlier ones
|
||||
# (last-write wins), which lets later labeling runs correct earlier labels.
|
||||
seen: dict[str, dict] = {}
|
||||
total = 0
|
||||
|
||||
for score_file in score_files:
|
||||
with score_file.open() as fh:
|
||||
for line in fh:
|
||||
line = line.strip()
|
||||
if not line:
|
||||
continue
|
||||
try:
|
||||
r = json.loads(line)
|
||||
except json.JSONDecodeError:
|
||||
continue
|
||||
lbl = r.get("label", "")
|
||||
if lbl not in label_set:
|
||||
print(
|
||||
f"[data] WARNING: Dropping row with non-canonical label {lbl!r}",
|
||||
flush=True,
|
||||
)
|
||||
continue
|
||||
content_hash = hashlib.md5(
|
||||
(r.get("subject", "") + (r.get("body", "") or "")[:100]).encode(
|
||||
"utf-8", errors="replace"
|
||||
)
|
||||
).hexdigest()
|
||||
seen[content_hash] = r
|
||||
total += 1
|
||||
|
||||
kept = len(seen)
|
||||
dropped = total - kept
|
||||
if dropped > 0:
|
||||
print(
|
||||
f"[data] Deduped: kept {kept} of {total} rows (dropped {dropped} duplicates)",
|
||||
flush=True,
|
||||
)
|
||||
|
||||
rows = list(seen.values())
|
||||
|
||||
# Count samples per class
|
||||
counts: Counter = Counter(r["label"] for r in rows)
|
||||
|
||||
# Drop classes with < 2 total samples (cannot stratify-split)
|
||||
drop_classes: set[str] = set()
|
||||
for lbl, cnt in counts.items():
|
||||
if cnt < 2:
|
||||
print(
|
||||
f"[data] WARNING: Dropping class {lbl!r} — only {counts[lbl]} total "
|
||||
f"sample(s). Need at least 2 for stratified split.",
|
||||
flush=True,
|
||||
)
|
||||
drop_classes.add(lbl)
|
||||
|
||||
# Warn for classes with < 5 samples (unreliable eval F1)
|
||||
for lbl, cnt in counts.items():
|
||||
if lbl not in drop_classes and cnt < 5:
|
||||
print(
|
||||
f"[data] WARNING: Class {lbl!r} has only {cnt} sample(s). "
|
||||
f"Eval F1 for this class will be unreliable.",
|
||||
flush=True,
|
||||
)
|
||||
|
||||
# Filter rows
|
||||
rows = [r for r in rows if r["label"] not in drop_classes]
|
||||
|
||||
texts = [f"{r['subject']} [SEP] {r['body'][:400]}" for r in rows]
|
||||
labels = [r["label"] for r in rows]
|
||||
|
||||
return texts, labels
|
||||
|
||||
|
||||
def compute_class_weights(label_ids: list[int], n_classes: int) -> torch.Tensor:
|
||||
"""Compute inverse-frequency class weights.
|
||||
|
||||
Formula: total / (n_classes * class_count) per class.
|
||||
Unseen classes (count=0) use count=1 to avoid division by zero.
|
||||
|
||||
Returns a CPU float32 tensor of shape (n_classes,).
|
||||
"""
|
||||
counts = Counter(label_ids)
|
||||
total = len(label_ids)
|
||||
weights = []
|
||||
for cls in range(n_classes):
|
||||
cnt = counts.get(cls, 1) # use 1 for unseen to avoid div-by-zero
|
||||
weights.append(total / (n_classes * cnt))
|
||||
return torch.tensor(weights, dtype=torch.float32)
|
||||
|
||||
|
||||
def compute_metrics_for_trainer(eval_pred: EvalPrediction) -> dict:
|
||||
"""Compute macro F1 and accuracy from EvalPrediction.
|
||||
|
||||
Called by Hugging Face Trainer at each evaluation step.
|
||||
"""
|
||||
logits, label_ids = eval_pred.predictions, eval_pred.label_ids
|
||||
preds = logits.argmax(axis=-1)
|
||||
macro_f1 = f1_score(label_ids, preds, average="macro", zero_division=0)
|
||||
acc = accuracy_score(label_ids, preds)
|
||||
return {"macro_f1": float(macro_f1), "accuracy": float(acc)}
|
||||
|
||||
|
||||
class WeightedTrainer(Trainer):
|
||||
"""Trainer subclass that applies per-class weights to the cross-entropy loss."""
|
||||
|
||||
def compute_loss(self, model, inputs, return_outputs=False, **kwargs):
|
||||
# **kwargs is required — absorbs num_items_in_batch added in Transformers 4.38.
|
||||
# Do not remove it; removing it causes TypeError on the first training step.
|
||||
labels = inputs.pop("labels")
|
||||
outputs = model(**inputs)
|
||||
# Move class_weights to the same device as logits — required for GPU training.
|
||||
# class_weights is created on CPU; logits are on cuda:0 during training.
|
||||
weight = self.class_weights.to(outputs.logits.device)
|
||||
loss = F.cross_entropy(outputs.logits, labels, weight=weight)
|
||||
return (loss, outputs) if return_outputs else loss
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Training dataset wrapper
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class _EmailDataset(TorchDataset):
|
||||
def __init__(self, encodings: dict, label_ids: list[int]) -> None:
|
||||
self.encodings = encodings
|
||||
self.label_ids = label_ids
|
||||
|
||||
def __len__(self) -> int:
|
||||
return len(self.label_ids)
|
||||
|
||||
def __getitem__(self, idx: int) -> dict:
|
||||
item = {k: torch.tensor(v[idx]) for k, v in self.encodings.items()}
|
||||
item["labels"] = torch.tensor(self.label_ids[idx], dtype=torch.long)
|
||||
return item
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Main training function
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def run_finetune(model_key: str, epochs: int = 5, score_files: list[Path] | None = None) -> None:
|
||||
"""Fine-tune the specified model on labeled data.
|
||||
|
||||
score_files: list of score JSONL paths to merge. Defaults to [_ROOT / "data" / "email_score.jsonl"].
|
||||
Saves model + tokenizer + training_info.json to models/avocet-{model_key}/.
|
||||
All prints use flush=True for SSE streaming.
|
||||
"""
|
||||
if model_key not in _MODEL_CONFIG:
|
||||
raise ValueError(f"Unknown model key: {model_key!r}. Choose from: {list(_MODEL_CONFIG)}")
|
||||
|
||||
if score_files is None:
|
||||
score_files = [_ROOT / "data" / "email_score.jsonl"]
|
||||
|
||||
config = _MODEL_CONFIG[model_key]
|
||||
base_model_id = config["base_model_id"]
|
||||
output_dir = _ROOT / "models" / f"avocet-{model_key}"
|
||||
|
||||
print(f"[finetune] Model: {model_key} ({base_model_id})", flush=True)
|
||||
print(f"[finetune] Score files: {[str(f) for f in score_files]}", flush=True)
|
||||
print(f"[finetune] Output: {output_dir}", flush=True)
|
||||
if output_dir.exists():
|
||||
print(f"[finetune] WARNING: {output_dir} already exists — will overwrite.", flush=True)
|
||||
|
||||
# --- Data ---
|
||||
print(f"[finetune] Loading data ...", flush=True)
|
||||
texts, str_labels = load_and_prepare_data(score_files)
|
||||
|
||||
present_labels = sorted(set(str_labels))
|
||||
label2id = {l: i for i, l in enumerate(present_labels)}
|
||||
id2label = {i: l for l, i in label2id.items()}
|
||||
n_classes = len(present_labels)
|
||||
label_ids = [label2id[l] for l in str_labels]
|
||||
|
||||
print(f"[finetune] {len(texts)} samples, {n_classes} classes", flush=True)
|
||||
|
||||
# Stratified 80/20 split — ensure val set has at least n_classes samples.
|
||||
# For very small datasets (e.g. example data) we may need to give the val set
|
||||
# more than 20% so every class appears at least once in eval.
|
||||
desired_test = max(int(len(texts) * 0.2), n_classes)
|
||||
# test_size must leave at least n_classes samples for train too
|
||||
desired_test = min(desired_test, len(texts) - n_classes)
|
||||
(train_texts, val_texts,
|
||||
train_label_ids, val_label_ids) = train_test_split(
|
||||
texts, label_ids,
|
||||
test_size=desired_test,
|
||||
stratify=label_ids,
|
||||
random_state=42,
|
||||
)
|
||||
print(f"[finetune] Train: {len(train_texts)}, Val: {len(val_texts)}", flush=True)
|
||||
|
||||
# Warn for classes with < 5 training samples
|
||||
train_counts = Counter(train_label_ids)
|
||||
for cls_id, cnt in train_counts.items():
|
||||
if cnt < 5:
|
||||
print(
|
||||
f"[finetune] WARNING: Class {id2label[cls_id]!r} has {cnt} training sample(s). "
|
||||
"Eval F1 for this class will be unreliable.",
|
||||
flush=True,
|
||||
)
|
||||
|
||||
# --- Tokenize ---
|
||||
print(f"[finetune] Loading tokenizer ...", flush=True)
|
||||
tokenizer = AutoTokenizer.from_pretrained(base_model_id)
|
||||
|
||||
train_enc = tokenizer(train_texts, truncation=True,
|
||||
max_length=config["max_tokens"], padding=True)
|
||||
val_enc = tokenizer(val_texts, truncation=True,
|
||||
max_length=config["max_tokens"], padding=True)
|
||||
|
||||
train_dataset = _EmailDataset(train_enc, train_label_ids)
|
||||
val_dataset = _EmailDataset(val_enc, val_label_ids)
|
||||
|
||||
# --- Class weights ---
|
||||
class_weights = compute_class_weights(train_label_ids, n_classes)
|
||||
print(f"[finetune] Class weights computed", flush=True)
|
||||
|
||||
# --- Model ---
|
||||
print(f"[finetune] Loading model ...", flush=True)
|
||||
model = AutoModelForSequenceClassification.from_pretrained(
|
||||
base_model_id,
|
||||
num_labels=n_classes,
|
||||
ignore_mismatched_sizes=True, # NLI head (3-class) → new head (n_classes)
|
||||
id2label=id2label,
|
||||
label2id=label2id,
|
||||
)
|
||||
if config["gradient_checkpointing"]:
|
||||
# use_reentrant=False avoids "backward through graph a second time" errors
|
||||
# when Accelerate's gradient accumulation context is layered on top.
|
||||
# Reentrant checkpointing (the default) conflicts with Accelerate ≥ 0.27.
|
||||
model.gradient_checkpointing_enable(
|
||||
gradient_checkpointing_kwargs={"use_reentrant": False}
|
||||
)
|
||||
|
||||
# --- TrainingArguments ---
|
||||
training_args = TrainingArguments(
|
||||
output_dir=str(output_dir),
|
||||
num_train_epochs=epochs,
|
||||
per_device_train_batch_size=config["batch_size"],
|
||||
per_device_eval_batch_size=config["batch_size"],
|
||||
gradient_accumulation_steps=config["grad_accum"],
|
||||
learning_rate=2e-5,
|
||||
lr_scheduler_type="linear",
|
||||
warmup_ratio=0.1,
|
||||
fp16=config["fp16"],
|
||||
eval_strategy="epoch",
|
||||
save_strategy="epoch",
|
||||
load_best_model_at_end=True,
|
||||
metric_for_best_model="macro_f1",
|
||||
greater_is_better=True,
|
||||
logging_steps=10,
|
||||
report_to="none",
|
||||
save_total_limit=2,
|
||||
)
|
||||
|
||||
trainer = WeightedTrainer(
|
||||
model=model,
|
||||
args=training_args,
|
||||
train_dataset=train_dataset,
|
||||
eval_dataset=val_dataset,
|
||||
compute_metrics=compute_metrics_for_trainer,
|
||||
callbacks=[EarlyStoppingCallback(early_stopping_patience=3)],
|
||||
)
|
||||
trainer.class_weights = class_weights
|
||||
|
||||
# --- Train ---
|
||||
print(f"[finetune] Starting training ({epochs} epochs) ...", flush=True)
|
||||
train_result = trainer.train()
|
||||
print(f"[finetune] Training complete. Steps: {train_result.global_step}", flush=True)
|
||||
|
||||
# --- Evaluate ---
|
||||
print(f"[finetune] Evaluating best checkpoint ...", flush=True)
|
||||
metrics = trainer.evaluate()
|
||||
val_macro_f1 = metrics.get("eval_macro_f1", 0.0)
|
||||
val_accuracy = metrics.get("eval_accuracy", 0.0)
|
||||
print(f"[finetune] Val macro-F1: {val_macro_f1:.4f}, Accuracy: {val_accuracy:.4f}", flush=True)
|
||||
|
||||
# --- Save model + tokenizer ---
|
||||
print(f"[finetune] Saving model to {output_dir} ...", flush=True)
|
||||
trainer.save_model(str(output_dir))
|
||||
tokenizer.save_pretrained(str(output_dir))
|
||||
|
||||
# --- Write training_info.json ---
|
||||
label_counts = dict(Counter(str_labels))
|
||||
info = {
|
||||
"name": f"avocet-{model_key}",
|
||||
"base_model_id": base_model_id,
|
||||
"timestamp": datetime.now(timezone.utc).isoformat(),
|
||||
"epochs_run": epochs,
|
||||
"val_macro_f1": round(val_macro_f1, 4),
|
||||
"val_accuracy": round(val_accuracy, 4),
|
||||
"sample_count": len(texts),
|
||||
"train_sample_count": len(train_texts),
|
||||
"label_counts": label_counts,
|
||||
"score_files": [str(f) for f in score_files],
|
||||
}
|
||||
info_path = output_dir / "training_info.json"
|
||||
info_path.write_text(json.dumps(info, indent=2), encoding="utf-8")
|
||||
print(f"[finetune] Saved training_info.json: val_macro_f1={val_macro_f1:.4f}", flush=True)
|
||||
print(f"[finetune] Done.", flush=True)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# CLI
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description="Fine-tune an email classifier")
|
||||
parser.add_argument(
|
||||
"--model",
|
||||
choices=list(_MODEL_CONFIG),
|
||||
required=True,
|
||||
help="Model key to fine-tune",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--epochs",
|
||||
type=int,
|
||||
default=5,
|
||||
help="Number of training epochs (default: 5)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--score",
|
||||
dest="score_files",
|
||||
type=Path,
|
||||
action="append",
|
||||
metavar="FILE",
|
||||
help="Score JSONL file to include (repeatable; defaults to data/email_score.jsonl)",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
score_files = args.score_files or None # None → run_finetune uses default
|
||||
run_finetune(args.model, args.epochs, score_files=score_files)
|
||||
|
|
@ -325,3 +325,237 @@ def test_fetch_stream_with_mock_imap(client, config_dir, data_dir):
|
|||
assert "start" in types
|
||||
assert "done" in types
|
||||
assert "complete" in types
|
||||
|
||||
|
||||
# ---- /api/finetune/status tests ----
|
||||
|
||||
def test_finetune_status_returns_empty_when_no_models_dir(client):
|
||||
"""GET /api/finetune/status must return [] if models/ does not exist."""
|
||||
r = client.get("/api/finetune/status")
|
||||
assert r.status_code == 200
|
||||
assert r.json() == []
|
||||
|
||||
|
||||
def test_finetune_status_returns_training_info(client, tmp_path):
|
||||
"""GET /api/finetune/status must return one entry per training_info.json found."""
|
||||
import json as _json
|
||||
from app import api as api_module
|
||||
|
||||
models_dir = tmp_path / "models" / "avocet-deberta-small"
|
||||
models_dir.mkdir(parents=True)
|
||||
info = {
|
||||
"name": "avocet-deberta-small",
|
||||
"base_model_id": "cross-encoder/nli-deberta-v3-small",
|
||||
"val_macro_f1": 0.712,
|
||||
"timestamp": "2026-03-15T12:00:00Z",
|
||||
"sample_count": 401,
|
||||
}
|
||||
(models_dir / "training_info.json").write_text(_json.dumps(info))
|
||||
|
||||
api_module.set_models_dir(tmp_path / "models")
|
||||
try:
|
||||
r = client.get("/api/finetune/status")
|
||||
assert r.status_code == 200
|
||||
data = r.json()
|
||||
assert any(d["name"] == "avocet-deberta-small" for d in data)
|
||||
finally:
|
||||
api_module.set_models_dir(api_module._ROOT / "models")
|
||||
|
||||
|
||||
def test_finetune_run_streams_sse_events(client):
|
||||
"""GET /api/finetune/run must return text/event-stream content type."""
|
||||
from unittest.mock import patch, MagicMock
|
||||
|
||||
mock_proc = MagicMock()
|
||||
mock_proc.stdout = iter(["Training epoch 1\n", "Done\n"])
|
||||
mock_proc.returncode = 0
|
||||
mock_proc.wait = MagicMock()
|
||||
|
||||
with patch("app.api._subprocess.Popen",return_value=mock_proc):
|
||||
r = client.get("/api/finetune/run?model=deberta-small&epochs=1")
|
||||
|
||||
assert r.status_code == 200
|
||||
assert "text/event-stream" in r.headers.get("content-type", "")
|
||||
|
||||
|
||||
def test_finetune_run_emits_complete_on_success(client):
|
||||
"""GET /api/finetune/run must emit a complete event on clean exit."""
|
||||
from unittest.mock import patch, MagicMock
|
||||
|
||||
mock_proc = MagicMock()
|
||||
mock_proc.stdout = iter(["progress line\n"])
|
||||
mock_proc.returncode = 0
|
||||
mock_proc.wait = MagicMock()
|
||||
|
||||
with patch("app.api._subprocess.Popen",return_value=mock_proc):
|
||||
r = client.get("/api/finetune/run?model=deberta-small&epochs=1")
|
||||
|
||||
assert '{"type": "complete"}' in r.text
|
||||
|
||||
|
||||
def test_finetune_run_emits_error_on_nonzero_exit(client):
|
||||
"""GET /api/finetune/run must emit an error event on non-zero exit."""
|
||||
from unittest.mock import patch, MagicMock
|
||||
|
||||
mock_proc = MagicMock()
|
||||
mock_proc.stdout = iter([])
|
||||
mock_proc.returncode = 1
|
||||
mock_proc.wait = MagicMock()
|
||||
|
||||
with patch("app.api._subprocess.Popen",return_value=mock_proc):
|
||||
r = client.get("/api/finetune/run?model=deberta-small&epochs=1")
|
||||
|
||||
assert '"type": "error"' in r.text
|
||||
|
||||
|
||||
def test_finetune_run_passes_score_files_to_subprocess(client):
|
||||
"""GET /api/finetune/run?score=file1&score=file2 must pass --score args to subprocess."""
|
||||
from unittest.mock import patch, MagicMock
|
||||
|
||||
captured_cmd = []
|
||||
|
||||
def mock_popen(cmd, **kwargs):
|
||||
captured_cmd.extend(cmd)
|
||||
m = MagicMock()
|
||||
m.stdout = iter([])
|
||||
m.returncode = 0
|
||||
m.wait = MagicMock()
|
||||
return m
|
||||
|
||||
with patch("app.api._subprocess.Popen",side_effect=mock_popen):
|
||||
client.get("/api/finetune/run?model=deberta-small&epochs=1&score=run1.jsonl&score=run2.jsonl")
|
||||
|
||||
assert "--score" in captured_cmd
|
||||
assert captured_cmd.count("--score") == 2
|
||||
# Paths are resolved to absolute — check filenames are present as substrings
|
||||
assert any("run1.jsonl" in arg for arg in captured_cmd)
|
||||
assert any("run2.jsonl" in arg for arg in captured_cmd)
|
||||
|
||||
|
||||
# ---- Cancel endpoint tests ----
|
||||
|
||||
def test_benchmark_cancel_returns_404_when_not_running(client):
|
||||
"""POST /api/benchmark/cancel must return 404 if no benchmark is running."""
|
||||
from app import api as api_module
|
||||
api_module._running_procs.pop("benchmark", None)
|
||||
r = client.post("/api/benchmark/cancel")
|
||||
assert r.status_code == 404
|
||||
|
||||
|
||||
def test_finetune_cancel_returns_404_when_not_running(client):
|
||||
"""POST /api/finetune/cancel must return 404 if no finetune is running."""
|
||||
from app import api as api_module
|
||||
api_module._running_procs.pop("finetune", None)
|
||||
r = client.post("/api/finetune/cancel")
|
||||
assert r.status_code == 404
|
||||
|
||||
|
||||
def test_benchmark_cancel_terminates_running_process(client):
|
||||
"""POST /api/benchmark/cancel must call terminate() on the running process."""
|
||||
from unittest.mock import MagicMock
|
||||
from app import api as api_module
|
||||
|
||||
mock_proc = MagicMock()
|
||||
mock_proc.wait = MagicMock()
|
||||
api_module._running_procs["benchmark"] = mock_proc
|
||||
|
||||
try:
|
||||
r = client.post("/api/benchmark/cancel")
|
||||
assert r.status_code == 200
|
||||
assert r.json()["status"] == "cancelled"
|
||||
mock_proc.terminate.assert_called_once()
|
||||
finally:
|
||||
api_module._running_procs.pop("benchmark", None)
|
||||
api_module._cancelled_jobs.discard("benchmark")
|
||||
|
||||
|
||||
def test_finetune_cancel_terminates_running_process(client):
|
||||
"""POST /api/finetune/cancel must call terminate() on the running process."""
|
||||
from unittest.mock import MagicMock
|
||||
from app import api as api_module
|
||||
|
||||
mock_proc = MagicMock()
|
||||
mock_proc.wait = MagicMock()
|
||||
api_module._running_procs["finetune"] = mock_proc
|
||||
|
||||
try:
|
||||
r = client.post("/api/finetune/cancel")
|
||||
assert r.status_code == 200
|
||||
assert r.json()["status"] == "cancelled"
|
||||
mock_proc.terminate.assert_called_once()
|
||||
finally:
|
||||
api_module._running_procs.pop("finetune", None)
|
||||
api_module._cancelled_jobs.discard("finetune")
|
||||
|
||||
|
||||
def test_benchmark_cancel_kills_process_on_timeout(client):
|
||||
"""POST /api/benchmark/cancel must call kill() if the process does not exit within 3 s."""
|
||||
import subprocess
|
||||
from unittest.mock import MagicMock
|
||||
from app import api as api_module
|
||||
|
||||
mock_proc = MagicMock()
|
||||
mock_proc.wait.side_effect = subprocess.TimeoutExpired(cmd="benchmark", timeout=3)
|
||||
api_module._running_procs["benchmark"] = mock_proc
|
||||
|
||||
try:
|
||||
r = client.post("/api/benchmark/cancel")
|
||||
assert r.status_code == 200
|
||||
mock_proc.kill.assert_called_once()
|
||||
finally:
|
||||
api_module._running_procs.pop("benchmark", None)
|
||||
api_module._cancelled_jobs.discard("benchmark")
|
||||
|
||||
|
||||
def test_finetune_run_emits_cancelled_event(client):
|
||||
"""GET /api/finetune/run must emit cancelled (not error) when job was cancelled."""
|
||||
from unittest.mock import patch, MagicMock
|
||||
from app import api as api_module
|
||||
|
||||
mock_proc = MagicMock()
|
||||
mock_proc.stdout = iter([])
|
||||
mock_proc.returncode = -15 # SIGTERM
|
||||
|
||||
def mock_wait():
|
||||
# Simulate cancel being called while the process is running (after discard clears stale flag)
|
||||
api_module._cancelled_jobs.add("finetune")
|
||||
|
||||
mock_proc.wait = mock_wait
|
||||
|
||||
def mock_popen(cmd, **kwargs):
|
||||
return mock_proc
|
||||
|
||||
try:
|
||||
with patch("app.api._subprocess.Popen",side_effect=mock_popen):
|
||||
r = client.get("/api/finetune/run?model=deberta-small&epochs=1")
|
||||
assert '{"type": "cancelled"}' in r.text
|
||||
assert '"type": "error"' not in r.text
|
||||
finally:
|
||||
api_module._cancelled_jobs.discard("finetune")
|
||||
|
||||
|
||||
def test_benchmark_run_emits_cancelled_event(client):
|
||||
"""GET /api/benchmark/run must emit cancelled (not error) when job was cancelled."""
|
||||
from unittest.mock import patch, MagicMock
|
||||
from app import api as api_module
|
||||
|
||||
mock_proc = MagicMock()
|
||||
mock_proc.stdout = iter([])
|
||||
mock_proc.returncode = -15
|
||||
|
||||
def mock_wait():
|
||||
# Simulate cancel being called while the process is running (after discard clears stale flag)
|
||||
api_module._cancelled_jobs.add("benchmark")
|
||||
|
||||
mock_proc.wait = mock_wait
|
||||
|
||||
def mock_popen(cmd, **kwargs):
|
||||
return mock_proc
|
||||
|
||||
try:
|
||||
with patch("app.api._subprocess.Popen",side_effect=mock_popen):
|
||||
r = client.get("/api/benchmark/run")
|
||||
assert '{"type": "cancelled"}' in r.text
|
||||
assert '"type": "error"' not in r.text
|
||||
finally:
|
||||
api_module._cancelled_jobs.discard("benchmark")
|
||||
|
|
|
|||
|
|
@ -92,3 +92,77 @@ def test_run_scoring_handles_classify_error(tmp_path):
|
|||
|
||||
results = run_scoring([broken], str(score_file))
|
||||
assert "broken" in results
|
||||
|
||||
|
||||
# ---- Auto-discovery tests ----
|
||||
|
||||
def test_discover_finetuned_models_finds_training_info_files(tmp_path):
|
||||
"""discover_finetuned_models() must return one entry per training_info.json found."""
|
||||
import json
|
||||
from scripts.benchmark_classifier import discover_finetuned_models
|
||||
|
||||
# Create two fake model directories
|
||||
for name in ("avocet-deberta-small", "avocet-bge-m3"):
|
||||
model_dir = tmp_path / name
|
||||
model_dir.mkdir()
|
||||
info = {
|
||||
"name": name,
|
||||
"base_model_id": "cross-encoder/nli-deberta-v3-small",
|
||||
"timestamp": "2026-03-15T12:00:00Z",
|
||||
"val_macro_f1": 0.72,
|
||||
"val_accuracy": 0.80,
|
||||
"sample_count": 401,
|
||||
}
|
||||
(model_dir / "training_info.json").write_text(json.dumps(info))
|
||||
|
||||
results = discover_finetuned_models(tmp_path)
|
||||
assert len(results) == 2
|
||||
names = {r["name"] for r in results}
|
||||
assert "avocet-deberta-small" in names
|
||||
assert "avocet-bge-m3" in names
|
||||
for r in results:
|
||||
assert "model_dir" in r, "discover_finetuned_models must inject model_dir key"
|
||||
assert r["model_dir"].endswith(r["name"])
|
||||
|
||||
|
||||
def test_discover_finetuned_models_returns_empty_when_no_models_dir():
|
||||
"""discover_finetuned_models() must return [] silently if models/ doesn't exist."""
|
||||
from pathlib import Path
|
||||
from scripts.benchmark_classifier import discover_finetuned_models
|
||||
|
||||
results = discover_finetuned_models(Path("/nonexistent/path/models"))
|
||||
assert results == []
|
||||
|
||||
|
||||
def test_discover_finetuned_models_skips_dirs_without_training_info(tmp_path):
|
||||
"""Subdirs without training_info.json are silently skipped."""
|
||||
from scripts.benchmark_classifier import discover_finetuned_models
|
||||
|
||||
# A dir WITHOUT training_info.json
|
||||
(tmp_path / "some-other-dir").mkdir()
|
||||
|
||||
results = discover_finetuned_models(tmp_path)
|
||||
assert results == []
|
||||
|
||||
|
||||
def test_active_models_includes_discovered_finetuned(tmp_path):
|
||||
"""The active models dict must include FineTunedAdapter entries for discovered models."""
|
||||
import json
|
||||
from unittest.mock import patch
|
||||
from scripts.benchmark_classifier import _active_models
|
||||
from scripts.classifier_adapters import FineTunedAdapter
|
||||
|
||||
model_dir = tmp_path / "avocet-deberta-small"
|
||||
model_dir.mkdir()
|
||||
(model_dir / "training_info.json").write_text(json.dumps({
|
||||
"name": "avocet-deberta-small",
|
||||
"base_model_id": "cross-encoder/nli-deberta-v3-small",
|
||||
"val_macro_f1": 0.72,
|
||||
"sample_count": 401,
|
||||
}))
|
||||
|
||||
with patch("scripts.benchmark_classifier._MODELS_DIR", tmp_path):
|
||||
models = _active_models(include_slow=False)
|
||||
|
||||
assert "avocet-deberta-small" in models
|
||||
assert isinstance(models["avocet-deberta-small"]["adapter_instance"], FineTunedAdapter)
|
||||
|
|
|
|||
|
|
@ -180,3 +180,91 @@ def test_reranker_adapter_picks_highest_score():
|
|||
def test_reranker_adapter_descriptions_cover_all_labels():
|
||||
from scripts.classifier_adapters import LABEL_DESCRIPTIONS, LABELS
|
||||
assert set(LABEL_DESCRIPTIONS.keys()) == set(LABELS)
|
||||
|
||||
|
||||
# ---- FineTunedAdapter tests ----
|
||||
|
||||
def test_finetuned_adapter_classify_calls_pipeline_with_sep_format(tmp_path):
|
||||
"""classify() must format input as 'subject [SEP] body[:400]' — not the zero-shot format."""
|
||||
from unittest.mock import MagicMock, patch
|
||||
from scripts.classifier_adapters import FineTunedAdapter
|
||||
|
||||
mock_result = [{"label": "digest", "score": 0.95}]
|
||||
mock_pipe_instance = MagicMock(return_value=mock_result)
|
||||
mock_pipe_factory = MagicMock(return_value=mock_pipe_instance)
|
||||
|
||||
adapter = FineTunedAdapter("avocet-deberta-small", str(tmp_path))
|
||||
with patch("scripts.classifier_adapters.pipeline", mock_pipe_factory):
|
||||
result = adapter.classify("Test subject", "Test body")
|
||||
|
||||
assert result == "digest"
|
||||
call_args = mock_pipe_instance.call_args[0][0]
|
||||
assert "[SEP]" in call_args
|
||||
assert "Test subject" in call_args
|
||||
assert "Test body" in call_args
|
||||
|
||||
|
||||
def test_finetuned_adapter_truncates_body_to_400():
|
||||
"""Body must be truncated to 400 chars in the [SEP] format."""
|
||||
from unittest.mock import MagicMock, patch
|
||||
from scripts.classifier_adapters import FineTunedAdapter, LABELS
|
||||
|
||||
long_body = "x" * 800
|
||||
mock_result = [{"label": "neutral", "score": 0.9}]
|
||||
mock_pipe_instance = MagicMock(return_value=mock_result)
|
||||
mock_pipe_factory = MagicMock(return_value=mock_pipe_instance)
|
||||
|
||||
adapter = FineTunedAdapter("avocet-deberta-small", "/fake/path")
|
||||
with patch("scripts.classifier_adapters.pipeline", mock_pipe_factory):
|
||||
adapter.classify("Subject", long_body)
|
||||
|
||||
call_text = mock_pipe_instance.call_args[0][0]
|
||||
parts = call_text.split(" [SEP] ", 1)
|
||||
assert len(parts) == 2, "Input must contain ' [SEP] ' separator"
|
||||
assert len(parts[1]) == 400, f"Body must be exactly 400 chars, got {len(parts[1])}"
|
||||
|
||||
|
||||
def test_finetuned_adapter_returns_label_string():
|
||||
"""classify() must return a plain string, not a dict."""
|
||||
from unittest.mock import MagicMock, patch
|
||||
from scripts.classifier_adapters import FineTunedAdapter
|
||||
|
||||
mock_result = [{"label": "interview_scheduled", "score": 0.87}]
|
||||
mock_pipe_instance = MagicMock(return_value=mock_result)
|
||||
mock_pipe_factory = MagicMock(return_value=mock_pipe_instance)
|
||||
|
||||
adapter = FineTunedAdapter("avocet-deberta-small", "/fake/path")
|
||||
with patch("scripts.classifier_adapters.pipeline", mock_pipe_factory):
|
||||
result = adapter.classify("S", "B")
|
||||
|
||||
assert isinstance(result, str)
|
||||
assert result == "interview_scheduled"
|
||||
|
||||
|
||||
def test_finetuned_adapter_lazy_loads_pipeline():
|
||||
"""Pipeline factory must not be called until classify() is first called."""
|
||||
from unittest.mock import MagicMock, patch
|
||||
from scripts.classifier_adapters import FineTunedAdapter
|
||||
|
||||
mock_pipe_factory = MagicMock(return_value=MagicMock(return_value=[{"label": "neutral", "score": 0.9}]))
|
||||
|
||||
with patch("scripts.classifier_adapters.pipeline", mock_pipe_factory):
|
||||
adapter = FineTunedAdapter("avocet-deberta-small", "/fake/path")
|
||||
assert not mock_pipe_factory.called
|
||||
adapter.classify("s", "b")
|
||||
assert mock_pipe_factory.called
|
||||
|
||||
|
||||
def test_finetuned_adapter_unload_clears_pipeline():
|
||||
"""unload() must set _pipeline to None so memory is released."""
|
||||
from unittest.mock import MagicMock, patch
|
||||
from scripts.classifier_adapters import FineTunedAdapter
|
||||
|
||||
mock_pipe_factory = MagicMock(return_value=MagicMock(return_value=[{"label": "neutral", "score": 0.9}]))
|
||||
|
||||
with patch("scripts.classifier_adapters.pipeline", mock_pipe_factory):
|
||||
adapter = FineTunedAdapter("avocet-deberta-small", "/fake/path")
|
||||
adapter.classify("s", "b")
|
||||
assert adapter._pipeline is not None
|
||||
adapter.unload()
|
||||
assert adapter._pipeline is None
|
||||
|
|
|
|||
371
tests/test_finetune.py
Normal file
371
tests/test_finetune.py
Normal file
|
|
@ -0,0 +1,371 @@
|
|||
"""Tests for finetune_classifier — no model downloads required."""
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import pytest
|
||||
|
||||
|
||||
# ---- Data loading tests ----
|
||||
|
||||
def test_load_and_prepare_data_drops_non_canonical_labels(tmp_path):
|
||||
"""Rows with labels not in LABELS must be silently dropped."""
|
||||
from scripts.finetune_classifier import load_and_prepare_data
|
||||
from scripts.classifier_adapters import LABELS
|
||||
|
||||
# Two samples per canonical label so they survive the < 2 class-drop rule.
|
||||
rows = [
|
||||
{"subject": "s1", "body": "b1", "label": "digest"},
|
||||
{"subject": "s2", "body": "b2", "label": "digest"},
|
||||
{"subject": "s3", "body": "b3", "label": "profile_alert"}, # non-canonical
|
||||
{"subject": "s4", "body": "b4", "label": "neutral"},
|
||||
{"subject": "s5", "body": "b5", "label": "neutral"},
|
||||
]
|
||||
score_file = tmp_path / "email_score.jsonl"
|
||||
score_file.write_text("\n".join(json.dumps(r) for r in rows))
|
||||
|
||||
texts, labels = load_and_prepare_data(score_file)
|
||||
assert len(texts) == 4
|
||||
assert all(l in LABELS for l in labels)
|
||||
|
||||
|
||||
def test_load_and_prepare_data_formats_input_as_sep(tmp_path):
|
||||
"""Input text must be 'subject [SEP] body[:400]'."""
|
||||
# Two samples with the same label so the class survives the < 2 drop rule.
|
||||
rows = [
|
||||
{"subject": "Hello", "body": "World" * 100, "label": "neutral"},
|
||||
{"subject": "Hello2", "body": "World" * 100, "label": "neutral"},
|
||||
]
|
||||
score_file = tmp_path / "email_score.jsonl"
|
||||
score_file.write_text("\n".join(json.dumps(r) for r in rows))
|
||||
|
||||
from scripts.finetune_classifier import load_and_prepare_data
|
||||
texts, labels = load_and_prepare_data(score_file)
|
||||
|
||||
assert texts[0].startswith("Hello [SEP] ")
|
||||
parts = texts[0].split(" [SEP] ", 1)
|
||||
assert len(parts[1]) == 400, f"Body must be exactly 400 chars, got {len(parts[1])}"
|
||||
|
||||
|
||||
def test_load_and_prepare_data_raises_on_missing_file():
|
||||
"""FileNotFoundError must be raised with actionable message."""
|
||||
from pathlib import Path
|
||||
from scripts.finetune_classifier import load_and_prepare_data
|
||||
|
||||
with pytest.raises(FileNotFoundError, match="email_score.jsonl"):
|
||||
load_and_prepare_data(Path("/nonexistent/email_score.jsonl"))
|
||||
|
||||
|
||||
def test_load_and_prepare_data_drops_class_with_fewer_than_2_samples(tmp_path, capsys):
|
||||
"""Classes with < 2 total samples must be dropped with a warning."""
|
||||
from scripts.finetune_classifier import load_and_prepare_data
|
||||
|
||||
rows = [
|
||||
{"subject": "s1", "body": "b", "label": "digest"},
|
||||
{"subject": "s2", "body": "b", "label": "digest"},
|
||||
{"subject": "s3", "body": "b", "label": "new_lead"}, # only 1 sample — drop
|
||||
]
|
||||
score_file = tmp_path / "email_score.jsonl"
|
||||
score_file.write_text("\n".join(json.dumps(r) for r in rows))
|
||||
|
||||
texts, labels = load_and_prepare_data(score_file)
|
||||
captured = capsys.readouterr()
|
||||
|
||||
assert "new_lead" not in labels
|
||||
assert "new_lead" in captured.out # warning printed
|
||||
|
||||
|
||||
# ---- Class weights tests ----
|
||||
|
||||
def test_compute_class_weights_returns_tensor_for_each_class():
|
||||
"""compute_class_weights must return a float tensor of length n_classes."""
|
||||
import torch
|
||||
from scripts.finetune_classifier import compute_class_weights
|
||||
|
||||
label_ids = [0, 0, 0, 1, 1, 2] # 3 classes, imbalanced
|
||||
weights = compute_class_weights(label_ids, n_classes=3)
|
||||
|
||||
assert isinstance(weights, torch.Tensor)
|
||||
assert weights.shape == (3,)
|
||||
assert all(w > 0 for w in weights)
|
||||
|
||||
|
||||
def test_compute_class_weights_upweights_minority():
|
||||
"""Minority classes must receive higher weight than majority classes."""
|
||||
from scripts.finetune_classifier import compute_class_weights
|
||||
|
||||
# Class 0: 10 samples, Class 1: 2 samples
|
||||
label_ids = [0] * 10 + [1] * 2
|
||||
weights = compute_class_weights(label_ids, n_classes=2)
|
||||
|
||||
assert weights[1] > weights[0]
|
||||
|
||||
|
||||
# ---- compute_metrics_for_trainer tests ----
|
||||
|
||||
def test_compute_metrics_for_trainer_returns_macro_f1_key():
|
||||
"""Must return a dict with 'macro_f1' key."""
|
||||
import numpy as np
|
||||
from scripts.finetune_classifier import compute_metrics_for_trainer
|
||||
from transformers import EvalPrediction
|
||||
|
||||
logits = np.array([[2.0, 0.1], [0.1, 2.0], [2.0, 0.1]])
|
||||
labels = np.array([0, 1, 0])
|
||||
pred = EvalPrediction(predictions=logits, label_ids=labels)
|
||||
|
||||
result = compute_metrics_for_trainer(pred)
|
||||
assert "macro_f1" in result
|
||||
assert result["macro_f1"] == pytest.approx(1.0)
|
||||
|
||||
|
||||
def test_compute_metrics_for_trainer_returns_accuracy_key():
|
||||
"""Must also return 'accuracy' key."""
|
||||
import numpy as np
|
||||
from scripts.finetune_classifier import compute_metrics_for_trainer
|
||||
from transformers import EvalPrediction
|
||||
|
||||
logits = np.array([[2.0, 0.1], [0.1, 2.0]])
|
||||
labels = np.array([0, 1])
|
||||
pred = EvalPrediction(predictions=logits, label_ids=labels)
|
||||
|
||||
result = compute_metrics_for_trainer(pred)
|
||||
assert "accuracy" in result
|
||||
assert result["accuracy"] == pytest.approx(1.0)
|
||||
|
||||
|
||||
# ---- WeightedTrainer tests ----
|
||||
|
||||
def test_weighted_trainer_compute_loss_returns_scalar():
|
||||
"""compute_loss must return a scalar tensor when return_outputs=False."""
|
||||
import torch
|
||||
from unittest.mock import MagicMock
|
||||
from scripts.finetune_classifier import WeightedTrainer
|
||||
|
||||
n_classes = 3
|
||||
batch = 4
|
||||
logits = torch.randn(batch, n_classes)
|
||||
|
||||
mock_outputs = MagicMock()
|
||||
mock_outputs.logits = logits
|
||||
mock_model = MagicMock(return_value=mock_outputs)
|
||||
|
||||
trainer = WeightedTrainer.__new__(WeightedTrainer)
|
||||
trainer.class_weights = torch.ones(n_classes)
|
||||
|
||||
inputs = {
|
||||
"input_ids": torch.zeros(batch, 10, dtype=torch.long),
|
||||
"labels": torch.randint(0, n_classes, (batch,)),
|
||||
}
|
||||
|
||||
loss = trainer.compute_loss(mock_model, inputs, return_outputs=False)
|
||||
assert isinstance(loss, torch.Tensor)
|
||||
assert loss.ndim == 0 # scalar
|
||||
|
||||
|
||||
def test_weighted_trainer_compute_loss_accepts_kwargs():
|
||||
"""compute_loss must not raise TypeError when called with num_items_in_batch kwarg."""
|
||||
import torch
|
||||
from unittest.mock import MagicMock
|
||||
from scripts.finetune_classifier import WeightedTrainer
|
||||
|
||||
n_classes = 3
|
||||
batch = 2
|
||||
logits = torch.randn(batch, n_classes)
|
||||
|
||||
mock_outputs = MagicMock()
|
||||
mock_outputs.logits = logits
|
||||
mock_model = MagicMock(return_value=mock_outputs)
|
||||
|
||||
trainer = WeightedTrainer.__new__(WeightedTrainer)
|
||||
trainer.class_weights = torch.ones(n_classes)
|
||||
|
||||
inputs = {
|
||||
"input_ids": torch.zeros(batch, 5, dtype=torch.long),
|
||||
"labels": torch.randint(0, n_classes, (batch,)),
|
||||
}
|
||||
|
||||
loss = trainer.compute_loss(mock_model, inputs, return_outputs=False,
|
||||
num_items_in_batch=batch)
|
||||
assert isinstance(loss, torch.Tensor)
|
||||
|
||||
|
||||
def test_weighted_trainer_weighted_loss_differs_from_unweighted():
|
||||
"""Weighted loss must differ from uniform-weight loss for imbalanced inputs."""
|
||||
import torch
|
||||
from unittest.mock import MagicMock
|
||||
from scripts.finetune_classifier import WeightedTrainer
|
||||
|
||||
n_classes = 2
|
||||
batch = 4
|
||||
# Mixed labels: 3× class-0, 1× class-1.
|
||||
# Asymmetric logits (class-0 samples predicted well, class-1 predicted poorly)
|
||||
# ensure per-class CE values differ, so re-weighting changes the weighted mean.
|
||||
labels = torch.tensor([0, 0, 0, 1], dtype=torch.long)
|
||||
logits = torch.tensor([[3.0, -1.0], [3.0, -1.0], [3.0, -1.0], [0.5, 0.5]])
|
||||
|
||||
mock_outputs = MagicMock()
|
||||
mock_outputs.logits = logits
|
||||
|
||||
trainer_uniform = WeightedTrainer.__new__(WeightedTrainer)
|
||||
trainer_uniform.class_weights = torch.ones(n_classes)
|
||||
inputs_uniform = {"input_ids": torch.zeros(batch, 5, dtype=torch.long), "labels": labels.clone()}
|
||||
loss_uniform = trainer_uniform.compute_loss(MagicMock(return_value=mock_outputs),
|
||||
inputs_uniform)
|
||||
|
||||
trainer_weighted = WeightedTrainer.__new__(WeightedTrainer)
|
||||
trainer_weighted.class_weights = torch.tensor([0.1, 10.0])
|
||||
inputs_weighted = {"input_ids": torch.zeros(batch, 5, dtype=torch.long), "labels": labels.clone()}
|
||||
|
||||
mock_outputs2 = MagicMock()
|
||||
mock_outputs2.logits = logits.clone()
|
||||
loss_weighted = trainer_weighted.compute_loss(MagicMock(return_value=mock_outputs2),
|
||||
inputs_weighted)
|
||||
|
||||
assert not torch.isclose(loss_uniform, loss_weighted)
|
||||
|
||||
|
||||
def test_weighted_trainer_compute_loss_returns_outputs_when_requested():
|
||||
"""compute_loss with return_outputs=True must return (loss, outputs) tuple."""
|
||||
import torch
|
||||
from unittest.mock import MagicMock
|
||||
from scripts.finetune_classifier import WeightedTrainer
|
||||
|
||||
n_classes = 3
|
||||
batch = 2
|
||||
logits = torch.randn(batch, n_classes)
|
||||
|
||||
mock_outputs = MagicMock()
|
||||
mock_outputs.logits = logits
|
||||
mock_model = MagicMock(return_value=mock_outputs)
|
||||
|
||||
trainer = WeightedTrainer.__new__(WeightedTrainer)
|
||||
trainer.class_weights = torch.ones(n_classes)
|
||||
|
||||
inputs = {
|
||||
"input_ids": torch.zeros(batch, 5, dtype=torch.long),
|
||||
"labels": torch.randint(0, n_classes, (batch,)),
|
||||
}
|
||||
|
||||
result = trainer.compute_loss(mock_model, inputs, return_outputs=True)
|
||||
assert isinstance(result, tuple)
|
||||
loss, outputs = result
|
||||
assert isinstance(loss, torch.Tensor)
|
||||
|
||||
|
||||
# ---- Multi-file merge / dedup tests ----
|
||||
|
||||
def test_load_and_prepare_data_merges_multiple_files(tmp_path):
|
||||
"""Multiple score files must be merged into a single dataset."""
|
||||
from scripts.finetune_classifier import load_and_prepare_data
|
||||
|
||||
file1 = tmp_path / "run1.jsonl"
|
||||
file2 = tmp_path / "run2.jsonl"
|
||||
file1.write_text(
|
||||
json.dumps({"subject": "s1", "body": "b1", "label": "digest"}) + "\n" +
|
||||
json.dumps({"subject": "s2", "body": "b2", "label": "digest"}) + "\n"
|
||||
)
|
||||
file2.write_text(
|
||||
json.dumps({"subject": "s3", "body": "b3", "label": "neutral"}) + "\n" +
|
||||
json.dumps({"subject": "s4", "body": "b4", "label": "neutral"}) + "\n"
|
||||
)
|
||||
|
||||
texts, labels = load_and_prepare_data([file1, file2])
|
||||
assert len(texts) == 4
|
||||
assert labels.count("digest") == 2
|
||||
assert labels.count("neutral") == 2
|
||||
|
||||
|
||||
def test_load_and_prepare_data_deduplicates_last_write_wins(tmp_path, capsys):
|
||||
"""Duplicate rows (same content hash) keep the last occurrence."""
|
||||
from scripts.finetune_classifier import load_and_prepare_data
|
||||
|
||||
# Same subject+body[:100] = same hash
|
||||
row_early = {"subject": "Hello", "body": "World", "label": "neutral"}
|
||||
row_late = {"subject": "Hello", "body": "World", "label": "digest"} # relabeled
|
||||
|
||||
file1 = tmp_path / "run1.jsonl"
|
||||
file2 = tmp_path / "run2.jsonl"
|
||||
# Add a second row with different content so class count >= 2 for both classes
|
||||
file1.write_text(
|
||||
json.dumps(row_early) + "\n" +
|
||||
json.dumps({"subject": "Other1", "body": "Other", "label": "neutral"}) + "\n"
|
||||
)
|
||||
file2.write_text(
|
||||
json.dumps(row_late) + "\n" +
|
||||
json.dumps({"subject": "Other2", "body": "Stuff", "label": "digest"}) + "\n"
|
||||
)
|
||||
|
||||
texts, labels = load_and_prepare_data([file1, file2])
|
||||
captured = capsys.readouterr()
|
||||
|
||||
# The duplicate row should be counted as dropped
|
||||
assert "Deduped" in captured.out
|
||||
# The relabeled row should have "digest" (last-write wins), not "neutral"
|
||||
hello_idx = next(i for i, t in enumerate(texts) if t.startswith("Hello [SEP]"))
|
||||
assert labels[hello_idx] == "digest"
|
||||
|
||||
|
||||
def test_load_and_prepare_data_single_path_still_works(tmp_path):
|
||||
"""Passing a single Path (not a list) must still work — backwards compatibility."""
|
||||
from scripts.finetune_classifier import load_and_prepare_data
|
||||
|
||||
rows = [
|
||||
{"subject": "s1", "body": "b1", "label": "digest"},
|
||||
{"subject": "s2", "body": "b2", "label": "digest"},
|
||||
]
|
||||
score_file = tmp_path / "email_score.jsonl"
|
||||
score_file.write_text("\n".join(json.dumps(r) for r in rows))
|
||||
|
||||
texts, labels = load_and_prepare_data(score_file) # single Path, not list
|
||||
assert len(texts) == 2
|
||||
|
||||
|
||||
# ---- Integration test ----
|
||||
|
||||
def test_integration_finetune_on_example_data(tmp_path):
|
||||
"""Fine-tune deberta-small on example data for 1 epoch.
|
||||
|
||||
Uses data/email_score.jsonl.example (8 samples, 5 labels represented).
|
||||
The 5 missing labels must trigger the < 2 samples drop warning.
|
||||
Verifies training_info.json is written with correct keys.
|
||||
|
||||
Requires job-seeker-classifiers env and downloads deberta-small (~100MB on first run).
|
||||
"""
|
||||
import shutil
|
||||
from scripts import finetune_classifier as ft_mod
|
||||
from scripts.finetune_classifier import run_finetune
|
||||
|
||||
example_file = ft_mod._ROOT / "data" / "email_score.jsonl.example"
|
||||
if not example_file.exists():
|
||||
pytest.skip("email_score.jsonl.example not found")
|
||||
|
||||
orig_root = ft_mod._ROOT
|
||||
ft_mod._ROOT = tmp_path
|
||||
(tmp_path / "data").mkdir()
|
||||
shutil.copy(example_file, tmp_path / "data" / "email_score.jsonl")
|
||||
|
||||
try:
|
||||
import io
|
||||
from contextlib import redirect_stdout
|
||||
captured = io.StringIO()
|
||||
with redirect_stdout(captured):
|
||||
run_finetune("deberta-small", epochs=1)
|
||||
output = captured.getvalue()
|
||||
finally:
|
||||
ft_mod._ROOT = orig_root
|
||||
|
||||
# Missing labels should trigger the < 2 samples drop warning
|
||||
assert "WARNING: Dropping class" in output
|
||||
|
||||
# training_info.json must exist with correct keys
|
||||
info_path = tmp_path / "models" / "avocet-deberta-small" / "training_info.json"
|
||||
assert info_path.exists(), "training_info.json not written"
|
||||
|
||||
info = json.loads(info_path.read_text())
|
||||
for key in ("name", "base_model_id", "timestamp", "epochs_run",
|
||||
"val_macro_f1", "val_accuracy", "sample_count", "train_sample_count",
|
||||
"label_counts", "score_files"):
|
||||
assert key in info, f"Missing key: {key}"
|
||||
|
||||
assert info["name"] == "avocet-deberta-small"
|
||||
assert info["epochs_run"] == 1
|
||||
assert isinstance(info["score_files"], list)
|
||||
|
|
@ -4,7 +4,12 @@
|
|||
<meta charset="UTF-8" />
|
||||
<link rel="icon" type="image/svg+xml" href="/vite.svg" />
|
||||
<meta name="viewport" content="width=device-width, initial-scale=1.0" />
|
||||
<title>web</title>
|
||||
<title>Avocet — Label Tool</title>
|
||||
<!-- Inline background prevents blank-white flash before the CSS bundle loads -->
|
||||
<style>
|
||||
html, body { margin: 0; background: #eaeff8; min-height: 100vh; }
|
||||
@media (prefers-color-scheme: dark) { html, body { background: #16202e; } }
|
||||
</style>
|
||||
</head>
|
||||
<body>
|
||||
<div id="app"></div>
|
||||
|
|
|
|||
11
web/package-lock.json
generated
11
web/package-lock.json
generated
|
|
@ -13,6 +13,7 @@
|
|||
"@fontsource/jetbrains-mono": "^5.2.8",
|
||||
"@vueuse/core": "^14.2.1",
|
||||
"@vueuse/integrations": "^14.2.1",
|
||||
"animejs": "^4.3.6",
|
||||
"pinia": "^3.0.4",
|
||||
"vue": "^3.5.25",
|
||||
"vue-router": "^5.0.3"
|
||||
|
|
@ -2570,6 +2571,16 @@
|
|||
"dev": true,
|
||||
"license": "MIT"
|
||||
},
|
||||
"node_modules/animejs": {
|
||||
"version": "4.3.6",
|
||||
"resolved": "https://registry.npmjs.org/animejs/-/animejs-4.3.6.tgz",
|
||||
"integrity": "sha512-rzZ4bDc8JAtyx6hYwxj7s5M/yWfnM5qqY4hZDnhy1cWFvMb6H5/necHS2sbCY3WQTDbRLuZL10dPXSxSCFOr/w==",
|
||||
"license": "MIT",
|
||||
"funding": {
|
||||
"type": "github",
|
||||
"url": "https://github.com/sponsors/juliangarnier"
|
||||
}
|
||||
},
|
||||
"node_modules/ansi-regex": {
|
||||
"version": "6.2.2",
|
||||
"resolved": "https://registry.npmjs.org/ansi-regex/-/ansi-regex-6.2.2.tgz",
|
||||
|
|
|
|||
|
|
@ -16,6 +16,7 @@
|
|||
"@fontsource/jetbrains-mono": "^5.2.8",
|
||||
"@vueuse/core": "^14.2.1",
|
||||
"@vueuse/integrations": "^14.2.1",
|
||||
"animejs": "^4.3.6",
|
||||
"pinia": "^3.0.4",
|
||||
"vue": "^3.5.25",
|
||||
"vue-router": "^5.0.3"
|
||||
|
|
|
|||
|
|
@ -11,11 +11,13 @@
|
|||
import { onMounted } from 'vue'
|
||||
import { RouterView } from 'vue-router'
|
||||
import { useMotion } from './composables/useMotion'
|
||||
import { useHackerMode } from './composables/useEasterEgg'
|
||||
import { useHackerMode, useKonamiCode } from './composables/useEasterEgg'
|
||||
import AppSidebar from './components/AppSidebar.vue'
|
||||
|
||||
const motion = useMotion()
|
||||
const { restore } = useHackerMode()
|
||||
const { toggle, restore } = useHackerMode()
|
||||
|
||||
useKonamiCode(toggle)
|
||||
|
||||
onMounted(() => {
|
||||
restore() // re-apply hacker mode from localStorage on page load
|
||||
|
|
|
|||
|
|
@ -8,8 +8,29 @@
|
|||
Accent — Russet (#B8622A) — inspired by avocet's vivid orange-russet head
|
||||
*/
|
||||
|
||||
/* ── Page-level overrides — must be in avocet.css (applied after theme.css base) ── */
|
||||
html {
|
||||
/* Prevent Mac Chrome's horizontal swipe-to-navigate page animation
|
||||
from triggering when the user scrolls near the viewport edge */
|
||||
overscroll-behavior-x: none;
|
||||
/* clip (not hidden) — prevents overflowing content from expanding the html layout
|
||||
width beyond the viewport. Without this, body's overflow-x:hidden propagates to
|
||||
the viewport and body has no BFC, so long email URLs inflate the layout and
|
||||
margin:0 auto centering drifts rightward as fonts load. */
|
||||
overflow-x: clip;
|
||||
}
|
||||
|
||||
body {
|
||||
/* Prevent horizontal scroll from card swipe animations */
|
||||
overflow-x: hidden;
|
||||
}
|
||||
|
||||
|
||||
/* ── Light mode (default) ──────────────────────────── */
|
||||
:root {
|
||||
/* Aliases bridging avocet component vars to CircuitForge base theme vars */
|
||||
--color-bg: var(--color-surface); /* App.vue body bg → #eaeff8 in light */
|
||||
--color-text-secondary: var(--color-text-muted); /* muted label text */
|
||||
/* Primary — Slate Teal */
|
||||
--app-primary: #2A6080; /* 4.8:1 on light surface #eaeff8 — ✅ AA */
|
||||
--app-primary-hover: #1E4D66; /* darker for hover */
|
||||
|
|
|
|||
|
|
@ -62,10 +62,11 @@ import { RouterLink } from 'vue-router'
|
|||
const LS_KEY = 'cf-avocet-nav-stowed'
|
||||
|
||||
const navItems = [
|
||||
{ path: '/', icon: '🃏', label: 'Label' },
|
||||
{ path: '/fetch', icon: '📥', label: 'Fetch' },
|
||||
{ path: '/stats', icon: '📊', label: 'Stats' },
|
||||
{ path: '/settings', icon: '⚙️', label: 'Settings' },
|
||||
{ path: '/', icon: '🃏', label: 'Label' },
|
||||
{ path: '/fetch', icon: '📥', label: 'Fetch' },
|
||||
{ path: '/stats', icon: '📊', label: 'Stats' },
|
||||
{ path: '/benchmark', icon: '🏁', label: 'Benchmark' },
|
||||
{ path: '/settings', icon: '⚙️', label: 'Settings' },
|
||||
]
|
||||
|
||||
const stowed = ref(localStorage.getItem(LS_KEY) === 'true')
|
||||
|
|
|
|||
|
|
@ -86,6 +86,7 @@ const displayBody = computed(() => {
|
|||
font-size: 0.9375rem;
|
||||
line-height: 1.6;
|
||||
white-space: pre-wrap;
|
||||
overflow-wrap: break-word;
|
||||
margin: 0;
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -2,6 +2,20 @@ import { mount } from '@vue/test-utils'
|
|||
import EmailCardStack from './EmailCardStack.vue'
|
||||
import { describe, it, expect, vi } from 'vitest'
|
||||
|
||||
vi.mock('../composables/useCardAnimation', () => ({
|
||||
useCardAnimation: vi.fn(() => ({
|
||||
pickup: vi.fn(),
|
||||
setDragPosition: vi.fn(),
|
||||
snapBack: vi.fn(),
|
||||
animateDismiss: vi.fn(),
|
||||
updateAura: vi.fn(),
|
||||
reset: vi.fn(),
|
||||
})),
|
||||
}))
|
||||
|
||||
import { useCardAnimation } from '../composables/useCardAnimation'
|
||||
import { nextTick } from 'vue'
|
||||
|
||||
const item = {
|
||||
id: 'abc',
|
||||
subject: 'Interview at Acme',
|
||||
|
|
@ -22,27 +36,13 @@ describe('EmailCardStack', () => {
|
|||
expect(w.findAll('.card-shadow')).toHaveLength(2)
|
||||
})
|
||||
|
||||
it('applies dismiss-label class when dismissType is label', () => {
|
||||
const w = mount(EmailCardStack, { props: { item, isBucketMode: false, dismissType: 'label' } })
|
||||
expect(w.find('.card-wrapper').classes()).toContain('dismiss-label')
|
||||
})
|
||||
|
||||
it('applies dismiss-discard class when dismissType is discard', () => {
|
||||
const w = mount(EmailCardStack, { props: { item, isBucketMode: false, dismissType: 'discard' } })
|
||||
expect(w.find('.card-wrapper').classes()).toContain('dismiss-discard')
|
||||
})
|
||||
|
||||
it('applies dismiss-skip class when dismissType is skip', () => {
|
||||
const w = mount(EmailCardStack, { props: { item, isBucketMode: false, dismissType: 'skip' } })
|
||||
expect(w.find('.card-wrapper').classes()).toContain('dismiss-skip')
|
||||
})
|
||||
|
||||
it('no dismiss class when dismissType is null', () => {
|
||||
it('calls animateDismiss with type when dismissType prop changes', async () => {
|
||||
;(useCardAnimation as ReturnType<typeof vi.fn>).mockClear()
|
||||
const w = mount(EmailCardStack, { props: { item, isBucketMode: false, dismissType: null } })
|
||||
const wrapperClasses = w.find('.card-wrapper').classes()
|
||||
expect(wrapperClasses).not.toContain('dismiss-label')
|
||||
expect(wrapperClasses).not.toContain('dismiss-discard')
|
||||
expect(wrapperClasses).not.toContain('dismiss-skip')
|
||||
const { animateDismiss } = (useCardAnimation as ReturnType<typeof vi.fn>).mock.results[0].value
|
||||
await w.setProps({ dismissType: 'label' })
|
||||
await nextTick()
|
||||
expect(animateDismiss).toHaveBeenCalledWith('label')
|
||||
})
|
||||
|
||||
// JSDOM doesn't implement setPointerCapture — mock it on the element.
|
||||
|
|
|
|||
|
|
@ -11,8 +11,7 @@
|
|||
<div
|
||||
class="card-wrapper"
|
||||
ref="cardEl"
|
||||
:class="[dismissClass, { 'is-held': isHeld }]"
|
||||
:style="cardStyle"
|
||||
:class="{ 'is-held': isHeld }"
|
||||
@pointerdown="onPointerDown"
|
||||
@pointermove="onPointerMove"
|
||||
@pointerup="onPointerUp"
|
||||
|
|
@ -29,8 +28,9 @@
|
|||
</template>
|
||||
|
||||
<script setup lang="ts">
|
||||
import { ref, computed } from 'vue'
|
||||
import { ref, watch } from 'vue'
|
||||
import { useMotion } from '../composables/useMotion'
|
||||
import { useCardAnimation } from '../composables/useCardAnimation'
|
||||
import EmailCard from './EmailCard.vue'
|
||||
import type { QueueItem } from '../stores/label'
|
||||
|
||||
|
|
@ -54,12 +54,22 @@ const motion = useMotion()
|
|||
const cardEl = ref<HTMLElement | null>(null)
|
||||
const isExpanded = ref(false)
|
||||
|
||||
const { pickup, setDragPosition, snapBack, animateDismiss, updateAura, reset } = useCardAnimation(cardEl, motion)
|
||||
|
||||
watch(() => props.dismissType, (type) => {
|
||||
if (type) animateDismiss(type)
|
||||
})
|
||||
|
||||
// When a new card loads into the same element, clear any inline styles left by the previous animation
|
||||
watch(() => props.item.id, () => {
|
||||
reset()
|
||||
isExpanded.value = false
|
||||
})
|
||||
|
||||
// Toss gesture state
|
||||
const isHeld = ref(false)
|
||||
const pickupX = ref(0)
|
||||
const pickupY = ref(0)
|
||||
const deltaX = ref(0)
|
||||
const deltaY = ref(0)
|
||||
const hoveredZone = ref<'discard' | 'skip' | null>(null)
|
||||
const hoveredBucketName = ref<string | null>(null)
|
||||
|
||||
|
|
@ -74,13 +84,14 @@ const FLING_WINDOW_MS = 50 // rolling sample window in ms
|
|||
let velocityBuf: { x: number; y: number; t: number }[] = []
|
||||
|
||||
function onPointerDown(e: PointerEvent) {
|
||||
// Let clicks on interactive children (expand/collapse, links, etc.) pass through
|
||||
if ((e.target as Element).closest('button, a, input, select, textarea')) return
|
||||
if (!motion.rich.value) return
|
||||
;(e.currentTarget as HTMLElement).setPointerCapture(e.pointerId)
|
||||
pickupX.value = e.clientX
|
||||
pickupY.value = e.clientY
|
||||
deltaX.value = 0
|
||||
deltaY.value = 0
|
||||
isHeld.value = true
|
||||
pickup()
|
||||
hoveredZone.value = null
|
||||
hoveredBucketName.value = null
|
||||
velocityBuf = []
|
||||
|
|
@ -89,8 +100,9 @@ function onPointerDown(e: PointerEvent) {
|
|||
|
||||
function onPointerMove(e: PointerEvent) {
|
||||
if (!isHeld.value) return
|
||||
deltaX.value = e.clientX - pickupX.value
|
||||
deltaY.value = e.clientY - pickupY.value
|
||||
const dx = e.clientX - pickupX.value
|
||||
const dy = e.clientY - pickupY.value
|
||||
setDragPosition(dx, dy)
|
||||
|
||||
// Rolling velocity buffer — keep only the last FLING_WINDOW_MS of samples
|
||||
const now = performance.now()
|
||||
|
|
@ -118,6 +130,7 @@ function onPointerMove(e: PointerEvent) {
|
|||
hoveredBucketName.value = bucketName
|
||||
emit('bucket-hover', bucketName)
|
||||
}
|
||||
updateAura(hoveredZone.value, hoveredBucketName.value)
|
||||
}
|
||||
|
||||
function onPointerUp(e: PointerEvent) {
|
||||
|
|
@ -163,9 +176,9 @@ function onPointerUp(e: PointerEvent) {
|
|||
hoveredBucketName.value = null
|
||||
emit('label', name)
|
||||
} else {
|
||||
// Snap back — reset deltas
|
||||
deltaX.value = 0
|
||||
deltaY.value = 0
|
||||
// Snap back
|
||||
snapBack()
|
||||
updateAura(null, null)
|
||||
hoveredZone.value = null
|
||||
hoveredBucketName.value = null
|
||||
}
|
||||
|
|
@ -175,8 +188,8 @@ function onPointerCancel(e: PointerEvent) {
|
|||
if (!isHeld.value) return
|
||||
;(e.currentTarget as HTMLElement).releasePointerCapture(e.pointerId)
|
||||
isHeld.value = false
|
||||
deltaX.value = 0
|
||||
deltaY.value = 0
|
||||
snapBack()
|
||||
updateAura(null, null)
|
||||
hoveredZone.value = null
|
||||
hoveredBucketName.value = null
|
||||
velocityBuf = []
|
||||
|
|
@ -184,32 +197,6 @@ function onPointerCancel(e: PointerEvent) {
|
|||
emit('zone-hover', null)
|
||||
emit('bucket-hover', null)
|
||||
}
|
||||
|
||||
const dismissClass = computed(() => {
|
||||
if (!props.dismissType) return null
|
||||
return `dismiss-${props.dismissType}`
|
||||
})
|
||||
|
||||
const cardStyle = computed(() => {
|
||||
if (!motion.rich.value || !isHeld.value) return {}
|
||||
|
||||
// Aura color: zone > bucket > neutral
|
||||
const aura =
|
||||
hoveredZone.value === 'discard' ? 'rgba(244,67,54,0.25)' :
|
||||
hoveredZone.value === 'skip' ? 'rgba(255,152,0,0.25)' :
|
||||
hoveredBucketName.value ? 'rgba(42,96,128,0.20)' :
|
||||
'transparent'
|
||||
|
||||
return {
|
||||
transform: `translate(${deltaX.value}px, ${deltaY.value - 80}px) scale(0.55)`,
|
||||
borderRadius: '50%',
|
||||
background: aura,
|
||||
transition: 'border-radius 150ms ease, background 150ms ease',
|
||||
cursor: 'grabbing',
|
||||
zIndex: 100,
|
||||
userSelect: 'none',
|
||||
}
|
||||
})
|
||||
</script>
|
||||
|
||||
<style scoped>
|
||||
|
|
@ -276,30 +263,6 @@ const cardStyle = computed(() => {
|
|||
pointer-events: auto !important;
|
||||
}
|
||||
|
||||
/* Dismissal animations — dismiss class is only applied during the motion.rich await window,
|
||||
so no ancestor guard needed; :global(.rich-motion) was being miscompiled by Vue's scoped
|
||||
CSS transformer (dropping the descendant selector entirely). */
|
||||
.card-wrapper.dismiss-label {
|
||||
animation: fileAway var(--card-dismiss, 350ms ease-in) forwards;
|
||||
}
|
||||
.card-wrapper.dismiss-discard {
|
||||
animation: crumple var(--card-dismiss, 350ms ease-in) forwards;
|
||||
}
|
||||
.card-wrapper.dismiss-skip {
|
||||
animation: slideUnder var(--card-skip, 300ms ease-out) forwards;
|
||||
}
|
||||
|
||||
@keyframes fileAway {
|
||||
to { transform: translateY(-120%) scale(0.85); opacity: 0; }
|
||||
}
|
||||
@keyframes crumple {
|
||||
50% { transform: scale(0.95) rotate(2deg); filter: brightness(0.6) sepia(1) hue-rotate(-20deg); }
|
||||
to { transform: scale(0) rotate(8deg); opacity: 0; }
|
||||
}
|
||||
@keyframes slideUnder {
|
||||
to { transform: translateX(110%) rotate(5deg); opacity: 0; }
|
||||
}
|
||||
|
||||
@media (prefers-reduced-motion: reduce) {
|
||||
.card-stack,
|
||||
.card-wrapper {
|
||||
|
|
|
|||
|
|
@ -1,7 +1,7 @@
|
|||
<template>
|
||||
<div class="label-grid" :class="{ 'bucket-mode': isBucketMode }" role="group" aria-label="Label buttons">
|
||||
<button
|
||||
v-for="label in labels"
|
||||
v-for="label in displayLabels"
|
||||
:key="label.key"
|
||||
data-testid="label-btn"
|
||||
:data-label-key="label.name"
|
||||
|
|
@ -19,6 +19,8 @@
|
|||
</template>
|
||||
|
||||
<script setup lang="ts">
|
||||
import { computed } from 'vue'
|
||||
|
||||
interface Label { name: string; emoji: string; color: string; key: string }
|
||||
|
||||
const props = defineProps<{
|
||||
|
|
@ -27,6 +29,16 @@ const props = defineProps<{
|
|||
hoveredBucket?: string | null
|
||||
}>()
|
||||
const emit = defineEmits<{ label: [name: string] }>()
|
||||
|
||||
// Numpad layout: reverse the row order of numeric keys (7-8-9 on top, 1-2-3 on bottom)
|
||||
// Non-numeric keys (e.g. 'h' for hired) stay pinned after the grid.
|
||||
const displayLabels = computed(() => {
|
||||
const numeric = props.labels.filter(l => !isNaN(Number(l.key)))
|
||||
const other = props.labels.filter(l => isNaN(Number(l.key)))
|
||||
const rows: Label[][] = []
|
||||
for (let i = 0; i < numeric.length; i += 3) rows.push(numeric.slice(i, i + 3))
|
||||
return [...rows.reverse().flat(), ...other]
|
||||
})
|
||||
</script>
|
||||
|
||||
<style scoped>
|
||||
|
|
@ -38,11 +50,9 @@ const emit = defineEmits<{ label: [name: string] }>()
|
|||
padding var(--bucket-expand, 250ms cubic-bezier(0.34, 1.56, 0.64, 1));
|
||||
}
|
||||
|
||||
/* 10th button (hired / key h) — centered below the 3×3 like a numpad 0 */
|
||||
/* 10th button (hired / key h) — full-width bar below the 3×3 */
|
||||
.label-btn:last-child {
|
||||
grid-column: 1 / -1;
|
||||
max-width: calc(33.333% - 0.34rem);
|
||||
justify-self: center;
|
||||
}
|
||||
|
||||
.label-grid.bucket-mode {
|
||||
|
|
|
|||
142
web/src/composables/useCardAnimation.test.ts
Normal file
142
web/src/composables/useCardAnimation.test.ts
Normal file
|
|
@ -0,0 +1,142 @@
|
|||
import { ref } from 'vue'
|
||||
import { describe, it, expect, vi, beforeEach } from 'vitest'
|
||||
|
||||
// Mock animejs before importing the composable
|
||||
vi.mock('animejs', () => ({
|
||||
animate: vi.fn(),
|
||||
spring: vi.fn(() => 'mock-spring'),
|
||||
utils: { set: vi.fn() },
|
||||
}))
|
||||
|
||||
import { useCardAnimation } from './useCardAnimation'
|
||||
import { animate, utils } from 'animejs'
|
||||
|
||||
const mockAnimate = animate as ReturnType<typeof vi.fn>
|
||||
const mockSet = utils.set as ReturnType<typeof vi.fn>
|
||||
|
||||
function makeEl() {
|
||||
return document.createElement('div')
|
||||
}
|
||||
|
||||
describe('useCardAnimation', () => {
|
||||
beforeEach(() => {
|
||||
vi.clearAllMocks()
|
||||
})
|
||||
|
||||
it('pickup() calls animate with ball shape', () => {
|
||||
const el = makeEl()
|
||||
const cardEl = ref<HTMLElement | null>(el)
|
||||
const motion = { rich: ref(true) }
|
||||
const { pickup } = useCardAnimation(cardEl, motion)
|
||||
pickup()
|
||||
expect(mockAnimate).toHaveBeenCalledWith(
|
||||
el,
|
||||
expect.objectContaining({ scale: 0.55, borderRadius: '50%' }),
|
||||
)
|
||||
})
|
||||
|
||||
it('pickup() is a no-op when motion.rich is false', () => {
|
||||
const el = makeEl()
|
||||
const cardEl = ref<HTMLElement | null>(el)
|
||||
const motion = { rich: ref(false) }
|
||||
const { pickup } = useCardAnimation(cardEl, motion)
|
||||
pickup()
|
||||
expect(mockAnimate).not.toHaveBeenCalled()
|
||||
})
|
||||
|
||||
it('setDragPosition() calls utils.set with translated coords', () => {
|
||||
const el = makeEl()
|
||||
const cardEl = ref<HTMLElement | null>(el)
|
||||
const motion = { rich: ref(true) }
|
||||
const { setDragPosition } = useCardAnimation(cardEl, motion)
|
||||
setDragPosition(50, 30)
|
||||
expect(mockSet).toHaveBeenCalledWith(el, expect.objectContaining({ x: 50, y: -50 }))
|
||||
// y = deltaY - 80 = 30 - 80 = -50
|
||||
})
|
||||
|
||||
it('snapBack() calls animate returning to card shape', () => {
|
||||
const el = makeEl()
|
||||
const cardEl = ref<HTMLElement | null>(el)
|
||||
const motion = { rich: ref(true) }
|
||||
const { snapBack } = useCardAnimation(cardEl, motion)
|
||||
snapBack()
|
||||
expect(mockAnimate).toHaveBeenCalledWith(
|
||||
el,
|
||||
expect.objectContaining({ x: 0, y: 0, scale: 1 }),
|
||||
)
|
||||
})
|
||||
|
||||
it('animateDismiss("label") calls animate', () => {
|
||||
const el = makeEl()
|
||||
const cardEl = ref<HTMLElement | null>(el)
|
||||
const motion = { rich: ref(true) }
|
||||
const { animateDismiss } = useCardAnimation(cardEl, motion)
|
||||
animateDismiss('label')
|
||||
expect(mockAnimate).toHaveBeenCalled()
|
||||
})
|
||||
|
||||
it('animateDismiss("discard") calls animate', () => {
|
||||
const el = makeEl()
|
||||
const cardEl = ref<HTMLElement | null>(el)
|
||||
const motion = { rich: ref(true) }
|
||||
const { animateDismiss } = useCardAnimation(cardEl, motion)
|
||||
animateDismiss('discard')
|
||||
expect(mockAnimate).toHaveBeenCalled()
|
||||
})
|
||||
|
||||
it('animateDismiss("skip") calls animate', () => {
|
||||
const el = makeEl()
|
||||
const cardEl = ref<HTMLElement | null>(el)
|
||||
const motion = { rich: ref(true) }
|
||||
const { animateDismiss } = useCardAnimation(cardEl, motion)
|
||||
animateDismiss('skip')
|
||||
expect(mockAnimate).toHaveBeenCalled()
|
||||
})
|
||||
|
||||
it('animateDismiss is a no-op when motion.rich is false', () => {
|
||||
const el = makeEl()
|
||||
const cardEl = ref<HTMLElement | null>(el)
|
||||
const motion = { rich: ref(false) }
|
||||
const { animateDismiss } = useCardAnimation(cardEl, motion)
|
||||
animateDismiss('label')
|
||||
expect(mockAnimate).not.toHaveBeenCalled()
|
||||
})
|
||||
|
||||
describe('updateAura', () => {
|
||||
it('sets red background for discard zone', () => {
|
||||
const el = makeEl()
|
||||
const cardEl = ref<HTMLElement | null>(el)
|
||||
const motion = { rich: ref(true) }
|
||||
const { updateAura } = useCardAnimation(cardEl, motion)
|
||||
updateAura('discard', null)
|
||||
expect(mockSet).toHaveBeenCalledWith(el, expect.objectContaining({ background: 'rgba(244, 67, 54, 0.25)' }))
|
||||
})
|
||||
|
||||
it('sets orange background for skip zone', () => {
|
||||
const el = makeEl()
|
||||
const cardEl = ref<HTMLElement | null>(el)
|
||||
const motion = { rich: ref(true) }
|
||||
const { updateAura } = useCardAnimation(cardEl, motion)
|
||||
updateAura('skip', null)
|
||||
expect(mockSet).toHaveBeenCalledWith(el, expect.objectContaining({ background: 'rgba(255, 152, 0, 0.25)' }))
|
||||
})
|
||||
|
||||
it('sets blue background for bucket hover', () => {
|
||||
const el = makeEl()
|
||||
const cardEl = ref<HTMLElement | null>(el)
|
||||
const motion = { rich: ref(true) }
|
||||
const { updateAura } = useCardAnimation(cardEl, motion)
|
||||
updateAura(null, 'interview_scheduled')
|
||||
expect(mockSet).toHaveBeenCalledWith(el, expect.objectContaining({ background: 'rgba(42, 96, 128, 0.20)' }))
|
||||
})
|
||||
|
||||
it('sets transparent background when no zone/bucket', () => {
|
||||
const el = makeEl()
|
||||
const cardEl = ref<HTMLElement | null>(el)
|
||||
const motion = { rich: ref(true) }
|
||||
const { updateAura } = useCardAnimation(cardEl, motion)
|
||||
updateAura(null, null)
|
||||
expect(mockSet).toHaveBeenCalledWith(el, expect.objectContaining({ background: 'transparent' }))
|
||||
})
|
||||
})
|
||||
})
|
||||
99
web/src/composables/useCardAnimation.ts
Normal file
99
web/src/composables/useCardAnimation.ts
Normal file
|
|
@ -0,0 +1,99 @@
|
|||
import { type Ref } from 'vue'
|
||||
import { animate, spring, utils } from 'animejs'
|
||||
|
||||
const BALL_SCALE = 0.55
|
||||
const BALL_RADIUS = '50%'
|
||||
const CARD_RADIUS = '1rem'
|
||||
const PICKUP_Y_OFFSET = 80 // px above finger
|
||||
const PICKUP_DURATION = 200
|
||||
|
||||
// Anime.js v4: spring() takes an object { mass, stiffness, damping, velocity }
|
||||
const SNAP_SPRING = spring({ mass: 1, stiffness: 80, damping: 10 })
|
||||
|
||||
interface Motion { rich: Ref<boolean> }
|
||||
|
||||
export function useCardAnimation(
|
||||
cardEl: Ref<HTMLElement | null>,
|
||||
motion: Motion,
|
||||
) {
|
||||
function pickup() {
|
||||
if (!motion.rich.value || !cardEl.value) return
|
||||
// Anime.js v4: animate(target, params) — all props + timing in one object
|
||||
animate(cardEl.value, {
|
||||
scale: BALL_SCALE,
|
||||
borderRadius: BALL_RADIUS,
|
||||
y: -PICKUP_Y_OFFSET,
|
||||
duration: PICKUP_DURATION,
|
||||
ease: SNAP_SPRING,
|
||||
})
|
||||
}
|
||||
|
||||
function setDragPosition(dx: number, dy: number) {
|
||||
if (!cardEl.value) return
|
||||
// utils.set() for instant (no-animation) position update — keeps Anime cache consistent
|
||||
utils.set(cardEl.value, { x: dx, y: dy - PICKUP_Y_OFFSET })
|
||||
}
|
||||
|
||||
function snapBack() {
|
||||
if (!motion.rich.value || !cardEl.value) return
|
||||
animate(cardEl.value, {
|
||||
x: 0,
|
||||
y: 0,
|
||||
scale: 1,
|
||||
borderRadius: CARD_RADIUS,
|
||||
ease: SNAP_SPRING,
|
||||
})
|
||||
}
|
||||
|
||||
function animateDismiss(type: 'label' | 'skip' | 'discard') {
|
||||
if (!motion.rich.value || !cardEl.value) return
|
||||
const el = cardEl.value
|
||||
if (type === 'label') {
|
||||
animate(el, { y: '-120%', scale: 0.85, opacity: 0, duration: 280, ease: 'out(3)' })
|
||||
} else if (type === 'discard') {
|
||||
// Anime.js v4 keyframe array: array of param objects, each can have its own duration
|
||||
animate(el, {
|
||||
keyframes: [
|
||||
{ scale: 0.95, rotate: 2, filter: 'brightness(0.6) sepia(1) hue-rotate(-20deg)', duration: 140 },
|
||||
{ scale: 0, rotate: 8, opacity: 0, duration: 210 },
|
||||
],
|
||||
})
|
||||
} else if (type === 'skip') {
|
||||
animate(el, { x: '110%', rotate: 5, opacity: 0, duration: 260, ease: 'out(2)' })
|
||||
}
|
||||
}
|
||||
|
||||
const AURA_COLORS = {
|
||||
discard: 'rgba(244, 67, 54, 0.25)',
|
||||
skip: 'rgba(255, 152, 0, 0.25)',
|
||||
bucket: 'rgba(42, 96, 128, 0.20)',
|
||||
none: 'transparent',
|
||||
} as const
|
||||
|
||||
function updateAura(zone: 'discard' | 'skip' | null, bucket: string | null) {
|
||||
if (!cardEl.value) return
|
||||
const color =
|
||||
zone === 'discard' ? AURA_COLORS.discard :
|
||||
zone === 'skip' ? AURA_COLORS.skip :
|
||||
bucket ? AURA_COLORS.bucket :
|
||||
AURA_COLORS.none
|
||||
utils.set(cardEl.value, { background: color })
|
||||
}
|
||||
|
||||
function reset() {
|
||||
if (!cardEl.value) return
|
||||
// Instantly restore initial card state — called when a new item loads into the same element
|
||||
utils.set(cardEl.value, {
|
||||
x: 0,
|
||||
y: 0,
|
||||
scale: 1,
|
||||
opacity: 1,
|
||||
rotate: 0,
|
||||
borderRadius: CARD_RADIUS,
|
||||
background: 'transparent',
|
||||
filter: 'none',
|
||||
})
|
||||
}
|
||||
|
||||
return { pickup, setDragPosition, snapBack, animateDismiss, updateAura, reset }
|
||||
}
|
||||
|
|
@ -1,14 +1,15 @@
|
|||
import { onMounted, onUnmounted } from 'vue'
|
||||
|
||||
const KONAMI = ['ArrowUp','ArrowUp','ArrowDown','ArrowDown','ArrowLeft','ArrowRight','ArrowLeft','ArrowRight','b','a']
|
||||
const KONAMI = ['ArrowUp','ArrowUp','ArrowDown','ArrowDown','ArrowLeft','ArrowRight','ArrowLeft','ArrowRight','b','a']
|
||||
const KONAMI_AB = ['ArrowUp','ArrowUp','ArrowDown','ArrowDown','ArrowLeft','ArrowRight','ArrowLeft','ArrowRight','a','b']
|
||||
|
||||
export function useKonamiCode(onActivate: () => void) {
|
||||
export function useKeySequence(sequence: string[], onActivate: () => void) {
|
||||
let pos = 0
|
||||
|
||||
function handler(e: KeyboardEvent) {
|
||||
if (e.key === KONAMI[pos]) {
|
||||
if (e.key === sequence[pos]) {
|
||||
pos++
|
||||
if (pos === KONAMI.length) {
|
||||
if (pos === sequence.length) {
|
||||
pos = 0
|
||||
onActivate()
|
||||
}
|
||||
|
|
@ -21,6 +22,11 @@ export function useKonamiCode(onActivate: () => void) {
|
|||
onUnmounted(() => window.removeEventListener('keydown', handler))
|
||||
}
|
||||
|
||||
export function useKonamiCode(onActivate: () => void) {
|
||||
useKeySequence(KONAMI, onActivate)
|
||||
useKeySequence(KONAMI_AB, onActivate)
|
||||
}
|
||||
|
||||
export function useHackerMode() {
|
||||
function toggle() {
|
||||
const root = document.documentElement
|
||||
|
|
|
|||
|
|
@ -2,16 +2,18 @@ import { createRouter, createWebHashHistory } from 'vue-router'
|
|||
import LabelView from '../views/LabelView.vue'
|
||||
|
||||
// Views are lazy-loaded to keep initial bundle small
|
||||
const FetchView = () => import('../views/FetchView.vue')
|
||||
const StatsView = () => import('../views/StatsView.vue')
|
||||
const SettingsView = () => import('../views/SettingsView.vue')
|
||||
const FetchView = () => import('../views/FetchView.vue')
|
||||
const StatsView = () => import('../views/StatsView.vue')
|
||||
const BenchmarkView = () => import('../views/BenchmarkView.vue')
|
||||
const SettingsView = () => import('../views/SettingsView.vue')
|
||||
|
||||
export const router = createRouter({
|
||||
history: createWebHashHistory(),
|
||||
routes: [
|
||||
{ path: '/', component: LabelView, meta: { title: 'Label' } },
|
||||
{ path: '/fetch', component: FetchView, meta: { title: 'Fetch' } },
|
||||
{ path: '/stats', component: StatsView, meta: { title: 'Stats' } },
|
||||
{ path: '/settings', component: SettingsView, meta: { title: 'Settings' } },
|
||||
{ path: '/', component: LabelView, meta: { title: 'Label' } },
|
||||
{ path: '/fetch', component: FetchView, meta: { title: 'Fetch' } },
|
||||
{ path: '/stats', component: StatsView, meta: { title: 'Stats' } },
|
||||
{ path: '/benchmark', component: BenchmarkView, meta: { title: 'Benchmark' } },
|
||||
{ path: '/settings', component: SettingsView, meta: { title: 'Settings' } },
|
||||
],
|
||||
})
|
||||
|
|
|
|||
846
web/src/views/BenchmarkView.vue
Normal file
846
web/src/views/BenchmarkView.vue
Normal file
|
|
@ -0,0 +1,846 @@
|
|||
<template>
|
||||
<div class="bench-view">
|
||||
<header class="bench-header">
|
||||
<h1 class="page-title">🏁 Benchmark</h1>
|
||||
<div class="header-actions">
|
||||
<label class="slow-toggle" :class="{ disabled: running }">
|
||||
<input type="checkbox" v-model="includeSlow" :disabled="running" />
|
||||
Include slow models
|
||||
</label>
|
||||
<button
|
||||
class="btn-run"
|
||||
:disabled="running"
|
||||
@click="startBenchmark"
|
||||
>
|
||||
{{ running ? '⏳ Running…' : results ? '🔄 Re-run' : '▶ Run Benchmark' }}
|
||||
</button>
|
||||
<button
|
||||
v-if="running"
|
||||
class="btn-cancel"
|
||||
@click="cancelBenchmark"
|
||||
>
|
||||
✕ Cancel
|
||||
</button>
|
||||
</div>
|
||||
</header>
|
||||
|
||||
<!-- Trained models badge row -->
|
||||
<div v-if="fineTunedModels.length > 0" class="trained-models-row">
|
||||
<span class="trained-label">Trained:</span>
|
||||
<span
|
||||
v-for="m in fineTunedModels"
|
||||
:key="m.name"
|
||||
class="trained-badge"
|
||||
:title="m.base_model_id ? `Base: ${m.base_model_id} · ${m.sample_count ?? '?'} samples` : m.name"
|
||||
>
|
||||
{{ m.name }}
|
||||
<span v-if="m.val_macro_f1 != null" class="trained-f1">
|
||||
F1 {{ (m.val_macro_f1 * 100).toFixed(1) }}%
|
||||
</span>
|
||||
</span>
|
||||
</div>
|
||||
|
||||
<!-- Progress log -->
|
||||
<div v-if="running || runLog.length" class="run-log">
|
||||
<div class="run-log-title">
|
||||
<span>{{ running ? '⏳ Running benchmark…' : runCancelled ? '⏹ Cancelled' : runError ? '❌ Failed' : '✅ Done' }}</span>
|
||||
<button class="btn-ghost" @click="runLog = []; runError = ''; runCancelled = false">Clear</button>
|
||||
</div>
|
||||
<div class="log-lines" ref="logEl">
|
||||
<div
|
||||
v-for="(line, i) in runLog"
|
||||
:key="i"
|
||||
class="log-line"
|
||||
:class="{ 'log-error': line.startsWith('ERROR') || line.startsWith('[error]') }"
|
||||
>{{ line }}</div>
|
||||
</div>
|
||||
<p v-if="runError" class="run-error">{{ runError }}</p>
|
||||
</div>
|
||||
|
||||
<!-- Loading -->
|
||||
<div v-if="loading" class="status-notice">Loading…</div>
|
||||
|
||||
<!-- No results yet -->
|
||||
<div v-else-if="!results" class="status-notice empty">
|
||||
<p>No benchmark results yet.</p>
|
||||
<p class="hint">Click <strong>Run Benchmark</strong> to score all default models against your labeled data.</p>
|
||||
</div>
|
||||
|
||||
<!-- Results -->
|
||||
<template v-else>
|
||||
<p class="meta-line">
|
||||
<span>{{ results.sample_count.toLocaleString() }} labeled emails</span>
|
||||
<span class="sep">·</span>
|
||||
<span>{{ modelCount }} model{{ modelCount === 1 ? '' : 's' }}</span>
|
||||
<span class="sep">·</span>
|
||||
<span>{{ formatDate(results.timestamp) }}</span>
|
||||
</p>
|
||||
|
||||
<!-- Macro-F1 chart -->
|
||||
<section class="chart-section">
|
||||
<h2 class="chart-title">Macro-F1 (higher = better)</h2>
|
||||
<div class="bar-chart">
|
||||
<div v-for="row in f1Rows" :key="row.name" class="bar-row">
|
||||
<span class="bar-label" :title="row.name">{{ row.name }}</span>
|
||||
<div class="bar-track">
|
||||
<div
|
||||
class="bar-fill"
|
||||
:style="{ width: `${row.pct}%`, background: scoreColor(row.value) }"
|
||||
/>
|
||||
</div>
|
||||
<span class="bar-value" :style="{ color: scoreColor(row.value) }">
|
||||
{{ row.value.toFixed(3) }}
|
||||
</span>
|
||||
</div>
|
||||
</div>
|
||||
</section>
|
||||
|
||||
<!-- Latency chart -->
|
||||
<section class="chart-section">
|
||||
<h2 class="chart-title">Latency (ms / email, lower = better)</h2>
|
||||
<div class="bar-chart">
|
||||
<div v-for="row in latencyRows" :key="row.name" class="bar-row">
|
||||
<span class="bar-label" :title="row.name">{{ row.name }}</span>
|
||||
<div class="bar-track">
|
||||
<div
|
||||
class="bar-fill latency-fill"
|
||||
:style="{ width: `${row.pct}%` }"
|
||||
/>
|
||||
</div>
|
||||
<span class="bar-value">{{ row.value.toFixed(1) }} ms</span>
|
||||
</div>
|
||||
</div>
|
||||
</section>
|
||||
|
||||
<!-- Per-label F1 heatmap -->
|
||||
<section class="chart-section">
|
||||
<h2 class="chart-title">Per-label F1</h2>
|
||||
<div class="heatmap-scroll">
|
||||
<table class="heatmap">
|
||||
<thead>
|
||||
<tr>
|
||||
<th class="hm-label-col">Label</th>
|
||||
<th v-for="name in modelNames" :key="name" class="hm-model-col" :title="name">
|
||||
{{ name }}
|
||||
</th>
|
||||
</tr>
|
||||
</thead>
|
||||
<tbody>
|
||||
<tr v-for="label in labelNames" :key="label">
|
||||
<td class="hm-label-cell">
|
||||
<span class="hm-emoji">{{ LABEL_META[label]?.emoji ?? '🏷️' }}</span>
|
||||
{{ label.replace(/_/g, '\u00a0') }}
|
||||
</td>
|
||||
<td
|
||||
v-for="name in modelNames"
|
||||
:key="name"
|
||||
class="hm-value-cell"
|
||||
:style="{ background: heatmapBg(f1For(name, label)), color: heatmapFg(f1For(name, label)) }"
|
||||
:title="`${name} / ${label}: F1 ${f1For(name, label).toFixed(3)}, support ${supportFor(name, label)}`"
|
||||
>
|
||||
{{ f1For(name, label).toFixed(2) }}
|
||||
</td>
|
||||
</tr>
|
||||
</tbody>
|
||||
</table>
|
||||
</div>
|
||||
<p class="heatmap-hint">Hover a cell for precision / recall / support. Color: 🟢 ≥ 0.7 · 🟡 0.4–0.7 · 🔴 < 0.4</p>
|
||||
</section>
|
||||
</template>
|
||||
|
||||
<!-- Fine-tune section -->
|
||||
<details class="ft-section">
|
||||
<summary class="ft-summary">Fine-tune a model</summary>
|
||||
<div class="ft-body">
|
||||
<div class="ft-controls">
|
||||
<label class="ft-field">
|
||||
<span class="ft-field-label">Model</span>
|
||||
<select v-model="ftModel" class="ft-select" :disabled="ftRunning">
|
||||
<option value="deberta-small">deberta-small (100M, fast)</option>
|
||||
<option value="bge-m3">bge-m3 (600M — stop Peregrine vLLM first)</option>
|
||||
</select>
|
||||
</label>
|
||||
<label class="ft-field">
|
||||
<span class="ft-field-label">Epochs</span>
|
||||
<input
|
||||
v-model.number="ftEpochs"
|
||||
type="number" min="1" max="20"
|
||||
class="ft-epochs"
|
||||
:disabled="ftRunning"
|
||||
/>
|
||||
</label>
|
||||
<button
|
||||
class="btn-run ft-run-btn"
|
||||
:disabled="ftRunning"
|
||||
@click="startFinetune"
|
||||
>
|
||||
{{ ftRunning ? '⏳ Training…' : '▶ Run fine-tune' }}
|
||||
</button>
|
||||
<button
|
||||
v-if="ftRunning"
|
||||
class="btn-cancel"
|
||||
@click="cancelFinetune"
|
||||
>
|
||||
✕ Cancel
|
||||
</button>
|
||||
</div>
|
||||
|
||||
<div v-if="ftRunning || ftLog.length || ftError" class="run-log ft-log">
|
||||
<div class="run-log-title">
|
||||
<span>{{ ftRunning ? '⏳ Training…' : ftCancelled ? '⏹ Cancelled' : ftError ? '❌ Failed' : '✅ Done' }}</span>
|
||||
<button class="btn-ghost" @click="ftLog = []; ftError = ''; ftCancelled = false">Clear</button>
|
||||
</div>
|
||||
<div class="log-lines" ref="ftLogEl">
|
||||
<div
|
||||
v-for="(line, i) in ftLog"
|
||||
:key="i"
|
||||
class="log-line"
|
||||
:class="{ 'log-error': line.startsWith('ERROR') || line.startsWith('[error]') }"
|
||||
>{{ line }}</div>
|
||||
</div>
|
||||
<p v-if="ftError" class="run-error">{{ ftError }}</p>
|
||||
</div>
|
||||
</div>
|
||||
</details>
|
||||
</div>
|
||||
</template>
|
||||
|
||||
<script setup lang="ts">
|
||||
import { ref, computed, onMounted, nextTick } from 'vue'
|
||||
import { useApiFetch, useApiSSE } from '../composables/useApi'
|
||||
|
||||
// ── Label metadata (same as StatsView) ──────────────────────────────────────
|
||||
const LABEL_META: Record<string, { emoji: string }> = {
|
||||
interview_scheduled: { emoji: '🗓️' },
|
||||
offer_received: { emoji: '🎉' },
|
||||
rejected: { emoji: '❌' },
|
||||
positive_response: { emoji: '👍' },
|
||||
survey_received: { emoji: '📋' },
|
||||
neutral: { emoji: '⬜' },
|
||||
event_rescheduled: { emoji: '🔄' },
|
||||
digest: { emoji: '📰' },
|
||||
new_lead: { emoji: '🤝' },
|
||||
hired: { emoji: '🎊' },
|
||||
}
|
||||
|
||||
// ── Types ────────────────────────────────────────────────────────────────────
|
||||
interface FineTunedModel {
|
||||
name: string
|
||||
base_model_id?: string
|
||||
val_macro_f1?: number
|
||||
timestamp?: string
|
||||
sample_count?: number
|
||||
}
|
||||
|
||||
interface PerLabel { f1: number; precision: number; recall: number; support: number }
|
||||
interface ModelResult {
|
||||
macro_f1: number
|
||||
accuracy: number
|
||||
latency_ms: number
|
||||
per_label: Record<string, PerLabel>
|
||||
}
|
||||
interface BenchResults {
|
||||
timestamp: string | null
|
||||
sample_count: number
|
||||
models: Record<string, ModelResult>
|
||||
}
|
||||
|
||||
// ── State ────────────────────────────────────────────────────────────────────
|
||||
const results = ref<BenchResults | null>(null)
|
||||
const loading = ref(true)
|
||||
const running = ref(false)
|
||||
const runLog = ref<string[]>([])
|
||||
const runError = ref('')
|
||||
const includeSlow = ref(false)
|
||||
const logEl = ref<HTMLElement | null>(null)
|
||||
|
||||
// Fine-tune state
|
||||
const fineTunedModels = ref<FineTunedModel[]>([])
|
||||
const ftModel = ref('deberta-small')
|
||||
const ftEpochs = ref(5)
|
||||
const ftRunning = ref(false)
|
||||
const ftLog = ref<string[]>([])
|
||||
const ftError = ref('')
|
||||
const ftLogEl = ref<HTMLElement | null>(null)
|
||||
|
||||
const runCancelled = ref(false)
|
||||
const ftCancelled = ref(false)
|
||||
|
||||
async function cancelBenchmark() {
|
||||
await fetch('/api/benchmark/cancel', { method: 'POST' }).catch(() => {})
|
||||
}
|
||||
|
||||
async function cancelFinetune() {
|
||||
await fetch('/api/finetune/cancel', { method: 'POST' }).catch(() => {})
|
||||
}
|
||||
|
||||
// ── Derived ──────────────────────────────────────────────────────────────────
|
||||
const modelNames = computed(() => Object.keys(results.value?.models ?? {}))
|
||||
const modelCount = computed(() => modelNames.value.length)
|
||||
|
||||
const labelNames = computed(() => {
|
||||
const canonical = Object.keys(LABEL_META)
|
||||
const inResults = new Set(
|
||||
modelNames.value.flatMap(n => Object.keys(results.value!.models[n].per_label))
|
||||
)
|
||||
return [...canonical.filter(l => inResults.has(l)), ...[...inResults].filter(l => !canonical.includes(l))]
|
||||
})
|
||||
|
||||
const f1Rows = computed(() => {
|
||||
if (!results.value) return []
|
||||
const rows = modelNames.value.map(name => ({
|
||||
name,
|
||||
value: results.value!.models[name].macro_f1,
|
||||
}))
|
||||
rows.sort((a, b) => b.value - a.value)
|
||||
const max = rows[0]?.value || 1
|
||||
return rows.map(r => ({ ...r, pct: Math.round((r.value / max) * 100) }))
|
||||
})
|
||||
|
||||
const latencyRows = computed(() => {
|
||||
if (!results.value) return []
|
||||
const rows = modelNames.value.map(name => ({
|
||||
name,
|
||||
value: results.value!.models[name].latency_ms,
|
||||
}))
|
||||
rows.sort((a, b) => a.value - b.value) // fastest first
|
||||
const max = rows[rows.length - 1]?.value || 1
|
||||
return rows.map(r => ({ ...r, pct: Math.round((r.value / max) * 100) }))
|
||||
})
|
||||
|
||||
// ── Helpers ──────────────────────────────────────────────────────────────────
|
||||
function f1For(model: string, label: string): number {
|
||||
return results.value?.models[model]?.per_label[label]?.f1 ?? 0
|
||||
}
|
||||
function supportFor(model: string, label: string): number {
|
||||
return results.value?.models[model]?.per_label[label]?.support ?? 0
|
||||
}
|
||||
|
||||
function scoreColor(v: number): string {
|
||||
if (v >= 0.7) return 'var(--color-success, #4CAF50)'
|
||||
if (v >= 0.4) return 'var(--app-accent, #B8622A)'
|
||||
return 'var(--color-error, #ef4444)'
|
||||
}
|
||||
|
||||
function heatmapBg(v: number): string {
|
||||
// Blend red→yellow→green using the F1 value
|
||||
if (v >= 0.7) return `color-mix(in srgb, #4CAF50 ${Math.round(v * 100)}%, #1a2338 ${Math.round((1 - v) * 80)}%)`
|
||||
if (v >= 0.4) return `color-mix(in srgb, #FF9800 ${Math.round(v * 120)}%, #1a2338 40%)`
|
||||
return `color-mix(in srgb, #ef4444 ${Math.round(v * 200 + 30)}%, #1a2338 60%)`
|
||||
}
|
||||
function heatmapFg(v: number): string {
|
||||
return v >= 0.5 ? '#fff' : 'rgba(255,255,255,0.75)'
|
||||
}
|
||||
|
||||
function formatDate(iso: string | null): string {
|
||||
if (!iso) return 'unknown date'
|
||||
const d = new Date(iso)
|
||||
return d.toLocaleString(undefined, { dateStyle: 'medium', timeStyle: 'short' })
|
||||
}
|
||||
|
||||
// ── Data loading ─────────────────────────────────────────────────────────────
|
||||
async function loadResults() {
|
||||
loading.value = true
|
||||
const { data } = await useApiFetch<BenchResults>('/api/benchmark/results')
|
||||
loading.value = false
|
||||
if (data && Object.keys(data.models).length > 0) {
|
||||
results.value = data
|
||||
}
|
||||
}
|
||||
|
||||
// ── Benchmark run ─────────────────────────────────────────────────────────────
|
||||
function startBenchmark() {
|
||||
running.value = true
|
||||
runLog.value = []
|
||||
runError.value = ''
|
||||
runCancelled.value = false
|
||||
|
||||
const url = `/api/benchmark/run${includeSlow.value ? '?include_slow=true' : ''}`
|
||||
useApiSSE(
|
||||
url,
|
||||
async (event) => {
|
||||
if (event.type === 'progress' && typeof event.message === 'string') {
|
||||
runLog.value.push(event.message)
|
||||
await nextTick()
|
||||
logEl.value?.scrollTo({ top: logEl.value.scrollHeight, behavior: 'smooth' })
|
||||
}
|
||||
if (event.type === 'error' && typeof event.message === 'string') {
|
||||
runError.value = event.message
|
||||
}
|
||||
if (event.type === 'cancelled') {
|
||||
running.value = false
|
||||
runCancelled.value = true
|
||||
}
|
||||
},
|
||||
async () => {
|
||||
running.value = false
|
||||
await loadResults()
|
||||
},
|
||||
() => {
|
||||
running.value = false
|
||||
if (!runError.value) runError.value = 'Connection lost'
|
||||
},
|
||||
)
|
||||
}
|
||||
|
||||
async function loadFineTunedModels() {
|
||||
const { data } = await useApiFetch<FineTunedModel[]>('/api/finetune/status')
|
||||
if (Array.isArray(data)) fineTunedModels.value = data
|
||||
}
|
||||
|
||||
function startFinetune() {
|
||||
if (ftRunning.value) return
|
||||
ftRunning.value = true
|
||||
ftLog.value = []
|
||||
ftError.value = ''
|
||||
ftCancelled.value = false
|
||||
|
||||
const params = new URLSearchParams({ model: ftModel.value, epochs: String(ftEpochs.value) })
|
||||
useApiSSE(
|
||||
`/api/finetune/run?${params}`,
|
||||
async (event) => {
|
||||
if (event.type === 'progress' && typeof event.message === 'string') {
|
||||
ftLog.value.push(event.message)
|
||||
await nextTick()
|
||||
ftLogEl.value?.scrollTo({ top: ftLogEl.value.scrollHeight, behavior: 'smooth' })
|
||||
}
|
||||
if (event.type === 'error' && typeof event.message === 'string') {
|
||||
ftError.value = event.message
|
||||
}
|
||||
if (event.type === 'cancelled') {
|
||||
ftRunning.value = false
|
||||
ftCancelled.value = true
|
||||
}
|
||||
},
|
||||
async () => {
|
||||
ftRunning.value = false
|
||||
await loadFineTunedModels()
|
||||
startBenchmark() // auto-trigger benchmark to refresh charts
|
||||
},
|
||||
() => {
|
||||
ftRunning.value = false
|
||||
if (!ftError.value) ftError.value = 'Connection lost'
|
||||
},
|
||||
)
|
||||
}
|
||||
|
||||
onMounted(() => {
|
||||
loadResults()
|
||||
loadFineTunedModels()
|
||||
})
|
||||
</script>
|
||||
|
||||
<style scoped>
|
||||
.bench-view {
|
||||
max-width: 860px;
|
||||
margin: 0 auto;
|
||||
padding: 1.5rem 1rem 4rem;
|
||||
display: flex;
|
||||
flex-direction: column;
|
||||
gap: 1.75rem;
|
||||
}
|
||||
|
||||
.bench-header {
|
||||
display: flex;
|
||||
align-items: center;
|
||||
justify-content: space-between;
|
||||
flex-wrap: wrap;
|
||||
gap: 0.75rem;
|
||||
}
|
||||
|
||||
.page-title {
|
||||
font-family: var(--font-display, var(--font-body, sans-serif));
|
||||
font-size: 1.4rem;
|
||||
font-weight: 700;
|
||||
color: var(--app-primary, #2A6080);
|
||||
margin: 0;
|
||||
}
|
||||
|
||||
.header-actions {
|
||||
display: flex;
|
||||
align-items: center;
|
||||
gap: 0.75rem;
|
||||
flex-wrap: wrap;
|
||||
}
|
||||
|
||||
.slow-toggle {
|
||||
display: flex;
|
||||
align-items: center;
|
||||
gap: 0.4rem;
|
||||
font-size: 0.85rem;
|
||||
color: var(--color-text-secondary, #6b7a99);
|
||||
cursor: pointer;
|
||||
user-select: none;
|
||||
}
|
||||
.slow-toggle.disabled { opacity: 0.5; pointer-events: none; }
|
||||
|
||||
.btn-run {
|
||||
padding: 0.45rem 1.1rem;
|
||||
border-radius: 0.375rem;
|
||||
border: none;
|
||||
background: var(--app-primary, #2A6080);
|
||||
color: #fff;
|
||||
font-size: 0.88rem;
|
||||
font-family: var(--font-body, sans-serif);
|
||||
cursor: pointer;
|
||||
transition: opacity 0.15s;
|
||||
}
|
||||
.btn-run:disabled { opacity: 0.5; cursor: not-allowed; }
|
||||
.btn-run:not(:disabled):hover { opacity: 0.85; }
|
||||
|
||||
.btn-cancel {
|
||||
padding: 0.45rem 0.9rem;
|
||||
background: transparent;
|
||||
border: 1px solid var(--color-text-secondary, #6b7a99);
|
||||
color: var(--color-text-secondary, #6b7a99);
|
||||
border-radius: 0.4rem;
|
||||
font-size: 0.85rem;
|
||||
font-weight: 500;
|
||||
cursor: pointer;
|
||||
transition: background 0.15s;
|
||||
}
|
||||
|
||||
.btn-cancel:hover {
|
||||
background: color-mix(in srgb, var(--color-text-secondary, #6b7a99) 12%, transparent);
|
||||
}
|
||||
|
||||
/* ── Run log ────────────────────────────────────────────── */
|
||||
.run-log {
|
||||
border: 1px solid var(--color-border, #d0d7e8);
|
||||
border-radius: 0.5rem;
|
||||
overflow: hidden;
|
||||
font-family: var(--font-mono, monospace);
|
||||
font-size: 0.78rem;
|
||||
}
|
||||
|
||||
.run-log-title {
|
||||
display: flex;
|
||||
justify-content: space-between;
|
||||
align-items: center;
|
||||
padding: 0.4rem 0.75rem;
|
||||
background: var(--color-surface-raised, #e4ebf5);
|
||||
border-bottom: 1px solid var(--color-border, #d0d7e8);
|
||||
font-size: 0.8rem;
|
||||
color: var(--color-text-secondary, #6b7a99);
|
||||
}
|
||||
|
||||
.btn-ghost {
|
||||
background: none;
|
||||
border: none;
|
||||
color: var(--color-text-secondary, #6b7a99);
|
||||
cursor: pointer;
|
||||
font-size: 0.78rem;
|
||||
padding: 0.1rem 0.3rem;
|
||||
border-radius: 0.2rem;
|
||||
}
|
||||
.btn-ghost:hover { background: var(--color-border, #d0d7e8); }
|
||||
|
||||
.log-lines {
|
||||
max-height: 200px;
|
||||
overflow-y: auto;
|
||||
padding: 0.5rem 0.75rem;
|
||||
background: var(--color-surface, #fff);
|
||||
display: flex;
|
||||
flex-direction: column;
|
||||
gap: 0.1rem;
|
||||
}
|
||||
|
||||
.log-line { color: var(--color-text, #1a2338); line-height: 1.5; }
|
||||
.log-line.log-error { color: var(--color-error, #ef4444); }
|
||||
|
||||
.run-error {
|
||||
margin: 0;
|
||||
padding: 0.4rem 0.75rem;
|
||||
background: color-mix(in srgb, var(--color-error, #ef4444) 10%, transparent);
|
||||
color: var(--color-error, #ef4444);
|
||||
font-size: 0.82rem;
|
||||
font-family: var(--font-mono, monospace);
|
||||
}
|
||||
|
||||
/* ── Status notices ─────────────────────────────────────── */
|
||||
.status-notice {
|
||||
color: var(--color-text-secondary, #6b7a99);
|
||||
font-size: 0.9rem;
|
||||
padding: 1rem;
|
||||
}
|
||||
.status-notice.empty {
|
||||
display: flex;
|
||||
flex-direction: column;
|
||||
align-items: center;
|
||||
gap: 0.5rem;
|
||||
padding: 3rem 1rem;
|
||||
text-align: center;
|
||||
}
|
||||
.hint { font-size: 0.85rem; opacity: 0.75; }
|
||||
|
||||
/* ── Meta line ──────────────────────────────────────────── */
|
||||
.meta-line {
|
||||
display: flex;
|
||||
gap: 0.5rem;
|
||||
align-items: center;
|
||||
font-size: 0.85rem;
|
||||
color: var(--color-text-secondary, #6b7a99);
|
||||
font-family: var(--font-mono, monospace);
|
||||
flex-wrap: wrap;
|
||||
}
|
||||
.sep { opacity: 0.4; }
|
||||
|
||||
/* ── Chart sections ─────────────────────────────────────── */
|
||||
.chart-section {
|
||||
display: flex;
|
||||
flex-direction: column;
|
||||
gap: 0.75rem;
|
||||
}
|
||||
|
||||
.chart-title {
|
||||
font-size: 0.95rem;
|
||||
font-weight: 600;
|
||||
color: var(--color-text, #1a2338);
|
||||
margin: 0;
|
||||
}
|
||||
|
||||
/* ── Bar charts ─────────────────────────────────────────── */
|
||||
.bar-chart {
|
||||
display: flex;
|
||||
flex-direction: column;
|
||||
gap: 0.4rem;
|
||||
}
|
||||
|
||||
.bar-row {
|
||||
display: grid;
|
||||
grid-template-columns: 14rem 1fr 5rem;
|
||||
align-items: center;
|
||||
gap: 0.5rem;
|
||||
font-size: 0.82rem;
|
||||
}
|
||||
|
||||
.bar-label {
|
||||
font-family: var(--font-mono, monospace);
|
||||
font-size: 0.76rem;
|
||||
white-space: nowrap;
|
||||
overflow: hidden;
|
||||
text-overflow: ellipsis;
|
||||
color: var(--color-text, #1a2338);
|
||||
}
|
||||
|
||||
.bar-track {
|
||||
height: 16px;
|
||||
background: var(--color-surface-raised, #e4ebf5);
|
||||
border-radius: 99px;
|
||||
overflow: hidden;
|
||||
}
|
||||
|
||||
.bar-fill {
|
||||
height: 100%;
|
||||
border-radius: 99px;
|
||||
transition: width 0.5s cubic-bezier(0.16, 1, 0.3, 1);
|
||||
}
|
||||
|
||||
.latency-fill { background: var(--app-primary, #2A6080); opacity: 0.65; }
|
||||
|
||||
.bar-value {
|
||||
text-align: right;
|
||||
font-family: var(--font-mono, monospace);
|
||||
font-size: 0.8rem;
|
||||
font-variant-numeric: tabular-nums;
|
||||
}
|
||||
|
||||
/* ── Heatmap ────────────────────────────────────────────── */
|
||||
.heatmap-scroll {
|
||||
overflow-x: auto;
|
||||
border-radius: 0.5rem;
|
||||
border: 1px solid var(--color-border, #d0d7e8);
|
||||
}
|
||||
|
||||
.heatmap {
|
||||
border-collapse: collapse;
|
||||
min-width: 100%;
|
||||
font-size: 0.78rem;
|
||||
}
|
||||
|
||||
.hm-label-col {
|
||||
text-align: left;
|
||||
min-width: 11rem;
|
||||
padding: 0.4rem 0.6rem;
|
||||
background: var(--color-surface-raised, #e4ebf5);
|
||||
font-weight: 600;
|
||||
border-bottom: 1px solid var(--color-border, #d0d7e8);
|
||||
position: sticky;
|
||||
left: 0;
|
||||
}
|
||||
|
||||
.hm-model-col {
|
||||
min-width: 5rem;
|
||||
max-width: 8rem;
|
||||
padding: 0.4rem 0.5rem;
|
||||
background: var(--color-surface-raised, #e4ebf5);
|
||||
border-bottom: 1px solid var(--color-border, #d0d7e8);
|
||||
font-family: var(--font-mono, monospace);
|
||||
font-size: 0.7rem;
|
||||
text-overflow: ellipsis;
|
||||
overflow: hidden;
|
||||
white-space: nowrap;
|
||||
text-align: center;
|
||||
}
|
||||
|
||||
.hm-label-cell {
|
||||
padding: 0.35rem 0.6rem;
|
||||
background: var(--color-surface, #fff);
|
||||
border-top: 1px solid var(--color-border, #d0d7e8);
|
||||
white-space: nowrap;
|
||||
font-family: var(--font-mono, monospace);
|
||||
font-size: 0.74rem;
|
||||
position: sticky;
|
||||
left: 0;
|
||||
}
|
||||
|
||||
.hm-emoji { margin-right: 0.3rem; }
|
||||
|
||||
.hm-value-cell {
|
||||
padding: 0.35rem 0.5rem;
|
||||
text-align: center;
|
||||
font-family: var(--font-mono, monospace);
|
||||
font-variant-numeric: tabular-nums;
|
||||
border-top: 1px solid rgba(255,255,255,0.08);
|
||||
cursor: default;
|
||||
transition: filter 0.15s;
|
||||
}
|
||||
.hm-value-cell:hover { filter: brightness(1.15); }
|
||||
|
||||
.heatmap-hint {
|
||||
font-size: 0.75rem;
|
||||
color: var(--color-text-secondary, #6b7a99);
|
||||
margin: 0;
|
||||
}
|
||||
|
||||
/* ── Mobile tweaks ──────────────────────────────────────── */
|
||||
@media (max-width: 600px) {
|
||||
.bar-row { grid-template-columns: 9rem 1fr 4rem; }
|
||||
.bar-label { font-size: 0.7rem; }
|
||||
.bench-header { flex-direction: column; align-items: flex-start; }
|
||||
}
|
||||
|
||||
/* ── Trained models badge row ──────────────────────────── */
|
||||
.trained-models-row {
|
||||
display: flex;
|
||||
flex-wrap: wrap;
|
||||
align-items: center;
|
||||
gap: 0.5rem;
|
||||
padding: 0.6rem 0.75rem;
|
||||
background: var(--color-surface-raised, #e4ebf5);
|
||||
border-radius: 0.5rem;
|
||||
border: 1px solid var(--color-border, #d0d7e8);
|
||||
}
|
||||
|
||||
.trained-label {
|
||||
font-size: 0.75rem;
|
||||
font-weight: 700;
|
||||
color: var(--color-text-secondary, #6b7a99);
|
||||
text-transform: uppercase;
|
||||
letter-spacing: 0.05em;
|
||||
flex-shrink: 0;
|
||||
}
|
||||
|
||||
.trained-badge {
|
||||
display: inline-flex;
|
||||
align-items: center;
|
||||
gap: 0.35rem;
|
||||
padding: 0.2rem 0.55rem;
|
||||
background: var(--app-primary, #2A6080);
|
||||
color: #fff;
|
||||
border-radius: 1rem;
|
||||
font-family: var(--font-mono, monospace);
|
||||
font-size: 0.76rem;
|
||||
cursor: default;
|
||||
}
|
||||
|
||||
.trained-f1 {
|
||||
background: rgba(255,255,255,0.2);
|
||||
border-radius: 0.75rem;
|
||||
padding: 0.05rem 0.35rem;
|
||||
font-size: 0.7rem;
|
||||
font-weight: 700;
|
||||
}
|
||||
|
||||
/* ── Fine-tune section ──────────────────────────────────── */
|
||||
.ft-section {
|
||||
border: 1px solid var(--color-border, #d0d7e8);
|
||||
border-radius: 0.5rem;
|
||||
overflow: hidden;
|
||||
}
|
||||
|
||||
.ft-summary {
|
||||
padding: 0.65rem 0.9rem;
|
||||
cursor: pointer;
|
||||
font-size: 0.9rem;
|
||||
font-weight: 600;
|
||||
color: var(--color-text, #1a2338);
|
||||
user-select: none;
|
||||
list-style: none;
|
||||
background: var(--color-surface-raised, #e4ebf5);
|
||||
}
|
||||
.ft-summary::-webkit-details-marker { display: none; }
|
||||
.ft-summary::before { content: '▶ '; font-size: 0.65rem; color: var(--color-text-secondary, #6b7a99); }
|
||||
details[open] .ft-summary::before { content: '▼ '; }
|
||||
|
||||
.ft-body {
|
||||
padding: 0.75rem;
|
||||
display: flex;
|
||||
flex-direction: column;
|
||||
gap: 0.75rem;
|
||||
border-top: 1px solid var(--color-border, #d0d7e8);
|
||||
}
|
||||
|
||||
.ft-controls {
|
||||
display: flex;
|
||||
flex-wrap: wrap;
|
||||
gap: 0.75rem;
|
||||
align-items: flex-end;
|
||||
}
|
||||
|
||||
.ft-field {
|
||||
display: flex;
|
||||
flex-direction: column;
|
||||
gap: 0.25rem;
|
||||
}
|
||||
|
||||
.ft-field-label {
|
||||
font-size: 0.75rem;
|
||||
font-weight: 600;
|
||||
color: var(--color-text-secondary, #6b7a99);
|
||||
text-transform: uppercase;
|
||||
letter-spacing: 0.04em;
|
||||
}
|
||||
|
||||
.ft-select {
|
||||
padding: 0.35rem 0.5rem;
|
||||
border: 1px solid var(--color-border, #d0d7e8);
|
||||
border-radius: 0.375rem;
|
||||
background: var(--color-surface, #fff);
|
||||
font-size: 0.85rem;
|
||||
color: var(--color-text, #1a2338);
|
||||
min-width: 220px;
|
||||
}
|
||||
.ft-select:disabled { opacity: 0.55; }
|
||||
|
||||
.ft-epochs {
|
||||
width: 64px;
|
||||
padding: 0.35rem 0.5rem;
|
||||
border: 1px solid var(--color-border, #d0d7e8);
|
||||
border-radius: 0.375rem;
|
||||
background: var(--color-surface, #fff);
|
||||
font-size: 0.85rem;
|
||||
color: var(--color-text, #1a2338);
|
||||
text-align: center;
|
||||
}
|
||||
.ft-epochs:disabled { opacity: 0.55; }
|
||||
|
||||
.ft-run-btn { align-self: flex-end; }
|
||||
|
||||
.ft-log { margin-top: 0; }
|
||||
|
||||
@media (max-width: 600px) {
|
||||
.ft-controls { flex-direction: column; align-items: stretch; }
|
||||
.ft-select { min-width: 0; width: 100%; }
|
||||
}
|
||||
</style>
|
||||
|
|
@ -8,12 +8,24 @@
|
|||
{{ store.totalRemaining }} remaining
|
||||
</template>
|
||||
<span v-else class="queue-status">Queue empty</span>
|
||||
<span v-if="onRoll" class="badge badge-roll">🔥 On a roll!</span>
|
||||
<span v-if="speedRound" class="badge badge-speed">⚡ Speed round!</span>
|
||||
<span v-if="fiftyDeep" class="badge badge-fifty">🎯 Fifty deep!</span>
|
||||
<span v-if="centuryMark" class="badge badge-century">💯 Century!</span>
|
||||
<span v-if="cleanSweep" class="badge badge-sweep">🧹 Clean sweep!</span>
|
||||
<span v-if="midnightLabeler" class="badge badge-midnight">🦉 Midnight labeler!</span>
|
||||
<Transition @enter="onBadgeEnter" :css="false">
|
||||
<span v-if="onRoll" class="badge badge-roll">🔥 On a roll!</span>
|
||||
</Transition>
|
||||
<Transition @enter="onBadgeEnter" :css="false">
|
||||
<span v-if="speedRound" class="badge badge-speed">⚡ Speed round!</span>
|
||||
</Transition>
|
||||
<Transition @enter="onBadgeEnter" :css="false">
|
||||
<span v-if="fiftyDeep" class="badge badge-fifty">🎯 Fifty deep!</span>
|
||||
</Transition>
|
||||
<Transition @enter="onBadgeEnter" :css="false">
|
||||
<span v-if="centuryMark" class="badge badge-century">💯 Century!</span>
|
||||
</Transition>
|
||||
<Transition @enter="onBadgeEnter" :css="false">
|
||||
<span v-if="cleanSweep" class="badge badge-sweep">🧹 Clean sweep!</span>
|
||||
</Transition>
|
||||
<Transition @enter="onBadgeEnter" :css="false">
|
||||
<span v-if="midnightLabeler" class="badge badge-midnight">🦉 Midnight labeler!</span>
|
||||
</Transition>
|
||||
</span>
|
||||
<div class="header-actions">
|
||||
<button @click="handleUndo" :disabled="!store.lastAction" class="btn-action">↩ Undo</button>
|
||||
|
|
@ -69,7 +81,7 @@
|
|||
/>
|
||||
</div>
|
||||
|
||||
<div class="bucket-grid-footer" :class="{ 'grid-active': isHeld }">
|
||||
<div ref="gridEl" class="bucket-grid-footer" :class="{ 'grid-active': isHeld }">
|
||||
<LabelBucketGrid
|
||||
:labels="labels"
|
||||
:is-bucket-mode="isHeld"
|
||||
|
|
@ -90,7 +102,8 @@
|
|||
</template>
|
||||
|
||||
<script setup lang="ts">
|
||||
import { ref, onMounted, onUnmounted } from 'vue'
|
||||
import { ref, watch, onMounted, onUnmounted } from 'vue'
|
||||
import { animate } from 'animejs'
|
||||
import { useLabelStore } from '../stores/label'
|
||||
import { useApiFetch } from '../composables/useApi'
|
||||
import { useHaptics } from '../composables/useHaptics'
|
||||
|
|
@ -105,6 +118,8 @@ const store = useLabelStore()
|
|||
const haptics = useHaptics()
|
||||
const motion = useMotion() // only needed to pass to child — actual value used in App.vue
|
||||
|
||||
const gridEl = ref<HTMLElement | null>(null)
|
||||
|
||||
const loading = ref(true)
|
||||
const apiError = ref(false)
|
||||
const isHeld = ref(false)
|
||||
|
|
@ -113,6 +128,22 @@ const hoveredBucket = ref<string | null>(null)
|
|||
const labels = ref<any[]>([])
|
||||
const dismissType = ref<'label' | 'skip' | 'discard' | null>(null)
|
||||
|
||||
watch(isHeld, (held) => {
|
||||
if (!motion.rich.value || !gridEl.value) return
|
||||
animate(gridEl.value,
|
||||
held
|
||||
? { y: -8, opacity: 0.45, ease: 'out(4)', duration: 380 }
|
||||
: { y: 0, opacity: 1, ease: 'out(4)', duration: 320 }
|
||||
)
|
||||
})
|
||||
|
||||
function onBadgeEnter(el: Element, done: () => void) {
|
||||
if (!motion.rich.value) { done(); return }
|
||||
animate(el as HTMLElement,
|
||||
{ scale: [0.6, 1], opacity: [0, 1], ease: spring({ mass: 1.5, stiffness: 80, damping: 8 }), duration: 300, onComplete: done }
|
||||
)
|
||||
}
|
||||
|
||||
// Easter egg state
|
||||
const consecutiveLabeled = ref(0)
|
||||
const recentLabels = ref<number[]>([])
|
||||
|
|
@ -314,8 +345,8 @@ onUnmounted(() => {
|
|||
padding: 1rem;
|
||||
max-width: 640px;
|
||||
margin: 0 auto;
|
||||
height: 100dvh; /* hard cap — prevents grid from drifting below fold */
|
||||
overflow: hidden;
|
||||
min-height: 100dvh;
|
||||
overflow-x: hidden; /* prevent card animations from causing horizontal scroll */
|
||||
}
|
||||
|
||||
.queue-status {
|
||||
|
|
@ -352,12 +383,6 @@ onUnmounted(() => {
|
|||
font-size: 0.75rem;
|
||||
font-weight: 700;
|
||||
font-family: var(--font-body, sans-serif);
|
||||
animation: badge-pop 0.3s cubic-bezier(0.34, 1.56, 0.64, 1);
|
||||
}
|
||||
|
||||
@keyframes badge-pop {
|
||||
from { transform: scale(0.6); opacity: 0; }
|
||||
to { transform: scale(1); opacity: 1; }
|
||||
}
|
||||
|
||||
.badge-roll { background: #ff6b35; color: #fff; }
|
||||
|
|
@ -424,13 +449,10 @@ onUnmounted(() => {
|
|||
|
||||
.card-stack-wrapper {
|
||||
flex: 1;
|
||||
min-height: 0; /* allow flex child to shrink — default auto prevents this */
|
||||
overflow-y: auto;
|
||||
min-height: 0;
|
||||
padding-bottom: 0.5rem;
|
||||
transition: opacity 200ms ease;
|
||||
}
|
||||
/* When held: escape the overflow clip so the ball floats freely,
|
||||
and rise above the footer (z-index 10) so the ball is visible. */
|
||||
/* When held: escape overflow clip so ball floats freely above the footer. */
|
||||
.card-stack-wrapper.is-held {
|
||||
overflow: visible;
|
||||
position: relative;
|
||||
|
|
@ -441,16 +463,17 @@ onUnmounted(() => {
|
|||
can be scrolled freely. "hired" (10th button) may clip on very small screens
|
||||
— that is intentional per design. */
|
||||
.bucket-grid-footer {
|
||||
position: sticky;
|
||||
bottom: 0;
|
||||
background: var(--color-bg, var(--color-surface, #f0f4fc));
|
||||
padding: 0.5rem 0 0.75rem;
|
||||
z-index: 10;
|
||||
transition: transform 250ms cubic-bezier(0.34, 1.56, 0.64, 1),
|
||||
opacity 200ms ease,
|
||||
background 200ms ease;
|
||||
}
|
||||
/* During toss: stay sticky so the grid holds its natural column position
|
||||
(fixed caused a horizontal jump on desktop due to sidebar offset).
|
||||
Opacity and translateY(-8px) are owned by Anime.js. */
|
||||
.bucket-grid-footer.grid-active {
|
||||
transform: translateY(-8px);
|
||||
opacity: 0.45; /* semi-transparent so ball aura is visible through it */
|
||||
opacity: 0.45;
|
||||
}
|
||||
|
||||
/* ── Toss edge zones ── */
|
||||
|
|
|
|||
Loading…
Reference in a new issue