Compare commits
No commits in common. "1fa5b9e2b03d1b8a77088d8ba8e4a522e4fac700" and "e03d91ece9c23f9ed3cf1bbfd007972f1c9d6de4" have entirely different histories.
1fa5b9e2b0
...
e03d91ece9
31 changed files with 145 additions and 5595 deletions
|
|
@ -1,7 +0,0 @@
|
||||||
# Privacy Policy
|
|
||||||
|
|
||||||
CircuitForge LLC's privacy policy applies to this product and is published at:
|
|
||||||
|
|
||||||
**<https://circuitforge.tech/privacy>**
|
|
||||||
|
|
||||||
Last reviewed: March 2026.
|
|
||||||
222
app/api.py
222
app/api.py
|
|
@ -7,8 +7,6 @@ from __future__ import annotations
|
||||||
|
|
||||||
import hashlib
|
import hashlib
|
||||||
import json
|
import json
|
||||||
import os
|
|
||||||
import subprocess as _subprocess
|
|
||||||
import yaml
|
import yaml
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
|
|
@ -19,14 +17,8 @@ from pydantic import BaseModel
|
||||||
|
|
||||||
_ROOT = Path(__file__).parent.parent
|
_ROOT = Path(__file__).parent.parent
|
||||||
_DATA_DIR: Path = _ROOT / "data" # overridable in tests via set_data_dir()
|
_DATA_DIR: Path = _ROOT / "data" # overridable in tests via set_data_dir()
|
||||||
_MODELS_DIR: Path = _ROOT / "models" # overridable in tests via set_models_dir()
|
|
||||||
_CONFIG_DIR: Path | None = None # None = use real path
|
_CONFIG_DIR: Path | None = None # None = use real path
|
||||||
|
|
||||||
# Process registry for running jobs — used by cancel endpoints.
|
|
||||||
# Keys: "benchmark" | "finetune". Values: the live Popen object.
|
|
||||||
_running_procs: dict = {}
|
|
||||||
_cancelled_jobs: set = set()
|
|
||||||
|
|
||||||
|
|
||||||
def set_data_dir(path: Path) -> None:
|
def set_data_dir(path: Path) -> None:
|
||||||
"""Override data directory — used by tests."""
|
"""Override data directory — used by tests."""
|
||||||
|
|
@ -34,40 +26,6 @@ def set_data_dir(path: Path) -> None:
|
||||||
_DATA_DIR = path
|
_DATA_DIR = path
|
||||||
|
|
||||||
|
|
||||||
def _best_cuda_device() -> str:
|
|
||||||
"""Return the index of the GPU with the most free VRAM as a string.
|
|
||||||
|
|
||||||
Uses nvidia-smi so it works in the job-seeker env (no torch). Returns ""
|
|
||||||
if nvidia-smi is unavailable or no GPUs are found. Restricting the
|
|
||||||
training subprocess to a single GPU via CUDA_VISIBLE_DEVICES prevents
|
|
||||||
PyTorch DataParallel from replicating the model across all GPUs, which
|
|
||||||
would OOM the GPU with less headroom.
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
out = _subprocess.check_output(
|
|
||||||
["nvidia-smi", "--query-gpu=index,memory.free",
|
|
||||||
"--format=csv,noheader,nounits"],
|
|
||||||
text=True,
|
|
||||||
timeout=5,
|
|
||||||
)
|
|
||||||
best_idx, best_free = "", 0
|
|
||||||
for line in out.strip().splitlines():
|
|
||||||
parts = line.strip().split(", ")
|
|
||||||
if len(parts) == 2:
|
|
||||||
idx, free = parts[0].strip(), int(parts[1].strip())
|
|
||||||
if free > best_free:
|
|
||||||
best_free, best_idx = free, idx
|
|
||||||
return best_idx
|
|
||||||
except Exception:
|
|
||||||
return ""
|
|
||||||
|
|
||||||
|
|
||||||
def set_models_dir(path: Path) -> None:
|
|
||||||
"""Override models directory — used by tests."""
|
|
||||||
global _MODELS_DIR
|
|
||||||
_MODELS_DIR = path
|
|
||||||
|
|
||||||
|
|
||||||
def set_config_dir(path: Path | None) -> None:
|
def set_config_dir(path: Path | None) -> None:
|
||||||
"""Override config directory — used by tests."""
|
"""Override config directory — used by tests."""
|
||||||
global _CONFIG_DIR
|
global _CONFIG_DIR
|
||||||
|
|
@ -329,186 +287,6 @@ def test_account(req: AccountTestRequest):
|
||||||
from fastapi.responses import StreamingResponse
|
from fastapi.responses import StreamingResponse
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
# Benchmark endpoints
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
|
|
||||||
@app.get("/api/benchmark/results")
|
|
||||||
def get_benchmark_results():
|
|
||||||
"""Return the most recently saved benchmark results, or an empty envelope."""
|
|
||||||
path = _DATA_DIR / "benchmark_results.json"
|
|
||||||
if not path.exists():
|
|
||||||
return {"models": {}, "sample_count": 0, "timestamp": None}
|
|
||||||
return json.loads(path.read_text())
|
|
||||||
|
|
||||||
|
|
||||||
@app.get("/api/benchmark/run")
|
|
||||||
def run_benchmark(include_slow: bool = False):
|
|
||||||
"""Spawn the benchmark script and stream stdout as SSE progress events."""
|
|
||||||
python_bin = "/devl/miniconda3/envs/job-seeker-classifiers/bin/python"
|
|
||||||
script = str(_ROOT / "scripts" / "benchmark_classifier.py")
|
|
||||||
cmd = [python_bin, script, "--score", "--save"]
|
|
||||||
if include_slow:
|
|
||||||
cmd.append("--include-slow")
|
|
||||||
|
|
||||||
def generate():
|
|
||||||
try:
|
|
||||||
proc = _subprocess.Popen(
|
|
||||||
cmd,
|
|
||||||
stdout=_subprocess.PIPE,
|
|
||||||
stderr=_subprocess.STDOUT,
|
|
||||||
text=True,
|
|
||||||
bufsize=1,
|
|
||||||
cwd=str(_ROOT),
|
|
||||||
)
|
|
||||||
_running_procs["benchmark"] = proc
|
|
||||||
_cancelled_jobs.discard("benchmark") # clear any stale flag from a prior run
|
|
||||||
try:
|
|
||||||
for line in proc.stdout:
|
|
||||||
line = line.rstrip()
|
|
||||||
if line:
|
|
||||||
yield f"data: {json.dumps({'type': 'progress', 'message': line})}\n\n"
|
|
||||||
proc.wait()
|
|
||||||
if proc.returncode == 0:
|
|
||||||
yield f"data: {json.dumps({'type': 'complete'})}\n\n"
|
|
||||||
elif "benchmark" in _cancelled_jobs:
|
|
||||||
_cancelled_jobs.discard("benchmark")
|
|
||||||
yield f"data: {json.dumps({'type': 'cancelled'})}\n\n"
|
|
||||||
else:
|
|
||||||
yield f"data: {json.dumps({'type': 'error', 'message': f'Process exited with code {proc.returncode}'})}\n\n"
|
|
||||||
finally:
|
|
||||||
_running_procs.pop("benchmark", None)
|
|
||||||
except Exception as exc:
|
|
||||||
yield f"data: {json.dumps({'type': 'error', 'message': str(exc)})}\n\n"
|
|
||||||
|
|
||||||
return StreamingResponse(
|
|
||||||
generate(),
|
|
||||||
media_type="text/event-stream",
|
|
||||||
headers={"Cache-Control": "no-cache", "X-Accel-Buffering": "no"},
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
# Finetune endpoints
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
|
|
||||||
@app.get("/api/finetune/status")
|
|
||||||
def get_finetune_status():
|
|
||||||
"""Scan models/ for training_info.json files. Returns [] if none exist."""
|
|
||||||
models_dir = _MODELS_DIR
|
|
||||||
if not models_dir.exists():
|
|
||||||
return []
|
|
||||||
results = []
|
|
||||||
for sub in models_dir.iterdir():
|
|
||||||
if not sub.is_dir():
|
|
||||||
continue
|
|
||||||
info_path = sub / "training_info.json"
|
|
||||||
if not info_path.exists():
|
|
||||||
continue
|
|
||||||
try:
|
|
||||||
info = json.loads(info_path.read_text(encoding="utf-8"))
|
|
||||||
results.append(info)
|
|
||||||
except Exception:
|
|
||||||
pass
|
|
||||||
return results
|
|
||||||
|
|
||||||
|
|
||||||
@app.get("/api/finetune/run")
|
|
||||||
def run_finetune_endpoint(
|
|
||||||
model: str = "deberta-small",
|
|
||||||
epochs: int = 5,
|
|
||||||
score: list[str] = Query(default=[]),
|
|
||||||
):
|
|
||||||
"""Spawn finetune_classifier.py and stream stdout as SSE progress events."""
|
|
||||||
python_bin = "/devl/miniconda3/envs/job-seeker-classifiers/bin/python"
|
|
||||||
script = str(_ROOT / "scripts" / "finetune_classifier.py")
|
|
||||||
cmd = [python_bin, script, "--model", model, "--epochs", str(epochs)]
|
|
||||||
data_root = _DATA_DIR.resolve()
|
|
||||||
for score_file in score:
|
|
||||||
resolved = (_DATA_DIR / score_file).resolve()
|
|
||||||
if not str(resolved).startswith(str(data_root)):
|
|
||||||
raise HTTPException(400, f"Invalid score path: {score_file!r}")
|
|
||||||
cmd.extend(["--score", str(resolved)])
|
|
||||||
|
|
||||||
# Pick the GPU with the most free VRAM. Setting CUDA_VISIBLE_DEVICES to a
|
|
||||||
# single device prevents DataParallel from replicating the model across all
|
|
||||||
# GPUs, which would force a full copy onto the more memory-constrained device.
|
|
||||||
proc_env = {**os.environ, "PYTORCH_ALLOC_CONF": "expandable_segments:True"}
|
|
||||||
best_gpu = _best_cuda_device()
|
|
||||||
if best_gpu:
|
|
||||||
proc_env["CUDA_VISIBLE_DEVICES"] = best_gpu
|
|
||||||
|
|
||||||
gpu_note = f"GPU {best_gpu}" if best_gpu else "CPU (no GPU found)"
|
|
||||||
|
|
||||||
def generate():
|
|
||||||
yield f"data: {json.dumps({'type': 'progress', 'message': f'[api] Using {gpu_note} (most free VRAM)'})}\n\n"
|
|
||||||
try:
|
|
||||||
proc = _subprocess.Popen(
|
|
||||||
cmd,
|
|
||||||
stdout=_subprocess.PIPE,
|
|
||||||
stderr=_subprocess.STDOUT,
|
|
||||||
text=True,
|
|
||||||
bufsize=1,
|
|
||||||
cwd=str(_ROOT),
|
|
||||||
env=proc_env,
|
|
||||||
)
|
|
||||||
_running_procs["finetune"] = proc
|
|
||||||
_cancelled_jobs.discard("finetune") # clear any stale flag from a prior run
|
|
||||||
try:
|
|
||||||
for line in proc.stdout:
|
|
||||||
line = line.rstrip()
|
|
||||||
if line:
|
|
||||||
yield f"data: {json.dumps({'type': 'progress', 'message': line})}\n\n"
|
|
||||||
proc.wait()
|
|
||||||
if proc.returncode == 0:
|
|
||||||
yield f"data: {json.dumps({'type': 'complete'})}\n\n"
|
|
||||||
elif "finetune" in _cancelled_jobs:
|
|
||||||
_cancelled_jobs.discard("finetune")
|
|
||||||
yield f"data: {json.dumps({'type': 'cancelled'})}\n\n"
|
|
||||||
else:
|
|
||||||
yield f"data: {json.dumps({'type': 'error', 'message': f'Process exited with code {proc.returncode}'})}\n\n"
|
|
||||||
finally:
|
|
||||||
_running_procs.pop("finetune", None)
|
|
||||||
except Exception as exc:
|
|
||||||
yield f"data: {json.dumps({'type': 'error', 'message': str(exc)})}\n\n"
|
|
||||||
|
|
||||||
return StreamingResponse(
|
|
||||||
generate(),
|
|
||||||
media_type="text/event-stream",
|
|
||||||
headers={"Cache-Control": "no-cache", "X-Accel-Buffering": "no"},
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@app.post("/api/benchmark/cancel")
|
|
||||||
def cancel_benchmark():
|
|
||||||
"""Kill the running benchmark subprocess. 404 if none is running."""
|
|
||||||
proc = _running_procs.get("benchmark")
|
|
||||||
if proc is None:
|
|
||||||
raise HTTPException(404, "No benchmark is running")
|
|
||||||
_cancelled_jobs.add("benchmark")
|
|
||||||
proc.terminate()
|
|
||||||
try:
|
|
||||||
proc.wait(timeout=3)
|
|
||||||
except _subprocess.TimeoutExpired:
|
|
||||||
proc.kill()
|
|
||||||
return {"status": "cancelled"}
|
|
||||||
|
|
||||||
|
|
||||||
@app.post("/api/finetune/cancel")
|
|
||||||
def cancel_finetune():
|
|
||||||
"""Kill the running fine-tune subprocess. 404 if none is running."""
|
|
||||||
proc = _running_procs.get("finetune")
|
|
||||||
if proc is None:
|
|
||||||
raise HTTPException(404, "No finetune is running")
|
|
||||||
_cancelled_jobs.add("finetune")
|
|
||||||
proc.terminate()
|
|
||||||
try:
|
|
||||||
proc.wait(timeout=3)
|
|
||||||
except _subprocess.TimeoutExpired:
|
|
||||||
proc.kill()
|
|
||||||
return {"status": "cancelled"}
|
|
||||||
|
|
||||||
|
|
||||||
@app.get("/api/fetch/stream")
|
@app.get("/api/fetch/stream")
|
||||||
def fetch_stream(
|
def fetch_stream(
|
||||||
accounts: str = Query(default=""),
|
accounts: str = Query(default=""),
|
||||||
|
|
|
||||||
|
|
@ -1,95 +0,0 @@
|
||||||
# Anime.js Animation Integration — Design
|
|
||||||
|
|
||||||
**Date:** 2026-03-08
|
|
||||||
**Status:** Approved
|
|
||||||
**Branch:** feat/vue-label-tab
|
|
||||||
|
|
||||||
## Problem
|
|
||||||
|
|
||||||
The current animation system mixes CSS keyframes, CSS transitions, and imperative inline-style bindings across three files. The seams between systems produce:
|
|
||||||
|
|
||||||
- Abrupt ball pickup (instant scale/borderRadius jump)
|
|
||||||
- No spring snap-back on release to no target
|
|
||||||
- Rigid CSS dismissals with no timing control
|
|
||||||
- Bucket grid and badge pop on basic `@keyframes`
|
|
||||||
|
|
||||||
## Decision
|
|
||||||
|
|
||||||
Integrate **Anime.js v4** as a single animation layer. Vue reactive state is unchanged; Anime.js owns all DOM motion imperatively.
|
|
||||||
|
|
||||||
## Architecture
|
|
||||||
|
|
||||||
One new composable, minimal changes to two existing files, CSS cleanup in two files.
|
|
||||||
|
|
||||||
```
|
|
||||||
web/src/composables/useCardAnimation.ts ← NEW
|
|
||||||
web/src/components/EmailCardStack.vue ← modify
|
|
||||||
web/src/views/LabelView.vue ← modify
|
|
||||||
```
|
|
||||||
|
|
||||||
**Data flow:**
|
|
||||||
```
|
|
||||||
pointer events → Vue refs (isHeld, deltaX, deltaY, dismissType)
|
|
||||||
↓ watched by
|
|
||||||
useCardAnimation(cardEl, stackEl, isHeld, ...)
|
|
||||||
↓ imperatively drives
|
|
||||||
Anime.js → DOM transforms
|
|
||||||
```
|
|
||||||
|
|
||||||
`useCardAnimation` is a pure side-effect composable — returns nothing to the template. The `cardStyle` computed in `EmailCardStack.vue` is removed; Anime.js owns the element's transform directly.
|
|
||||||
|
|
||||||
## Animation Surfaces
|
|
||||||
|
|
||||||
### Pickup morph
|
|
||||||
```
|
|
||||||
animate(cardEl, { scale: 0.55, borderRadius: '50%', y: -80 }, { duration: 200, ease: spring(1, 80, 10) })
|
|
||||||
```
|
|
||||||
Replaces the instant CSS transform jump on `onPointerDown`.
|
|
||||||
|
|
||||||
### Drag tracking
|
|
||||||
Raw `cardEl.style.translate` update on `onPointerMove` — no animation, just position. Easing only at boundaries (pickup / release), not during active drag.
|
|
||||||
|
|
||||||
### Snap-back
|
|
||||||
```
|
|
||||||
animate(cardEl, { x: 0, y: 0, scale: 1, borderRadius: '1rem' }, { ease: spring(1, 80, 10) })
|
|
||||||
```
|
|
||||||
Fires on `onPointerUp` when no zone/bucket target was hit.
|
|
||||||
|
|
||||||
### Dismissals (replace CSS `@keyframes`)
|
|
||||||
- **fileAway** — `animate(cardEl, { y: '-120%', scale: 0.85, opacity: 0 }, { duration: 280, ease: 'out(3)' })`
|
|
||||||
- **crumple** — 2-step timeline: shrink + redden → `scale(0)` + rotate
|
|
||||||
- **slideUnder** — `animate(cardEl, { x: '110%', rotate: 5, opacity: 0 }, { duration: 260 })`
|
|
||||||
|
|
||||||
### Bucket grid rise
|
|
||||||
`animate(gridEl, { y: -8, opacity: 0.45 })` on `isHeld` → true; reversed on false. Spring easing.
|
|
||||||
|
|
||||||
### Badge pop
|
|
||||||
`animate(badgeEl, { scale: [0.6, 1], opacity: [0, 1] }, { ease: spring(1.5, 80, 8), duration: 300 })` triggered on badge mount via Vue's `onMounted` lifecycle hook in a `BadgePop` wrapper component or `v-enter-active` transition hook.
|
|
||||||
|
|
||||||
## Constraints
|
|
||||||
|
|
||||||
### Reduced motion
|
|
||||||
`useCardAnimation` checks `motion.rich.value` before firing any Anime.js call. If false, all animations are skipped — instant state changes only. Consistent with existing `useMotion` pattern.
|
|
||||||
|
|
||||||
### Bundle size
|
|
||||||
Anime.js v4 core ~17KB gzipped. Only `animate`, `spring`, and `createTimeline` are imported — Vite ESM tree-shaking keeps footprint minimal. The `draggable` module is not used.
|
|
||||||
|
|
||||||
### Tests
|
|
||||||
Existing `EmailCardStack.test.ts` tests emit behavior, not animation — they remain passing. Anime.js mocked at module level in Vitest via `vi.mock('animejs')` where needed.
|
|
||||||
|
|
||||||
### CSS cleanup
|
|
||||||
Remove from `EmailCardStack.vue` and `LabelView.vue`:
|
|
||||||
- `@keyframes fileAway`, `crumple`, `slideUnder`
|
|
||||||
- `@keyframes badge-pop`
|
|
||||||
- `.dismiss-label`, `.dismiss-skip`, `.dismiss-discard` classes (Anime.js fires on element refs directly)
|
|
||||||
- The `dismissClass` computed in `EmailCardStack.vue`
|
|
||||||
|
|
||||||
## Files Changed
|
|
||||||
|
|
||||||
| File | Change |
|
|
||||||
|------|--------|
|
|
||||||
| `web/package.json` | Add `animejs` dependency |
|
|
||||||
| `web/src/composables/useCardAnimation.ts` | New — all Anime.js animation logic |
|
|
||||||
| `web/src/components/EmailCardStack.vue` | Remove `cardStyle` computed + dismiss classes; call `useCardAnimation` |
|
|
||||||
| `web/src/views/LabelView.vue` | Badge pop + bucket grid rise via Anime.js |
|
|
||||||
| `web/src/assets/avocet.css` | Remove any global animation keyframes if present |
|
|
||||||
|
|
@ -1,573 +0,0 @@
|
||||||
# Anime.js Animation Integration — Implementation Plan
|
|
||||||
|
|
||||||
> **For Claude:** REQUIRED SUB-SKILL: Use superpowers:executing-plans to implement this plan task-by-task.
|
|
||||||
|
|
||||||
**Goal:** Replace the current mixed CSS keyframes / inline-style animation system with Anime.js v4 for all card motion — pickup morph, drag tracking, spring snap-back, dismissals, bucket grid rise, and badge pop.
|
|
||||||
|
|
||||||
**Architecture:** A new `useCardAnimation` composable owns all Anime.js calls imperatively against DOM refs. Vue reactive state (`isHeld`, `deltaX`, `deltaY`, `dismissType`) is unchanged. `cardStyle` computed and `dismissClass` computed are deleted; Anime.js writes to the element directly.
|
|
||||||
|
|
||||||
**Tech Stack:** Anime.js v4 (`animejs`), Vue 3 Composition API, `@vue/test-utils` + Vitest for tests.
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
## Task 1: Install Anime.js
|
|
||||||
|
|
||||||
**Files:**
|
|
||||||
- Modify: `web/package.json`
|
|
||||||
|
|
||||||
**Step 1: Install the package**
|
|
||||||
|
|
||||||
```bash
|
|
||||||
cd /Library/Development/CircuitForge/avocet/web
|
|
||||||
npm install animejs
|
|
||||||
```
|
|
||||||
|
|
||||||
**Step 2: Verify the import resolves**
|
|
||||||
|
|
||||||
Create a throwaway check — open `web/src/main.ts` briefly and confirm:
|
|
||||||
```ts
|
|
||||||
import { animate, spring } from 'animejs'
|
|
||||||
```
|
|
||||||
resolves without error in the editor (TypeScript types ship with animejs v4).
|
|
||||||
Remove the import immediately after verifying — do not commit it.
|
|
||||||
|
|
||||||
**Step 3: Commit**
|
|
||||||
|
|
||||||
```bash
|
|
||||||
cd /Library/Development/CircuitForge/avocet/web
|
|
||||||
git add package.json package-lock.json
|
|
||||||
git commit -m "feat(avocet): add animejs v4 dependency"
|
|
||||||
```
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
## Task 2: Create `useCardAnimation` composable
|
|
||||||
|
|
||||||
**Files:**
|
|
||||||
- Create: `web/src/composables/useCardAnimation.ts`
|
|
||||||
- Create: `web/src/composables/useCardAnimation.test.ts`
|
|
||||||
|
|
||||||
**Background — Anime.js v4 transform model:**
|
|
||||||
Anime.js v4 tracks `x`, `y`, `scale`, `rotate`, etc. as separate transform components internally.
|
|
||||||
Use `utils.set(el, props)` for instant (no-animation) property updates — this keeps the internal cache consistent.
|
|
||||||
Never mix direct `el.style.transform = "..."` with Anime.js on the same element, or the cache desyncs.
|
|
||||||
|
|
||||||
**Step 1: Write the failing tests**
|
|
||||||
|
|
||||||
`web/src/composables/useCardAnimation.test.ts`:
|
|
||||||
```ts
|
|
||||||
import { ref, nextTick } from 'vue'
|
|
||||||
import { describe, it, expect, vi, beforeEach } from 'vitest'
|
|
||||||
|
|
||||||
// Mock animejs before importing the composable
|
|
||||||
vi.mock('animejs', () => ({
|
|
||||||
animate: vi.fn(),
|
|
||||||
spring: vi.fn(() => 'mock-spring'),
|
|
||||||
utils: { set: vi.fn() },
|
|
||||||
}))
|
|
||||||
|
|
||||||
import { useCardAnimation } from './useCardAnimation'
|
|
||||||
import { animate, utils } from 'animejs'
|
|
||||||
|
|
||||||
const mockAnimate = animate as ReturnType<typeof vi.fn>
|
|
||||||
const mockSet = utils.set as ReturnType<typeof vi.fn>
|
|
||||||
|
|
||||||
function makeEl() {
|
|
||||||
return document.createElement('div')
|
|
||||||
}
|
|
||||||
|
|
||||||
describe('useCardAnimation', () => {
|
|
||||||
beforeEach(() => {
|
|
||||||
vi.clearAllMocks()
|
|
||||||
})
|
|
||||||
|
|
||||||
it('pickup() calls animate with ball shape', () => {
|
|
||||||
const el = makeEl()
|
|
||||||
const cardEl = ref<HTMLElement | null>(el)
|
|
||||||
const motion = { rich: ref(true) }
|
|
||||||
const { pickup } = useCardAnimation(cardEl, motion)
|
|
||||||
pickup()
|
|
||||||
expect(mockAnimate).toHaveBeenCalledWith(
|
|
||||||
el,
|
|
||||||
expect.objectContaining({ scale: 0.55, borderRadius: '50%' }),
|
|
||||||
expect.anything(),
|
|
||||||
)
|
|
||||||
})
|
|
||||||
|
|
||||||
it('pickup() is a no-op when motion.rich is false', () => {
|
|
||||||
const el = makeEl()
|
|
||||||
const cardEl = ref<HTMLElement | null>(el)
|
|
||||||
const motion = { rich: ref(false) }
|
|
||||||
const { pickup } = useCardAnimation(cardEl, motion)
|
|
||||||
pickup()
|
|
||||||
expect(mockAnimate).not.toHaveBeenCalled()
|
|
||||||
})
|
|
||||||
|
|
||||||
it('setDragPosition() calls utils.set with translated coords', () => {
|
|
||||||
const el = makeEl()
|
|
||||||
const cardEl = ref<HTMLElement | null>(el)
|
|
||||||
const motion = { rich: ref(true) }
|
|
||||||
const { setDragPosition } = useCardAnimation(cardEl, motion)
|
|
||||||
setDragPosition(50, 30)
|
|
||||||
expect(mockSet).toHaveBeenCalledWith(el, expect.objectContaining({ x: 50, y: -50 }))
|
|
||||||
// y = deltaY - 80 = 30 - 80 = -50
|
|
||||||
})
|
|
||||||
|
|
||||||
it('snapBack() calls animate returning to card shape', () => {
|
|
||||||
const el = makeEl()
|
|
||||||
const cardEl = ref<HTMLElement | null>(el)
|
|
||||||
const motion = { rich: ref(true) }
|
|
||||||
const { snapBack } = useCardAnimation(cardEl, motion)
|
|
||||||
snapBack()
|
|
||||||
expect(mockAnimate).toHaveBeenCalledWith(
|
|
||||||
el,
|
|
||||||
expect.objectContaining({ x: 0, y: 0, scale: 1 }),
|
|
||||||
expect.anything(),
|
|
||||||
)
|
|
||||||
})
|
|
||||||
|
|
||||||
it('animateDismiss("label") calls animate', () => {
|
|
||||||
const el = makeEl()
|
|
||||||
const cardEl = ref<HTMLElement | null>(el)
|
|
||||||
const motion = { rich: ref(true) }
|
|
||||||
const { animateDismiss } = useCardAnimation(cardEl, motion)
|
|
||||||
animateDismiss('label')
|
|
||||||
expect(mockAnimate).toHaveBeenCalled()
|
|
||||||
})
|
|
||||||
|
|
||||||
it('animateDismiss("discard") calls animate', () => {
|
|
||||||
const el = makeEl()
|
|
||||||
const cardEl = ref<HTMLElement | null>(el)
|
|
||||||
const motion = { rich: ref(true) }
|
|
||||||
const { animateDismiss } = useCardAnimation(cardEl, motion)
|
|
||||||
animateDismiss('discard')
|
|
||||||
expect(mockAnimate).toHaveBeenCalled()
|
|
||||||
})
|
|
||||||
|
|
||||||
it('animateDismiss("skip") calls animate', () => {
|
|
||||||
const el = makeEl()
|
|
||||||
const cardEl = ref<HTMLElement | null>(el)
|
|
||||||
const motion = { rich: ref(true) }
|
|
||||||
const { animateDismiss } = useCardAnimation(cardEl, motion)
|
|
||||||
animateDismiss('skip')
|
|
||||||
expect(mockAnimate).toHaveBeenCalled()
|
|
||||||
})
|
|
||||||
|
|
||||||
it('animateDismiss is a no-op when motion.rich is false', () => {
|
|
||||||
const el = makeEl()
|
|
||||||
const cardEl = ref<HTMLElement | null>(el)
|
|
||||||
const motion = { rich: ref(false) }
|
|
||||||
const { animateDismiss } = useCardAnimation(cardEl, motion)
|
|
||||||
animateDismiss('label')
|
|
||||||
expect(mockAnimate).not.toHaveBeenCalled()
|
|
||||||
})
|
|
||||||
})
|
|
||||||
```
|
|
||||||
|
|
||||||
**Step 2: Run tests to confirm they fail**
|
|
||||||
|
|
||||||
```bash
|
|
||||||
cd /Library/Development/CircuitForge/avocet/web
|
|
||||||
npm test -- useCardAnimation
|
|
||||||
```
|
|
||||||
|
|
||||||
Expected: FAIL — "Cannot find module './useCardAnimation'"
|
|
||||||
|
|
||||||
**Step 3: Implement the composable**
|
|
||||||
|
|
||||||
`web/src/composables/useCardAnimation.ts`:
|
|
||||||
```ts
|
|
||||||
import { type Ref } from 'vue'
|
|
||||||
import { animate, spring, utils } from 'animejs'
|
|
||||||
|
|
||||||
const BALL_SCALE = 0.55
|
|
||||||
const BALL_RADIUS = '50%'
|
|
||||||
const CARD_RADIUS = '1rem'
|
|
||||||
const PICKUP_Y_OFFSET = 80 // px above finger
|
|
||||||
const PICKUP_DURATION = 200
|
|
||||||
// NOTE: animejs v4 — spring() takes an object, not positional args
|
|
||||||
const SNAP_SPRING = spring({ mass: 1, stiffness: 80, damping: 10 })
|
|
||||||
|
|
||||||
interface Motion { rich: Ref<boolean> }
|
|
||||||
|
|
||||||
export function useCardAnimation(
|
|
||||||
cardEl: Ref<HTMLElement | null>,
|
|
||||||
motion: Motion,
|
|
||||||
) {
|
|
||||||
function pickup() {
|
|
||||||
if (!motion.rich.value || !cardEl.value) return
|
|
||||||
// NOTE: animejs v4 — animate() is 2-arg; timing options merge into the params object
|
|
||||||
animate(cardEl.value, {
|
|
||||||
scale: BALL_SCALE,
|
|
||||||
borderRadius: BALL_RADIUS,
|
|
||||||
y: -PICKUP_Y_OFFSET,
|
|
||||||
duration: PICKUP_DURATION,
|
|
||||||
ease: SNAP_SPRING,
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
function setDragPosition(dx: number, dy: number) {
|
|
||||||
if (!cardEl.value) return
|
|
||||||
utils.set(cardEl.value, { x: dx, y: dy - PICKUP_Y_OFFSET })
|
|
||||||
}
|
|
||||||
|
|
||||||
function snapBack() {
|
|
||||||
if (!motion.rich.value || !cardEl.value) return
|
|
||||||
// No duration — spring physics determines settling time
|
|
||||||
animate(cardEl.value, {
|
|
||||||
x: 0,
|
|
||||||
y: 0,
|
|
||||||
scale: 1,
|
|
||||||
borderRadius: CARD_RADIUS,
|
|
||||||
ease: SNAP_SPRING,
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
function animateDismiss(type: 'label' | 'skip' | 'discard') {
|
|
||||||
if (!motion.rich.value || !cardEl.value) return
|
|
||||||
const el = cardEl.value
|
|
||||||
if (type === 'label') {
|
|
||||||
animate(el, { y: '-120%', scale: 0.85, opacity: 0, duration: 280, ease: 'out(3)' })
|
|
||||||
} else if (type === 'discard') {
|
|
||||||
// Two-step: crumple then shrink (keyframes array in params object)
|
|
||||||
animate(el, { keyframes: [
|
|
||||||
{ scale: 0.95, rotate: 2, filter: 'brightness(0.6) sepia(1) hue-rotate(-20deg)', duration: 140 },
|
|
||||||
{ scale: 0, rotate: 8, opacity: 0, duration: 210 },
|
|
||||||
])
|
|
||||||
} else if (type === 'skip') {
|
|
||||||
animate(el, { x: '110%', rotate: 5, opacity: 0 }, { duration: 260, ease: 'out(2)' })
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return { pickup, setDragPosition, snapBack, animateDismiss }
|
|
||||||
}
|
|
||||||
```
|
|
||||||
|
|
||||||
**Step 4: Run tests — expect pass**
|
|
||||||
|
|
||||||
```bash
|
|
||||||
npm test -- useCardAnimation
|
|
||||||
```
|
|
||||||
|
|
||||||
Expected: All 8 tests PASS.
|
|
||||||
|
|
||||||
**Step 5: Commit**
|
|
||||||
|
|
||||||
```bash
|
|
||||||
git add web/src/composables/useCardAnimation.ts web/src/composables/useCardAnimation.test.ts
|
|
||||||
git commit -m "feat(avocet): add useCardAnimation composable with Anime.js"
|
|
||||||
```
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
## Task 3: Wire `useCardAnimation` into `EmailCardStack.vue`
|
|
||||||
|
|
||||||
**Files:**
|
|
||||||
- Modify: `web/src/components/EmailCardStack.vue`
|
|
||||||
- Modify: `web/src/components/EmailCardStack.test.ts`
|
|
||||||
|
|
||||||
**What changes:**
|
|
||||||
- Remove `cardStyle` computed and `:style="cardStyle"` binding
|
|
||||||
- Remove `dismissClass` computed and `:class="[dismissClass, ...]"` binding (keep `is-held`)
|
|
||||||
- Remove `deltaX`, `deltaY` reactive refs (position now owned by Anime.js)
|
|
||||||
- Call `pickup()` in `onPointerDown`, `setDragPosition()` in `onPointerMove`, `snapBack()` in `onPointerUp` (no-target path)
|
|
||||||
- Watch `props.dismissType` and call `animateDismiss()`
|
|
||||||
- Remove CSS `@keyframes fileAway`, `crumple`, `slideUnder` and their `.dismiss-*` rule blocks from `<style>`
|
|
||||||
|
|
||||||
**Step 1: Update the tests that check dismiss classes**
|
|
||||||
|
|
||||||
In `EmailCardStack.test.ts`, the 5 tests checking `.dismiss-label`, `.dismiss-discard`, `.dismiss-skip` classes are testing implementation (CSS class name), not behavior. Replace them with a single test that verifies `animateDismiss` is called:
|
|
||||||
|
|
||||||
```ts
|
|
||||||
// Add at the top of the file (after existing imports):
|
|
||||||
vi.mock('../composables/useCardAnimation', () => ({
|
|
||||||
useCardAnimation: vi.fn(() => ({
|
|
||||||
pickup: vi.fn(),
|
|
||||||
setDragPosition: vi.fn(),
|
|
||||||
snapBack: vi.fn(),
|
|
||||||
animateDismiss: vi.fn(),
|
|
||||||
})),
|
|
||||||
}))
|
|
||||||
|
|
||||||
import { useCardAnimation } from '../composables/useCardAnimation'
|
|
||||||
```
|
|
||||||
|
|
||||||
Replace the five `dismissType` class tests (lines 25–46) with:
|
|
||||||
|
|
||||||
```ts
|
|
||||||
it('calls animateDismiss with type when dismissType prop changes', async () => {
|
|
||||||
const w = mount(EmailCardStack, { props: { item, isBucketMode: false, dismissType: null } })
|
|
||||||
const { animateDismiss } = (useCardAnimation as ReturnType<typeof vi.fn>).mock.results[0].value
|
|
||||||
await w.setProps({ dismissType: 'label' })
|
|
||||||
await nextTick()
|
|
||||||
expect(animateDismiss).toHaveBeenCalledWith('label')
|
|
||||||
})
|
|
||||||
```
|
|
||||||
|
|
||||||
Add `nextTick` import to the test file header if not already present:
|
|
||||||
```ts
|
|
||||||
import { nextTick } from 'vue'
|
|
||||||
```
|
|
||||||
|
|
||||||
**Step 2: Run tests to confirm the replaced tests fail**
|
|
||||||
|
|
||||||
```bash
|
|
||||||
npm test -- EmailCardStack
|
|
||||||
```
|
|
||||||
|
|
||||||
Expected: FAIL — `animateDismiss` not called (not yet wired in component)
|
|
||||||
|
|
||||||
**Step 3: Modify `EmailCardStack.vue`**
|
|
||||||
|
|
||||||
Script section changes:
|
|
||||||
|
|
||||||
```ts
|
|
||||||
// Remove:
|
|
||||||
// import { ref, computed } from 'vue' → change to:
|
|
||||||
import { ref, watch } from 'vue'
|
|
||||||
|
|
||||||
// Add import:
|
|
||||||
import { useCardAnimation } from '../composables/useCardAnimation'
|
|
||||||
|
|
||||||
// Remove these refs:
|
|
||||||
// const deltaX = ref(0)
|
|
||||||
// const deltaY = ref(0)
|
|
||||||
|
|
||||||
// Add after const motion = useMotion():
|
|
||||||
const { pickup, setDragPosition, snapBack, animateDismiss } = useCardAnimation(cardEl, motion)
|
|
||||||
|
|
||||||
// Add watcher:
|
|
||||||
watch(() => props.dismissType, (type) => {
|
|
||||||
if (type) animateDismiss(type)
|
|
||||||
})
|
|
||||||
|
|
||||||
// Remove dismissClass computed entirely.
|
|
||||||
|
|
||||||
// In onPointerDown — add after isHeld.value = true:
|
|
||||||
pickup()
|
|
||||||
|
|
||||||
// In onPointerMove — replace deltaX/deltaY assignments with:
|
|
||||||
const dx = e.clientX - pickupX.value
|
|
||||||
const dy = e.clientY - pickupY.value
|
|
||||||
setDragPosition(dx, dy)
|
|
||||||
// (keep the zone/bucket detection that uses e.clientX/e.clientY — those stay the same)
|
|
||||||
|
|
||||||
// In onPointerUp — in the snap-back else branch, replace:
|
|
||||||
// deltaX.value = 0
|
|
||||||
// deltaY.value = 0
|
|
||||||
// with:
|
|
||||||
snapBack()
|
|
||||||
```
|
|
||||||
|
|
||||||
Template changes — on the `.card-wrapper` div:
|
|
||||||
```html
|
|
||||||
<!-- Remove: :class="[dismissClass, { 'is-held': isHeld }]" -->
|
|
||||||
<!-- Replace with: -->
|
|
||||||
:class="{ 'is-held': isHeld }"
|
|
||||||
<!-- Remove: :style="cardStyle" -->
|
|
||||||
```
|
|
||||||
|
|
||||||
CSS changes in `<style scoped>` — delete these entire blocks:
|
|
||||||
```
|
|
||||||
@keyframes fileAway { ... }
|
|
||||||
@keyframes crumple { ... }
|
|
||||||
@keyframes slideUnder { ... }
|
|
||||||
.card-wrapper.dismiss-label { ... }
|
|
||||||
.card-wrapper.dismiss-discard { ... }
|
|
||||||
.card-wrapper.dismiss-skip { ... }
|
|
||||||
```
|
|
||||||
|
|
||||||
Also delete `--card-dismiss` and `--card-skip` CSS var usages if present.
|
|
||||||
|
|
||||||
**Step 4: Run all tests**
|
|
||||||
|
|
||||||
```bash
|
|
||||||
npm test
|
|
||||||
```
|
|
||||||
|
|
||||||
Expected: All pass (both `useCardAnimation.test.ts` and `EmailCardStack.test.ts`).
|
|
||||||
|
|
||||||
**Step 5: Commit**
|
|
||||||
|
|
||||||
```bash
|
|
||||||
git add web/src/components/EmailCardStack.vue web/src/components/EmailCardStack.test.ts
|
|
||||||
git commit -m "feat(avocet): wire Anime.js card animation into EmailCardStack"
|
|
||||||
```
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
## Task 4: Bucket grid rise animation
|
|
||||||
|
|
||||||
**Files:**
|
|
||||||
- Modify: `web/src/views/LabelView.vue`
|
|
||||||
|
|
||||||
**What changes:**
|
|
||||||
Replace the CSS class-toggle animation on `.bucket-grid-footer.grid-active` with an Anime.js watch in `LabelView.vue`. The `position: sticky → fixed` switch stays as a CSS class (can't animate position), but `translateY` and `opacity` move to Anime.js.
|
|
||||||
|
|
||||||
**Step 1: Add gridEl ref and import animate**
|
|
||||||
|
|
||||||
In `LabelView.vue` `<script setup>`:
|
|
||||||
```ts
|
|
||||||
// Add to imports:
|
|
||||||
import { ref, onMounted, onUnmounted, watch } from 'vue'
|
|
||||||
import { animate, spring } from 'animejs'
|
|
||||||
|
|
||||||
// Add ref:
|
|
||||||
const gridEl = ref<HTMLElement | null>(null)
|
|
||||||
```
|
|
||||||
|
|
||||||
**Step 2: Add watcher for isHeld**
|
|
||||||
|
|
||||||
```ts
|
|
||||||
watch(isHeld, (held) => {
|
|
||||||
if (!motion.rich.value || !gridEl.value) return
|
|
||||||
// animejs v4: 2-arg animate, spring() takes object
|
|
||||||
animate(gridEl.value,
|
|
||||||
held
|
|
||||||
? { y: -8, opacity: 0.45, ease: spring({ mass: 1, stiffness: 80, damping: 10 }), duration: 250 }
|
|
||||||
: { y: 0, opacity: 1, ease: spring({ mass: 1, stiffness: 80, damping: 10 }), duration: 250 }
|
|
||||||
)
|
|
||||||
})
|
|
||||||
```
|
|
||||||
|
|
||||||
**Step 3: Wire ref in template**
|
|
||||||
|
|
||||||
On the `.bucket-grid-footer` div:
|
|
||||||
```html
|
|
||||||
<div ref="gridEl" class="bucket-grid-footer" :class="{ 'grid-active': isHeld }">
|
|
||||||
```
|
|
||||||
|
|
||||||
**Step 4: Remove CSS transition from `.bucket-grid-footer`**
|
|
||||||
|
|
||||||
In `LabelView.vue <style scoped>`, delete the `transition:` line from `.bucket-grid-footer`:
|
|
||||||
```css
|
|
||||||
/* DELETE this line: */
|
|
||||||
transition: transform 250ms cubic-bezier(0.34, 1.56, 0.64, 1),
|
|
||||||
opacity 200ms ease,
|
|
||||||
background 200ms ease;
|
|
||||||
```
|
|
||||||
Keep the `transform: translateY(-8px)` and `opacity: 0.45` on `.bucket-grid-footer.grid-active` as fallback for reduced-motion users (no-JS fallback too).
|
|
||||||
|
|
||||||
Actually — keep `.grid-active` rules as-is for the no-motion path. The Anime.js `watch` guard (`if (!motion.rich.value)`) means reduced-motion users never hit Anime.js; the CSS class handles them.
|
|
||||||
|
|
||||||
**Step 5: Run tests**
|
|
||||||
|
|
||||||
```bash
|
|
||||||
npm test
|
|
||||||
```
|
|
||||||
|
|
||||||
Expected: All pass (LabelView has no dedicated tests, but full suite should be green).
|
|
||||||
|
|
||||||
**Step 6: Commit**
|
|
||||||
|
|
||||||
```bash
|
|
||||||
git add web/src/views/LabelView.vue
|
|
||||||
git commit -m "feat(avocet): animate bucket grid rise with Anime.js spring"
|
|
||||||
```
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
## Task 5: Badge pop animation
|
|
||||||
|
|
||||||
**Files:**
|
|
||||||
- Modify: `web/src/views/LabelView.vue`
|
|
||||||
|
|
||||||
**What changes:**
|
|
||||||
Replace `@keyframes badge-pop` (scale + opacity keyframe) with a Vue `<Transition>` `@enter` hook that calls `animate()`. Badges already appear/disappear via `v-if`, so they have natural mount/unmount lifecycle.
|
|
||||||
|
|
||||||
**Step 1: Wrap each badge in a `<Transition>`**
|
|
||||||
|
|
||||||
In `LabelView.vue` template, each badge `<span v-if="...">` gets wrapped:
|
|
||||||
|
|
||||||
```html
|
|
||||||
<Transition @enter="onBadgeEnter" :css="false">
|
|
||||||
<span v-if="onRoll" class="badge badge-roll">🔥 On a roll!</span>
|
|
||||||
</Transition>
|
|
||||||
<Transition @enter="onBadgeEnter" :css="false">
|
|
||||||
<span v-if="speedRound" class="badge badge-speed">⚡ Speed round!</span>
|
|
||||||
</Transition>
|
|
||||||
<!-- repeat for all 6 badges -->
|
|
||||||
```
|
|
||||||
|
|
||||||
`:css="false"` tells Vue not to apply any CSS transition classes — Anime.js owns the enter animation entirely.
|
|
||||||
|
|
||||||
**Step 2: Add `onBadgeEnter` hook**
|
|
||||||
|
|
||||||
```ts
|
|
||||||
function onBadgeEnter(el: Element, done: () => void) {
|
|
||||||
if (!motion.rich.value) { done(); return }
|
|
||||||
animate(el as HTMLElement,
|
|
||||||
{ scale: [0.6, 1], opacity: [0, 1] },
|
|
||||||
{ ease: spring(1.5, 80, 8), duration: 300, onComplete: done }
|
|
||||||
)
|
|
||||||
}
|
|
||||||
```
|
|
||||||
|
|
||||||
**Step 3: Remove `@keyframes badge-pop` from CSS**
|
|
||||||
|
|
||||||
In `LabelView.vue <style scoped>`:
|
|
||||||
```css
|
|
||||||
/* DELETE: */
|
|
||||||
@keyframes badge-pop {
|
|
||||||
from { transform: scale(0.6); opacity: 0; }
|
|
||||||
to { transform: scale(1); opacity: 1; }
|
|
||||||
}
|
|
||||||
|
|
||||||
/* DELETE animation line from .badge: */
|
|
||||||
animation: badge-pop 0.3s cubic-bezier(0.34, 1.56, 0.64, 1);
|
|
||||||
```
|
|
||||||
|
|
||||||
**Step 4: Run tests**
|
|
||||||
|
|
||||||
```bash
|
|
||||||
npm test
|
|
||||||
```
|
|
||||||
|
|
||||||
Expected: All pass.
|
|
||||||
|
|
||||||
**Step 5: Commit**
|
|
||||||
|
|
||||||
```bash
|
|
||||||
git add web/src/views/LabelView.vue
|
|
||||||
git commit -m "feat(avocet): badge pop via Anime.js spring transition hook"
|
|
||||||
```
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
## Task 6: Build and smoke test
|
|
||||||
|
|
||||||
**Step 1: Build the SPA**
|
|
||||||
|
|
||||||
```bash
|
|
||||||
cd /Library/Development/CircuitForge/avocet
|
|
||||||
./manage.sh start-api
|
|
||||||
```
|
|
||||||
|
|
||||||
(This builds Vue + starts FastAPI on port 8503.)
|
|
||||||
|
|
||||||
**Step 2: Open the app**
|
|
||||||
|
|
||||||
```bash
|
|
||||||
./manage.sh open-api
|
|
||||||
```
|
|
||||||
|
|
||||||
**Step 3: Manual smoke test checklist**
|
|
||||||
|
|
||||||
- [ ] Pick up a card — ball morph is smooth (not instant jump)
|
|
||||||
- [ ] Drag ball around — follows finger with no lag
|
|
||||||
- [ ] Release in center — springs back to card with bounce
|
|
||||||
- [ ] Release in left zone — discard fires (card crumples)
|
|
||||||
- [ ] Release in right zone — skip fires (card slides right)
|
|
||||||
- [ ] Release on a bucket — label fires (card files up)
|
|
||||||
- [ ] Fling left fast — discard fires
|
|
||||||
- [ ] Bucket grid rises smoothly on pickup, falls on release
|
|
||||||
- [ ] Badge (label 10 in a row for 🔥) pops in with spring
|
|
||||||
- [ ] Reduced motion: toggle in system settings → no animations, instant behavior
|
|
||||||
- [ ] Keyboard labels (1–9) still work (pointer events unchanged)
|
|
||||||
|
|
||||||
**Step 4: Final commit if all green**
|
|
||||||
|
|
||||||
```bash
|
|
||||||
git add -A
|
|
||||||
git commit -m "feat(avocet): complete Anime.js animation integration"
|
|
||||||
```
|
|
||||||
File diff suppressed because it is too large
Load diff
|
|
@ -1,254 +0,0 @@
|
||||||
# Fine-tune Email Classifier — Design Spec
|
|
||||||
|
|
||||||
**Date:** 2026-03-15
|
|
||||||
**Status:** Approved
|
|
||||||
**Scope:** Avocet — `scripts/`, `app/api.py`, `web/src/views/BenchmarkView.vue`, `environment.yml`
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
## Problem
|
|
||||||
|
|
||||||
The benchmark baseline shows zero-shot macro-F1 of 0.366 for the best models (`deberta-zeroshot`, `deberta-base-anli`). Zero-shot inference cannot improve with more labeled data. Fine-tuning the fastest models (`deberta-small` at 111ms, `bge-m3` at 123ms) on the growing labeled dataset is the path to meaningful accuracy gains.
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
## Constraints
|
|
||||||
|
|
||||||
- 501 labeled samples after dropping 2 non-canonical `profile_alert` rows
|
|
||||||
- Heavy class imbalance: `digest` 29%, `neutral` 26%, `new_lead` 2.6%, `survey_received` 3%
|
|
||||||
- 8.2 GB VRAM (shared with Peregrine vLLM during dev)
|
|
||||||
- Target models: `cross-encoder/nli-deberta-v3-small` (100M params), `MoritzLaurer/bge-m3-zeroshot-v2.0` (600M params)
|
|
||||||
- Output: local `models/avocet-{name}/` directory
|
|
||||||
- UI-triggerable via web interface (SSE streaming log)
|
|
||||||
- Stack: transformers 4.57.3, torch 2.10.0, accelerate 1.12.0, sklearn, CUDA 8.2GB
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
## Environment changes
|
|
||||||
|
|
||||||
`environment.yml` must add:
|
|
||||||
- `scikit-learn` — required for `train_test_split(stratify=...)` and `f1_score`
|
|
||||||
- `peft` is NOT used by this spec; it is available in the env but not required here
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
## Architecture
|
|
||||||
|
|
||||||
### New file: `scripts/finetune_classifier.py`
|
|
||||||
|
|
||||||
CLI entry point for fine-tuning. All prints use `flush=True` so stdout is SSE-streamable.
|
|
||||||
|
|
||||||
```
|
|
||||||
python scripts/finetune_classifier.py --model deberta-small [--epochs 5]
|
|
||||||
```
|
|
||||||
|
|
||||||
Supported `--model` values: `deberta-small`, `bge-m3`
|
|
||||||
|
|
||||||
**Model registry** (internal to this script):
|
|
||||||
|
|
||||||
| Key | Base model ID | Max tokens | fp16 | Batch size | Grad accum steps | Gradient checkpointing |
|
|
||||||
|-----|--------------|------------|------|------------|-----------------|----------------------|
|
|
||||||
| `deberta-small` | `cross-encoder/nli-deberta-v3-small` | 512 | No | 16 | 1 | No |
|
|
||||||
| `bge-m3` | `MoritzLaurer/bge-m3-zeroshot-v2.0` | 512 | Yes | 4 | 4 | Yes |
|
|
||||||
|
|
||||||
`bge-m3` uses `fp16=True` (halves optimizer state from ~4.8GB to ~2.4GB) with batch size 4 + gradient accumulation 4 = effective batch 16, matching `deberta-small`. These settings are required to fit within 8.2GB VRAM. Still stop Peregrine vLLM before running bge-m3 fine-tuning.
|
|
||||||
|
|
||||||
### Modified: `scripts/classifier_adapters.py`
|
|
||||||
|
|
||||||
Add `FineTunedAdapter(ClassifierAdapter)`:
|
|
||||||
- Takes `model_dir: str` (path to a `models/avocet-*/` checkpoint)
|
|
||||||
- Loads via `pipeline("text-classification", model=model_dir)`
|
|
||||||
- `classify()` input format: **`f"{subject} [SEP] {body[:400]}"`** — must match the training format exactly. Do NOT use the zero-shot adapters' `f"Subject: {subject}\n\n{body[:600]}"` format; distribution shift will degrade accuracy.
|
|
||||||
- Returns the top predicted label directly (single forward pass — no per-label NLI scoring loop)
|
|
||||||
- Expected inference speed: ~10–20ms/email vs 111–338ms for zero-shot
|
|
||||||
|
|
||||||
### Modified: `scripts/benchmark_classifier.py`
|
|
||||||
|
|
||||||
At startup, scan `models/` for subdirectories containing `training_info.json`. Register each as a dynamic entry in the model registry using `FineTunedAdapter`. Silently skips if `models/` does not exist. Existing CLI behaviour unchanged.
|
|
||||||
|
|
||||||
### Modified: `app/api.py`
|
|
||||||
|
|
||||||
Two new GET endpoints (GET required for `EventSource` compatibility):
|
|
||||||
|
|
||||||
**`GET /api/finetune/status`**
|
|
||||||
Scans `models/` for `training_info.json` files. Returns:
|
|
||||||
```json
|
|
||||||
[
|
|
||||||
{
|
|
||||||
"name": "avocet-deberta-small",
|
|
||||||
"base_model": "cross-encoder/nli-deberta-v3-small",
|
|
||||||
"val_macro_f1": 0.712,
|
|
||||||
"timestamp": "2026-03-15T12:00:00Z",
|
|
||||||
"sample_count": 401
|
|
||||||
}
|
|
||||||
]
|
|
||||||
```
|
|
||||||
Returns `[]` if no fine-tuned models exist.
|
|
||||||
|
|
||||||
**`GET /api/finetune/run?model=deberta-small&epochs=5`**
|
|
||||||
Spawns `finetune_classifier.py` via the `job-seeker-classifiers` Python binary. Streams stdout as SSE `{"type":"progress","message":"..."}` events. Emits `{"type":"complete"}` on clean exit, `{"type":"error","message":"..."}` on non-zero exit. Same implementation pattern as `/api/benchmark/run`.
|
|
||||||
|
|
||||||
### Modified: `web/src/views/BenchmarkView.vue`
|
|
||||||
|
|
||||||
**Trained models badge row** (top of view, conditional on fine-tuned models existing):
|
|
||||||
Shows each fine-tuned model name + val macro-F1 chip. Fetches from `/api/finetune/status` on mount.
|
|
||||||
|
|
||||||
**Fine-tune section** (collapsible, below benchmark charts):
|
|
||||||
- Dropdown: `deberta-small` | `bge-m3`
|
|
||||||
- Number input: epochs (default 5, range 1–20)
|
|
||||||
- Run button → streams into existing log component
|
|
||||||
- On `complete`: auto-triggers `/api/benchmark/run` (with `--save`) so charts update immediately
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
## Training Pipeline
|
|
||||||
|
|
||||||
### Data preparation
|
|
||||||
|
|
||||||
1. Load `data/email_score.jsonl`
|
|
||||||
2. Drop rows where `label` not in canonical `LABELS` (removes `profile_alert` etc.)
|
|
||||||
3. Check for classes with < 2 **total** samples (before any split). Drop those classes and warn. Additionally warn — but do not skip — classes with < 5 training samples, noting eval F1 for those classes will be unreliable.
|
|
||||||
4. Input text: `f"{subject} [SEP] {body[:400]}"` — fits within 512 tokens for both target models
|
|
||||||
5. Stratified 80/20 train/val split via `sklearn.model_selection.train_test_split(stratify=labels)`
|
|
||||||
|
|
||||||
### Class weighting
|
|
||||||
|
|
||||||
Compute per-class weights: `total_samples / (n_classes × class_count)`. Pass to a `WeightedTrainer` subclass:
|
|
||||||
|
|
||||||
```python
|
|
||||||
class WeightedTrainer(Trainer):
|
|
||||||
def compute_loss(self, model, inputs, return_outputs=False, **kwargs):
|
|
||||||
# **kwargs is required — absorbs num_items_in_batch added in Transformers 4.38.
|
|
||||||
# Do not remove it; removing it causes TypeError on the first training step.
|
|
||||||
labels = inputs.pop("labels")
|
|
||||||
outputs = model(**inputs)
|
|
||||||
# Move class_weights to the same device as logits — required for GPU training.
|
|
||||||
# class_weights is created on CPU; logits are on cuda:0 during training.
|
|
||||||
weight = self.class_weights.to(outputs.logits.device)
|
|
||||||
loss = F.cross_entropy(outputs.logits, labels, weight=weight)
|
|
||||||
return (loss, outputs) if return_outputs else loss
|
|
||||||
```
|
|
||||||
|
|
||||||
### Model setup
|
|
||||||
|
|
||||||
```python
|
|
||||||
AutoModelForSequenceClassification.from_pretrained(
|
|
||||||
base_model_id,
|
|
||||||
num_labels=10,
|
|
||||||
ignore_mismatched_sizes=True, # see note below
|
|
||||||
id2label=id2label,
|
|
||||||
label2id=label2id,
|
|
||||||
)
|
|
||||||
```
|
|
||||||
|
|
||||||
**Note on `ignore_mismatched_sizes=True`:** The pretrained NLI head is a 3-class linear projection. It mismatches the 10-class head constructed by `num_labels=10`, so its weights are skipped during loading. PyTorch initializes the new head from scratch using the model's default init scheme. The backbone weights load normally. Do not set this to `False` — it will raise a shape error.
|
|
||||||
|
|
||||||
### Training config and `compute_metrics`
|
|
||||||
|
|
||||||
The Trainer requires a `compute_metrics` callback that takes an `EvalPrediction` (logits + label_ids) and returns a dict with a `macro_f1` key. This is distinct from the existing `compute_metrics` in `classifier_adapters.py` (which operates on string predictions):
|
|
||||||
|
|
||||||
```python
|
|
||||||
def compute_metrics_for_trainer(eval_pred: EvalPrediction) -> dict:
|
|
||||||
logits, labels = eval_pred
|
|
||||||
preds = logits.argmax(axis=-1)
|
|
||||||
return {
|
|
||||||
"macro_f1": f1_score(labels, preds, average="macro", zero_division=0),
|
|
||||||
"accuracy": accuracy_score(labels, preds),
|
|
||||||
}
|
|
||||||
```
|
|
||||||
|
|
||||||
`TrainingArguments` must include:
|
|
||||||
- `load_best_model_at_end=True`
|
|
||||||
- `metric_for_best_model="macro_f1"`
|
|
||||||
- `greater_is_better=True`
|
|
||||||
|
|
||||||
These are required for `EarlyStoppingCallback` to work correctly. Without `load_best_model_at_end=True`, `EarlyStoppingCallback` raises `AssertionError` on init.
|
|
||||||
|
|
||||||
| Hyperparameter | deberta-small | bge-m3 |
|
|
||||||
|---------------|--------------|--------|
|
|
||||||
| Epochs | 5 (default, CLI-overridable) | 5 |
|
|
||||||
| Batch size | 16 | 4 |
|
|
||||||
| Gradient accumulation | 1 | 4 (effective batch = 16) |
|
|
||||||
| Learning rate | 2e-5 | 2e-5 |
|
|
||||||
| LR schedule | Linear with 10% warmup | same |
|
|
||||||
| Optimizer | AdamW | AdamW |
|
|
||||||
| fp16 | No | Yes |
|
|
||||||
| Gradient checkpointing | No | Yes |
|
|
||||||
| Eval strategy | Every epoch | Every epoch |
|
|
||||||
| Best checkpoint | By `macro_f1` | same |
|
|
||||||
| Early stopping patience | 3 epochs | 3 epochs |
|
|
||||||
|
|
||||||
### Output
|
|
||||||
|
|
||||||
Saved to `models/avocet-{name}/`:
|
|
||||||
- Model weights + tokenizer (standard HuggingFace format)
|
|
||||||
- `training_info.json`:
|
|
||||||
```json
|
|
||||||
{
|
|
||||||
"name": "avocet-deberta-small",
|
|
||||||
"base_model_id": "cross-encoder/nli-deberta-v3-small",
|
|
||||||
"timestamp": "2026-03-15T12:00:00Z",
|
|
||||||
"epochs_run": 5,
|
|
||||||
"val_macro_f1": 0.712,
|
|
||||||
"val_accuracy": 0.798,
|
|
||||||
"sample_count": 401,
|
|
||||||
"label_counts": { "digest": 116, "neutral": 104, ... }
|
|
||||||
}
|
|
||||||
```
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
## Data Flow
|
|
||||||
|
|
||||||
```
|
|
||||||
email_score.jsonl
|
|
||||||
│
|
|
||||||
▼
|
|
||||||
finetune_classifier.py
|
|
||||||
├── drop non-canonical labels
|
|
||||||
├── check for < 2 total samples per class (drop + warn)
|
|
||||||
├── stratified 80/20 split
|
|
||||||
├── tokenize (subject [SEP] body[:400])
|
|
||||||
├── compute class weights
|
|
||||||
├── WeightedTrainer + EarlyStoppingCallback
|
|
||||||
└── save → models/avocet-{name}/
|
|
||||||
│
|
|
||||||
├── FineTunedAdapter (classifier_adapters.py)
|
|
||||||
│ ├── pipeline("text-classification")
|
|
||||||
│ ├── input: subject [SEP] body[:400] ← must match training format
|
|
||||||
│ └── ~10–20ms/email inference
|
|
||||||
│
|
|
||||||
└── training_info.json
|
|
||||||
└── /api/finetune/status
|
|
||||||
└── BenchmarkView badge row
|
|
||||||
```
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
## Error Handling
|
|
||||||
|
|
||||||
- **Insufficient data (< 2 total samples in a class):** Drop class before split, print warning with class name and count.
|
|
||||||
- **Low data warning (< 5 training samples in a class):** Warn but continue; note eval F1 for that class will be unreliable.
|
|
||||||
- **VRAM OOM on bge-m3:** Surface as clear SSE error message. Suggest stopping Peregrine vLLM first (it holds ~5.7GB).
|
|
||||||
- **Missing score file:** Raise `FileNotFoundError` with actionable message (same pattern as `load_scoring_jsonl`).
|
|
||||||
- **Model dir already exists:** Overwrite with a warning log line. Re-running always produces a fresh checkpoint.
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
## Testing
|
|
||||||
|
|
||||||
- Unit test `WeightedTrainer.compute_loss` with a mock model and known label distribution — verify weighted loss differs from unweighted; verify `**kwargs` does not raise `TypeError`
|
|
||||||
- Unit test `compute_metrics_for_trainer` — verify `macro_f1` key in output, correct value on known inputs
|
|
||||||
- Unit test `FineTunedAdapter.classify` with a mock pipeline — verify it returns a string from `LABELS` using `subject [SEP] body[:400]` format
|
|
||||||
- Unit test auto-discovery in `benchmark_classifier.py` — mock `models/` dir with two `training_info.json` files, verify both appear in the active registry
|
|
||||||
- Integration test: fine-tune on `data/email_score.jsonl.example` (8 samples, 5 of 10 labels represented, 1 epoch, `--model deberta-small`). The 5 missing labels trigger the `< 2 total samples` drop path — the test must verify the drop warning is emitted for each missing label rather than treating it as a failure. Verify `models/avocet-deberta-small/training_info.json` is written with correct keys.
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
## Out of Scope
|
|
||||||
|
|
||||||
- Pushing fine-tuned weights to HuggingFace Hub (future)
|
|
||||||
- Cross-validation or k-fold evaluation (future — dataset too small to be meaningful now)
|
|
||||||
- Hyperparameter search (future)
|
|
||||||
- LoRA/PEFT adapter fine-tuning (future — relevant if model sizes grow beyond available VRAM)
|
|
||||||
- Fine-tuning models other than `deberta-small` and `bge-m3`
|
|
||||||
|
|
@ -14,7 +14,6 @@ dependencies:
|
||||||
- transformers>=4.40
|
- transformers>=4.40
|
||||||
- torch>=2.2
|
- torch>=2.2
|
||||||
- accelerate>=0.27
|
- accelerate>=0.27
|
||||||
- scikit-learn>=1.4
|
|
||||||
|
|
||||||
# Optional: GLiClass adapter
|
# Optional: GLiClass adapter
|
||||||
# - gliclass
|
# - gliclass
|
||||||
|
|
|
||||||
|
|
@ -96,7 +96,6 @@ usage() {
|
||||||
echo " Vue API:"
|
echo " Vue API:"
|
||||||
echo -e " ${GREEN}start-api${NC} Build Vue SPA + start FastAPI on port 8503"
|
echo -e " ${GREEN}start-api${NC} Build Vue SPA + start FastAPI on port 8503"
|
||||||
echo -e " ${GREEN}stop-api${NC} Stop FastAPI server"
|
echo -e " ${GREEN}stop-api${NC} Stop FastAPI server"
|
||||||
echo -e " ${GREEN}restart-api${NC} Stop + rebuild + restart FastAPI server"
|
|
||||||
echo -e " ${GREEN}open-api${NC} Open Vue UI in browser (http://localhost:8503)"
|
echo -e " ${GREEN}open-api${NC} Open Vue UI in browser (http://localhost:8503)"
|
||||||
echo ""
|
echo ""
|
||||||
echo " Dev:"
|
echo " Dev:"
|
||||||
|
|
@ -306,11 +305,6 @@ case "$CMD" in
|
||||||
fi
|
fi
|
||||||
;;
|
;;
|
||||||
|
|
||||||
restart-api)
|
|
||||||
bash "$0" stop-api
|
|
||||||
exec bash "$0" start-api
|
|
||||||
;;
|
|
||||||
|
|
||||||
open-api)
|
open-api)
|
||||||
URL="http://localhost:8503"
|
URL="http://localhost:8503"
|
||||||
info "Opening ${URL}"
|
info "Opening ${URL}"
|
||||||
|
|
|
||||||
|
|
@ -32,14 +32,10 @@ from typing import Any
|
||||||
|
|
||||||
sys.path.insert(0, str(Path(__file__).parent.parent))
|
sys.path.insert(0, str(Path(__file__).parent.parent))
|
||||||
|
|
||||||
_ROOT = Path(__file__).parent.parent
|
|
||||||
_MODELS_DIR = _ROOT / "models"
|
|
||||||
|
|
||||||
from scripts.classifier_adapters import (
|
from scripts.classifier_adapters import (
|
||||||
LABELS,
|
LABELS,
|
||||||
LABEL_DESCRIPTIONS,
|
LABEL_DESCRIPTIONS,
|
||||||
ClassifierAdapter,
|
ClassifierAdapter,
|
||||||
FineTunedAdapter,
|
|
||||||
GLiClassAdapter,
|
GLiClassAdapter,
|
||||||
RerankerAdapter,
|
RerankerAdapter,
|
||||||
ZeroShotAdapter,
|
ZeroShotAdapter,
|
||||||
|
|
@ -154,55 +150,8 @@ def load_scoring_jsonl(path: str) -> list[dict[str, str]]:
|
||||||
return rows
|
return rows
|
||||||
|
|
||||||
|
|
||||||
def discover_finetuned_models(models_dir: Path | None = None) -> list[dict]:
|
def _active_models(include_slow: bool) -> dict[str, dict[str, Any]]:
|
||||||
"""Scan models/ for subdirs containing training_info.json.
|
return {k: v for k, v in MODEL_REGISTRY.items() if v["default"] or include_slow}
|
||||||
|
|
||||||
Returns a list of training_info dicts, each with an added 'model_dir' key.
|
|
||||||
Returns [] silently if models_dir does not exist.
|
|
||||||
"""
|
|
||||||
if models_dir is None:
|
|
||||||
models_dir = _MODELS_DIR
|
|
||||||
if not models_dir.exists():
|
|
||||||
return []
|
|
||||||
found = []
|
|
||||||
for sub in models_dir.iterdir():
|
|
||||||
if not sub.is_dir():
|
|
||||||
continue
|
|
||||||
info_path = sub / "training_info.json"
|
|
||||||
if not info_path.exists():
|
|
||||||
continue
|
|
||||||
try:
|
|
||||||
info = json.loads(info_path.read_text(encoding="utf-8"))
|
|
||||||
except Exception as exc:
|
|
||||||
print(f"[discover] WARN: skipping {info_path}: {exc}", flush=True)
|
|
||||||
continue
|
|
||||||
if "name" not in info:
|
|
||||||
print(f"[discover] WARN: skipping {info_path}: missing 'name' key", flush=True)
|
|
||||||
continue
|
|
||||||
info["model_dir"] = str(sub)
|
|
||||||
found.append(info)
|
|
||||||
return found
|
|
||||||
|
|
||||||
|
|
||||||
def _active_models(include_slow: bool = False) -> dict[str, dict[str, Any]]:
|
|
||||||
"""Return the active model registry, merged with any discovered fine-tuned models."""
|
|
||||||
active: dict[str, dict[str, Any]] = {
|
|
||||||
key: {**entry, "adapter_instance": entry["adapter"](
|
|
||||||
key,
|
|
||||||
entry["model_id"],
|
|
||||||
**entry.get("kwargs", {}),
|
|
||||||
)}
|
|
||||||
for key, entry in MODEL_REGISTRY.items()
|
|
||||||
if include_slow or entry.get("default", False)
|
|
||||||
}
|
|
||||||
for info in discover_finetuned_models():
|
|
||||||
name = info["name"]
|
|
||||||
active[name] = {
|
|
||||||
"adapter_instance": FineTunedAdapter(name, info["model_dir"]),
|
|
||||||
"params": "fine-tuned",
|
|
||||||
"default": True,
|
|
||||||
}
|
|
||||||
return active
|
|
||||||
|
|
||||||
|
|
||||||
def run_scoring(
|
def run_scoring(
|
||||||
|
|
@ -214,8 +163,7 @@ def run_scoring(
|
||||||
gold = [r["label"] for r in rows]
|
gold = [r["label"] for r in rows]
|
||||||
results: dict[str, Any] = {}
|
results: dict[str, Any] = {}
|
||||||
|
|
||||||
for i, adapter in enumerate(adapters, 1):
|
for adapter in adapters:
|
||||||
print(f"[{i}/{len(adapters)}] Running {adapter.name} ({len(rows)} samples) …", flush=True)
|
|
||||||
preds: list[str] = []
|
preds: list[str] = []
|
||||||
t0 = time.monotonic()
|
t0 = time.monotonic()
|
||||||
for row in rows:
|
for row in rows:
|
||||||
|
|
@ -229,7 +177,6 @@ def run_scoring(
|
||||||
metrics = compute_metrics(preds, gold, LABELS)
|
metrics = compute_metrics(preds, gold, LABELS)
|
||||||
metrics["latency_ms"] = round(elapsed_ms / len(rows), 1)
|
metrics["latency_ms"] = round(elapsed_ms / len(rows), 1)
|
||||||
results[adapter.name] = metrics
|
results[adapter.name] = metrics
|
||||||
print(f" → macro-F1 {metrics['__macro_f1__']:.3f} accuracy {metrics['__accuracy__']:.3f} {metrics['latency_ms']:.1f} ms/email", flush=True)
|
|
||||||
adapter.unload()
|
adapter.unload()
|
||||||
|
|
||||||
return results
|
return results
|
||||||
|
|
@ -398,7 +345,10 @@ def cmd_score(args: argparse.Namespace) -> None:
|
||||||
if args.models:
|
if args.models:
|
||||||
active = {k: v for k, v in active.items() if k in args.models}
|
active = {k: v for k, v in active.items() if k in args.models}
|
||||||
|
|
||||||
adapters = [entry["adapter_instance"] for entry in active.values()]
|
adapters = [
|
||||||
|
entry["adapter"](name, entry["model_id"], **entry.get("kwargs", {}))
|
||||||
|
for name, entry in active.items()
|
||||||
|
]
|
||||||
|
|
||||||
print(f"\nScoring {len(adapters)} model(s) against {args.score_file} …\n")
|
print(f"\nScoring {len(adapters)} model(s) against {args.score_file} …\n")
|
||||||
results = run_scoring(adapters, args.score_file)
|
results = run_scoring(adapters, args.score_file)
|
||||||
|
|
@ -425,31 +375,6 @@ def cmd_score(args: argparse.Namespace) -> None:
|
||||||
print(row_str)
|
print(row_str)
|
||||||
print()
|
print()
|
||||||
|
|
||||||
if args.save:
|
|
||||||
import datetime
|
|
||||||
rows = load_scoring_jsonl(args.score_file)
|
|
||||||
save_data = {
|
|
||||||
"timestamp": datetime.datetime.utcnow().isoformat() + "Z",
|
|
||||||
"sample_count": len(rows),
|
|
||||||
"models": {
|
|
||||||
name: {
|
|
||||||
"macro_f1": round(m["__macro_f1__"], 4),
|
|
||||||
"accuracy": round(m["__accuracy__"], 4),
|
|
||||||
"latency_ms": m["latency_ms"],
|
|
||||||
"per_label": {
|
|
||||||
label: {k: round(v, 4) for k, v in m[label].items()}
|
|
||||||
for label in LABELS
|
|
||||||
if label in m
|
|
||||||
},
|
|
||||||
}
|
|
||||||
for name, m in results.items()
|
|
||||||
},
|
|
||||||
}
|
|
||||||
save_path = Path(args.score_file).parent / "benchmark_results.json"
|
|
||||||
with open(save_path, "w") as f:
|
|
||||||
json.dump(save_data, f, indent=2)
|
|
||||||
print(f"Results saved → {save_path}", flush=True)
|
|
||||||
|
|
||||||
|
|
||||||
def cmd_compare(args: argparse.Namespace) -> None:
|
def cmd_compare(args: argparse.Namespace) -> None:
|
||||||
active = _active_models(args.include_slow)
|
active = _active_models(args.include_slow)
|
||||||
|
|
@ -460,7 +385,10 @@ def cmd_compare(args: argparse.Namespace) -> None:
|
||||||
emails = _fetch_imap_sample(args.limit, args.days)
|
emails = _fetch_imap_sample(args.limit, args.days)
|
||||||
print(f"Fetched {len(emails)} emails. Loading {len(active)} model(s) …\n")
|
print(f"Fetched {len(emails)} emails. Loading {len(active)} model(s) …\n")
|
||||||
|
|
||||||
adapters = [entry["adapter_instance"] for entry in active.values()]
|
adapters = [
|
||||||
|
entry["adapter"](name, entry["model_id"], **entry.get("kwargs", {}))
|
||||||
|
for name, entry in active.items()
|
||||||
|
]
|
||||||
model_names = [a.name for a in adapters]
|
model_names = [a.name for a in adapters]
|
||||||
|
|
||||||
col = 22
|
col = 22
|
||||||
|
|
@ -503,8 +431,6 @@ def main() -> None:
|
||||||
parser.add_argument("--days", type=int, default=90, help="Days back for IMAP search")
|
parser.add_argument("--days", type=int, default=90, help="Days back for IMAP search")
|
||||||
parser.add_argument("--include-slow", action="store_true", help="Include non-default heavy models")
|
parser.add_argument("--include-slow", action="store_true", help="Include non-default heavy models")
|
||||||
parser.add_argument("--models", nargs="+", help="Override: run only these model names")
|
parser.add_argument("--models", nargs="+", help="Override: run only these model names")
|
||||||
parser.add_argument("--save", action="store_true",
|
|
||||||
help="Save results to data/benchmark_results.json (for the web UI)")
|
|
||||||
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -17,7 +17,6 @@ __all__ = [
|
||||||
"ZeroShotAdapter",
|
"ZeroShotAdapter",
|
||||||
"GLiClassAdapter",
|
"GLiClassAdapter",
|
||||||
"RerankerAdapter",
|
"RerankerAdapter",
|
||||||
"FineTunedAdapter",
|
|
||||||
]
|
]
|
||||||
|
|
||||||
LABELS: list[str] = [
|
LABELS: list[str] = [
|
||||||
|
|
@ -264,43 +263,3 @@ class RerankerAdapter(ClassifierAdapter):
|
||||||
pairs = [[text, LABEL_DESCRIPTIONS.get(label, label.replace("_", " "))] for label in LABELS]
|
pairs = [[text, LABEL_DESCRIPTIONS.get(label, label.replace("_", " "))] for label in LABELS]
|
||||||
scores: list[float] = self._reranker.compute_score(pairs, normalize=True)
|
scores: list[float] = self._reranker.compute_score(pairs, normalize=True)
|
||||||
return LABELS[scores.index(max(scores))]
|
return LABELS[scores.index(max(scores))]
|
||||||
|
|
||||||
|
|
||||||
class FineTunedAdapter(ClassifierAdapter):
|
|
||||||
"""Loads a fine-tuned checkpoint from a local models/ directory.
|
|
||||||
|
|
||||||
Uses pipeline("text-classification") for a single forward pass.
|
|
||||||
Input format: 'subject [SEP] body[:400]' — must match training format exactly.
|
|
||||||
Expected inference speed: ~10–20ms/email vs 111–338ms for zero-shot.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, name: str, model_dir: str) -> None:
|
|
||||||
self._name = name
|
|
||||||
self._model_dir = model_dir
|
|
||||||
self._pipeline: Any = None
|
|
||||||
|
|
||||||
@property
|
|
||||||
def name(self) -> str:
|
|
||||||
return self._name
|
|
||||||
|
|
||||||
@property
|
|
||||||
def model_id(self) -> str:
|
|
||||||
return self._model_dir
|
|
||||||
|
|
||||||
def load(self) -> None:
|
|
||||||
import scripts.classifier_adapters as _mod # noqa: PLC0415
|
|
||||||
_pipe_fn = _mod.pipeline
|
|
||||||
if _pipe_fn is None:
|
|
||||||
raise ImportError("transformers not installed — run: pip install transformers")
|
|
||||||
device = 0 if _cuda_available() else -1
|
|
||||||
self._pipeline = _pipe_fn("text-classification", model=self._model_dir, device=device)
|
|
||||||
|
|
||||||
def unload(self) -> None:
|
|
||||||
self._pipeline = None
|
|
||||||
|
|
||||||
def classify(self, subject: str, body: str) -> str:
|
|
||||||
if self._pipeline is None:
|
|
||||||
self.load()
|
|
||||||
text = f"{subject} [SEP] {body[:400]}"
|
|
||||||
result = self._pipeline(text)
|
|
||||||
return result[0]["label"]
|
|
||||||
|
|
|
||||||
|
|
@ -1,416 +0,0 @@
|
||||||
"""Fine-tune email classifiers on the labeled dataset.
|
|
||||||
|
|
||||||
CLI entry point. All prints use flush=True so stdout is SSE-streamable.
|
|
||||||
|
|
||||||
Usage:
|
|
||||||
python scripts/finetune_classifier.py --model deberta-small [--epochs 5]
|
|
||||||
|
|
||||||
Supported --model values: deberta-small, bge-m3
|
|
||||||
"""
|
|
||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
import argparse
|
|
||||||
import hashlib
|
|
||||||
import json
|
|
||||||
import sys
|
|
||||||
from collections import Counter
|
|
||||||
from datetime import datetime, timezone
|
|
||||||
from pathlib import Path
|
|
||||||
from typing import Any
|
|
||||||
|
|
||||||
import torch
|
|
||||||
import torch.nn.functional as F
|
|
||||||
from torch.utils.data import Dataset as TorchDataset
|
|
||||||
from sklearn.model_selection import train_test_split
|
|
||||||
from sklearn.metrics import f1_score, accuracy_score
|
|
||||||
from transformers import (
|
|
||||||
AutoTokenizer,
|
|
||||||
AutoModelForSequenceClassification,
|
|
||||||
EvalPrediction,
|
|
||||||
Trainer,
|
|
||||||
TrainingArguments,
|
|
||||||
EarlyStoppingCallback,
|
|
||||||
)
|
|
||||||
|
|
||||||
sys.path.insert(0, str(Path(__file__).parent.parent))
|
|
||||||
|
|
||||||
from scripts.classifier_adapters import LABELS
|
|
||||||
|
|
||||||
_ROOT = Path(__file__).parent.parent
|
|
||||||
|
|
||||||
_MODEL_CONFIG: dict[str, dict[str, Any]] = {
|
|
||||||
"deberta-small": {
|
|
||||||
"base_model_id": "cross-encoder/nli-deberta-v3-small",
|
|
||||||
"max_tokens": 512,
|
|
||||||
# fp16 must stay OFF — DeBERTa-v3 disentangled attention overflows fp16.
|
|
||||||
"fp16": False,
|
|
||||||
# batch_size=8 + grad_accum=2 keeps effective batch of 16 while halving
|
|
||||||
# per-step activation memory. gradient_checkpointing recomputes activations
|
|
||||||
# on backward instead of storing them — ~60% less activation VRAM.
|
|
||||||
"batch_size": 8,
|
|
||||||
"grad_accum": 2,
|
|
||||||
"gradient_checkpointing": True,
|
|
||||||
},
|
|
||||||
"bge-m3": {
|
|
||||||
"base_model_id": "MoritzLaurer/bge-m3-zeroshot-v2.0",
|
|
||||||
"max_tokens": 512,
|
|
||||||
"fp16": True,
|
|
||||||
"batch_size": 4,
|
|
||||||
"grad_accum": 4,
|
|
||||||
"gradient_checkpointing": True,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
def load_and_prepare_data(score_files: Path | list[Path]) -> tuple[list[str], list[str]]:
|
|
||||||
"""Load labeled JSONL and return (texts, labels) filtered to canonical LABELS.
|
|
||||||
|
|
||||||
score_files: a single Path or a list of Paths. When multiple files are given,
|
|
||||||
rows are merged with last-write-wins deduplication keyed by content hash
|
|
||||||
(MD5 of subject + body[:100]).
|
|
||||||
|
|
||||||
Drops rows with non-canonical labels (with warning), and drops entire classes
|
|
||||||
that have fewer than 2 total samples (required for stratified split).
|
|
||||||
Warns (but continues) for classes with fewer than 5 samples.
|
|
||||||
"""
|
|
||||||
# Normalise to list — backwards compatible with single-Path callers.
|
|
||||||
if isinstance(score_files, Path):
|
|
||||||
score_files = [score_files]
|
|
||||||
|
|
||||||
for score_file in score_files:
|
|
||||||
if not score_file.exists():
|
|
||||||
raise FileNotFoundError(
|
|
||||||
f"Labeled data not found: {score_file}\n"
|
|
||||||
"Run the label tool first to generate email_score.jsonl."
|
|
||||||
)
|
|
||||||
|
|
||||||
label_set = set(LABELS)
|
|
||||||
# Use a plain dict keyed by content hash; later entries overwrite earlier ones
|
|
||||||
# (last-write wins), which lets later labeling runs correct earlier labels.
|
|
||||||
seen: dict[str, dict] = {}
|
|
||||||
total = 0
|
|
||||||
|
|
||||||
for score_file in score_files:
|
|
||||||
with score_file.open() as fh:
|
|
||||||
for line in fh:
|
|
||||||
line = line.strip()
|
|
||||||
if not line:
|
|
||||||
continue
|
|
||||||
try:
|
|
||||||
r = json.loads(line)
|
|
||||||
except json.JSONDecodeError:
|
|
||||||
continue
|
|
||||||
lbl = r.get("label", "")
|
|
||||||
if lbl not in label_set:
|
|
||||||
print(
|
|
||||||
f"[data] WARNING: Dropping row with non-canonical label {lbl!r}",
|
|
||||||
flush=True,
|
|
||||||
)
|
|
||||||
continue
|
|
||||||
content_hash = hashlib.md5(
|
|
||||||
(r.get("subject", "") + (r.get("body", "") or "")[:100]).encode(
|
|
||||||
"utf-8", errors="replace"
|
|
||||||
)
|
|
||||||
).hexdigest()
|
|
||||||
seen[content_hash] = r
|
|
||||||
total += 1
|
|
||||||
|
|
||||||
kept = len(seen)
|
|
||||||
dropped = total - kept
|
|
||||||
if dropped > 0:
|
|
||||||
print(
|
|
||||||
f"[data] Deduped: kept {kept} of {total} rows (dropped {dropped} duplicates)",
|
|
||||||
flush=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
rows = list(seen.values())
|
|
||||||
|
|
||||||
# Count samples per class
|
|
||||||
counts: Counter = Counter(r["label"] for r in rows)
|
|
||||||
|
|
||||||
# Drop classes with < 2 total samples (cannot stratify-split)
|
|
||||||
drop_classes: set[str] = set()
|
|
||||||
for lbl, cnt in counts.items():
|
|
||||||
if cnt < 2:
|
|
||||||
print(
|
|
||||||
f"[data] WARNING: Dropping class {lbl!r} — only {counts[lbl]} total "
|
|
||||||
f"sample(s). Need at least 2 for stratified split.",
|
|
||||||
flush=True,
|
|
||||||
)
|
|
||||||
drop_classes.add(lbl)
|
|
||||||
|
|
||||||
# Warn for classes with < 5 samples (unreliable eval F1)
|
|
||||||
for lbl, cnt in counts.items():
|
|
||||||
if lbl not in drop_classes and cnt < 5:
|
|
||||||
print(
|
|
||||||
f"[data] WARNING: Class {lbl!r} has only {cnt} sample(s). "
|
|
||||||
f"Eval F1 for this class will be unreliable.",
|
|
||||||
flush=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Filter rows
|
|
||||||
rows = [r for r in rows if r["label"] not in drop_classes]
|
|
||||||
|
|
||||||
texts = [f"{r['subject']} [SEP] {r['body'][:400]}" for r in rows]
|
|
||||||
labels = [r["label"] for r in rows]
|
|
||||||
|
|
||||||
return texts, labels
|
|
||||||
|
|
||||||
|
|
||||||
def compute_class_weights(label_ids: list[int], n_classes: int) -> torch.Tensor:
|
|
||||||
"""Compute inverse-frequency class weights.
|
|
||||||
|
|
||||||
Formula: total / (n_classes * class_count) per class.
|
|
||||||
Unseen classes (count=0) use count=1 to avoid division by zero.
|
|
||||||
|
|
||||||
Returns a CPU float32 tensor of shape (n_classes,).
|
|
||||||
"""
|
|
||||||
counts = Counter(label_ids)
|
|
||||||
total = len(label_ids)
|
|
||||||
weights = []
|
|
||||||
for cls in range(n_classes):
|
|
||||||
cnt = counts.get(cls, 1) # use 1 for unseen to avoid div-by-zero
|
|
||||||
weights.append(total / (n_classes * cnt))
|
|
||||||
return torch.tensor(weights, dtype=torch.float32)
|
|
||||||
|
|
||||||
|
|
||||||
def compute_metrics_for_trainer(eval_pred: EvalPrediction) -> dict:
|
|
||||||
"""Compute macro F1 and accuracy from EvalPrediction.
|
|
||||||
|
|
||||||
Called by Hugging Face Trainer at each evaluation step.
|
|
||||||
"""
|
|
||||||
logits, label_ids = eval_pred.predictions, eval_pred.label_ids
|
|
||||||
preds = logits.argmax(axis=-1)
|
|
||||||
macro_f1 = f1_score(label_ids, preds, average="macro", zero_division=0)
|
|
||||||
acc = accuracy_score(label_ids, preds)
|
|
||||||
return {"macro_f1": float(macro_f1), "accuracy": float(acc)}
|
|
||||||
|
|
||||||
|
|
||||||
class WeightedTrainer(Trainer):
|
|
||||||
"""Trainer subclass that applies per-class weights to the cross-entropy loss."""
|
|
||||||
|
|
||||||
def compute_loss(self, model, inputs, return_outputs=False, **kwargs):
|
|
||||||
# **kwargs is required — absorbs num_items_in_batch added in Transformers 4.38.
|
|
||||||
# Do not remove it; removing it causes TypeError on the first training step.
|
|
||||||
labels = inputs.pop("labels")
|
|
||||||
outputs = model(**inputs)
|
|
||||||
# Move class_weights to the same device as logits — required for GPU training.
|
|
||||||
# class_weights is created on CPU; logits are on cuda:0 during training.
|
|
||||||
weight = self.class_weights.to(outputs.logits.device)
|
|
||||||
loss = F.cross_entropy(outputs.logits, labels, weight=weight)
|
|
||||||
return (loss, outputs) if return_outputs else loss
|
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
# Training dataset wrapper
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
|
|
||||||
|
|
||||||
class _EmailDataset(TorchDataset):
|
|
||||||
def __init__(self, encodings: dict, label_ids: list[int]) -> None:
|
|
||||||
self.encodings = encodings
|
|
||||||
self.label_ids = label_ids
|
|
||||||
|
|
||||||
def __len__(self) -> int:
|
|
||||||
return len(self.label_ids)
|
|
||||||
|
|
||||||
def __getitem__(self, idx: int) -> dict:
|
|
||||||
item = {k: torch.tensor(v[idx]) for k, v in self.encodings.items()}
|
|
||||||
item["labels"] = torch.tensor(self.label_ids[idx], dtype=torch.long)
|
|
||||||
return item
|
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
# Main training function
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
|
|
||||||
def run_finetune(model_key: str, epochs: int = 5, score_files: list[Path] | None = None) -> None:
|
|
||||||
"""Fine-tune the specified model on labeled data.
|
|
||||||
|
|
||||||
score_files: list of score JSONL paths to merge. Defaults to [_ROOT / "data" / "email_score.jsonl"].
|
|
||||||
Saves model + tokenizer + training_info.json to models/avocet-{model_key}/.
|
|
||||||
All prints use flush=True for SSE streaming.
|
|
||||||
"""
|
|
||||||
if model_key not in _MODEL_CONFIG:
|
|
||||||
raise ValueError(f"Unknown model key: {model_key!r}. Choose from: {list(_MODEL_CONFIG)}")
|
|
||||||
|
|
||||||
if score_files is None:
|
|
||||||
score_files = [_ROOT / "data" / "email_score.jsonl"]
|
|
||||||
|
|
||||||
config = _MODEL_CONFIG[model_key]
|
|
||||||
base_model_id = config["base_model_id"]
|
|
||||||
output_dir = _ROOT / "models" / f"avocet-{model_key}"
|
|
||||||
|
|
||||||
print(f"[finetune] Model: {model_key} ({base_model_id})", flush=True)
|
|
||||||
print(f"[finetune] Score files: {[str(f) for f in score_files]}", flush=True)
|
|
||||||
print(f"[finetune] Output: {output_dir}", flush=True)
|
|
||||||
if output_dir.exists():
|
|
||||||
print(f"[finetune] WARNING: {output_dir} already exists — will overwrite.", flush=True)
|
|
||||||
|
|
||||||
# --- Data ---
|
|
||||||
print(f"[finetune] Loading data ...", flush=True)
|
|
||||||
texts, str_labels = load_and_prepare_data(score_files)
|
|
||||||
|
|
||||||
present_labels = sorted(set(str_labels))
|
|
||||||
label2id = {l: i for i, l in enumerate(present_labels)}
|
|
||||||
id2label = {i: l for l, i in label2id.items()}
|
|
||||||
n_classes = len(present_labels)
|
|
||||||
label_ids = [label2id[l] for l in str_labels]
|
|
||||||
|
|
||||||
print(f"[finetune] {len(texts)} samples, {n_classes} classes", flush=True)
|
|
||||||
|
|
||||||
# Stratified 80/20 split — ensure val set has at least n_classes samples.
|
|
||||||
# For very small datasets (e.g. example data) we may need to give the val set
|
|
||||||
# more than 20% so every class appears at least once in eval.
|
|
||||||
desired_test = max(int(len(texts) * 0.2), n_classes)
|
|
||||||
# test_size must leave at least n_classes samples for train too
|
|
||||||
desired_test = min(desired_test, len(texts) - n_classes)
|
|
||||||
(train_texts, val_texts,
|
|
||||||
train_label_ids, val_label_ids) = train_test_split(
|
|
||||||
texts, label_ids,
|
|
||||||
test_size=desired_test,
|
|
||||||
stratify=label_ids,
|
|
||||||
random_state=42,
|
|
||||||
)
|
|
||||||
print(f"[finetune] Train: {len(train_texts)}, Val: {len(val_texts)}", flush=True)
|
|
||||||
|
|
||||||
# Warn for classes with < 5 training samples
|
|
||||||
train_counts = Counter(train_label_ids)
|
|
||||||
for cls_id, cnt in train_counts.items():
|
|
||||||
if cnt < 5:
|
|
||||||
print(
|
|
||||||
f"[finetune] WARNING: Class {id2label[cls_id]!r} has {cnt} training sample(s). "
|
|
||||||
"Eval F1 for this class will be unreliable.",
|
|
||||||
flush=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
# --- Tokenize ---
|
|
||||||
print(f"[finetune] Loading tokenizer ...", flush=True)
|
|
||||||
tokenizer = AutoTokenizer.from_pretrained(base_model_id)
|
|
||||||
|
|
||||||
train_enc = tokenizer(train_texts, truncation=True,
|
|
||||||
max_length=config["max_tokens"], padding=True)
|
|
||||||
val_enc = tokenizer(val_texts, truncation=True,
|
|
||||||
max_length=config["max_tokens"], padding=True)
|
|
||||||
|
|
||||||
train_dataset = _EmailDataset(train_enc, train_label_ids)
|
|
||||||
val_dataset = _EmailDataset(val_enc, val_label_ids)
|
|
||||||
|
|
||||||
# --- Class weights ---
|
|
||||||
class_weights = compute_class_weights(train_label_ids, n_classes)
|
|
||||||
print(f"[finetune] Class weights computed", flush=True)
|
|
||||||
|
|
||||||
# --- Model ---
|
|
||||||
print(f"[finetune] Loading model ...", flush=True)
|
|
||||||
model = AutoModelForSequenceClassification.from_pretrained(
|
|
||||||
base_model_id,
|
|
||||||
num_labels=n_classes,
|
|
||||||
ignore_mismatched_sizes=True, # NLI head (3-class) → new head (n_classes)
|
|
||||||
id2label=id2label,
|
|
||||||
label2id=label2id,
|
|
||||||
)
|
|
||||||
if config["gradient_checkpointing"]:
|
|
||||||
# use_reentrant=False avoids "backward through graph a second time" errors
|
|
||||||
# when Accelerate's gradient accumulation context is layered on top.
|
|
||||||
# Reentrant checkpointing (the default) conflicts with Accelerate ≥ 0.27.
|
|
||||||
model.gradient_checkpointing_enable(
|
|
||||||
gradient_checkpointing_kwargs={"use_reentrant": False}
|
|
||||||
)
|
|
||||||
|
|
||||||
# --- TrainingArguments ---
|
|
||||||
training_args = TrainingArguments(
|
|
||||||
output_dir=str(output_dir),
|
|
||||||
num_train_epochs=epochs,
|
|
||||||
per_device_train_batch_size=config["batch_size"],
|
|
||||||
per_device_eval_batch_size=config["batch_size"],
|
|
||||||
gradient_accumulation_steps=config["grad_accum"],
|
|
||||||
learning_rate=2e-5,
|
|
||||||
lr_scheduler_type="linear",
|
|
||||||
warmup_ratio=0.1,
|
|
||||||
fp16=config["fp16"],
|
|
||||||
eval_strategy="epoch",
|
|
||||||
save_strategy="epoch",
|
|
||||||
load_best_model_at_end=True,
|
|
||||||
metric_for_best_model="macro_f1",
|
|
||||||
greater_is_better=True,
|
|
||||||
logging_steps=10,
|
|
||||||
report_to="none",
|
|
||||||
save_total_limit=2,
|
|
||||||
)
|
|
||||||
|
|
||||||
trainer = WeightedTrainer(
|
|
||||||
model=model,
|
|
||||||
args=training_args,
|
|
||||||
train_dataset=train_dataset,
|
|
||||||
eval_dataset=val_dataset,
|
|
||||||
compute_metrics=compute_metrics_for_trainer,
|
|
||||||
callbacks=[EarlyStoppingCallback(early_stopping_patience=3)],
|
|
||||||
)
|
|
||||||
trainer.class_weights = class_weights
|
|
||||||
|
|
||||||
# --- Train ---
|
|
||||||
print(f"[finetune] Starting training ({epochs} epochs) ...", flush=True)
|
|
||||||
train_result = trainer.train()
|
|
||||||
print(f"[finetune] Training complete. Steps: {train_result.global_step}", flush=True)
|
|
||||||
|
|
||||||
# --- Evaluate ---
|
|
||||||
print(f"[finetune] Evaluating best checkpoint ...", flush=True)
|
|
||||||
metrics = trainer.evaluate()
|
|
||||||
val_macro_f1 = metrics.get("eval_macro_f1", 0.0)
|
|
||||||
val_accuracy = metrics.get("eval_accuracy", 0.0)
|
|
||||||
print(f"[finetune] Val macro-F1: {val_macro_f1:.4f}, Accuracy: {val_accuracy:.4f}", flush=True)
|
|
||||||
|
|
||||||
# --- Save model + tokenizer ---
|
|
||||||
print(f"[finetune] Saving model to {output_dir} ...", flush=True)
|
|
||||||
trainer.save_model(str(output_dir))
|
|
||||||
tokenizer.save_pretrained(str(output_dir))
|
|
||||||
|
|
||||||
# --- Write training_info.json ---
|
|
||||||
label_counts = dict(Counter(str_labels))
|
|
||||||
info = {
|
|
||||||
"name": f"avocet-{model_key}",
|
|
||||||
"base_model_id": base_model_id,
|
|
||||||
"timestamp": datetime.now(timezone.utc).isoformat(),
|
|
||||||
"epochs_run": epochs,
|
|
||||||
"val_macro_f1": round(val_macro_f1, 4),
|
|
||||||
"val_accuracy": round(val_accuracy, 4),
|
|
||||||
"sample_count": len(texts),
|
|
||||||
"train_sample_count": len(train_texts),
|
|
||||||
"label_counts": label_counts,
|
|
||||||
"score_files": [str(f) for f in score_files],
|
|
||||||
}
|
|
||||||
info_path = output_dir / "training_info.json"
|
|
||||||
info_path.write_text(json.dumps(info, indent=2), encoding="utf-8")
|
|
||||||
print(f"[finetune] Saved training_info.json: val_macro_f1={val_macro_f1:.4f}", flush=True)
|
|
||||||
print(f"[finetune] Done.", flush=True)
|
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
# CLI
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
parser = argparse.ArgumentParser(description="Fine-tune an email classifier")
|
|
||||||
parser.add_argument(
|
|
||||||
"--model",
|
|
||||||
choices=list(_MODEL_CONFIG),
|
|
||||||
required=True,
|
|
||||||
help="Model key to fine-tune",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--epochs",
|
|
||||||
type=int,
|
|
||||||
default=5,
|
|
||||||
help="Number of training epochs (default: 5)",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--score",
|
|
||||||
dest="score_files",
|
|
||||||
type=Path,
|
|
||||||
action="append",
|
|
||||||
metavar="FILE",
|
|
||||||
help="Score JSONL file to include (repeatable; defaults to data/email_score.jsonl)",
|
|
||||||
)
|
|
||||||
args = parser.parse_args()
|
|
||||||
score_files = args.score_files or None # None → run_finetune uses default
|
|
||||||
run_finetune(args.model, args.epochs, score_files=score_files)
|
|
||||||
|
|
@ -325,237 +325,3 @@ def test_fetch_stream_with_mock_imap(client, config_dir, data_dir):
|
||||||
assert "start" in types
|
assert "start" in types
|
||||||
assert "done" in types
|
assert "done" in types
|
||||||
assert "complete" in types
|
assert "complete" in types
|
||||||
|
|
||||||
|
|
||||||
# ---- /api/finetune/status tests ----
|
|
||||||
|
|
||||||
def test_finetune_status_returns_empty_when_no_models_dir(client):
|
|
||||||
"""GET /api/finetune/status must return [] if models/ does not exist."""
|
|
||||||
r = client.get("/api/finetune/status")
|
|
||||||
assert r.status_code == 200
|
|
||||||
assert r.json() == []
|
|
||||||
|
|
||||||
|
|
||||||
def test_finetune_status_returns_training_info(client, tmp_path):
|
|
||||||
"""GET /api/finetune/status must return one entry per training_info.json found."""
|
|
||||||
import json as _json
|
|
||||||
from app import api as api_module
|
|
||||||
|
|
||||||
models_dir = tmp_path / "models" / "avocet-deberta-small"
|
|
||||||
models_dir.mkdir(parents=True)
|
|
||||||
info = {
|
|
||||||
"name": "avocet-deberta-small",
|
|
||||||
"base_model_id": "cross-encoder/nli-deberta-v3-small",
|
|
||||||
"val_macro_f1": 0.712,
|
|
||||||
"timestamp": "2026-03-15T12:00:00Z",
|
|
||||||
"sample_count": 401,
|
|
||||||
}
|
|
||||||
(models_dir / "training_info.json").write_text(_json.dumps(info))
|
|
||||||
|
|
||||||
api_module.set_models_dir(tmp_path / "models")
|
|
||||||
try:
|
|
||||||
r = client.get("/api/finetune/status")
|
|
||||||
assert r.status_code == 200
|
|
||||||
data = r.json()
|
|
||||||
assert any(d["name"] == "avocet-deberta-small" for d in data)
|
|
||||||
finally:
|
|
||||||
api_module.set_models_dir(api_module._ROOT / "models")
|
|
||||||
|
|
||||||
|
|
||||||
def test_finetune_run_streams_sse_events(client):
|
|
||||||
"""GET /api/finetune/run must return text/event-stream content type."""
|
|
||||||
from unittest.mock import patch, MagicMock
|
|
||||||
|
|
||||||
mock_proc = MagicMock()
|
|
||||||
mock_proc.stdout = iter(["Training epoch 1\n", "Done\n"])
|
|
||||||
mock_proc.returncode = 0
|
|
||||||
mock_proc.wait = MagicMock()
|
|
||||||
|
|
||||||
with patch("app.api._subprocess.Popen",return_value=mock_proc):
|
|
||||||
r = client.get("/api/finetune/run?model=deberta-small&epochs=1")
|
|
||||||
|
|
||||||
assert r.status_code == 200
|
|
||||||
assert "text/event-stream" in r.headers.get("content-type", "")
|
|
||||||
|
|
||||||
|
|
||||||
def test_finetune_run_emits_complete_on_success(client):
|
|
||||||
"""GET /api/finetune/run must emit a complete event on clean exit."""
|
|
||||||
from unittest.mock import patch, MagicMock
|
|
||||||
|
|
||||||
mock_proc = MagicMock()
|
|
||||||
mock_proc.stdout = iter(["progress line\n"])
|
|
||||||
mock_proc.returncode = 0
|
|
||||||
mock_proc.wait = MagicMock()
|
|
||||||
|
|
||||||
with patch("app.api._subprocess.Popen",return_value=mock_proc):
|
|
||||||
r = client.get("/api/finetune/run?model=deberta-small&epochs=1")
|
|
||||||
|
|
||||||
assert '{"type": "complete"}' in r.text
|
|
||||||
|
|
||||||
|
|
||||||
def test_finetune_run_emits_error_on_nonzero_exit(client):
|
|
||||||
"""GET /api/finetune/run must emit an error event on non-zero exit."""
|
|
||||||
from unittest.mock import patch, MagicMock
|
|
||||||
|
|
||||||
mock_proc = MagicMock()
|
|
||||||
mock_proc.stdout = iter([])
|
|
||||||
mock_proc.returncode = 1
|
|
||||||
mock_proc.wait = MagicMock()
|
|
||||||
|
|
||||||
with patch("app.api._subprocess.Popen",return_value=mock_proc):
|
|
||||||
r = client.get("/api/finetune/run?model=deberta-small&epochs=1")
|
|
||||||
|
|
||||||
assert '"type": "error"' in r.text
|
|
||||||
|
|
||||||
|
|
||||||
def test_finetune_run_passes_score_files_to_subprocess(client):
|
|
||||||
"""GET /api/finetune/run?score=file1&score=file2 must pass --score args to subprocess."""
|
|
||||||
from unittest.mock import patch, MagicMock
|
|
||||||
|
|
||||||
captured_cmd = []
|
|
||||||
|
|
||||||
def mock_popen(cmd, **kwargs):
|
|
||||||
captured_cmd.extend(cmd)
|
|
||||||
m = MagicMock()
|
|
||||||
m.stdout = iter([])
|
|
||||||
m.returncode = 0
|
|
||||||
m.wait = MagicMock()
|
|
||||||
return m
|
|
||||||
|
|
||||||
with patch("app.api._subprocess.Popen",side_effect=mock_popen):
|
|
||||||
client.get("/api/finetune/run?model=deberta-small&epochs=1&score=run1.jsonl&score=run2.jsonl")
|
|
||||||
|
|
||||||
assert "--score" in captured_cmd
|
|
||||||
assert captured_cmd.count("--score") == 2
|
|
||||||
# Paths are resolved to absolute — check filenames are present as substrings
|
|
||||||
assert any("run1.jsonl" in arg for arg in captured_cmd)
|
|
||||||
assert any("run2.jsonl" in arg for arg in captured_cmd)
|
|
||||||
|
|
||||||
|
|
||||||
# ---- Cancel endpoint tests ----
|
|
||||||
|
|
||||||
def test_benchmark_cancel_returns_404_when_not_running(client):
|
|
||||||
"""POST /api/benchmark/cancel must return 404 if no benchmark is running."""
|
|
||||||
from app import api as api_module
|
|
||||||
api_module._running_procs.pop("benchmark", None)
|
|
||||||
r = client.post("/api/benchmark/cancel")
|
|
||||||
assert r.status_code == 404
|
|
||||||
|
|
||||||
|
|
||||||
def test_finetune_cancel_returns_404_when_not_running(client):
|
|
||||||
"""POST /api/finetune/cancel must return 404 if no finetune is running."""
|
|
||||||
from app import api as api_module
|
|
||||||
api_module._running_procs.pop("finetune", None)
|
|
||||||
r = client.post("/api/finetune/cancel")
|
|
||||||
assert r.status_code == 404
|
|
||||||
|
|
||||||
|
|
||||||
def test_benchmark_cancel_terminates_running_process(client):
|
|
||||||
"""POST /api/benchmark/cancel must call terminate() on the running process."""
|
|
||||||
from unittest.mock import MagicMock
|
|
||||||
from app import api as api_module
|
|
||||||
|
|
||||||
mock_proc = MagicMock()
|
|
||||||
mock_proc.wait = MagicMock()
|
|
||||||
api_module._running_procs["benchmark"] = mock_proc
|
|
||||||
|
|
||||||
try:
|
|
||||||
r = client.post("/api/benchmark/cancel")
|
|
||||||
assert r.status_code == 200
|
|
||||||
assert r.json()["status"] == "cancelled"
|
|
||||||
mock_proc.terminate.assert_called_once()
|
|
||||||
finally:
|
|
||||||
api_module._running_procs.pop("benchmark", None)
|
|
||||||
api_module._cancelled_jobs.discard("benchmark")
|
|
||||||
|
|
||||||
|
|
||||||
def test_finetune_cancel_terminates_running_process(client):
|
|
||||||
"""POST /api/finetune/cancel must call terminate() on the running process."""
|
|
||||||
from unittest.mock import MagicMock
|
|
||||||
from app import api as api_module
|
|
||||||
|
|
||||||
mock_proc = MagicMock()
|
|
||||||
mock_proc.wait = MagicMock()
|
|
||||||
api_module._running_procs["finetune"] = mock_proc
|
|
||||||
|
|
||||||
try:
|
|
||||||
r = client.post("/api/finetune/cancel")
|
|
||||||
assert r.status_code == 200
|
|
||||||
assert r.json()["status"] == "cancelled"
|
|
||||||
mock_proc.terminate.assert_called_once()
|
|
||||||
finally:
|
|
||||||
api_module._running_procs.pop("finetune", None)
|
|
||||||
api_module._cancelled_jobs.discard("finetune")
|
|
||||||
|
|
||||||
|
|
||||||
def test_benchmark_cancel_kills_process_on_timeout(client):
|
|
||||||
"""POST /api/benchmark/cancel must call kill() if the process does not exit within 3 s."""
|
|
||||||
import subprocess
|
|
||||||
from unittest.mock import MagicMock
|
|
||||||
from app import api as api_module
|
|
||||||
|
|
||||||
mock_proc = MagicMock()
|
|
||||||
mock_proc.wait.side_effect = subprocess.TimeoutExpired(cmd="benchmark", timeout=3)
|
|
||||||
api_module._running_procs["benchmark"] = mock_proc
|
|
||||||
|
|
||||||
try:
|
|
||||||
r = client.post("/api/benchmark/cancel")
|
|
||||||
assert r.status_code == 200
|
|
||||||
mock_proc.kill.assert_called_once()
|
|
||||||
finally:
|
|
||||||
api_module._running_procs.pop("benchmark", None)
|
|
||||||
api_module._cancelled_jobs.discard("benchmark")
|
|
||||||
|
|
||||||
|
|
||||||
def test_finetune_run_emits_cancelled_event(client):
|
|
||||||
"""GET /api/finetune/run must emit cancelled (not error) when job was cancelled."""
|
|
||||||
from unittest.mock import patch, MagicMock
|
|
||||||
from app import api as api_module
|
|
||||||
|
|
||||||
mock_proc = MagicMock()
|
|
||||||
mock_proc.stdout = iter([])
|
|
||||||
mock_proc.returncode = -15 # SIGTERM
|
|
||||||
|
|
||||||
def mock_wait():
|
|
||||||
# Simulate cancel being called while the process is running (after discard clears stale flag)
|
|
||||||
api_module._cancelled_jobs.add("finetune")
|
|
||||||
|
|
||||||
mock_proc.wait = mock_wait
|
|
||||||
|
|
||||||
def mock_popen(cmd, **kwargs):
|
|
||||||
return mock_proc
|
|
||||||
|
|
||||||
try:
|
|
||||||
with patch("app.api._subprocess.Popen",side_effect=mock_popen):
|
|
||||||
r = client.get("/api/finetune/run?model=deberta-small&epochs=1")
|
|
||||||
assert '{"type": "cancelled"}' in r.text
|
|
||||||
assert '"type": "error"' not in r.text
|
|
||||||
finally:
|
|
||||||
api_module._cancelled_jobs.discard("finetune")
|
|
||||||
|
|
||||||
|
|
||||||
def test_benchmark_run_emits_cancelled_event(client):
|
|
||||||
"""GET /api/benchmark/run must emit cancelled (not error) when job was cancelled."""
|
|
||||||
from unittest.mock import patch, MagicMock
|
|
||||||
from app import api as api_module
|
|
||||||
|
|
||||||
mock_proc = MagicMock()
|
|
||||||
mock_proc.stdout = iter([])
|
|
||||||
mock_proc.returncode = -15
|
|
||||||
|
|
||||||
def mock_wait():
|
|
||||||
# Simulate cancel being called while the process is running (after discard clears stale flag)
|
|
||||||
api_module._cancelled_jobs.add("benchmark")
|
|
||||||
|
|
||||||
mock_proc.wait = mock_wait
|
|
||||||
|
|
||||||
def mock_popen(cmd, **kwargs):
|
|
||||||
return mock_proc
|
|
||||||
|
|
||||||
try:
|
|
||||||
with patch("app.api._subprocess.Popen",side_effect=mock_popen):
|
|
||||||
r = client.get("/api/benchmark/run")
|
|
||||||
assert '{"type": "cancelled"}' in r.text
|
|
||||||
assert '"type": "error"' not in r.text
|
|
||||||
finally:
|
|
||||||
api_module._cancelled_jobs.discard("benchmark")
|
|
||||||
|
|
|
||||||
|
|
@ -92,77 +92,3 @@ def test_run_scoring_handles_classify_error(tmp_path):
|
||||||
|
|
||||||
results = run_scoring([broken], str(score_file))
|
results = run_scoring([broken], str(score_file))
|
||||||
assert "broken" in results
|
assert "broken" in results
|
||||||
|
|
||||||
|
|
||||||
# ---- Auto-discovery tests ----
|
|
||||||
|
|
||||||
def test_discover_finetuned_models_finds_training_info_files(tmp_path):
|
|
||||||
"""discover_finetuned_models() must return one entry per training_info.json found."""
|
|
||||||
import json
|
|
||||||
from scripts.benchmark_classifier import discover_finetuned_models
|
|
||||||
|
|
||||||
# Create two fake model directories
|
|
||||||
for name in ("avocet-deberta-small", "avocet-bge-m3"):
|
|
||||||
model_dir = tmp_path / name
|
|
||||||
model_dir.mkdir()
|
|
||||||
info = {
|
|
||||||
"name": name,
|
|
||||||
"base_model_id": "cross-encoder/nli-deberta-v3-small",
|
|
||||||
"timestamp": "2026-03-15T12:00:00Z",
|
|
||||||
"val_macro_f1": 0.72,
|
|
||||||
"val_accuracy": 0.80,
|
|
||||||
"sample_count": 401,
|
|
||||||
}
|
|
||||||
(model_dir / "training_info.json").write_text(json.dumps(info))
|
|
||||||
|
|
||||||
results = discover_finetuned_models(tmp_path)
|
|
||||||
assert len(results) == 2
|
|
||||||
names = {r["name"] for r in results}
|
|
||||||
assert "avocet-deberta-small" in names
|
|
||||||
assert "avocet-bge-m3" in names
|
|
||||||
for r in results:
|
|
||||||
assert "model_dir" in r, "discover_finetuned_models must inject model_dir key"
|
|
||||||
assert r["model_dir"].endswith(r["name"])
|
|
||||||
|
|
||||||
|
|
||||||
def test_discover_finetuned_models_returns_empty_when_no_models_dir():
|
|
||||||
"""discover_finetuned_models() must return [] silently if models/ doesn't exist."""
|
|
||||||
from pathlib import Path
|
|
||||||
from scripts.benchmark_classifier import discover_finetuned_models
|
|
||||||
|
|
||||||
results = discover_finetuned_models(Path("/nonexistent/path/models"))
|
|
||||||
assert results == []
|
|
||||||
|
|
||||||
|
|
||||||
def test_discover_finetuned_models_skips_dirs_without_training_info(tmp_path):
|
|
||||||
"""Subdirs without training_info.json are silently skipped."""
|
|
||||||
from scripts.benchmark_classifier import discover_finetuned_models
|
|
||||||
|
|
||||||
# A dir WITHOUT training_info.json
|
|
||||||
(tmp_path / "some-other-dir").mkdir()
|
|
||||||
|
|
||||||
results = discover_finetuned_models(tmp_path)
|
|
||||||
assert results == []
|
|
||||||
|
|
||||||
|
|
||||||
def test_active_models_includes_discovered_finetuned(tmp_path):
|
|
||||||
"""The active models dict must include FineTunedAdapter entries for discovered models."""
|
|
||||||
import json
|
|
||||||
from unittest.mock import patch
|
|
||||||
from scripts.benchmark_classifier import _active_models
|
|
||||||
from scripts.classifier_adapters import FineTunedAdapter
|
|
||||||
|
|
||||||
model_dir = tmp_path / "avocet-deberta-small"
|
|
||||||
model_dir.mkdir()
|
|
||||||
(model_dir / "training_info.json").write_text(json.dumps({
|
|
||||||
"name": "avocet-deberta-small",
|
|
||||||
"base_model_id": "cross-encoder/nli-deberta-v3-small",
|
|
||||||
"val_macro_f1": 0.72,
|
|
||||||
"sample_count": 401,
|
|
||||||
}))
|
|
||||||
|
|
||||||
with patch("scripts.benchmark_classifier._MODELS_DIR", tmp_path):
|
|
||||||
models = _active_models(include_slow=False)
|
|
||||||
|
|
||||||
assert "avocet-deberta-small" in models
|
|
||||||
assert isinstance(models["avocet-deberta-small"]["adapter_instance"], FineTunedAdapter)
|
|
||||||
|
|
|
||||||
|
|
@ -180,91 +180,3 @@ def test_reranker_adapter_picks_highest_score():
|
||||||
def test_reranker_adapter_descriptions_cover_all_labels():
|
def test_reranker_adapter_descriptions_cover_all_labels():
|
||||||
from scripts.classifier_adapters import LABEL_DESCRIPTIONS, LABELS
|
from scripts.classifier_adapters import LABEL_DESCRIPTIONS, LABELS
|
||||||
assert set(LABEL_DESCRIPTIONS.keys()) == set(LABELS)
|
assert set(LABEL_DESCRIPTIONS.keys()) == set(LABELS)
|
||||||
|
|
||||||
|
|
||||||
# ---- FineTunedAdapter tests ----
|
|
||||||
|
|
||||||
def test_finetuned_adapter_classify_calls_pipeline_with_sep_format(tmp_path):
|
|
||||||
"""classify() must format input as 'subject [SEP] body[:400]' — not the zero-shot format."""
|
|
||||||
from unittest.mock import MagicMock, patch
|
|
||||||
from scripts.classifier_adapters import FineTunedAdapter
|
|
||||||
|
|
||||||
mock_result = [{"label": "digest", "score": 0.95}]
|
|
||||||
mock_pipe_instance = MagicMock(return_value=mock_result)
|
|
||||||
mock_pipe_factory = MagicMock(return_value=mock_pipe_instance)
|
|
||||||
|
|
||||||
adapter = FineTunedAdapter("avocet-deberta-small", str(tmp_path))
|
|
||||||
with patch("scripts.classifier_adapters.pipeline", mock_pipe_factory):
|
|
||||||
result = adapter.classify("Test subject", "Test body")
|
|
||||||
|
|
||||||
assert result == "digest"
|
|
||||||
call_args = mock_pipe_instance.call_args[0][0]
|
|
||||||
assert "[SEP]" in call_args
|
|
||||||
assert "Test subject" in call_args
|
|
||||||
assert "Test body" in call_args
|
|
||||||
|
|
||||||
|
|
||||||
def test_finetuned_adapter_truncates_body_to_400():
|
|
||||||
"""Body must be truncated to 400 chars in the [SEP] format."""
|
|
||||||
from unittest.mock import MagicMock, patch
|
|
||||||
from scripts.classifier_adapters import FineTunedAdapter, LABELS
|
|
||||||
|
|
||||||
long_body = "x" * 800
|
|
||||||
mock_result = [{"label": "neutral", "score": 0.9}]
|
|
||||||
mock_pipe_instance = MagicMock(return_value=mock_result)
|
|
||||||
mock_pipe_factory = MagicMock(return_value=mock_pipe_instance)
|
|
||||||
|
|
||||||
adapter = FineTunedAdapter("avocet-deberta-small", "/fake/path")
|
|
||||||
with patch("scripts.classifier_adapters.pipeline", mock_pipe_factory):
|
|
||||||
adapter.classify("Subject", long_body)
|
|
||||||
|
|
||||||
call_text = mock_pipe_instance.call_args[0][0]
|
|
||||||
parts = call_text.split(" [SEP] ", 1)
|
|
||||||
assert len(parts) == 2, "Input must contain ' [SEP] ' separator"
|
|
||||||
assert len(parts[1]) == 400, f"Body must be exactly 400 chars, got {len(parts[1])}"
|
|
||||||
|
|
||||||
|
|
||||||
def test_finetuned_adapter_returns_label_string():
|
|
||||||
"""classify() must return a plain string, not a dict."""
|
|
||||||
from unittest.mock import MagicMock, patch
|
|
||||||
from scripts.classifier_adapters import FineTunedAdapter
|
|
||||||
|
|
||||||
mock_result = [{"label": "interview_scheduled", "score": 0.87}]
|
|
||||||
mock_pipe_instance = MagicMock(return_value=mock_result)
|
|
||||||
mock_pipe_factory = MagicMock(return_value=mock_pipe_instance)
|
|
||||||
|
|
||||||
adapter = FineTunedAdapter("avocet-deberta-small", "/fake/path")
|
|
||||||
with patch("scripts.classifier_adapters.pipeline", mock_pipe_factory):
|
|
||||||
result = adapter.classify("S", "B")
|
|
||||||
|
|
||||||
assert isinstance(result, str)
|
|
||||||
assert result == "interview_scheduled"
|
|
||||||
|
|
||||||
|
|
||||||
def test_finetuned_adapter_lazy_loads_pipeline():
|
|
||||||
"""Pipeline factory must not be called until classify() is first called."""
|
|
||||||
from unittest.mock import MagicMock, patch
|
|
||||||
from scripts.classifier_adapters import FineTunedAdapter
|
|
||||||
|
|
||||||
mock_pipe_factory = MagicMock(return_value=MagicMock(return_value=[{"label": "neutral", "score": 0.9}]))
|
|
||||||
|
|
||||||
with patch("scripts.classifier_adapters.pipeline", mock_pipe_factory):
|
|
||||||
adapter = FineTunedAdapter("avocet-deberta-small", "/fake/path")
|
|
||||||
assert not mock_pipe_factory.called
|
|
||||||
adapter.classify("s", "b")
|
|
||||||
assert mock_pipe_factory.called
|
|
||||||
|
|
||||||
|
|
||||||
def test_finetuned_adapter_unload_clears_pipeline():
|
|
||||||
"""unload() must set _pipeline to None so memory is released."""
|
|
||||||
from unittest.mock import MagicMock, patch
|
|
||||||
from scripts.classifier_adapters import FineTunedAdapter
|
|
||||||
|
|
||||||
mock_pipe_factory = MagicMock(return_value=MagicMock(return_value=[{"label": "neutral", "score": 0.9}]))
|
|
||||||
|
|
||||||
with patch("scripts.classifier_adapters.pipeline", mock_pipe_factory):
|
|
||||||
adapter = FineTunedAdapter("avocet-deberta-small", "/fake/path")
|
|
||||||
adapter.classify("s", "b")
|
|
||||||
assert adapter._pipeline is not None
|
|
||||||
adapter.unload()
|
|
||||||
assert adapter._pipeline is None
|
|
||||||
|
|
|
||||||
|
|
@ -1,371 +0,0 @@
|
||||||
"""Tests for finetune_classifier — no model downloads required."""
|
|
||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
import json
|
|
||||||
import pytest
|
|
||||||
|
|
||||||
|
|
||||||
# ---- Data loading tests ----
|
|
||||||
|
|
||||||
def test_load_and_prepare_data_drops_non_canonical_labels(tmp_path):
|
|
||||||
"""Rows with labels not in LABELS must be silently dropped."""
|
|
||||||
from scripts.finetune_classifier import load_and_prepare_data
|
|
||||||
from scripts.classifier_adapters import LABELS
|
|
||||||
|
|
||||||
# Two samples per canonical label so they survive the < 2 class-drop rule.
|
|
||||||
rows = [
|
|
||||||
{"subject": "s1", "body": "b1", "label": "digest"},
|
|
||||||
{"subject": "s2", "body": "b2", "label": "digest"},
|
|
||||||
{"subject": "s3", "body": "b3", "label": "profile_alert"}, # non-canonical
|
|
||||||
{"subject": "s4", "body": "b4", "label": "neutral"},
|
|
||||||
{"subject": "s5", "body": "b5", "label": "neutral"},
|
|
||||||
]
|
|
||||||
score_file = tmp_path / "email_score.jsonl"
|
|
||||||
score_file.write_text("\n".join(json.dumps(r) for r in rows))
|
|
||||||
|
|
||||||
texts, labels = load_and_prepare_data(score_file)
|
|
||||||
assert len(texts) == 4
|
|
||||||
assert all(l in LABELS for l in labels)
|
|
||||||
|
|
||||||
|
|
||||||
def test_load_and_prepare_data_formats_input_as_sep(tmp_path):
|
|
||||||
"""Input text must be 'subject [SEP] body[:400]'."""
|
|
||||||
# Two samples with the same label so the class survives the < 2 drop rule.
|
|
||||||
rows = [
|
|
||||||
{"subject": "Hello", "body": "World" * 100, "label": "neutral"},
|
|
||||||
{"subject": "Hello2", "body": "World" * 100, "label": "neutral"},
|
|
||||||
]
|
|
||||||
score_file = tmp_path / "email_score.jsonl"
|
|
||||||
score_file.write_text("\n".join(json.dumps(r) for r in rows))
|
|
||||||
|
|
||||||
from scripts.finetune_classifier import load_and_prepare_data
|
|
||||||
texts, labels = load_and_prepare_data(score_file)
|
|
||||||
|
|
||||||
assert texts[0].startswith("Hello [SEP] ")
|
|
||||||
parts = texts[0].split(" [SEP] ", 1)
|
|
||||||
assert len(parts[1]) == 400, f"Body must be exactly 400 chars, got {len(parts[1])}"
|
|
||||||
|
|
||||||
|
|
||||||
def test_load_and_prepare_data_raises_on_missing_file():
|
|
||||||
"""FileNotFoundError must be raised with actionable message."""
|
|
||||||
from pathlib import Path
|
|
||||||
from scripts.finetune_classifier import load_and_prepare_data
|
|
||||||
|
|
||||||
with pytest.raises(FileNotFoundError, match="email_score.jsonl"):
|
|
||||||
load_and_prepare_data(Path("/nonexistent/email_score.jsonl"))
|
|
||||||
|
|
||||||
|
|
||||||
def test_load_and_prepare_data_drops_class_with_fewer_than_2_samples(tmp_path, capsys):
|
|
||||||
"""Classes with < 2 total samples must be dropped with a warning."""
|
|
||||||
from scripts.finetune_classifier import load_and_prepare_data
|
|
||||||
|
|
||||||
rows = [
|
|
||||||
{"subject": "s1", "body": "b", "label": "digest"},
|
|
||||||
{"subject": "s2", "body": "b", "label": "digest"},
|
|
||||||
{"subject": "s3", "body": "b", "label": "new_lead"}, # only 1 sample — drop
|
|
||||||
]
|
|
||||||
score_file = tmp_path / "email_score.jsonl"
|
|
||||||
score_file.write_text("\n".join(json.dumps(r) for r in rows))
|
|
||||||
|
|
||||||
texts, labels = load_and_prepare_data(score_file)
|
|
||||||
captured = capsys.readouterr()
|
|
||||||
|
|
||||||
assert "new_lead" not in labels
|
|
||||||
assert "new_lead" in captured.out # warning printed
|
|
||||||
|
|
||||||
|
|
||||||
# ---- Class weights tests ----
|
|
||||||
|
|
||||||
def test_compute_class_weights_returns_tensor_for_each_class():
|
|
||||||
"""compute_class_weights must return a float tensor of length n_classes."""
|
|
||||||
import torch
|
|
||||||
from scripts.finetune_classifier import compute_class_weights
|
|
||||||
|
|
||||||
label_ids = [0, 0, 0, 1, 1, 2] # 3 classes, imbalanced
|
|
||||||
weights = compute_class_weights(label_ids, n_classes=3)
|
|
||||||
|
|
||||||
assert isinstance(weights, torch.Tensor)
|
|
||||||
assert weights.shape == (3,)
|
|
||||||
assert all(w > 0 for w in weights)
|
|
||||||
|
|
||||||
|
|
||||||
def test_compute_class_weights_upweights_minority():
|
|
||||||
"""Minority classes must receive higher weight than majority classes."""
|
|
||||||
from scripts.finetune_classifier import compute_class_weights
|
|
||||||
|
|
||||||
# Class 0: 10 samples, Class 1: 2 samples
|
|
||||||
label_ids = [0] * 10 + [1] * 2
|
|
||||||
weights = compute_class_weights(label_ids, n_classes=2)
|
|
||||||
|
|
||||||
assert weights[1] > weights[0]
|
|
||||||
|
|
||||||
|
|
||||||
# ---- compute_metrics_for_trainer tests ----
|
|
||||||
|
|
||||||
def test_compute_metrics_for_trainer_returns_macro_f1_key():
|
|
||||||
"""Must return a dict with 'macro_f1' key."""
|
|
||||||
import numpy as np
|
|
||||||
from scripts.finetune_classifier import compute_metrics_for_trainer
|
|
||||||
from transformers import EvalPrediction
|
|
||||||
|
|
||||||
logits = np.array([[2.0, 0.1], [0.1, 2.0], [2.0, 0.1]])
|
|
||||||
labels = np.array([0, 1, 0])
|
|
||||||
pred = EvalPrediction(predictions=logits, label_ids=labels)
|
|
||||||
|
|
||||||
result = compute_metrics_for_trainer(pred)
|
|
||||||
assert "macro_f1" in result
|
|
||||||
assert result["macro_f1"] == pytest.approx(1.0)
|
|
||||||
|
|
||||||
|
|
||||||
def test_compute_metrics_for_trainer_returns_accuracy_key():
|
|
||||||
"""Must also return 'accuracy' key."""
|
|
||||||
import numpy as np
|
|
||||||
from scripts.finetune_classifier import compute_metrics_for_trainer
|
|
||||||
from transformers import EvalPrediction
|
|
||||||
|
|
||||||
logits = np.array([[2.0, 0.1], [0.1, 2.0]])
|
|
||||||
labels = np.array([0, 1])
|
|
||||||
pred = EvalPrediction(predictions=logits, label_ids=labels)
|
|
||||||
|
|
||||||
result = compute_metrics_for_trainer(pred)
|
|
||||||
assert "accuracy" in result
|
|
||||||
assert result["accuracy"] == pytest.approx(1.0)
|
|
||||||
|
|
||||||
|
|
||||||
# ---- WeightedTrainer tests ----
|
|
||||||
|
|
||||||
def test_weighted_trainer_compute_loss_returns_scalar():
|
|
||||||
"""compute_loss must return a scalar tensor when return_outputs=False."""
|
|
||||||
import torch
|
|
||||||
from unittest.mock import MagicMock
|
|
||||||
from scripts.finetune_classifier import WeightedTrainer
|
|
||||||
|
|
||||||
n_classes = 3
|
|
||||||
batch = 4
|
|
||||||
logits = torch.randn(batch, n_classes)
|
|
||||||
|
|
||||||
mock_outputs = MagicMock()
|
|
||||||
mock_outputs.logits = logits
|
|
||||||
mock_model = MagicMock(return_value=mock_outputs)
|
|
||||||
|
|
||||||
trainer = WeightedTrainer.__new__(WeightedTrainer)
|
|
||||||
trainer.class_weights = torch.ones(n_classes)
|
|
||||||
|
|
||||||
inputs = {
|
|
||||||
"input_ids": torch.zeros(batch, 10, dtype=torch.long),
|
|
||||||
"labels": torch.randint(0, n_classes, (batch,)),
|
|
||||||
}
|
|
||||||
|
|
||||||
loss = trainer.compute_loss(mock_model, inputs, return_outputs=False)
|
|
||||||
assert isinstance(loss, torch.Tensor)
|
|
||||||
assert loss.ndim == 0 # scalar
|
|
||||||
|
|
||||||
|
|
||||||
def test_weighted_trainer_compute_loss_accepts_kwargs():
|
|
||||||
"""compute_loss must not raise TypeError when called with num_items_in_batch kwarg."""
|
|
||||||
import torch
|
|
||||||
from unittest.mock import MagicMock
|
|
||||||
from scripts.finetune_classifier import WeightedTrainer
|
|
||||||
|
|
||||||
n_classes = 3
|
|
||||||
batch = 2
|
|
||||||
logits = torch.randn(batch, n_classes)
|
|
||||||
|
|
||||||
mock_outputs = MagicMock()
|
|
||||||
mock_outputs.logits = logits
|
|
||||||
mock_model = MagicMock(return_value=mock_outputs)
|
|
||||||
|
|
||||||
trainer = WeightedTrainer.__new__(WeightedTrainer)
|
|
||||||
trainer.class_weights = torch.ones(n_classes)
|
|
||||||
|
|
||||||
inputs = {
|
|
||||||
"input_ids": torch.zeros(batch, 5, dtype=torch.long),
|
|
||||||
"labels": torch.randint(0, n_classes, (batch,)),
|
|
||||||
}
|
|
||||||
|
|
||||||
loss = trainer.compute_loss(mock_model, inputs, return_outputs=False,
|
|
||||||
num_items_in_batch=batch)
|
|
||||||
assert isinstance(loss, torch.Tensor)
|
|
||||||
|
|
||||||
|
|
||||||
def test_weighted_trainer_weighted_loss_differs_from_unweighted():
|
|
||||||
"""Weighted loss must differ from uniform-weight loss for imbalanced inputs."""
|
|
||||||
import torch
|
|
||||||
from unittest.mock import MagicMock
|
|
||||||
from scripts.finetune_classifier import WeightedTrainer
|
|
||||||
|
|
||||||
n_classes = 2
|
|
||||||
batch = 4
|
|
||||||
# Mixed labels: 3× class-0, 1× class-1.
|
|
||||||
# Asymmetric logits (class-0 samples predicted well, class-1 predicted poorly)
|
|
||||||
# ensure per-class CE values differ, so re-weighting changes the weighted mean.
|
|
||||||
labels = torch.tensor([0, 0, 0, 1], dtype=torch.long)
|
|
||||||
logits = torch.tensor([[3.0, -1.0], [3.0, -1.0], [3.0, -1.0], [0.5, 0.5]])
|
|
||||||
|
|
||||||
mock_outputs = MagicMock()
|
|
||||||
mock_outputs.logits = logits
|
|
||||||
|
|
||||||
trainer_uniform = WeightedTrainer.__new__(WeightedTrainer)
|
|
||||||
trainer_uniform.class_weights = torch.ones(n_classes)
|
|
||||||
inputs_uniform = {"input_ids": torch.zeros(batch, 5, dtype=torch.long), "labels": labels.clone()}
|
|
||||||
loss_uniform = trainer_uniform.compute_loss(MagicMock(return_value=mock_outputs),
|
|
||||||
inputs_uniform)
|
|
||||||
|
|
||||||
trainer_weighted = WeightedTrainer.__new__(WeightedTrainer)
|
|
||||||
trainer_weighted.class_weights = torch.tensor([0.1, 10.0])
|
|
||||||
inputs_weighted = {"input_ids": torch.zeros(batch, 5, dtype=torch.long), "labels": labels.clone()}
|
|
||||||
|
|
||||||
mock_outputs2 = MagicMock()
|
|
||||||
mock_outputs2.logits = logits.clone()
|
|
||||||
loss_weighted = trainer_weighted.compute_loss(MagicMock(return_value=mock_outputs2),
|
|
||||||
inputs_weighted)
|
|
||||||
|
|
||||||
assert not torch.isclose(loss_uniform, loss_weighted)
|
|
||||||
|
|
||||||
|
|
||||||
def test_weighted_trainer_compute_loss_returns_outputs_when_requested():
|
|
||||||
"""compute_loss with return_outputs=True must return (loss, outputs) tuple."""
|
|
||||||
import torch
|
|
||||||
from unittest.mock import MagicMock
|
|
||||||
from scripts.finetune_classifier import WeightedTrainer
|
|
||||||
|
|
||||||
n_classes = 3
|
|
||||||
batch = 2
|
|
||||||
logits = torch.randn(batch, n_classes)
|
|
||||||
|
|
||||||
mock_outputs = MagicMock()
|
|
||||||
mock_outputs.logits = logits
|
|
||||||
mock_model = MagicMock(return_value=mock_outputs)
|
|
||||||
|
|
||||||
trainer = WeightedTrainer.__new__(WeightedTrainer)
|
|
||||||
trainer.class_weights = torch.ones(n_classes)
|
|
||||||
|
|
||||||
inputs = {
|
|
||||||
"input_ids": torch.zeros(batch, 5, dtype=torch.long),
|
|
||||||
"labels": torch.randint(0, n_classes, (batch,)),
|
|
||||||
}
|
|
||||||
|
|
||||||
result = trainer.compute_loss(mock_model, inputs, return_outputs=True)
|
|
||||||
assert isinstance(result, tuple)
|
|
||||||
loss, outputs = result
|
|
||||||
assert isinstance(loss, torch.Tensor)
|
|
||||||
|
|
||||||
|
|
||||||
# ---- Multi-file merge / dedup tests ----
|
|
||||||
|
|
||||||
def test_load_and_prepare_data_merges_multiple_files(tmp_path):
|
|
||||||
"""Multiple score files must be merged into a single dataset."""
|
|
||||||
from scripts.finetune_classifier import load_and_prepare_data
|
|
||||||
|
|
||||||
file1 = tmp_path / "run1.jsonl"
|
|
||||||
file2 = tmp_path / "run2.jsonl"
|
|
||||||
file1.write_text(
|
|
||||||
json.dumps({"subject": "s1", "body": "b1", "label": "digest"}) + "\n" +
|
|
||||||
json.dumps({"subject": "s2", "body": "b2", "label": "digest"}) + "\n"
|
|
||||||
)
|
|
||||||
file2.write_text(
|
|
||||||
json.dumps({"subject": "s3", "body": "b3", "label": "neutral"}) + "\n" +
|
|
||||||
json.dumps({"subject": "s4", "body": "b4", "label": "neutral"}) + "\n"
|
|
||||||
)
|
|
||||||
|
|
||||||
texts, labels = load_and_prepare_data([file1, file2])
|
|
||||||
assert len(texts) == 4
|
|
||||||
assert labels.count("digest") == 2
|
|
||||||
assert labels.count("neutral") == 2
|
|
||||||
|
|
||||||
|
|
||||||
def test_load_and_prepare_data_deduplicates_last_write_wins(tmp_path, capsys):
|
|
||||||
"""Duplicate rows (same content hash) keep the last occurrence."""
|
|
||||||
from scripts.finetune_classifier import load_and_prepare_data
|
|
||||||
|
|
||||||
# Same subject+body[:100] = same hash
|
|
||||||
row_early = {"subject": "Hello", "body": "World", "label": "neutral"}
|
|
||||||
row_late = {"subject": "Hello", "body": "World", "label": "digest"} # relabeled
|
|
||||||
|
|
||||||
file1 = tmp_path / "run1.jsonl"
|
|
||||||
file2 = tmp_path / "run2.jsonl"
|
|
||||||
# Add a second row with different content so class count >= 2 for both classes
|
|
||||||
file1.write_text(
|
|
||||||
json.dumps(row_early) + "\n" +
|
|
||||||
json.dumps({"subject": "Other1", "body": "Other", "label": "neutral"}) + "\n"
|
|
||||||
)
|
|
||||||
file2.write_text(
|
|
||||||
json.dumps(row_late) + "\n" +
|
|
||||||
json.dumps({"subject": "Other2", "body": "Stuff", "label": "digest"}) + "\n"
|
|
||||||
)
|
|
||||||
|
|
||||||
texts, labels = load_and_prepare_data([file1, file2])
|
|
||||||
captured = capsys.readouterr()
|
|
||||||
|
|
||||||
# The duplicate row should be counted as dropped
|
|
||||||
assert "Deduped" in captured.out
|
|
||||||
# The relabeled row should have "digest" (last-write wins), not "neutral"
|
|
||||||
hello_idx = next(i for i, t in enumerate(texts) if t.startswith("Hello [SEP]"))
|
|
||||||
assert labels[hello_idx] == "digest"
|
|
||||||
|
|
||||||
|
|
||||||
def test_load_and_prepare_data_single_path_still_works(tmp_path):
|
|
||||||
"""Passing a single Path (not a list) must still work — backwards compatibility."""
|
|
||||||
from scripts.finetune_classifier import load_and_prepare_data
|
|
||||||
|
|
||||||
rows = [
|
|
||||||
{"subject": "s1", "body": "b1", "label": "digest"},
|
|
||||||
{"subject": "s2", "body": "b2", "label": "digest"},
|
|
||||||
]
|
|
||||||
score_file = tmp_path / "email_score.jsonl"
|
|
||||||
score_file.write_text("\n".join(json.dumps(r) for r in rows))
|
|
||||||
|
|
||||||
texts, labels = load_and_prepare_data(score_file) # single Path, not list
|
|
||||||
assert len(texts) == 2
|
|
||||||
|
|
||||||
|
|
||||||
# ---- Integration test ----
|
|
||||||
|
|
||||||
def test_integration_finetune_on_example_data(tmp_path):
|
|
||||||
"""Fine-tune deberta-small on example data for 1 epoch.
|
|
||||||
|
|
||||||
Uses data/email_score.jsonl.example (8 samples, 5 labels represented).
|
|
||||||
The 5 missing labels must trigger the < 2 samples drop warning.
|
|
||||||
Verifies training_info.json is written with correct keys.
|
|
||||||
|
|
||||||
Requires job-seeker-classifiers env and downloads deberta-small (~100MB on first run).
|
|
||||||
"""
|
|
||||||
import shutil
|
|
||||||
from scripts import finetune_classifier as ft_mod
|
|
||||||
from scripts.finetune_classifier import run_finetune
|
|
||||||
|
|
||||||
example_file = ft_mod._ROOT / "data" / "email_score.jsonl.example"
|
|
||||||
if not example_file.exists():
|
|
||||||
pytest.skip("email_score.jsonl.example not found")
|
|
||||||
|
|
||||||
orig_root = ft_mod._ROOT
|
|
||||||
ft_mod._ROOT = tmp_path
|
|
||||||
(tmp_path / "data").mkdir()
|
|
||||||
shutil.copy(example_file, tmp_path / "data" / "email_score.jsonl")
|
|
||||||
|
|
||||||
try:
|
|
||||||
import io
|
|
||||||
from contextlib import redirect_stdout
|
|
||||||
captured = io.StringIO()
|
|
||||||
with redirect_stdout(captured):
|
|
||||||
run_finetune("deberta-small", epochs=1)
|
|
||||||
output = captured.getvalue()
|
|
||||||
finally:
|
|
||||||
ft_mod._ROOT = orig_root
|
|
||||||
|
|
||||||
# Missing labels should trigger the < 2 samples drop warning
|
|
||||||
assert "WARNING: Dropping class" in output
|
|
||||||
|
|
||||||
# training_info.json must exist with correct keys
|
|
||||||
info_path = tmp_path / "models" / "avocet-deberta-small" / "training_info.json"
|
|
||||||
assert info_path.exists(), "training_info.json not written"
|
|
||||||
|
|
||||||
info = json.loads(info_path.read_text())
|
|
||||||
for key in ("name", "base_model_id", "timestamp", "epochs_run",
|
|
||||||
"val_macro_f1", "val_accuracy", "sample_count", "train_sample_count",
|
|
||||||
"label_counts", "score_files"):
|
|
||||||
assert key in info, f"Missing key: {key}"
|
|
||||||
|
|
||||||
assert info["name"] == "avocet-deberta-small"
|
|
||||||
assert info["epochs_run"] == 1
|
|
||||||
assert isinstance(info["score_files"], list)
|
|
||||||
|
|
@ -4,12 +4,7 @@
|
||||||
<meta charset="UTF-8" />
|
<meta charset="UTF-8" />
|
||||||
<link rel="icon" type="image/svg+xml" href="/vite.svg" />
|
<link rel="icon" type="image/svg+xml" href="/vite.svg" />
|
||||||
<meta name="viewport" content="width=device-width, initial-scale=1.0" />
|
<meta name="viewport" content="width=device-width, initial-scale=1.0" />
|
||||||
<title>Avocet — Label Tool</title>
|
<title>web</title>
|
||||||
<!-- Inline background prevents blank-white flash before the CSS bundle loads -->
|
|
||||||
<style>
|
|
||||||
html, body { margin: 0; background: #eaeff8; min-height: 100vh; }
|
|
||||||
@media (prefers-color-scheme: dark) { html, body { background: #16202e; } }
|
|
||||||
</style>
|
|
||||||
</head>
|
</head>
|
||||||
<body>
|
<body>
|
||||||
<div id="app"></div>
|
<div id="app"></div>
|
||||||
|
|
|
||||||
11
web/package-lock.json
generated
11
web/package-lock.json
generated
|
|
@ -13,7 +13,6 @@
|
||||||
"@fontsource/jetbrains-mono": "^5.2.8",
|
"@fontsource/jetbrains-mono": "^5.2.8",
|
||||||
"@vueuse/core": "^14.2.1",
|
"@vueuse/core": "^14.2.1",
|
||||||
"@vueuse/integrations": "^14.2.1",
|
"@vueuse/integrations": "^14.2.1",
|
||||||
"animejs": "^4.3.6",
|
|
||||||
"pinia": "^3.0.4",
|
"pinia": "^3.0.4",
|
||||||
"vue": "^3.5.25",
|
"vue": "^3.5.25",
|
||||||
"vue-router": "^5.0.3"
|
"vue-router": "^5.0.3"
|
||||||
|
|
@ -2571,16 +2570,6 @@
|
||||||
"dev": true,
|
"dev": true,
|
||||||
"license": "MIT"
|
"license": "MIT"
|
||||||
},
|
},
|
||||||
"node_modules/animejs": {
|
|
||||||
"version": "4.3.6",
|
|
||||||
"resolved": "https://registry.npmjs.org/animejs/-/animejs-4.3.6.tgz",
|
|
||||||
"integrity": "sha512-rzZ4bDc8JAtyx6hYwxj7s5M/yWfnM5qqY4hZDnhy1cWFvMb6H5/necHS2sbCY3WQTDbRLuZL10dPXSxSCFOr/w==",
|
|
||||||
"license": "MIT",
|
|
||||||
"funding": {
|
|
||||||
"type": "github",
|
|
||||||
"url": "https://github.com/sponsors/juliangarnier"
|
|
||||||
}
|
|
||||||
},
|
|
||||||
"node_modules/ansi-regex": {
|
"node_modules/ansi-regex": {
|
||||||
"version": "6.2.2",
|
"version": "6.2.2",
|
||||||
"resolved": "https://registry.npmjs.org/ansi-regex/-/ansi-regex-6.2.2.tgz",
|
"resolved": "https://registry.npmjs.org/ansi-regex/-/ansi-regex-6.2.2.tgz",
|
||||||
|
|
|
||||||
|
|
@ -16,7 +16,6 @@
|
||||||
"@fontsource/jetbrains-mono": "^5.2.8",
|
"@fontsource/jetbrains-mono": "^5.2.8",
|
||||||
"@vueuse/core": "^14.2.1",
|
"@vueuse/core": "^14.2.1",
|
||||||
"@vueuse/integrations": "^14.2.1",
|
"@vueuse/integrations": "^14.2.1",
|
||||||
"animejs": "^4.3.6",
|
|
||||||
"pinia": "^3.0.4",
|
"pinia": "^3.0.4",
|
||||||
"vue": "^3.5.25",
|
"vue": "^3.5.25",
|
||||||
"vue-router": "^5.0.3"
|
"vue-router": "^5.0.3"
|
||||||
|
|
|
||||||
|
|
@ -11,13 +11,11 @@
|
||||||
import { onMounted } from 'vue'
|
import { onMounted } from 'vue'
|
||||||
import { RouterView } from 'vue-router'
|
import { RouterView } from 'vue-router'
|
||||||
import { useMotion } from './composables/useMotion'
|
import { useMotion } from './composables/useMotion'
|
||||||
import { useHackerMode, useKonamiCode } from './composables/useEasterEgg'
|
import { useHackerMode } from './composables/useEasterEgg'
|
||||||
import AppSidebar from './components/AppSidebar.vue'
|
import AppSidebar from './components/AppSidebar.vue'
|
||||||
|
|
||||||
const motion = useMotion()
|
const motion = useMotion()
|
||||||
const { toggle, restore } = useHackerMode()
|
const { restore } = useHackerMode()
|
||||||
|
|
||||||
useKonamiCode(toggle)
|
|
||||||
|
|
||||||
onMounted(() => {
|
onMounted(() => {
|
||||||
restore() // re-apply hacker mode from localStorage on page load
|
restore() // re-apply hacker mode from localStorage on page load
|
||||||
|
|
|
||||||
|
|
@ -8,29 +8,8 @@
|
||||||
Accent — Russet (#B8622A) — inspired by avocet's vivid orange-russet head
|
Accent — Russet (#B8622A) — inspired by avocet's vivid orange-russet head
|
||||||
*/
|
*/
|
||||||
|
|
||||||
/* ── Page-level overrides — must be in avocet.css (applied after theme.css base) ── */
|
|
||||||
html {
|
|
||||||
/* Prevent Mac Chrome's horizontal swipe-to-navigate page animation
|
|
||||||
from triggering when the user scrolls near the viewport edge */
|
|
||||||
overscroll-behavior-x: none;
|
|
||||||
/* clip (not hidden) — prevents overflowing content from expanding the html layout
|
|
||||||
width beyond the viewport. Without this, body's overflow-x:hidden propagates to
|
|
||||||
the viewport and body has no BFC, so long email URLs inflate the layout and
|
|
||||||
margin:0 auto centering drifts rightward as fonts load. */
|
|
||||||
overflow-x: clip;
|
|
||||||
}
|
|
||||||
|
|
||||||
body {
|
|
||||||
/* Prevent horizontal scroll from card swipe animations */
|
|
||||||
overflow-x: hidden;
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
/* ── Light mode (default) ──────────────────────────── */
|
/* ── Light mode (default) ──────────────────────────── */
|
||||||
:root {
|
:root {
|
||||||
/* Aliases bridging avocet component vars to CircuitForge base theme vars */
|
|
||||||
--color-bg: var(--color-surface); /* App.vue body bg → #eaeff8 in light */
|
|
||||||
--color-text-secondary: var(--color-text-muted); /* muted label text */
|
|
||||||
/* Primary — Slate Teal */
|
/* Primary — Slate Teal */
|
||||||
--app-primary: #2A6080; /* 4.8:1 on light surface #eaeff8 — ✅ AA */
|
--app-primary: #2A6080; /* 4.8:1 on light surface #eaeff8 — ✅ AA */
|
||||||
--app-primary-hover: #1E4D66; /* darker for hover */
|
--app-primary-hover: #1E4D66; /* darker for hover */
|
||||||
|
|
|
||||||
|
|
@ -65,7 +65,6 @@ const navItems = [
|
||||||
{ path: '/', icon: '🃏', label: 'Label' },
|
{ path: '/', icon: '🃏', label: 'Label' },
|
||||||
{ path: '/fetch', icon: '📥', label: 'Fetch' },
|
{ path: '/fetch', icon: '📥', label: 'Fetch' },
|
||||||
{ path: '/stats', icon: '📊', label: 'Stats' },
|
{ path: '/stats', icon: '📊', label: 'Stats' },
|
||||||
{ path: '/benchmark', icon: '🏁', label: 'Benchmark' },
|
|
||||||
{ path: '/settings', icon: '⚙️', label: 'Settings' },
|
{ path: '/settings', icon: '⚙️', label: 'Settings' },
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -86,7 +86,6 @@ const displayBody = computed(() => {
|
||||||
font-size: 0.9375rem;
|
font-size: 0.9375rem;
|
||||||
line-height: 1.6;
|
line-height: 1.6;
|
||||||
white-space: pre-wrap;
|
white-space: pre-wrap;
|
||||||
overflow-wrap: break-word;
|
|
||||||
margin: 0;
|
margin: 0;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -2,20 +2,6 @@ import { mount } from '@vue/test-utils'
|
||||||
import EmailCardStack from './EmailCardStack.vue'
|
import EmailCardStack from './EmailCardStack.vue'
|
||||||
import { describe, it, expect, vi } from 'vitest'
|
import { describe, it, expect, vi } from 'vitest'
|
||||||
|
|
||||||
vi.mock('../composables/useCardAnimation', () => ({
|
|
||||||
useCardAnimation: vi.fn(() => ({
|
|
||||||
pickup: vi.fn(),
|
|
||||||
setDragPosition: vi.fn(),
|
|
||||||
snapBack: vi.fn(),
|
|
||||||
animateDismiss: vi.fn(),
|
|
||||||
updateAura: vi.fn(),
|
|
||||||
reset: vi.fn(),
|
|
||||||
})),
|
|
||||||
}))
|
|
||||||
|
|
||||||
import { useCardAnimation } from '../composables/useCardAnimation'
|
|
||||||
import { nextTick } from 'vue'
|
|
||||||
|
|
||||||
const item = {
|
const item = {
|
||||||
id: 'abc',
|
id: 'abc',
|
||||||
subject: 'Interview at Acme',
|
subject: 'Interview at Acme',
|
||||||
|
|
@ -36,13 +22,27 @@ describe('EmailCardStack', () => {
|
||||||
expect(w.findAll('.card-shadow')).toHaveLength(2)
|
expect(w.findAll('.card-shadow')).toHaveLength(2)
|
||||||
})
|
})
|
||||||
|
|
||||||
it('calls animateDismiss with type when dismissType prop changes', async () => {
|
it('applies dismiss-label class when dismissType is label', () => {
|
||||||
;(useCardAnimation as ReturnType<typeof vi.fn>).mockClear()
|
const w = mount(EmailCardStack, { props: { item, isBucketMode: false, dismissType: 'label' } })
|
||||||
|
expect(w.find('.card-wrapper').classes()).toContain('dismiss-label')
|
||||||
|
})
|
||||||
|
|
||||||
|
it('applies dismiss-discard class when dismissType is discard', () => {
|
||||||
|
const w = mount(EmailCardStack, { props: { item, isBucketMode: false, dismissType: 'discard' } })
|
||||||
|
expect(w.find('.card-wrapper').classes()).toContain('dismiss-discard')
|
||||||
|
})
|
||||||
|
|
||||||
|
it('applies dismiss-skip class when dismissType is skip', () => {
|
||||||
|
const w = mount(EmailCardStack, { props: { item, isBucketMode: false, dismissType: 'skip' } })
|
||||||
|
expect(w.find('.card-wrapper').classes()).toContain('dismiss-skip')
|
||||||
|
})
|
||||||
|
|
||||||
|
it('no dismiss class when dismissType is null', () => {
|
||||||
const w = mount(EmailCardStack, { props: { item, isBucketMode: false, dismissType: null } })
|
const w = mount(EmailCardStack, { props: { item, isBucketMode: false, dismissType: null } })
|
||||||
const { animateDismiss } = (useCardAnimation as ReturnType<typeof vi.fn>).mock.results[0].value
|
const wrapperClasses = w.find('.card-wrapper').classes()
|
||||||
await w.setProps({ dismissType: 'label' })
|
expect(wrapperClasses).not.toContain('dismiss-label')
|
||||||
await nextTick()
|
expect(wrapperClasses).not.toContain('dismiss-discard')
|
||||||
expect(animateDismiss).toHaveBeenCalledWith('label')
|
expect(wrapperClasses).not.toContain('dismiss-skip')
|
||||||
})
|
})
|
||||||
|
|
||||||
// JSDOM doesn't implement setPointerCapture — mock it on the element.
|
// JSDOM doesn't implement setPointerCapture — mock it on the element.
|
||||||
|
|
|
||||||
|
|
@ -11,7 +11,8 @@
|
||||||
<div
|
<div
|
||||||
class="card-wrapper"
|
class="card-wrapper"
|
||||||
ref="cardEl"
|
ref="cardEl"
|
||||||
:class="{ 'is-held': isHeld }"
|
:class="[dismissClass, { 'is-held': isHeld }]"
|
||||||
|
:style="cardStyle"
|
||||||
@pointerdown="onPointerDown"
|
@pointerdown="onPointerDown"
|
||||||
@pointermove="onPointerMove"
|
@pointermove="onPointerMove"
|
||||||
@pointerup="onPointerUp"
|
@pointerup="onPointerUp"
|
||||||
|
|
@ -28,9 +29,8 @@
|
||||||
</template>
|
</template>
|
||||||
|
|
||||||
<script setup lang="ts">
|
<script setup lang="ts">
|
||||||
import { ref, watch } from 'vue'
|
import { ref, computed } from 'vue'
|
||||||
import { useMotion } from '../composables/useMotion'
|
import { useMotion } from '../composables/useMotion'
|
||||||
import { useCardAnimation } from '../composables/useCardAnimation'
|
|
||||||
import EmailCard from './EmailCard.vue'
|
import EmailCard from './EmailCard.vue'
|
||||||
import type { QueueItem } from '../stores/label'
|
import type { QueueItem } from '../stores/label'
|
||||||
|
|
||||||
|
|
@ -54,22 +54,12 @@ const motion = useMotion()
|
||||||
const cardEl = ref<HTMLElement | null>(null)
|
const cardEl = ref<HTMLElement | null>(null)
|
||||||
const isExpanded = ref(false)
|
const isExpanded = ref(false)
|
||||||
|
|
||||||
const { pickup, setDragPosition, snapBack, animateDismiss, updateAura, reset } = useCardAnimation(cardEl, motion)
|
|
||||||
|
|
||||||
watch(() => props.dismissType, (type) => {
|
|
||||||
if (type) animateDismiss(type)
|
|
||||||
})
|
|
||||||
|
|
||||||
// When a new card loads into the same element, clear any inline styles left by the previous animation
|
|
||||||
watch(() => props.item.id, () => {
|
|
||||||
reset()
|
|
||||||
isExpanded.value = false
|
|
||||||
})
|
|
||||||
|
|
||||||
// Toss gesture state
|
// Toss gesture state
|
||||||
const isHeld = ref(false)
|
const isHeld = ref(false)
|
||||||
const pickupX = ref(0)
|
const pickupX = ref(0)
|
||||||
const pickupY = ref(0)
|
const pickupY = ref(0)
|
||||||
|
const deltaX = ref(0)
|
||||||
|
const deltaY = ref(0)
|
||||||
const hoveredZone = ref<'discard' | 'skip' | null>(null)
|
const hoveredZone = ref<'discard' | 'skip' | null>(null)
|
||||||
const hoveredBucketName = ref<string | null>(null)
|
const hoveredBucketName = ref<string | null>(null)
|
||||||
|
|
||||||
|
|
@ -84,14 +74,13 @@ const FLING_WINDOW_MS = 50 // rolling sample window in ms
|
||||||
let velocityBuf: { x: number; y: number; t: number }[] = []
|
let velocityBuf: { x: number; y: number; t: number }[] = []
|
||||||
|
|
||||||
function onPointerDown(e: PointerEvent) {
|
function onPointerDown(e: PointerEvent) {
|
||||||
// Let clicks on interactive children (expand/collapse, links, etc.) pass through
|
|
||||||
if ((e.target as Element).closest('button, a, input, select, textarea')) return
|
|
||||||
if (!motion.rich.value) return
|
if (!motion.rich.value) return
|
||||||
;(e.currentTarget as HTMLElement).setPointerCapture(e.pointerId)
|
;(e.currentTarget as HTMLElement).setPointerCapture(e.pointerId)
|
||||||
pickupX.value = e.clientX
|
pickupX.value = e.clientX
|
||||||
pickupY.value = e.clientY
|
pickupY.value = e.clientY
|
||||||
|
deltaX.value = 0
|
||||||
|
deltaY.value = 0
|
||||||
isHeld.value = true
|
isHeld.value = true
|
||||||
pickup()
|
|
||||||
hoveredZone.value = null
|
hoveredZone.value = null
|
||||||
hoveredBucketName.value = null
|
hoveredBucketName.value = null
|
||||||
velocityBuf = []
|
velocityBuf = []
|
||||||
|
|
@ -100,9 +89,8 @@ function onPointerDown(e: PointerEvent) {
|
||||||
|
|
||||||
function onPointerMove(e: PointerEvent) {
|
function onPointerMove(e: PointerEvent) {
|
||||||
if (!isHeld.value) return
|
if (!isHeld.value) return
|
||||||
const dx = e.clientX - pickupX.value
|
deltaX.value = e.clientX - pickupX.value
|
||||||
const dy = e.clientY - pickupY.value
|
deltaY.value = e.clientY - pickupY.value
|
||||||
setDragPosition(dx, dy)
|
|
||||||
|
|
||||||
// Rolling velocity buffer — keep only the last FLING_WINDOW_MS of samples
|
// Rolling velocity buffer — keep only the last FLING_WINDOW_MS of samples
|
||||||
const now = performance.now()
|
const now = performance.now()
|
||||||
|
|
@ -130,7 +118,6 @@ function onPointerMove(e: PointerEvent) {
|
||||||
hoveredBucketName.value = bucketName
|
hoveredBucketName.value = bucketName
|
||||||
emit('bucket-hover', bucketName)
|
emit('bucket-hover', bucketName)
|
||||||
}
|
}
|
||||||
updateAura(hoveredZone.value, hoveredBucketName.value)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
function onPointerUp(e: PointerEvent) {
|
function onPointerUp(e: PointerEvent) {
|
||||||
|
|
@ -176,9 +163,9 @@ function onPointerUp(e: PointerEvent) {
|
||||||
hoveredBucketName.value = null
|
hoveredBucketName.value = null
|
||||||
emit('label', name)
|
emit('label', name)
|
||||||
} else {
|
} else {
|
||||||
// Snap back
|
// Snap back — reset deltas
|
||||||
snapBack()
|
deltaX.value = 0
|
||||||
updateAura(null, null)
|
deltaY.value = 0
|
||||||
hoveredZone.value = null
|
hoveredZone.value = null
|
||||||
hoveredBucketName.value = null
|
hoveredBucketName.value = null
|
||||||
}
|
}
|
||||||
|
|
@ -188,8 +175,8 @@ function onPointerCancel(e: PointerEvent) {
|
||||||
if (!isHeld.value) return
|
if (!isHeld.value) return
|
||||||
;(e.currentTarget as HTMLElement).releasePointerCapture(e.pointerId)
|
;(e.currentTarget as HTMLElement).releasePointerCapture(e.pointerId)
|
||||||
isHeld.value = false
|
isHeld.value = false
|
||||||
snapBack()
|
deltaX.value = 0
|
||||||
updateAura(null, null)
|
deltaY.value = 0
|
||||||
hoveredZone.value = null
|
hoveredZone.value = null
|
||||||
hoveredBucketName.value = null
|
hoveredBucketName.value = null
|
||||||
velocityBuf = []
|
velocityBuf = []
|
||||||
|
|
@ -197,6 +184,32 @@ function onPointerCancel(e: PointerEvent) {
|
||||||
emit('zone-hover', null)
|
emit('zone-hover', null)
|
||||||
emit('bucket-hover', null)
|
emit('bucket-hover', null)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
const dismissClass = computed(() => {
|
||||||
|
if (!props.dismissType) return null
|
||||||
|
return `dismiss-${props.dismissType}`
|
||||||
|
})
|
||||||
|
|
||||||
|
const cardStyle = computed(() => {
|
||||||
|
if (!motion.rich.value || !isHeld.value) return {}
|
||||||
|
|
||||||
|
// Aura color: zone > bucket > neutral
|
||||||
|
const aura =
|
||||||
|
hoveredZone.value === 'discard' ? 'rgba(244,67,54,0.25)' :
|
||||||
|
hoveredZone.value === 'skip' ? 'rgba(255,152,0,0.25)' :
|
||||||
|
hoveredBucketName.value ? 'rgba(42,96,128,0.20)' :
|
||||||
|
'transparent'
|
||||||
|
|
||||||
|
return {
|
||||||
|
transform: `translate(${deltaX.value}px, ${deltaY.value - 80}px) scale(0.55)`,
|
||||||
|
borderRadius: '50%',
|
||||||
|
background: aura,
|
||||||
|
transition: 'border-radius 150ms ease, background 150ms ease',
|
||||||
|
cursor: 'grabbing',
|
||||||
|
zIndex: 100,
|
||||||
|
userSelect: 'none',
|
||||||
|
}
|
||||||
|
})
|
||||||
</script>
|
</script>
|
||||||
|
|
||||||
<style scoped>
|
<style scoped>
|
||||||
|
|
@ -263,6 +276,30 @@ function onPointerCancel(e: PointerEvent) {
|
||||||
pointer-events: auto !important;
|
pointer-events: auto !important;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/* Dismissal animations — dismiss class is only applied during the motion.rich await window,
|
||||||
|
so no ancestor guard needed; :global(.rich-motion) was being miscompiled by Vue's scoped
|
||||||
|
CSS transformer (dropping the descendant selector entirely). */
|
||||||
|
.card-wrapper.dismiss-label {
|
||||||
|
animation: fileAway var(--card-dismiss, 350ms ease-in) forwards;
|
||||||
|
}
|
||||||
|
.card-wrapper.dismiss-discard {
|
||||||
|
animation: crumple var(--card-dismiss, 350ms ease-in) forwards;
|
||||||
|
}
|
||||||
|
.card-wrapper.dismiss-skip {
|
||||||
|
animation: slideUnder var(--card-skip, 300ms ease-out) forwards;
|
||||||
|
}
|
||||||
|
|
||||||
|
@keyframes fileAway {
|
||||||
|
to { transform: translateY(-120%) scale(0.85); opacity: 0; }
|
||||||
|
}
|
||||||
|
@keyframes crumple {
|
||||||
|
50% { transform: scale(0.95) rotate(2deg); filter: brightness(0.6) sepia(1) hue-rotate(-20deg); }
|
||||||
|
to { transform: scale(0) rotate(8deg); opacity: 0; }
|
||||||
|
}
|
||||||
|
@keyframes slideUnder {
|
||||||
|
to { transform: translateX(110%) rotate(5deg); opacity: 0; }
|
||||||
|
}
|
||||||
|
|
||||||
@media (prefers-reduced-motion: reduce) {
|
@media (prefers-reduced-motion: reduce) {
|
||||||
.card-stack,
|
.card-stack,
|
||||||
.card-wrapper {
|
.card-wrapper {
|
||||||
|
|
|
||||||
|
|
@ -1,7 +1,7 @@
|
||||||
<template>
|
<template>
|
||||||
<div class="label-grid" :class="{ 'bucket-mode': isBucketMode }" role="group" aria-label="Label buttons">
|
<div class="label-grid" :class="{ 'bucket-mode': isBucketMode }" role="group" aria-label="Label buttons">
|
||||||
<button
|
<button
|
||||||
v-for="label in displayLabels"
|
v-for="label in labels"
|
||||||
:key="label.key"
|
:key="label.key"
|
||||||
data-testid="label-btn"
|
data-testid="label-btn"
|
||||||
:data-label-key="label.name"
|
:data-label-key="label.name"
|
||||||
|
|
@ -19,8 +19,6 @@
|
||||||
</template>
|
</template>
|
||||||
|
|
||||||
<script setup lang="ts">
|
<script setup lang="ts">
|
||||||
import { computed } from 'vue'
|
|
||||||
|
|
||||||
interface Label { name: string; emoji: string; color: string; key: string }
|
interface Label { name: string; emoji: string; color: string; key: string }
|
||||||
|
|
||||||
const props = defineProps<{
|
const props = defineProps<{
|
||||||
|
|
@ -29,16 +27,6 @@ const props = defineProps<{
|
||||||
hoveredBucket?: string | null
|
hoveredBucket?: string | null
|
||||||
}>()
|
}>()
|
||||||
const emit = defineEmits<{ label: [name: string] }>()
|
const emit = defineEmits<{ label: [name: string] }>()
|
||||||
|
|
||||||
// Numpad layout: reverse the row order of numeric keys (7-8-9 on top, 1-2-3 on bottom)
|
|
||||||
// Non-numeric keys (e.g. 'h' for hired) stay pinned after the grid.
|
|
||||||
const displayLabels = computed(() => {
|
|
||||||
const numeric = props.labels.filter(l => !isNaN(Number(l.key)))
|
|
||||||
const other = props.labels.filter(l => isNaN(Number(l.key)))
|
|
||||||
const rows: Label[][] = []
|
|
||||||
for (let i = 0; i < numeric.length; i += 3) rows.push(numeric.slice(i, i + 3))
|
|
||||||
return [...rows.reverse().flat(), ...other]
|
|
||||||
})
|
|
||||||
</script>
|
</script>
|
||||||
|
|
||||||
<style scoped>
|
<style scoped>
|
||||||
|
|
@ -50,9 +38,11 @@ const displayLabels = computed(() => {
|
||||||
padding var(--bucket-expand, 250ms cubic-bezier(0.34, 1.56, 0.64, 1));
|
padding var(--bucket-expand, 250ms cubic-bezier(0.34, 1.56, 0.64, 1));
|
||||||
}
|
}
|
||||||
|
|
||||||
/* 10th button (hired / key h) — full-width bar below the 3×3 */
|
/* 10th button (hired / key h) — centered below the 3×3 like a numpad 0 */
|
||||||
.label-btn:last-child {
|
.label-btn:last-child {
|
||||||
grid-column: 1 / -1;
|
grid-column: 1 / -1;
|
||||||
|
max-width: calc(33.333% - 0.34rem);
|
||||||
|
justify-self: center;
|
||||||
}
|
}
|
||||||
|
|
||||||
.label-grid.bucket-mode {
|
.label-grid.bucket-mode {
|
||||||
|
|
|
||||||
|
|
@ -1,142 +0,0 @@
|
||||||
import { ref } from 'vue'
|
|
||||||
import { describe, it, expect, vi, beforeEach } from 'vitest'
|
|
||||||
|
|
||||||
// Mock animejs before importing the composable
|
|
||||||
vi.mock('animejs', () => ({
|
|
||||||
animate: vi.fn(),
|
|
||||||
spring: vi.fn(() => 'mock-spring'),
|
|
||||||
utils: { set: vi.fn() },
|
|
||||||
}))
|
|
||||||
|
|
||||||
import { useCardAnimation } from './useCardAnimation'
|
|
||||||
import { animate, utils } from 'animejs'
|
|
||||||
|
|
||||||
const mockAnimate = animate as ReturnType<typeof vi.fn>
|
|
||||||
const mockSet = utils.set as ReturnType<typeof vi.fn>
|
|
||||||
|
|
||||||
function makeEl() {
|
|
||||||
return document.createElement('div')
|
|
||||||
}
|
|
||||||
|
|
||||||
describe('useCardAnimation', () => {
|
|
||||||
beforeEach(() => {
|
|
||||||
vi.clearAllMocks()
|
|
||||||
})
|
|
||||||
|
|
||||||
it('pickup() calls animate with ball shape', () => {
|
|
||||||
const el = makeEl()
|
|
||||||
const cardEl = ref<HTMLElement | null>(el)
|
|
||||||
const motion = { rich: ref(true) }
|
|
||||||
const { pickup } = useCardAnimation(cardEl, motion)
|
|
||||||
pickup()
|
|
||||||
expect(mockAnimate).toHaveBeenCalledWith(
|
|
||||||
el,
|
|
||||||
expect.objectContaining({ scale: 0.55, borderRadius: '50%' }),
|
|
||||||
)
|
|
||||||
})
|
|
||||||
|
|
||||||
it('pickup() is a no-op when motion.rich is false', () => {
|
|
||||||
const el = makeEl()
|
|
||||||
const cardEl = ref<HTMLElement | null>(el)
|
|
||||||
const motion = { rich: ref(false) }
|
|
||||||
const { pickup } = useCardAnimation(cardEl, motion)
|
|
||||||
pickup()
|
|
||||||
expect(mockAnimate).not.toHaveBeenCalled()
|
|
||||||
})
|
|
||||||
|
|
||||||
it('setDragPosition() calls utils.set with translated coords', () => {
|
|
||||||
const el = makeEl()
|
|
||||||
const cardEl = ref<HTMLElement | null>(el)
|
|
||||||
const motion = { rich: ref(true) }
|
|
||||||
const { setDragPosition } = useCardAnimation(cardEl, motion)
|
|
||||||
setDragPosition(50, 30)
|
|
||||||
expect(mockSet).toHaveBeenCalledWith(el, expect.objectContaining({ x: 50, y: -50 }))
|
|
||||||
// y = deltaY - 80 = 30 - 80 = -50
|
|
||||||
})
|
|
||||||
|
|
||||||
it('snapBack() calls animate returning to card shape', () => {
|
|
||||||
const el = makeEl()
|
|
||||||
const cardEl = ref<HTMLElement | null>(el)
|
|
||||||
const motion = { rich: ref(true) }
|
|
||||||
const { snapBack } = useCardAnimation(cardEl, motion)
|
|
||||||
snapBack()
|
|
||||||
expect(mockAnimate).toHaveBeenCalledWith(
|
|
||||||
el,
|
|
||||||
expect.objectContaining({ x: 0, y: 0, scale: 1 }),
|
|
||||||
)
|
|
||||||
})
|
|
||||||
|
|
||||||
it('animateDismiss("label") calls animate', () => {
|
|
||||||
const el = makeEl()
|
|
||||||
const cardEl = ref<HTMLElement | null>(el)
|
|
||||||
const motion = { rich: ref(true) }
|
|
||||||
const { animateDismiss } = useCardAnimation(cardEl, motion)
|
|
||||||
animateDismiss('label')
|
|
||||||
expect(mockAnimate).toHaveBeenCalled()
|
|
||||||
})
|
|
||||||
|
|
||||||
it('animateDismiss("discard") calls animate', () => {
|
|
||||||
const el = makeEl()
|
|
||||||
const cardEl = ref<HTMLElement | null>(el)
|
|
||||||
const motion = { rich: ref(true) }
|
|
||||||
const { animateDismiss } = useCardAnimation(cardEl, motion)
|
|
||||||
animateDismiss('discard')
|
|
||||||
expect(mockAnimate).toHaveBeenCalled()
|
|
||||||
})
|
|
||||||
|
|
||||||
it('animateDismiss("skip") calls animate', () => {
|
|
||||||
const el = makeEl()
|
|
||||||
const cardEl = ref<HTMLElement | null>(el)
|
|
||||||
const motion = { rich: ref(true) }
|
|
||||||
const { animateDismiss } = useCardAnimation(cardEl, motion)
|
|
||||||
animateDismiss('skip')
|
|
||||||
expect(mockAnimate).toHaveBeenCalled()
|
|
||||||
})
|
|
||||||
|
|
||||||
it('animateDismiss is a no-op when motion.rich is false', () => {
|
|
||||||
const el = makeEl()
|
|
||||||
const cardEl = ref<HTMLElement | null>(el)
|
|
||||||
const motion = { rich: ref(false) }
|
|
||||||
const { animateDismiss } = useCardAnimation(cardEl, motion)
|
|
||||||
animateDismiss('label')
|
|
||||||
expect(mockAnimate).not.toHaveBeenCalled()
|
|
||||||
})
|
|
||||||
|
|
||||||
describe('updateAura', () => {
|
|
||||||
it('sets red background for discard zone', () => {
|
|
||||||
const el = makeEl()
|
|
||||||
const cardEl = ref<HTMLElement | null>(el)
|
|
||||||
const motion = { rich: ref(true) }
|
|
||||||
const { updateAura } = useCardAnimation(cardEl, motion)
|
|
||||||
updateAura('discard', null)
|
|
||||||
expect(mockSet).toHaveBeenCalledWith(el, expect.objectContaining({ background: 'rgba(244, 67, 54, 0.25)' }))
|
|
||||||
})
|
|
||||||
|
|
||||||
it('sets orange background for skip zone', () => {
|
|
||||||
const el = makeEl()
|
|
||||||
const cardEl = ref<HTMLElement | null>(el)
|
|
||||||
const motion = { rich: ref(true) }
|
|
||||||
const { updateAura } = useCardAnimation(cardEl, motion)
|
|
||||||
updateAura('skip', null)
|
|
||||||
expect(mockSet).toHaveBeenCalledWith(el, expect.objectContaining({ background: 'rgba(255, 152, 0, 0.25)' }))
|
|
||||||
})
|
|
||||||
|
|
||||||
it('sets blue background for bucket hover', () => {
|
|
||||||
const el = makeEl()
|
|
||||||
const cardEl = ref<HTMLElement | null>(el)
|
|
||||||
const motion = { rich: ref(true) }
|
|
||||||
const { updateAura } = useCardAnimation(cardEl, motion)
|
|
||||||
updateAura(null, 'interview_scheduled')
|
|
||||||
expect(mockSet).toHaveBeenCalledWith(el, expect.objectContaining({ background: 'rgba(42, 96, 128, 0.20)' }))
|
|
||||||
})
|
|
||||||
|
|
||||||
it('sets transparent background when no zone/bucket', () => {
|
|
||||||
const el = makeEl()
|
|
||||||
const cardEl = ref<HTMLElement | null>(el)
|
|
||||||
const motion = { rich: ref(true) }
|
|
||||||
const { updateAura } = useCardAnimation(cardEl, motion)
|
|
||||||
updateAura(null, null)
|
|
||||||
expect(mockSet).toHaveBeenCalledWith(el, expect.objectContaining({ background: 'transparent' }))
|
|
||||||
})
|
|
||||||
})
|
|
||||||
})
|
|
||||||
|
|
@ -1,99 +0,0 @@
|
||||||
import { type Ref } from 'vue'
|
|
||||||
import { animate, spring, utils } from 'animejs'
|
|
||||||
|
|
||||||
const BALL_SCALE = 0.55
|
|
||||||
const BALL_RADIUS = '50%'
|
|
||||||
const CARD_RADIUS = '1rem'
|
|
||||||
const PICKUP_Y_OFFSET = 80 // px above finger
|
|
||||||
const PICKUP_DURATION = 200
|
|
||||||
|
|
||||||
// Anime.js v4: spring() takes an object { mass, stiffness, damping, velocity }
|
|
||||||
const SNAP_SPRING = spring({ mass: 1, stiffness: 80, damping: 10 })
|
|
||||||
|
|
||||||
interface Motion { rich: Ref<boolean> }
|
|
||||||
|
|
||||||
export function useCardAnimation(
|
|
||||||
cardEl: Ref<HTMLElement | null>,
|
|
||||||
motion: Motion,
|
|
||||||
) {
|
|
||||||
function pickup() {
|
|
||||||
if (!motion.rich.value || !cardEl.value) return
|
|
||||||
// Anime.js v4: animate(target, params) — all props + timing in one object
|
|
||||||
animate(cardEl.value, {
|
|
||||||
scale: BALL_SCALE,
|
|
||||||
borderRadius: BALL_RADIUS,
|
|
||||||
y: -PICKUP_Y_OFFSET,
|
|
||||||
duration: PICKUP_DURATION,
|
|
||||||
ease: SNAP_SPRING,
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
function setDragPosition(dx: number, dy: number) {
|
|
||||||
if (!cardEl.value) return
|
|
||||||
// utils.set() for instant (no-animation) position update — keeps Anime cache consistent
|
|
||||||
utils.set(cardEl.value, { x: dx, y: dy - PICKUP_Y_OFFSET })
|
|
||||||
}
|
|
||||||
|
|
||||||
function snapBack() {
|
|
||||||
if (!motion.rich.value || !cardEl.value) return
|
|
||||||
animate(cardEl.value, {
|
|
||||||
x: 0,
|
|
||||||
y: 0,
|
|
||||||
scale: 1,
|
|
||||||
borderRadius: CARD_RADIUS,
|
|
||||||
ease: SNAP_SPRING,
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
function animateDismiss(type: 'label' | 'skip' | 'discard') {
|
|
||||||
if (!motion.rich.value || !cardEl.value) return
|
|
||||||
const el = cardEl.value
|
|
||||||
if (type === 'label') {
|
|
||||||
animate(el, { y: '-120%', scale: 0.85, opacity: 0, duration: 280, ease: 'out(3)' })
|
|
||||||
} else if (type === 'discard') {
|
|
||||||
// Anime.js v4 keyframe array: array of param objects, each can have its own duration
|
|
||||||
animate(el, {
|
|
||||||
keyframes: [
|
|
||||||
{ scale: 0.95, rotate: 2, filter: 'brightness(0.6) sepia(1) hue-rotate(-20deg)', duration: 140 },
|
|
||||||
{ scale: 0, rotate: 8, opacity: 0, duration: 210 },
|
|
||||||
],
|
|
||||||
})
|
|
||||||
} else if (type === 'skip') {
|
|
||||||
animate(el, { x: '110%', rotate: 5, opacity: 0, duration: 260, ease: 'out(2)' })
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
const AURA_COLORS = {
|
|
||||||
discard: 'rgba(244, 67, 54, 0.25)',
|
|
||||||
skip: 'rgba(255, 152, 0, 0.25)',
|
|
||||||
bucket: 'rgba(42, 96, 128, 0.20)',
|
|
||||||
none: 'transparent',
|
|
||||||
} as const
|
|
||||||
|
|
||||||
function updateAura(zone: 'discard' | 'skip' | null, bucket: string | null) {
|
|
||||||
if (!cardEl.value) return
|
|
||||||
const color =
|
|
||||||
zone === 'discard' ? AURA_COLORS.discard :
|
|
||||||
zone === 'skip' ? AURA_COLORS.skip :
|
|
||||||
bucket ? AURA_COLORS.bucket :
|
|
||||||
AURA_COLORS.none
|
|
||||||
utils.set(cardEl.value, { background: color })
|
|
||||||
}
|
|
||||||
|
|
||||||
function reset() {
|
|
||||||
if (!cardEl.value) return
|
|
||||||
// Instantly restore initial card state — called when a new item loads into the same element
|
|
||||||
utils.set(cardEl.value, {
|
|
||||||
x: 0,
|
|
||||||
y: 0,
|
|
||||||
scale: 1,
|
|
||||||
opacity: 1,
|
|
||||||
rotate: 0,
|
|
||||||
borderRadius: CARD_RADIUS,
|
|
||||||
background: 'transparent',
|
|
||||||
filter: 'none',
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
return { pickup, setDragPosition, snapBack, animateDismiss, updateAura, reset }
|
|
||||||
}
|
|
||||||
|
|
@ -1,15 +1,14 @@
|
||||||
import { onMounted, onUnmounted } from 'vue'
|
import { onMounted, onUnmounted } from 'vue'
|
||||||
|
|
||||||
const KONAMI = ['ArrowUp','ArrowUp','ArrowDown','ArrowDown','ArrowLeft','ArrowRight','ArrowLeft','ArrowRight','b','a']
|
const KONAMI = ['ArrowUp','ArrowUp','ArrowDown','ArrowDown','ArrowLeft','ArrowRight','ArrowLeft','ArrowRight','b','a']
|
||||||
const KONAMI_AB = ['ArrowUp','ArrowUp','ArrowDown','ArrowDown','ArrowLeft','ArrowRight','ArrowLeft','ArrowRight','a','b']
|
|
||||||
|
|
||||||
export function useKeySequence(sequence: string[], onActivate: () => void) {
|
export function useKonamiCode(onActivate: () => void) {
|
||||||
let pos = 0
|
let pos = 0
|
||||||
|
|
||||||
function handler(e: KeyboardEvent) {
|
function handler(e: KeyboardEvent) {
|
||||||
if (e.key === sequence[pos]) {
|
if (e.key === KONAMI[pos]) {
|
||||||
pos++
|
pos++
|
||||||
if (pos === sequence.length) {
|
if (pos === KONAMI.length) {
|
||||||
pos = 0
|
pos = 0
|
||||||
onActivate()
|
onActivate()
|
||||||
}
|
}
|
||||||
|
|
@ -22,11 +21,6 @@ export function useKeySequence(sequence: string[], onActivate: () => void) {
|
||||||
onUnmounted(() => window.removeEventListener('keydown', handler))
|
onUnmounted(() => window.removeEventListener('keydown', handler))
|
||||||
}
|
}
|
||||||
|
|
||||||
export function useKonamiCode(onActivate: () => void) {
|
|
||||||
useKeySequence(KONAMI, onActivate)
|
|
||||||
useKeySequence(KONAMI_AB, onActivate)
|
|
||||||
}
|
|
||||||
|
|
||||||
export function useHackerMode() {
|
export function useHackerMode() {
|
||||||
function toggle() {
|
function toggle() {
|
||||||
const root = document.documentElement
|
const root = document.documentElement
|
||||||
|
|
|
||||||
|
|
@ -4,7 +4,6 @@ import LabelView from '../views/LabelView.vue'
|
||||||
// Views are lazy-loaded to keep initial bundle small
|
// Views are lazy-loaded to keep initial bundle small
|
||||||
const FetchView = () => import('../views/FetchView.vue')
|
const FetchView = () => import('../views/FetchView.vue')
|
||||||
const StatsView = () => import('../views/StatsView.vue')
|
const StatsView = () => import('../views/StatsView.vue')
|
||||||
const BenchmarkView = () => import('../views/BenchmarkView.vue')
|
|
||||||
const SettingsView = () => import('../views/SettingsView.vue')
|
const SettingsView = () => import('../views/SettingsView.vue')
|
||||||
|
|
||||||
export const router = createRouter({
|
export const router = createRouter({
|
||||||
|
|
@ -13,7 +12,6 @@ export const router = createRouter({
|
||||||
{ path: '/', component: LabelView, meta: { title: 'Label' } },
|
{ path: '/', component: LabelView, meta: { title: 'Label' } },
|
||||||
{ path: '/fetch', component: FetchView, meta: { title: 'Fetch' } },
|
{ path: '/fetch', component: FetchView, meta: { title: 'Fetch' } },
|
||||||
{ path: '/stats', component: StatsView, meta: { title: 'Stats' } },
|
{ path: '/stats', component: StatsView, meta: { title: 'Stats' } },
|
||||||
{ path: '/benchmark', component: BenchmarkView, meta: { title: 'Benchmark' } },
|
|
||||||
{ path: '/settings', component: SettingsView, meta: { title: 'Settings' } },
|
{ path: '/settings', component: SettingsView, meta: { title: 'Settings' } },
|
||||||
],
|
],
|
||||||
})
|
})
|
||||||
|
|
|
||||||
|
|
@ -1,846 +0,0 @@
|
||||||
<template>
|
|
||||||
<div class="bench-view">
|
|
||||||
<header class="bench-header">
|
|
||||||
<h1 class="page-title">🏁 Benchmark</h1>
|
|
||||||
<div class="header-actions">
|
|
||||||
<label class="slow-toggle" :class="{ disabled: running }">
|
|
||||||
<input type="checkbox" v-model="includeSlow" :disabled="running" />
|
|
||||||
Include slow models
|
|
||||||
</label>
|
|
||||||
<button
|
|
||||||
class="btn-run"
|
|
||||||
:disabled="running"
|
|
||||||
@click="startBenchmark"
|
|
||||||
>
|
|
||||||
{{ running ? '⏳ Running…' : results ? '🔄 Re-run' : '▶ Run Benchmark' }}
|
|
||||||
</button>
|
|
||||||
<button
|
|
||||||
v-if="running"
|
|
||||||
class="btn-cancel"
|
|
||||||
@click="cancelBenchmark"
|
|
||||||
>
|
|
||||||
✕ Cancel
|
|
||||||
</button>
|
|
||||||
</div>
|
|
||||||
</header>
|
|
||||||
|
|
||||||
<!-- Trained models badge row -->
|
|
||||||
<div v-if="fineTunedModels.length > 0" class="trained-models-row">
|
|
||||||
<span class="trained-label">Trained:</span>
|
|
||||||
<span
|
|
||||||
v-for="m in fineTunedModels"
|
|
||||||
:key="m.name"
|
|
||||||
class="trained-badge"
|
|
||||||
:title="m.base_model_id ? `Base: ${m.base_model_id} · ${m.sample_count ?? '?'} samples` : m.name"
|
|
||||||
>
|
|
||||||
{{ m.name }}
|
|
||||||
<span v-if="m.val_macro_f1 != null" class="trained-f1">
|
|
||||||
F1 {{ (m.val_macro_f1 * 100).toFixed(1) }}%
|
|
||||||
</span>
|
|
||||||
</span>
|
|
||||||
</div>
|
|
||||||
|
|
||||||
<!-- Progress log -->
|
|
||||||
<div v-if="running || runLog.length" class="run-log">
|
|
||||||
<div class="run-log-title">
|
|
||||||
<span>{{ running ? '⏳ Running benchmark…' : runCancelled ? '⏹ Cancelled' : runError ? '❌ Failed' : '✅ Done' }}</span>
|
|
||||||
<button class="btn-ghost" @click="runLog = []; runError = ''; runCancelled = false">Clear</button>
|
|
||||||
</div>
|
|
||||||
<div class="log-lines" ref="logEl">
|
|
||||||
<div
|
|
||||||
v-for="(line, i) in runLog"
|
|
||||||
:key="i"
|
|
||||||
class="log-line"
|
|
||||||
:class="{ 'log-error': line.startsWith('ERROR') || line.startsWith('[error]') }"
|
|
||||||
>{{ line }}</div>
|
|
||||||
</div>
|
|
||||||
<p v-if="runError" class="run-error">{{ runError }}</p>
|
|
||||||
</div>
|
|
||||||
|
|
||||||
<!-- Loading -->
|
|
||||||
<div v-if="loading" class="status-notice">Loading…</div>
|
|
||||||
|
|
||||||
<!-- No results yet -->
|
|
||||||
<div v-else-if="!results" class="status-notice empty">
|
|
||||||
<p>No benchmark results yet.</p>
|
|
||||||
<p class="hint">Click <strong>Run Benchmark</strong> to score all default models against your labeled data.</p>
|
|
||||||
</div>
|
|
||||||
|
|
||||||
<!-- Results -->
|
|
||||||
<template v-else>
|
|
||||||
<p class="meta-line">
|
|
||||||
<span>{{ results.sample_count.toLocaleString() }} labeled emails</span>
|
|
||||||
<span class="sep">·</span>
|
|
||||||
<span>{{ modelCount }} model{{ modelCount === 1 ? '' : 's' }}</span>
|
|
||||||
<span class="sep">·</span>
|
|
||||||
<span>{{ formatDate(results.timestamp) }}</span>
|
|
||||||
</p>
|
|
||||||
|
|
||||||
<!-- Macro-F1 chart -->
|
|
||||||
<section class="chart-section">
|
|
||||||
<h2 class="chart-title">Macro-F1 (higher = better)</h2>
|
|
||||||
<div class="bar-chart">
|
|
||||||
<div v-for="row in f1Rows" :key="row.name" class="bar-row">
|
|
||||||
<span class="bar-label" :title="row.name">{{ row.name }}</span>
|
|
||||||
<div class="bar-track">
|
|
||||||
<div
|
|
||||||
class="bar-fill"
|
|
||||||
:style="{ width: `${row.pct}%`, background: scoreColor(row.value) }"
|
|
||||||
/>
|
|
||||||
</div>
|
|
||||||
<span class="bar-value" :style="{ color: scoreColor(row.value) }">
|
|
||||||
{{ row.value.toFixed(3) }}
|
|
||||||
</span>
|
|
||||||
</div>
|
|
||||||
</div>
|
|
||||||
</section>
|
|
||||||
|
|
||||||
<!-- Latency chart -->
|
|
||||||
<section class="chart-section">
|
|
||||||
<h2 class="chart-title">Latency (ms / email, lower = better)</h2>
|
|
||||||
<div class="bar-chart">
|
|
||||||
<div v-for="row in latencyRows" :key="row.name" class="bar-row">
|
|
||||||
<span class="bar-label" :title="row.name">{{ row.name }}</span>
|
|
||||||
<div class="bar-track">
|
|
||||||
<div
|
|
||||||
class="bar-fill latency-fill"
|
|
||||||
:style="{ width: `${row.pct}%` }"
|
|
||||||
/>
|
|
||||||
</div>
|
|
||||||
<span class="bar-value">{{ row.value.toFixed(1) }} ms</span>
|
|
||||||
</div>
|
|
||||||
</div>
|
|
||||||
</section>
|
|
||||||
|
|
||||||
<!-- Per-label F1 heatmap -->
|
|
||||||
<section class="chart-section">
|
|
||||||
<h2 class="chart-title">Per-label F1</h2>
|
|
||||||
<div class="heatmap-scroll">
|
|
||||||
<table class="heatmap">
|
|
||||||
<thead>
|
|
||||||
<tr>
|
|
||||||
<th class="hm-label-col">Label</th>
|
|
||||||
<th v-for="name in modelNames" :key="name" class="hm-model-col" :title="name">
|
|
||||||
{{ name }}
|
|
||||||
</th>
|
|
||||||
</tr>
|
|
||||||
</thead>
|
|
||||||
<tbody>
|
|
||||||
<tr v-for="label in labelNames" :key="label">
|
|
||||||
<td class="hm-label-cell">
|
|
||||||
<span class="hm-emoji">{{ LABEL_META[label]?.emoji ?? '🏷️' }}</span>
|
|
||||||
{{ label.replace(/_/g, '\u00a0') }}
|
|
||||||
</td>
|
|
||||||
<td
|
|
||||||
v-for="name in modelNames"
|
|
||||||
:key="name"
|
|
||||||
class="hm-value-cell"
|
|
||||||
:style="{ background: heatmapBg(f1For(name, label)), color: heatmapFg(f1For(name, label)) }"
|
|
||||||
:title="`${name} / ${label}: F1 ${f1For(name, label).toFixed(3)}, support ${supportFor(name, label)}`"
|
|
||||||
>
|
|
||||||
{{ f1For(name, label).toFixed(2) }}
|
|
||||||
</td>
|
|
||||||
</tr>
|
|
||||||
</tbody>
|
|
||||||
</table>
|
|
||||||
</div>
|
|
||||||
<p class="heatmap-hint">Hover a cell for precision / recall / support. Color: 🟢 ≥ 0.7 · 🟡 0.4–0.7 · 🔴 < 0.4</p>
|
|
||||||
</section>
|
|
||||||
</template>
|
|
||||||
|
|
||||||
<!-- Fine-tune section -->
|
|
||||||
<details class="ft-section">
|
|
||||||
<summary class="ft-summary">Fine-tune a model</summary>
|
|
||||||
<div class="ft-body">
|
|
||||||
<div class="ft-controls">
|
|
||||||
<label class="ft-field">
|
|
||||||
<span class="ft-field-label">Model</span>
|
|
||||||
<select v-model="ftModel" class="ft-select" :disabled="ftRunning">
|
|
||||||
<option value="deberta-small">deberta-small (100M, fast)</option>
|
|
||||||
<option value="bge-m3">bge-m3 (600M — stop Peregrine vLLM first)</option>
|
|
||||||
</select>
|
|
||||||
</label>
|
|
||||||
<label class="ft-field">
|
|
||||||
<span class="ft-field-label">Epochs</span>
|
|
||||||
<input
|
|
||||||
v-model.number="ftEpochs"
|
|
||||||
type="number" min="1" max="20"
|
|
||||||
class="ft-epochs"
|
|
||||||
:disabled="ftRunning"
|
|
||||||
/>
|
|
||||||
</label>
|
|
||||||
<button
|
|
||||||
class="btn-run ft-run-btn"
|
|
||||||
:disabled="ftRunning"
|
|
||||||
@click="startFinetune"
|
|
||||||
>
|
|
||||||
{{ ftRunning ? '⏳ Training…' : '▶ Run fine-tune' }}
|
|
||||||
</button>
|
|
||||||
<button
|
|
||||||
v-if="ftRunning"
|
|
||||||
class="btn-cancel"
|
|
||||||
@click="cancelFinetune"
|
|
||||||
>
|
|
||||||
✕ Cancel
|
|
||||||
</button>
|
|
||||||
</div>
|
|
||||||
|
|
||||||
<div v-if="ftRunning || ftLog.length || ftError" class="run-log ft-log">
|
|
||||||
<div class="run-log-title">
|
|
||||||
<span>{{ ftRunning ? '⏳ Training…' : ftCancelled ? '⏹ Cancelled' : ftError ? '❌ Failed' : '✅ Done' }}</span>
|
|
||||||
<button class="btn-ghost" @click="ftLog = []; ftError = ''; ftCancelled = false">Clear</button>
|
|
||||||
</div>
|
|
||||||
<div class="log-lines" ref="ftLogEl">
|
|
||||||
<div
|
|
||||||
v-for="(line, i) in ftLog"
|
|
||||||
:key="i"
|
|
||||||
class="log-line"
|
|
||||||
:class="{ 'log-error': line.startsWith('ERROR') || line.startsWith('[error]') }"
|
|
||||||
>{{ line }}</div>
|
|
||||||
</div>
|
|
||||||
<p v-if="ftError" class="run-error">{{ ftError }}</p>
|
|
||||||
</div>
|
|
||||||
</div>
|
|
||||||
</details>
|
|
||||||
</div>
|
|
||||||
</template>
|
|
||||||
|
|
||||||
<script setup lang="ts">
|
|
||||||
import { ref, computed, onMounted, nextTick } from 'vue'
|
|
||||||
import { useApiFetch, useApiSSE } from '../composables/useApi'
|
|
||||||
|
|
||||||
// ── Label metadata (same as StatsView) ──────────────────────────────────────
|
|
||||||
const LABEL_META: Record<string, { emoji: string }> = {
|
|
||||||
interview_scheduled: { emoji: '🗓️' },
|
|
||||||
offer_received: { emoji: '🎉' },
|
|
||||||
rejected: { emoji: '❌' },
|
|
||||||
positive_response: { emoji: '👍' },
|
|
||||||
survey_received: { emoji: '📋' },
|
|
||||||
neutral: { emoji: '⬜' },
|
|
||||||
event_rescheduled: { emoji: '🔄' },
|
|
||||||
digest: { emoji: '📰' },
|
|
||||||
new_lead: { emoji: '🤝' },
|
|
||||||
hired: { emoji: '🎊' },
|
|
||||||
}
|
|
||||||
|
|
||||||
// ── Types ────────────────────────────────────────────────────────────────────
|
|
||||||
interface FineTunedModel {
|
|
||||||
name: string
|
|
||||||
base_model_id?: string
|
|
||||||
val_macro_f1?: number
|
|
||||||
timestamp?: string
|
|
||||||
sample_count?: number
|
|
||||||
}
|
|
||||||
|
|
||||||
interface PerLabel { f1: number; precision: number; recall: number; support: number }
|
|
||||||
interface ModelResult {
|
|
||||||
macro_f1: number
|
|
||||||
accuracy: number
|
|
||||||
latency_ms: number
|
|
||||||
per_label: Record<string, PerLabel>
|
|
||||||
}
|
|
||||||
interface BenchResults {
|
|
||||||
timestamp: string | null
|
|
||||||
sample_count: number
|
|
||||||
models: Record<string, ModelResult>
|
|
||||||
}
|
|
||||||
|
|
||||||
// ── State ────────────────────────────────────────────────────────────────────
|
|
||||||
const results = ref<BenchResults | null>(null)
|
|
||||||
const loading = ref(true)
|
|
||||||
const running = ref(false)
|
|
||||||
const runLog = ref<string[]>([])
|
|
||||||
const runError = ref('')
|
|
||||||
const includeSlow = ref(false)
|
|
||||||
const logEl = ref<HTMLElement | null>(null)
|
|
||||||
|
|
||||||
// Fine-tune state
|
|
||||||
const fineTunedModels = ref<FineTunedModel[]>([])
|
|
||||||
const ftModel = ref('deberta-small')
|
|
||||||
const ftEpochs = ref(5)
|
|
||||||
const ftRunning = ref(false)
|
|
||||||
const ftLog = ref<string[]>([])
|
|
||||||
const ftError = ref('')
|
|
||||||
const ftLogEl = ref<HTMLElement | null>(null)
|
|
||||||
|
|
||||||
const runCancelled = ref(false)
|
|
||||||
const ftCancelled = ref(false)
|
|
||||||
|
|
||||||
async function cancelBenchmark() {
|
|
||||||
await fetch('/api/benchmark/cancel', { method: 'POST' }).catch(() => {})
|
|
||||||
}
|
|
||||||
|
|
||||||
async function cancelFinetune() {
|
|
||||||
await fetch('/api/finetune/cancel', { method: 'POST' }).catch(() => {})
|
|
||||||
}
|
|
||||||
|
|
||||||
// ── Derived ──────────────────────────────────────────────────────────────────
|
|
||||||
const modelNames = computed(() => Object.keys(results.value?.models ?? {}))
|
|
||||||
const modelCount = computed(() => modelNames.value.length)
|
|
||||||
|
|
||||||
const labelNames = computed(() => {
|
|
||||||
const canonical = Object.keys(LABEL_META)
|
|
||||||
const inResults = new Set(
|
|
||||||
modelNames.value.flatMap(n => Object.keys(results.value!.models[n].per_label))
|
|
||||||
)
|
|
||||||
return [...canonical.filter(l => inResults.has(l)), ...[...inResults].filter(l => !canonical.includes(l))]
|
|
||||||
})
|
|
||||||
|
|
||||||
const f1Rows = computed(() => {
|
|
||||||
if (!results.value) return []
|
|
||||||
const rows = modelNames.value.map(name => ({
|
|
||||||
name,
|
|
||||||
value: results.value!.models[name].macro_f1,
|
|
||||||
}))
|
|
||||||
rows.sort((a, b) => b.value - a.value)
|
|
||||||
const max = rows[0]?.value || 1
|
|
||||||
return rows.map(r => ({ ...r, pct: Math.round((r.value / max) * 100) }))
|
|
||||||
})
|
|
||||||
|
|
||||||
const latencyRows = computed(() => {
|
|
||||||
if (!results.value) return []
|
|
||||||
const rows = modelNames.value.map(name => ({
|
|
||||||
name,
|
|
||||||
value: results.value!.models[name].latency_ms,
|
|
||||||
}))
|
|
||||||
rows.sort((a, b) => a.value - b.value) // fastest first
|
|
||||||
const max = rows[rows.length - 1]?.value || 1
|
|
||||||
return rows.map(r => ({ ...r, pct: Math.round((r.value / max) * 100) }))
|
|
||||||
})
|
|
||||||
|
|
||||||
// ── Helpers ──────────────────────────────────────────────────────────────────
|
|
||||||
function f1For(model: string, label: string): number {
|
|
||||||
return results.value?.models[model]?.per_label[label]?.f1 ?? 0
|
|
||||||
}
|
|
||||||
function supportFor(model: string, label: string): number {
|
|
||||||
return results.value?.models[model]?.per_label[label]?.support ?? 0
|
|
||||||
}
|
|
||||||
|
|
||||||
function scoreColor(v: number): string {
|
|
||||||
if (v >= 0.7) return 'var(--color-success, #4CAF50)'
|
|
||||||
if (v >= 0.4) return 'var(--app-accent, #B8622A)'
|
|
||||||
return 'var(--color-error, #ef4444)'
|
|
||||||
}
|
|
||||||
|
|
||||||
function heatmapBg(v: number): string {
|
|
||||||
// Blend red→yellow→green using the F1 value
|
|
||||||
if (v >= 0.7) return `color-mix(in srgb, #4CAF50 ${Math.round(v * 100)}%, #1a2338 ${Math.round((1 - v) * 80)}%)`
|
|
||||||
if (v >= 0.4) return `color-mix(in srgb, #FF9800 ${Math.round(v * 120)}%, #1a2338 40%)`
|
|
||||||
return `color-mix(in srgb, #ef4444 ${Math.round(v * 200 + 30)}%, #1a2338 60%)`
|
|
||||||
}
|
|
||||||
function heatmapFg(v: number): string {
|
|
||||||
return v >= 0.5 ? '#fff' : 'rgba(255,255,255,0.75)'
|
|
||||||
}
|
|
||||||
|
|
||||||
function formatDate(iso: string | null): string {
|
|
||||||
if (!iso) return 'unknown date'
|
|
||||||
const d = new Date(iso)
|
|
||||||
return d.toLocaleString(undefined, { dateStyle: 'medium', timeStyle: 'short' })
|
|
||||||
}
|
|
||||||
|
|
||||||
// ── Data loading ─────────────────────────────────────────────────────────────
|
|
||||||
async function loadResults() {
|
|
||||||
loading.value = true
|
|
||||||
const { data } = await useApiFetch<BenchResults>('/api/benchmark/results')
|
|
||||||
loading.value = false
|
|
||||||
if (data && Object.keys(data.models).length > 0) {
|
|
||||||
results.value = data
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// ── Benchmark run ─────────────────────────────────────────────────────────────
|
|
||||||
function startBenchmark() {
|
|
||||||
running.value = true
|
|
||||||
runLog.value = []
|
|
||||||
runError.value = ''
|
|
||||||
runCancelled.value = false
|
|
||||||
|
|
||||||
const url = `/api/benchmark/run${includeSlow.value ? '?include_slow=true' : ''}`
|
|
||||||
useApiSSE(
|
|
||||||
url,
|
|
||||||
async (event) => {
|
|
||||||
if (event.type === 'progress' && typeof event.message === 'string') {
|
|
||||||
runLog.value.push(event.message)
|
|
||||||
await nextTick()
|
|
||||||
logEl.value?.scrollTo({ top: logEl.value.scrollHeight, behavior: 'smooth' })
|
|
||||||
}
|
|
||||||
if (event.type === 'error' && typeof event.message === 'string') {
|
|
||||||
runError.value = event.message
|
|
||||||
}
|
|
||||||
if (event.type === 'cancelled') {
|
|
||||||
running.value = false
|
|
||||||
runCancelled.value = true
|
|
||||||
}
|
|
||||||
},
|
|
||||||
async () => {
|
|
||||||
running.value = false
|
|
||||||
await loadResults()
|
|
||||||
},
|
|
||||||
() => {
|
|
||||||
running.value = false
|
|
||||||
if (!runError.value) runError.value = 'Connection lost'
|
|
||||||
},
|
|
||||||
)
|
|
||||||
}
|
|
||||||
|
|
||||||
async function loadFineTunedModels() {
|
|
||||||
const { data } = await useApiFetch<FineTunedModel[]>('/api/finetune/status')
|
|
||||||
if (Array.isArray(data)) fineTunedModels.value = data
|
|
||||||
}
|
|
||||||
|
|
||||||
function startFinetune() {
|
|
||||||
if (ftRunning.value) return
|
|
||||||
ftRunning.value = true
|
|
||||||
ftLog.value = []
|
|
||||||
ftError.value = ''
|
|
||||||
ftCancelled.value = false
|
|
||||||
|
|
||||||
const params = new URLSearchParams({ model: ftModel.value, epochs: String(ftEpochs.value) })
|
|
||||||
useApiSSE(
|
|
||||||
`/api/finetune/run?${params}`,
|
|
||||||
async (event) => {
|
|
||||||
if (event.type === 'progress' && typeof event.message === 'string') {
|
|
||||||
ftLog.value.push(event.message)
|
|
||||||
await nextTick()
|
|
||||||
ftLogEl.value?.scrollTo({ top: ftLogEl.value.scrollHeight, behavior: 'smooth' })
|
|
||||||
}
|
|
||||||
if (event.type === 'error' && typeof event.message === 'string') {
|
|
||||||
ftError.value = event.message
|
|
||||||
}
|
|
||||||
if (event.type === 'cancelled') {
|
|
||||||
ftRunning.value = false
|
|
||||||
ftCancelled.value = true
|
|
||||||
}
|
|
||||||
},
|
|
||||||
async () => {
|
|
||||||
ftRunning.value = false
|
|
||||||
await loadFineTunedModels()
|
|
||||||
startBenchmark() // auto-trigger benchmark to refresh charts
|
|
||||||
},
|
|
||||||
() => {
|
|
||||||
ftRunning.value = false
|
|
||||||
if (!ftError.value) ftError.value = 'Connection lost'
|
|
||||||
},
|
|
||||||
)
|
|
||||||
}
|
|
||||||
|
|
||||||
onMounted(() => {
|
|
||||||
loadResults()
|
|
||||||
loadFineTunedModels()
|
|
||||||
})
|
|
||||||
</script>
|
|
||||||
|
|
||||||
<style scoped>
|
|
||||||
.bench-view {
|
|
||||||
max-width: 860px;
|
|
||||||
margin: 0 auto;
|
|
||||||
padding: 1.5rem 1rem 4rem;
|
|
||||||
display: flex;
|
|
||||||
flex-direction: column;
|
|
||||||
gap: 1.75rem;
|
|
||||||
}
|
|
||||||
|
|
||||||
.bench-header {
|
|
||||||
display: flex;
|
|
||||||
align-items: center;
|
|
||||||
justify-content: space-between;
|
|
||||||
flex-wrap: wrap;
|
|
||||||
gap: 0.75rem;
|
|
||||||
}
|
|
||||||
|
|
||||||
.page-title {
|
|
||||||
font-family: var(--font-display, var(--font-body, sans-serif));
|
|
||||||
font-size: 1.4rem;
|
|
||||||
font-weight: 700;
|
|
||||||
color: var(--app-primary, #2A6080);
|
|
||||||
margin: 0;
|
|
||||||
}
|
|
||||||
|
|
||||||
.header-actions {
|
|
||||||
display: flex;
|
|
||||||
align-items: center;
|
|
||||||
gap: 0.75rem;
|
|
||||||
flex-wrap: wrap;
|
|
||||||
}
|
|
||||||
|
|
||||||
.slow-toggle {
|
|
||||||
display: flex;
|
|
||||||
align-items: center;
|
|
||||||
gap: 0.4rem;
|
|
||||||
font-size: 0.85rem;
|
|
||||||
color: var(--color-text-secondary, #6b7a99);
|
|
||||||
cursor: pointer;
|
|
||||||
user-select: none;
|
|
||||||
}
|
|
||||||
.slow-toggle.disabled { opacity: 0.5; pointer-events: none; }
|
|
||||||
|
|
||||||
.btn-run {
|
|
||||||
padding: 0.45rem 1.1rem;
|
|
||||||
border-radius: 0.375rem;
|
|
||||||
border: none;
|
|
||||||
background: var(--app-primary, #2A6080);
|
|
||||||
color: #fff;
|
|
||||||
font-size: 0.88rem;
|
|
||||||
font-family: var(--font-body, sans-serif);
|
|
||||||
cursor: pointer;
|
|
||||||
transition: opacity 0.15s;
|
|
||||||
}
|
|
||||||
.btn-run:disabled { opacity: 0.5; cursor: not-allowed; }
|
|
||||||
.btn-run:not(:disabled):hover { opacity: 0.85; }
|
|
||||||
|
|
||||||
.btn-cancel {
|
|
||||||
padding: 0.45rem 0.9rem;
|
|
||||||
background: transparent;
|
|
||||||
border: 1px solid var(--color-text-secondary, #6b7a99);
|
|
||||||
color: var(--color-text-secondary, #6b7a99);
|
|
||||||
border-radius: 0.4rem;
|
|
||||||
font-size: 0.85rem;
|
|
||||||
font-weight: 500;
|
|
||||||
cursor: pointer;
|
|
||||||
transition: background 0.15s;
|
|
||||||
}
|
|
||||||
|
|
||||||
.btn-cancel:hover {
|
|
||||||
background: color-mix(in srgb, var(--color-text-secondary, #6b7a99) 12%, transparent);
|
|
||||||
}
|
|
||||||
|
|
||||||
/* ── Run log ────────────────────────────────────────────── */
|
|
||||||
.run-log {
|
|
||||||
border: 1px solid var(--color-border, #d0d7e8);
|
|
||||||
border-radius: 0.5rem;
|
|
||||||
overflow: hidden;
|
|
||||||
font-family: var(--font-mono, monospace);
|
|
||||||
font-size: 0.78rem;
|
|
||||||
}
|
|
||||||
|
|
||||||
.run-log-title {
|
|
||||||
display: flex;
|
|
||||||
justify-content: space-between;
|
|
||||||
align-items: center;
|
|
||||||
padding: 0.4rem 0.75rem;
|
|
||||||
background: var(--color-surface-raised, #e4ebf5);
|
|
||||||
border-bottom: 1px solid var(--color-border, #d0d7e8);
|
|
||||||
font-size: 0.8rem;
|
|
||||||
color: var(--color-text-secondary, #6b7a99);
|
|
||||||
}
|
|
||||||
|
|
||||||
.btn-ghost {
|
|
||||||
background: none;
|
|
||||||
border: none;
|
|
||||||
color: var(--color-text-secondary, #6b7a99);
|
|
||||||
cursor: pointer;
|
|
||||||
font-size: 0.78rem;
|
|
||||||
padding: 0.1rem 0.3rem;
|
|
||||||
border-radius: 0.2rem;
|
|
||||||
}
|
|
||||||
.btn-ghost:hover { background: var(--color-border, #d0d7e8); }
|
|
||||||
|
|
||||||
.log-lines {
|
|
||||||
max-height: 200px;
|
|
||||||
overflow-y: auto;
|
|
||||||
padding: 0.5rem 0.75rem;
|
|
||||||
background: var(--color-surface, #fff);
|
|
||||||
display: flex;
|
|
||||||
flex-direction: column;
|
|
||||||
gap: 0.1rem;
|
|
||||||
}
|
|
||||||
|
|
||||||
.log-line { color: var(--color-text, #1a2338); line-height: 1.5; }
|
|
||||||
.log-line.log-error { color: var(--color-error, #ef4444); }
|
|
||||||
|
|
||||||
.run-error {
|
|
||||||
margin: 0;
|
|
||||||
padding: 0.4rem 0.75rem;
|
|
||||||
background: color-mix(in srgb, var(--color-error, #ef4444) 10%, transparent);
|
|
||||||
color: var(--color-error, #ef4444);
|
|
||||||
font-size: 0.82rem;
|
|
||||||
font-family: var(--font-mono, monospace);
|
|
||||||
}
|
|
||||||
|
|
||||||
/* ── Status notices ─────────────────────────────────────── */
|
|
||||||
.status-notice {
|
|
||||||
color: var(--color-text-secondary, #6b7a99);
|
|
||||||
font-size: 0.9rem;
|
|
||||||
padding: 1rem;
|
|
||||||
}
|
|
||||||
.status-notice.empty {
|
|
||||||
display: flex;
|
|
||||||
flex-direction: column;
|
|
||||||
align-items: center;
|
|
||||||
gap: 0.5rem;
|
|
||||||
padding: 3rem 1rem;
|
|
||||||
text-align: center;
|
|
||||||
}
|
|
||||||
.hint { font-size: 0.85rem; opacity: 0.75; }
|
|
||||||
|
|
||||||
/* ── Meta line ──────────────────────────────────────────── */
|
|
||||||
.meta-line {
|
|
||||||
display: flex;
|
|
||||||
gap: 0.5rem;
|
|
||||||
align-items: center;
|
|
||||||
font-size: 0.85rem;
|
|
||||||
color: var(--color-text-secondary, #6b7a99);
|
|
||||||
font-family: var(--font-mono, monospace);
|
|
||||||
flex-wrap: wrap;
|
|
||||||
}
|
|
||||||
.sep { opacity: 0.4; }
|
|
||||||
|
|
||||||
/* ── Chart sections ─────────────────────────────────────── */
|
|
||||||
.chart-section {
|
|
||||||
display: flex;
|
|
||||||
flex-direction: column;
|
|
||||||
gap: 0.75rem;
|
|
||||||
}
|
|
||||||
|
|
||||||
.chart-title {
|
|
||||||
font-size: 0.95rem;
|
|
||||||
font-weight: 600;
|
|
||||||
color: var(--color-text, #1a2338);
|
|
||||||
margin: 0;
|
|
||||||
}
|
|
||||||
|
|
||||||
/* ── Bar charts ─────────────────────────────────────────── */
|
|
||||||
.bar-chart {
|
|
||||||
display: flex;
|
|
||||||
flex-direction: column;
|
|
||||||
gap: 0.4rem;
|
|
||||||
}
|
|
||||||
|
|
||||||
.bar-row {
|
|
||||||
display: grid;
|
|
||||||
grid-template-columns: 14rem 1fr 5rem;
|
|
||||||
align-items: center;
|
|
||||||
gap: 0.5rem;
|
|
||||||
font-size: 0.82rem;
|
|
||||||
}
|
|
||||||
|
|
||||||
.bar-label {
|
|
||||||
font-family: var(--font-mono, monospace);
|
|
||||||
font-size: 0.76rem;
|
|
||||||
white-space: nowrap;
|
|
||||||
overflow: hidden;
|
|
||||||
text-overflow: ellipsis;
|
|
||||||
color: var(--color-text, #1a2338);
|
|
||||||
}
|
|
||||||
|
|
||||||
.bar-track {
|
|
||||||
height: 16px;
|
|
||||||
background: var(--color-surface-raised, #e4ebf5);
|
|
||||||
border-radius: 99px;
|
|
||||||
overflow: hidden;
|
|
||||||
}
|
|
||||||
|
|
||||||
.bar-fill {
|
|
||||||
height: 100%;
|
|
||||||
border-radius: 99px;
|
|
||||||
transition: width 0.5s cubic-bezier(0.16, 1, 0.3, 1);
|
|
||||||
}
|
|
||||||
|
|
||||||
.latency-fill { background: var(--app-primary, #2A6080); opacity: 0.65; }
|
|
||||||
|
|
||||||
.bar-value {
|
|
||||||
text-align: right;
|
|
||||||
font-family: var(--font-mono, monospace);
|
|
||||||
font-size: 0.8rem;
|
|
||||||
font-variant-numeric: tabular-nums;
|
|
||||||
}
|
|
||||||
|
|
||||||
/* ── Heatmap ────────────────────────────────────────────── */
|
|
||||||
.heatmap-scroll {
|
|
||||||
overflow-x: auto;
|
|
||||||
border-radius: 0.5rem;
|
|
||||||
border: 1px solid var(--color-border, #d0d7e8);
|
|
||||||
}
|
|
||||||
|
|
||||||
.heatmap {
|
|
||||||
border-collapse: collapse;
|
|
||||||
min-width: 100%;
|
|
||||||
font-size: 0.78rem;
|
|
||||||
}
|
|
||||||
|
|
||||||
.hm-label-col {
|
|
||||||
text-align: left;
|
|
||||||
min-width: 11rem;
|
|
||||||
padding: 0.4rem 0.6rem;
|
|
||||||
background: var(--color-surface-raised, #e4ebf5);
|
|
||||||
font-weight: 600;
|
|
||||||
border-bottom: 1px solid var(--color-border, #d0d7e8);
|
|
||||||
position: sticky;
|
|
||||||
left: 0;
|
|
||||||
}
|
|
||||||
|
|
||||||
.hm-model-col {
|
|
||||||
min-width: 5rem;
|
|
||||||
max-width: 8rem;
|
|
||||||
padding: 0.4rem 0.5rem;
|
|
||||||
background: var(--color-surface-raised, #e4ebf5);
|
|
||||||
border-bottom: 1px solid var(--color-border, #d0d7e8);
|
|
||||||
font-family: var(--font-mono, monospace);
|
|
||||||
font-size: 0.7rem;
|
|
||||||
text-overflow: ellipsis;
|
|
||||||
overflow: hidden;
|
|
||||||
white-space: nowrap;
|
|
||||||
text-align: center;
|
|
||||||
}
|
|
||||||
|
|
||||||
.hm-label-cell {
|
|
||||||
padding: 0.35rem 0.6rem;
|
|
||||||
background: var(--color-surface, #fff);
|
|
||||||
border-top: 1px solid var(--color-border, #d0d7e8);
|
|
||||||
white-space: nowrap;
|
|
||||||
font-family: var(--font-mono, monospace);
|
|
||||||
font-size: 0.74rem;
|
|
||||||
position: sticky;
|
|
||||||
left: 0;
|
|
||||||
}
|
|
||||||
|
|
||||||
.hm-emoji { margin-right: 0.3rem; }
|
|
||||||
|
|
||||||
.hm-value-cell {
|
|
||||||
padding: 0.35rem 0.5rem;
|
|
||||||
text-align: center;
|
|
||||||
font-family: var(--font-mono, monospace);
|
|
||||||
font-variant-numeric: tabular-nums;
|
|
||||||
border-top: 1px solid rgba(255,255,255,0.08);
|
|
||||||
cursor: default;
|
|
||||||
transition: filter 0.15s;
|
|
||||||
}
|
|
||||||
.hm-value-cell:hover { filter: brightness(1.15); }
|
|
||||||
|
|
||||||
.heatmap-hint {
|
|
||||||
font-size: 0.75rem;
|
|
||||||
color: var(--color-text-secondary, #6b7a99);
|
|
||||||
margin: 0;
|
|
||||||
}
|
|
||||||
|
|
||||||
/* ── Mobile tweaks ──────────────────────────────────────── */
|
|
||||||
@media (max-width: 600px) {
|
|
||||||
.bar-row { grid-template-columns: 9rem 1fr 4rem; }
|
|
||||||
.bar-label { font-size: 0.7rem; }
|
|
||||||
.bench-header { flex-direction: column; align-items: flex-start; }
|
|
||||||
}
|
|
||||||
|
|
||||||
/* ── Trained models badge row ──────────────────────────── */
|
|
||||||
.trained-models-row {
|
|
||||||
display: flex;
|
|
||||||
flex-wrap: wrap;
|
|
||||||
align-items: center;
|
|
||||||
gap: 0.5rem;
|
|
||||||
padding: 0.6rem 0.75rem;
|
|
||||||
background: var(--color-surface-raised, #e4ebf5);
|
|
||||||
border-radius: 0.5rem;
|
|
||||||
border: 1px solid var(--color-border, #d0d7e8);
|
|
||||||
}
|
|
||||||
|
|
||||||
.trained-label {
|
|
||||||
font-size: 0.75rem;
|
|
||||||
font-weight: 700;
|
|
||||||
color: var(--color-text-secondary, #6b7a99);
|
|
||||||
text-transform: uppercase;
|
|
||||||
letter-spacing: 0.05em;
|
|
||||||
flex-shrink: 0;
|
|
||||||
}
|
|
||||||
|
|
||||||
.trained-badge {
|
|
||||||
display: inline-flex;
|
|
||||||
align-items: center;
|
|
||||||
gap: 0.35rem;
|
|
||||||
padding: 0.2rem 0.55rem;
|
|
||||||
background: var(--app-primary, #2A6080);
|
|
||||||
color: #fff;
|
|
||||||
border-radius: 1rem;
|
|
||||||
font-family: var(--font-mono, monospace);
|
|
||||||
font-size: 0.76rem;
|
|
||||||
cursor: default;
|
|
||||||
}
|
|
||||||
|
|
||||||
.trained-f1 {
|
|
||||||
background: rgba(255,255,255,0.2);
|
|
||||||
border-radius: 0.75rem;
|
|
||||||
padding: 0.05rem 0.35rem;
|
|
||||||
font-size: 0.7rem;
|
|
||||||
font-weight: 700;
|
|
||||||
}
|
|
||||||
|
|
||||||
/* ── Fine-tune section ──────────────────────────────────── */
|
|
||||||
.ft-section {
|
|
||||||
border: 1px solid var(--color-border, #d0d7e8);
|
|
||||||
border-radius: 0.5rem;
|
|
||||||
overflow: hidden;
|
|
||||||
}
|
|
||||||
|
|
||||||
.ft-summary {
|
|
||||||
padding: 0.65rem 0.9rem;
|
|
||||||
cursor: pointer;
|
|
||||||
font-size: 0.9rem;
|
|
||||||
font-weight: 600;
|
|
||||||
color: var(--color-text, #1a2338);
|
|
||||||
user-select: none;
|
|
||||||
list-style: none;
|
|
||||||
background: var(--color-surface-raised, #e4ebf5);
|
|
||||||
}
|
|
||||||
.ft-summary::-webkit-details-marker { display: none; }
|
|
||||||
.ft-summary::before { content: '▶ '; font-size: 0.65rem; color: var(--color-text-secondary, #6b7a99); }
|
|
||||||
details[open] .ft-summary::before { content: '▼ '; }
|
|
||||||
|
|
||||||
.ft-body {
|
|
||||||
padding: 0.75rem;
|
|
||||||
display: flex;
|
|
||||||
flex-direction: column;
|
|
||||||
gap: 0.75rem;
|
|
||||||
border-top: 1px solid var(--color-border, #d0d7e8);
|
|
||||||
}
|
|
||||||
|
|
||||||
.ft-controls {
|
|
||||||
display: flex;
|
|
||||||
flex-wrap: wrap;
|
|
||||||
gap: 0.75rem;
|
|
||||||
align-items: flex-end;
|
|
||||||
}
|
|
||||||
|
|
||||||
.ft-field {
|
|
||||||
display: flex;
|
|
||||||
flex-direction: column;
|
|
||||||
gap: 0.25rem;
|
|
||||||
}
|
|
||||||
|
|
||||||
.ft-field-label {
|
|
||||||
font-size: 0.75rem;
|
|
||||||
font-weight: 600;
|
|
||||||
color: var(--color-text-secondary, #6b7a99);
|
|
||||||
text-transform: uppercase;
|
|
||||||
letter-spacing: 0.04em;
|
|
||||||
}
|
|
||||||
|
|
||||||
.ft-select {
|
|
||||||
padding: 0.35rem 0.5rem;
|
|
||||||
border: 1px solid var(--color-border, #d0d7e8);
|
|
||||||
border-radius: 0.375rem;
|
|
||||||
background: var(--color-surface, #fff);
|
|
||||||
font-size: 0.85rem;
|
|
||||||
color: var(--color-text, #1a2338);
|
|
||||||
min-width: 220px;
|
|
||||||
}
|
|
||||||
.ft-select:disabled { opacity: 0.55; }
|
|
||||||
|
|
||||||
.ft-epochs {
|
|
||||||
width: 64px;
|
|
||||||
padding: 0.35rem 0.5rem;
|
|
||||||
border: 1px solid var(--color-border, #d0d7e8);
|
|
||||||
border-radius: 0.375rem;
|
|
||||||
background: var(--color-surface, #fff);
|
|
||||||
font-size: 0.85rem;
|
|
||||||
color: var(--color-text, #1a2338);
|
|
||||||
text-align: center;
|
|
||||||
}
|
|
||||||
.ft-epochs:disabled { opacity: 0.55; }
|
|
||||||
|
|
||||||
.ft-run-btn { align-self: flex-end; }
|
|
||||||
|
|
||||||
.ft-log { margin-top: 0; }
|
|
||||||
|
|
||||||
@media (max-width: 600px) {
|
|
||||||
.ft-controls { flex-direction: column; align-items: stretch; }
|
|
||||||
.ft-select { min-width: 0; width: 100%; }
|
|
||||||
}
|
|
||||||
</style>
|
|
||||||
|
|
@ -8,24 +8,12 @@
|
||||||
{{ store.totalRemaining }} remaining
|
{{ store.totalRemaining }} remaining
|
||||||
</template>
|
</template>
|
||||||
<span v-else class="queue-status">Queue empty</span>
|
<span v-else class="queue-status">Queue empty</span>
|
||||||
<Transition @enter="onBadgeEnter" :css="false">
|
|
||||||
<span v-if="onRoll" class="badge badge-roll">🔥 On a roll!</span>
|
<span v-if="onRoll" class="badge badge-roll">🔥 On a roll!</span>
|
||||||
</Transition>
|
|
||||||
<Transition @enter="onBadgeEnter" :css="false">
|
|
||||||
<span v-if="speedRound" class="badge badge-speed">⚡ Speed round!</span>
|
<span v-if="speedRound" class="badge badge-speed">⚡ Speed round!</span>
|
||||||
</Transition>
|
|
||||||
<Transition @enter="onBadgeEnter" :css="false">
|
|
||||||
<span v-if="fiftyDeep" class="badge badge-fifty">🎯 Fifty deep!</span>
|
<span v-if="fiftyDeep" class="badge badge-fifty">🎯 Fifty deep!</span>
|
||||||
</Transition>
|
|
||||||
<Transition @enter="onBadgeEnter" :css="false">
|
|
||||||
<span v-if="centuryMark" class="badge badge-century">💯 Century!</span>
|
<span v-if="centuryMark" class="badge badge-century">💯 Century!</span>
|
||||||
</Transition>
|
|
||||||
<Transition @enter="onBadgeEnter" :css="false">
|
|
||||||
<span v-if="cleanSweep" class="badge badge-sweep">🧹 Clean sweep!</span>
|
<span v-if="cleanSweep" class="badge badge-sweep">🧹 Clean sweep!</span>
|
||||||
</Transition>
|
|
||||||
<Transition @enter="onBadgeEnter" :css="false">
|
|
||||||
<span v-if="midnightLabeler" class="badge badge-midnight">🦉 Midnight labeler!</span>
|
<span v-if="midnightLabeler" class="badge badge-midnight">🦉 Midnight labeler!</span>
|
||||||
</Transition>
|
|
||||||
</span>
|
</span>
|
||||||
<div class="header-actions">
|
<div class="header-actions">
|
||||||
<button @click="handleUndo" :disabled="!store.lastAction" class="btn-action">↩ Undo</button>
|
<button @click="handleUndo" :disabled="!store.lastAction" class="btn-action">↩ Undo</button>
|
||||||
|
|
@ -81,7 +69,7 @@
|
||||||
/>
|
/>
|
||||||
</div>
|
</div>
|
||||||
|
|
||||||
<div ref="gridEl" class="bucket-grid-footer" :class="{ 'grid-active': isHeld }">
|
<div class="bucket-grid-footer" :class="{ 'grid-active': isHeld }">
|
||||||
<LabelBucketGrid
|
<LabelBucketGrid
|
||||||
:labels="labels"
|
:labels="labels"
|
||||||
:is-bucket-mode="isHeld"
|
:is-bucket-mode="isHeld"
|
||||||
|
|
@ -102,8 +90,7 @@
|
||||||
</template>
|
</template>
|
||||||
|
|
||||||
<script setup lang="ts">
|
<script setup lang="ts">
|
||||||
import { ref, watch, onMounted, onUnmounted } from 'vue'
|
import { ref, onMounted, onUnmounted } from 'vue'
|
||||||
import { animate } from 'animejs'
|
|
||||||
import { useLabelStore } from '../stores/label'
|
import { useLabelStore } from '../stores/label'
|
||||||
import { useApiFetch } from '../composables/useApi'
|
import { useApiFetch } from '../composables/useApi'
|
||||||
import { useHaptics } from '../composables/useHaptics'
|
import { useHaptics } from '../composables/useHaptics'
|
||||||
|
|
@ -118,8 +105,6 @@ const store = useLabelStore()
|
||||||
const haptics = useHaptics()
|
const haptics = useHaptics()
|
||||||
const motion = useMotion() // only needed to pass to child — actual value used in App.vue
|
const motion = useMotion() // only needed to pass to child — actual value used in App.vue
|
||||||
|
|
||||||
const gridEl = ref<HTMLElement | null>(null)
|
|
||||||
|
|
||||||
const loading = ref(true)
|
const loading = ref(true)
|
||||||
const apiError = ref(false)
|
const apiError = ref(false)
|
||||||
const isHeld = ref(false)
|
const isHeld = ref(false)
|
||||||
|
|
@ -128,22 +113,6 @@ const hoveredBucket = ref<string | null>(null)
|
||||||
const labels = ref<any[]>([])
|
const labels = ref<any[]>([])
|
||||||
const dismissType = ref<'label' | 'skip' | 'discard' | null>(null)
|
const dismissType = ref<'label' | 'skip' | 'discard' | null>(null)
|
||||||
|
|
||||||
watch(isHeld, (held) => {
|
|
||||||
if (!motion.rich.value || !gridEl.value) return
|
|
||||||
animate(gridEl.value,
|
|
||||||
held
|
|
||||||
? { y: -8, opacity: 0.45, ease: 'out(4)', duration: 380 }
|
|
||||||
: { y: 0, opacity: 1, ease: 'out(4)', duration: 320 }
|
|
||||||
)
|
|
||||||
})
|
|
||||||
|
|
||||||
function onBadgeEnter(el: Element, done: () => void) {
|
|
||||||
if (!motion.rich.value) { done(); return }
|
|
||||||
animate(el as HTMLElement,
|
|
||||||
{ scale: [0.6, 1], opacity: [0, 1], ease: spring({ mass: 1.5, stiffness: 80, damping: 8 }), duration: 300, onComplete: done }
|
|
||||||
)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Easter egg state
|
// Easter egg state
|
||||||
const consecutiveLabeled = ref(0)
|
const consecutiveLabeled = ref(0)
|
||||||
const recentLabels = ref<number[]>([])
|
const recentLabels = ref<number[]>([])
|
||||||
|
|
@ -345,8 +314,8 @@ onUnmounted(() => {
|
||||||
padding: 1rem;
|
padding: 1rem;
|
||||||
max-width: 640px;
|
max-width: 640px;
|
||||||
margin: 0 auto;
|
margin: 0 auto;
|
||||||
min-height: 100dvh;
|
height: 100dvh; /* hard cap — prevents grid from drifting below fold */
|
||||||
overflow-x: hidden; /* prevent card animations from causing horizontal scroll */
|
overflow: hidden;
|
||||||
}
|
}
|
||||||
|
|
||||||
.queue-status {
|
.queue-status {
|
||||||
|
|
@ -383,6 +352,12 @@ onUnmounted(() => {
|
||||||
font-size: 0.75rem;
|
font-size: 0.75rem;
|
||||||
font-weight: 700;
|
font-weight: 700;
|
||||||
font-family: var(--font-body, sans-serif);
|
font-family: var(--font-body, sans-serif);
|
||||||
|
animation: badge-pop 0.3s cubic-bezier(0.34, 1.56, 0.64, 1);
|
||||||
|
}
|
||||||
|
|
||||||
|
@keyframes badge-pop {
|
||||||
|
from { transform: scale(0.6); opacity: 0; }
|
||||||
|
to { transform: scale(1); opacity: 1; }
|
||||||
}
|
}
|
||||||
|
|
||||||
.badge-roll { background: #ff6b35; color: #fff; }
|
.badge-roll { background: #ff6b35; color: #fff; }
|
||||||
|
|
@ -449,10 +424,13 @@ onUnmounted(() => {
|
||||||
|
|
||||||
.card-stack-wrapper {
|
.card-stack-wrapper {
|
||||||
flex: 1;
|
flex: 1;
|
||||||
min-height: 0;
|
min-height: 0; /* allow flex child to shrink — default auto prevents this */
|
||||||
|
overflow-y: auto;
|
||||||
padding-bottom: 0.5rem;
|
padding-bottom: 0.5rem;
|
||||||
|
transition: opacity 200ms ease;
|
||||||
}
|
}
|
||||||
/* When held: escape overflow clip so ball floats freely above the footer. */
|
/* When held: escape the overflow clip so the ball floats freely,
|
||||||
|
and rise above the footer (z-index 10) so the ball is visible. */
|
||||||
.card-stack-wrapper.is-held {
|
.card-stack-wrapper.is-held {
|
||||||
overflow: visible;
|
overflow: visible;
|
||||||
position: relative;
|
position: relative;
|
||||||
|
|
@ -463,17 +441,16 @@ onUnmounted(() => {
|
||||||
can be scrolled freely. "hired" (10th button) may clip on very small screens
|
can be scrolled freely. "hired" (10th button) may clip on very small screens
|
||||||
— that is intentional per design. */
|
— that is intentional per design. */
|
||||||
.bucket-grid-footer {
|
.bucket-grid-footer {
|
||||||
position: sticky;
|
|
||||||
bottom: 0;
|
|
||||||
background: var(--color-bg, var(--color-surface, #f0f4fc));
|
background: var(--color-bg, var(--color-surface, #f0f4fc));
|
||||||
padding: 0.5rem 0 0.75rem;
|
padding: 0.5rem 0 0.75rem;
|
||||||
z-index: 10;
|
z-index: 10;
|
||||||
|
transition: transform 250ms cubic-bezier(0.34, 1.56, 0.64, 1),
|
||||||
|
opacity 200ms ease,
|
||||||
|
background 200ms ease;
|
||||||
}
|
}
|
||||||
/* During toss: stay sticky so the grid holds its natural column position
|
|
||||||
(fixed caused a horizontal jump on desktop due to sidebar offset).
|
|
||||||
Opacity and translateY(-8px) are owned by Anime.js. */
|
|
||||||
.bucket-grid-footer.grid-active {
|
.bucket-grid-footer.grid-active {
|
||||||
opacity: 0.45;
|
transform: translateY(-8px);
|
||||||
|
opacity: 0.45; /* semi-transparent so ball aura is visible through it */
|
||||||
}
|
}
|
||||||
|
|
||||||
/* ── Toss edge zones ── */
|
/* ── Toss edge zones ── */
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue