From 7c304ebc459d4b4ffe53dd781383d0fae1f1efbf Mon Sep 17 00:00:00 2001 From: pyr0ball Date: Wed, 8 Apr 2026 23:03:56 -0700 Subject: [PATCH] feat: benchmark model picker, category grouping, stats benchmark results MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Backend (app/api.py): - GET /api/benchmark/models β€” returns installed models grouped by adapter type (ZeroShotAdapter, RerankerAdapter, GenerationAdapter, Unknown); reads _MODELS_DIR via app.models so test overrides are respected - GET /api/benchmark/run β€” add model_names query param (comma-separated); when set, passes --models to benchmark_classifier.py - GET /api/stats β€” add benchmark_results field from benchmark_results.json Frontend: - BenchmarkView: collapsible Model Selection panel with per-category checkboxes, select-all per category (supports indeterminate state), collapsed summary badge ("All models (N)" or "N of M selected"); model_names only sent when a strict subset is selected - StatsView: Benchmark Results table (accuracy, macro_f1, weighted_f1) with best-model highlighting per metric; hidden when no results exist --- app/api.py | 45 +++++- web/src/views/BenchmarkView.vue | 251 +++++++++++++++++++++++++++++++- web/src/views/StatsView.vue | 150 +++++++++++++++++++ 3 files changed, 444 insertions(+), 2 deletions(-) diff --git a/app/api.py b/app/api.py index a96b88c..0788e13 100644 --- a/app/api.py +++ b/app/api.py @@ -146,6 +146,7 @@ from app.sft import router as sft_router app.include_router(sft_router, prefix="/api/sft") from app.models import router as models_router +import app.models as _models_module app.include_router(models_router, prefix="/api/models") # In-memory last-action store (single user, local tool β€” in-memory is fine) @@ -301,10 +302,18 @@ def get_stats(): lbl = r.get("label", "") if lbl: counts[lbl] = counts.get(lbl, 0) + 1 + benchmark_results: dict = {} + benchmark_path = _DATA_DIR / "benchmark_results.json" + if benchmark_path.exists(): + try: + benchmark_results = json.loads(benchmark_path.read_text(encoding="utf-8")) + except Exception: + pass return { "total": len(records), "counts": counts, "score_file_bytes": _score_file().stat().st_size if _score_file().exists() else 0, + "benchmark_results": benchmark_results, } @@ -339,6 +348,36 @@ from fastapi.responses import StreamingResponse # Benchmark endpoints # --------------------------------------------------------------------------- +@app.get("/api/benchmark/models") +def get_benchmark_models() -> dict: + """Return installed models grouped by adapter_type category.""" + models_dir: Path = _models_module._MODELS_DIR + categories: dict[str, list[dict]] = { + "ZeroShotAdapter": [], + "RerankerAdapter": [], + "GenerationAdapter": [], + "Unknown": [], + } + if models_dir.exists(): + for sub in models_dir.iterdir(): + if not sub.is_dir(): + continue + info_path = sub / "model_info.json" + adapter_type = "Unknown" + repo_id: str | None = None + if info_path.exists(): + try: + info = json.loads(info_path.read_text(encoding="utf-8")) + adapter_type = info.get("adapter_type") or info.get("adapter_recommendation") or "Unknown" + repo_id = info.get("repo_id") + except Exception: + pass + bucket = adapter_type if adapter_type in categories else "Unknown" + entry: dict = {"name": sub.name, "repo_id": repo_id, "adapter_type": adapter_type} + categories[bucket].append(entry) + return {"categories": categories} + + @app.get("/api/benchmark/results") def get_benchmark_results(): """Return the most recently saved benchmark results, or an empty envelope.""" @@ -349,13 +388,17 @@ def get_benchmark_results(): @app.get("/api/benchmark/run") -def run_benchmark(include_slow: bool = False): +def run_benchmark(include_slow: bool = False, model_names: str = ""): """Spawn the benchmark script and stream stdout as SSE progress events.""" python_bin = "/devl/miniconda3/envs/job-seeker-classifiers/bin/python" script = str(_ROOT / "scripts" / "benchmark_classifier.py") cmd = [python_bin, script, "--score", "--save"] if include_slow: cmd.append("--include-slow") + if model_names: + names = [n.strip() for n in model_names.split(",") if n.strip()] + if names: + cmd.extend(["--models"] + names) def generate(): try: diff --git a/web/src/views/BenchmarkView.vue b/web/src/views/BenchmarkView.vue index 7351b42..8497181 100644 --- a/web/src/views/BenchmarkView.vue +++ b/web/src/views/BenchmarkView.vue @@ -24,6 +24,54 @@ + +
+ + 🎯 Model Selection + {{ pickerSummaryText }} + +
+
Loading models…
+
+ No models found β€” check API connection. +
+ +
+
+
Trained: @@ -224,6 +272,16 @@ const LABEL_META: Record = { } // ── Types ──────────────────────────────────────────────────────────────────── +interface AvailableModel { + name: string + repo_id?: string + adapter_type: string +} + +interface ModelCategoriesResponse { + categories: Record +} + interface FineTunedModel { name: string base_model_id?: string @@ -254,6 +312,13 @@ const runError = ref('') const includeSlow = ref(false) const logEl = ref(null) +// Model picker state +const modelCategories = ref>({}) +const selectedModels = ref>(new Set()) +const allModels = ref([]) +const modelsLoading = ref(false) +const pickerEl = ref(null) + // Fine-tune state const fineTunedModels = ref([]) const ftModel = ref('deberta-small') @@ -274,6 +339,52 @@ async function cancelFinetune() { await fetch('/api/finetune/cancel', { method: 'POST' }).catch(() => {}) } +// ── Model picker computed ───────────────────────────────────────────────────── +const pickerSummaryText = computed(() => { + const total = allModels.value.length + if (total === 0) return 'No models available' + const selected = selectedModels.value.size + if (selected === total) return `All models (${total})` + return `${selected} of ${total} selected` +}) + +function isCategoryAllSelected(models: AvailableModel[]): boolean { + return models.length > 0 && models.every(m => selectedModels.value.has(m.name)) +} + +function isCategoryIndeterminate(models: AvailableModel[]): boolean { + const someSelected = models.some(m => selectedModels.value.has(m.name)) + return someSelected && !isCategoryAllSelected(models) +} + +function toggleModel(name: string, checked: boolean) { + const next = new Set(selectedModels.value) + if (checked) next.add(name) + else next.delete(name) + selectedModels.value = next +} + +function toggleCategory(models: AvailableModel[], checked: boolean) { + const next = new Set(selectedModels.value) + for (const m of models) { + if (checked) next.add(m.name) + else next.delete(m.name) + } + selectedModels.value = next +} + +async function loadModelCategories() { + modelsLoading.value = true + const { data } = await useApiFetch('/api/benchmark/models') + modelsLoading.value = false + if (data?.categories) { + modelCategories.value = data.categories + const flat = Object.values(data.categories).flat().map(m => m.name) + allModels.value = flat + selectedModels.value = new Set(flat) + } +} + // ── Derived ────────────────────────────────────────────────────────────────── const modelNames = computed(() => Object.keys(results.value?.models ?? {})) const modelCount = computed(() => modelNames.value.length) @@ -355,7 +466,16 @@ function startBenchmark() { runError.value = '' runCancelled.value = false - const url = `/api/benchmark/run${includeSlow.value ? '?include_slow=true' : ''}` + const params = new URLSearchParams() + if (includeSlow.value) params.set('include_slow', 'true') + // Only send model_names when a subset is selected (not all, not none) + const total = allModels.value.length + const selected = selectedModels.value.size + if (total > 0 && selected > 0 && selected < total) { + params.set('model_names', [...selectedModels.value].join(',')) + } + const qs = params.toString() + const url = `/api/benchmark/run${qs ? `?${qs}` : ''}` useApiSSE( url, async (event) => { @@ -427,6 +547,7 @@ function startFinetune() { onMounted(() => { loadResults() loadFineTunedModels() + loadModelCategories() }) @@ -762,6 +883,134 @@ onMounted(() => { font-weight: 700; } +/* ── Model Picker ───────────────────────────────────────── */ +.model-picker { + border: 1px solid var(--color-border, #d0d7e8); + border-radius: 0.5rem; + overflow: hidden; +} + +.picker-summary { + display: flex; + align-items: center; + gap: 0.6rem; + padding: 0.65rem 0.9rem; + cursor: pointer; + user-select: none; + list-style: none; + background: var(--color-surface-raised, #e4ebf5); +} +.picker-summary::-webkit-details-marker { display: none; } +.picker-summary::before { content: 'β–Ά '; font-size: 0.65rem; color: var(--color-text-secondary, #6b7a99); } +details[open] .picker-summary::before { content: 'β–Ό '; } + +.picker-title { + font-size: 0.9rem; + font-weight: 600; + color: var(--color-text, #1a2338); +} + +.picker-badge { + font-size: 0.75rem; + color: var(--color-text-secondary, #6b7a99); + background: var(--color-surface, #fff); + border: 1px solid var(--color-border, #d0d7e8); + padding: 0.15rem 0.5rem; + border-radius: 1rem; + font-family: var(--font-mono, monospace); + margin-left: auto; +} + +.picker-body { + padding: 0.75rem; + border-top: 1px solid var(--color-border, #d0d7e8); + display: flex; + flex-direction: column; + gap: 0.75rem; +} + +.picker-loading, +.picker-empty { + font-size: 0.85rem; + color: var(--color-text-secondary, #6b7a99); + padding: 0.5rem 0; +} + +.picker-category { + display: flex; + flex-direction: column; + gap: 0.3rem; +} + +.picker-cat-header { + display: flex; + align-items: center; + gap: 0.45rem; + font-size: 0.82rem; + font-weight: 700; + color: var(--color-text, #1a2338); + text-transform: uppercase; + letter-spacing: 0.04em; + cursor: pointer; +} + +.picker-cat-count { + font-weight: 400; + color: var(--color-text-secondary, #6b7a99); + font-family: var(--font-mono, monospace); + font-size: 0.75rem; + text-transform: none; + letter-spacing: 0; +} + +.picker-no-models { + font-size: 0.78rem; + color: var(--color-text-secondary, #6b7a99); + opacity: 0.65; + padding-left: 1.4rem; + font-style: italic; +} + +.picker-model-list { + display: flex; + flex-wrap: wrap; + gap: 0.35rem 0.75rem; + padding-left: 1.4rem; +} + +.picker-model-row { + display: flex; + align-items: center; + gap: 0.35rem; + font-size: 0.82rem; + cursor: pointer; + color: var(--color-text, #1a2338); +} + +.picker-model-name { + font-family: var(--font-mono, monospace); + font-size: 0.78rem; + white-space: nowrap; + max-width: 18ch; + overflow: hidden; + text-overflow: ellipsis; +} + +.picker-adapter-type { + font-size: 0.68rem; + color: var(--color-text-secondary, #6b7a99); + background: var(--color-surface-raised, #e4ebf5); + border: 1px solid var(--color-border, #d0d7e8); + border-radius: 0.25rem; + padding: 0.05rem 0.3rem; + font-family: var(--font-mono, monospace); +} + +@media (max-width: 600px) { + .picker-model-list { padding-left: 0; } + .picker-model-name { max-width: 14ch; } +} + /* ── Fine-tune section ──────────────────────────────────── */ .ft-section { border: 1px solid var(--color-border, #d0d7e8); diff --git a/web/src/views/StatsView.vue b/web/src/views/StatsView.vue index d1d3cd8..5290553 100644 --- a/web/src/views/StatsView.vue +++ b/web/src/views/StatsView.vue @@ -35,6 +35,39 @@
+ + +
Score file: data/email_score.jsonl {{ fileSizeLabel }} @@ -54,10 +87,18 @@ import { ref, computed, onMounted } from 'vue' import { useApiFetch } from '../composables/useApi' +interface BenchmarkModelResult { + accuracy?: number + macro_f1?: number + weighted_f1?: number + [key: string]: number | undefined +} + interface StatsResponse { total: number counts: Record score_file_bytes: number + benchmark_results?: Record } // Canonical label order + metadata @@ -108,6 +149,42 @@ const fileSizeLabel = computed(() => { return `${(b / 1024 / 1024).toFixed(2)} MB` }) +// Benchmark results helpers +const BENCH_METRICS: Array<{ key: keyof BenchmarkModelResult; label: string }> = [ + { key: 'accuracy', label: 'Accuracy' }, + { key: 'macro_f1', label: 'Macro F1' }, + { key: 'weighted_f1', label: 'Weighted F1' }, +] + +const benchRows = computed(() => { + const br = stats.value.benchmark_results + if (!br || Object.keys(br).length === 0) return [] + return Object.entries(br).map(([name, result]) => ({ name, result })) +}) + +// Find the best model name for each metric +const bestByMetric = computed((): Record => { + const result: Record = {} + for (const { key } of BENCH_METRICS) { + let bestName = '' + let bestVal = -Infinity + for (const { name, result: r } of benchRows.value) { + const v = r[key] + if (v != null && v > bestVal) { bestVal = v; bestName = name } + } + result[key as string] = bestName + } + return result +}) + +function formatMetric(v: number | undefined): string { + if (v == null) return 'β€”' + // Values in 0-1 range: format as percentage + if (v <= 1) return `${(v * 100).toFixed(1)}%` + // Already a percentage + return `${v.toFixed(1)}%` +} + async function load() { loading.value = true error.value = '' @@ -234,6 +311,79 @@ onMounted(load) padding: 1rem; } +/* ── Benchmark Results ──────────────────────────── */ +.section-title { + font-family: var(--font-display, var(--font-body, sans-serif)); + font-size: 1.05rem; + font-weight: 700; + color: var(--app-primary, #2A6080); + margin: 0; +} + +.bench-table-wrap { + overflow-x: auto; + border: 1px solid var(--color-border, #d0d7e8); + border-radius: 0.5rem; +} + +.bench-table { + border-collapse: collapse; + width: 100%; + font-size: 0.82rem; +} + +.bt-model-col { + text-align: left; + padding: 0.45rem 0.75rem; + background: var(--color-surface-raised, #e4ebf5); + border-bottom: 1px solid var(--color-border, #d0d7e8); + font-weight: 600; + min-width: 12rem; +} + +.bt-metric-col { + text-align: right; + padding: 0.45rem 0.75rem; + background: var(--color-surface-raised, #e4ebf5); + border-bottom: 1px solid var(--color-border, #d0d7e8); + font-weight: 600; + white-space: nowrap; + min-width: 6rem; +} + +.bt-model-cell { + padding: 0.4rem 0.75rem; + border-top: 1px solid var(--color-border, #d0d7e8); + font-family: var(--font-mono, monospace); + font-size: 0.76rem; + white-space: nowrap; + overflow: hidden; + text-overflow: ellipsis; + max-width: 16rem; + color: var(--color-text, #1a2338); +} + +.bt-metric-cell { + padding: 0.4rem 0.75rem; + border-top: 1px solid var(--color-border, #d0d7e8); + text-align: right; + font-family: var(--font-mono, monospace); + font-variant-numeric: tabular-nums; + color: var(--color-text, #1a2338); +} + +.bt-metric-cell.bt-best { + color: var(--color-success, #3a7a32); + font-weight: 700; + background: color-mix(in srgb, var(--color-success, #3a7a32) 8%, transparent); +} + +.bench-hint { + font-size: 0.75rem; + color: var(--color-text-secondary, #6b7a99); + margin: 0; +} + @media (max-width: 480px) { .bar-row { grid-template-columns: 1.5rem 1fr 1fr 3rem;