Merge pull request 'feat: benchmark model picker, category grouping, stats benchmark results' (#20) from feat/benchmark-model-picker into main
This commit is contained in:
commit
49ec85706c
3 changed files with 444 additions and 2 deletions
45
app/api.py
45
app/api.py
|
|
@ -146,6 +146,7 @@ from app.sft import router as sft_router
|
|||
app.include_router(sft_router, prefix="/api/sft")
|
||||
|
||||
from app.models import router as models_router
|
||||
import app.models as _models_module
|
||||
app.include_router(models_router, prefix="/api/models")
|
||||
|
||||
# In-memory last-action store (single user, local tool — in-memory is fine)
|
||||
|
|
@ -301,10 +302,18 @@ def get_stats():
|
|||
lbl = r.get("label", "")
|
||||
if lbl:
|
||||
counts[lbl] = counts.get(lbl, 0) + 1
|
||||
benchmark_results: dict = {}
|
||||
benchmark_path = _DATA_DIR / "benchmark_results.json"
|
||||
if benchmark_path.exists():
|
||||
try:
|
||||
benchmark_results = json.loads(benchmark_path.read_text(encoding="utf-8"))
|
||||
except Exception:
|
||||
pass
|
||||
return {
|
||||
"total": len(records),
|
||||
"counts": counts,
|
||||
"score_file_bytes": _score_file().stat().st_size if _score_file().exists() else 0,
|
||||
"benchmark_results": benchmark_results,
|
||||
}
|
||||
|
||||
|
||||
|
|
@ -339,6 +348,36 @@ from fastapi.responses import StreamingResponse
|
|||
# Benchmark endpoints
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
@app.get("/api/benchmark/models")
|
||||
def get_benchmark_models() -> dict:
|
||||
"""Return installed models grouped by adapter_type category."""
|
||||
models_dir: Path = _models_module._MODELS_DIR
|
||||
categories: dict[str, list[dict]] = {
|
||||
"ZeroShotAdapter": [],
|
||||
"RerankerAdapter": [],
|
||||
"GenerationAdapter": [],
|
||||
"Unknown": [],
|
||||
}
|
||||
if models_dir.exists():
|
||||
for sub in models_dir.iterdir():
|
||||
if not sub.is_dir():
|
||||
continue
|
||||
info_path = sub / "model_info.json"
|
||||
adapter_type = "Unknown"
|
||||
repo_id: str | None = None
|
||||
if info_path.exists():
|
||||
try:
|
||||
info = json.loads(info_path.read_text(encoding="utf-8"))
|
||||
adapter_type = info.get("adapter_type") or info.get("adapter_recommendation") or "Unknown"
|
||||
repo_id = info.get("repo_id")
|
||||
except Exception:
|
||||
pass
|
||||
bucket = adapter_type if adapter_type in categories else "Unknown"
|
||||
entry: dict = {"name": sub.name, "repo_id": repo_id, "adapter_type": adapter_type}
|
||||
categories[bucket].append(entry)
|
||||
return {"categories": categories}
|
||||
|
||||
|
||||
@app.get("/api/benchmark/results")
|
||||
def get_benchmark_results():
|
||||
"""Return the most recently saved benchmark results, or an empty envelope."""
|
||||
|
|
@ -349,13 +388,17 @@ def get_benchmark_results():
|
|||
|
||||
|
||||
@app.get("/api/benchmark/run")
|
||||
def run_benchmark(include_slow: bool = False):
|
||||
def run_benchmark(include_slow: bool = False, model_names: str = ""):
|
||||
"""Spawn the benchmark script and stream stdout as SSE progress events."""
|
||||
python_bin = "/devl/miniconda3/envs/job-seeker-classifiers/bin/python"
|
||||
script = str(_ROOT / "scripts" / "benchmark_classifier.py")
|
||||
cmd = [python_bin, script, "--score", "--save"]
|
||||
if include_slow:
|
||||
cmd.append("--include-slow")
|
||||
if model_names:
|
||||
names = [n.strip() for n in model_names.split(",") if n.strip()]
|
||||
if names:
|
||||
cmd.extend(["--models"] + names)
|
||||
|
||||
def generate():
|
||||
try:
|
||||
|
|
|
|||
|
|
@ -24,6 +24,54 @@
|
|||
</div>
|
||||
</header>
|
||||
|
||||
<!-- Model Picker -->
|
||||
<details class="model-picker" ref="pickerEl">
|
||||
<summary class="picker-summary">
|
||||
<span class="picker-title">🎯 Model Selection</span>
|
||||
<span class="picker-badge">{{ pickerSummaryText }}</span>
|
||||
</summary>
|
||||
<div class="picker-body">
|
||||
<div v-if="modelsLoading" class="picker-loading">Loading models…</div>
|
||||
<div v-else-if="Object.keys(modelCategories).length === 0" class="picker-empty">
|
||||
No models found — check API connection.
|
||||
</div>
|
||||
<template v-else>
|
||||
<div
|
||||
v-for="(models, category) in modelCategories"
|
||||
:key="category"
|
||||
class="picker-category"
|
||||
>
|
||||
<label class="picker-cat-header">
|
||||
<input
|
||||
type="checkbox"
|
||||
:checked="isCategoryAllSelected(models)"
|
||||
:indeterminate="isCategoryIndeterminate(models)"
|
||||
@change="toggleCategory(models, ($event.target as HTMLInputElement).checked)"
|
||||
/>
|
||||
<span class="picker-cat-name">{{ category }}</span>
|
||||
<span class="picker-cat-count">({{ models.length }})</span>
|
||||
</label>
|
||||
<div v-if="models.length === 0" class="picker-no-models">No models installed</div>
|
||||
<div v-else class="picker-model-list">
|
||||
<label
|
||||
v-for="m in models"
|
||||
:key="m.name"
|
||||
class="picker-model-row"
|
||||
>
|
||||
<input
|
||||
type="checkbox"
|
||||
:checked="selectedModels.has(m.name)"
|
||||
@change="toggleModel(m.name, ($event.target as HTMLInputElement).checked)"
|
||||
/>
|
||||
<span class="picker-model-name" :title="m.repo_id ?? m.name">{{ m.name }}</span>
|
||||
<span class="picker-adapter-type">{{ m.adapter_type }}</span>
|
||||
</label>
|
||||
</div>
|
||||
</div>
|
||||
</template>
|
||||
</div>
|
||||
</details>
|
||||
|
||||
<!-- Trained models badge row -->
|
||||
<div v-if="fineTunedModels.length > 0" class="trained-models-row">
|
||||
<span class="trained-label">Trained:</span>
|
||||
|
|
@ -224,6 +272,16 @@ const LABEL_META: Record<string, { emoji: string }> = {
|
|||
}
|
||||
|
||||
// ── Types ────────────────────────────────────────────────────────────────────
|
||||
interface AvailableModel {
|
||||
name: string
|
||||
repo_id?: string
|
||||
adapter_type: string
|
||||
}
|
||||
|
||||
interface ModelCategoriesResponse {
|
||||
categories: Record<string, AvailableModel[]>
|
||||
}
|
||||
|
||||
interface FineTunedModel {
|
||||
name: string
|
||||
base_model_id?: string
|
||||
|
|
@ -254,6 +312,13 @@ const runError = ref('')
|
|||
const includeSlow = ref(false)
|
||||
const logEl = ref<HTMLElement | null>(null)
|
||||
|
||||
// Model picker state
|
||||
const modelCategories = ref<Record<string, AvailableModel[]>>({})
|
||||
const selectedModels = ref<Set<string>>(new Set())
|
||||
const allModels = ref<string[]>([])
|
||||
const modelsLoading = ref(false)
|
||||
const pickerEl = ref<HTMLDetailsElement | null>(null)
|
||||
|
||||
// Fine-tune state
|
||||
const fineTunedModels = ref<FineTunedModel[]>([])
|
||||
const ftModel = ref('deberta-small')
|
||||
|
|
@ -274,6 +339,52 @@ async function cancelFinetune() {
|
|||
await fetch('/api/finetune/cancel', { method: 'POST' }).catch(() => {})
|
||||
}
|
||||
|
||||
// ── Model picker computed ─────────────────────────────────────────────────────
|
||||
const pickerSummaryText = computed(() => {
|
||||
const total = allModels.value.length
|
||||
if (total === 0) return 'No models available'
|
||||
const selected = selectedModels.value.size
|
||||
if (selected === total) return `All models (${total})`
|
||||
return `${selected} of ${total} selected`
|
||||
})
|
||||
|
||||
function isCategoryAllSelected(models: AvailableModel[]): boolean {
|
||||
return models.length > 0 && models.every(m => selectedModels.value.has(m.name))
|
||||
}
|
||||
|
||||
function isCategoryIndeterminate(models: AvailableModel[]): boolean {
|
||||
const someSelected = models.some(m => selectedModels.value.has(m.name))
|
||||
return someSelected && !isCategoryAllSelected(models)
|
||||
}
|
||||
|
||||
function toggleModel(name: string, checked: boolean) {
|
||||
const next = new Set(selectedModels.value)
|
||||
if (checked) next.add(name)
|
||||
else next.delete(name)
|
||||
selectedModels.value = next
|
||||
}
|
||||
|
||||
function toggleCategory(models: AvailableModel[], checked: boolean) {
|
||||
const next = new Set(selectedModels.value)
|
||||
for (const m of models) {
|
||||
if (checked) next.add(m.name)
|
||||
else next.delete(m.name)
|
||||
}
|
||||
selectedModels.value = next
|
||||
}
|
||||
|
||||
async function loadModelCategories() {
|
||||
modelsLoading.value = true
|
||||
const { data } = await useApiFetch<ModelCategoriesResponse>('/api/benchmark/models')
|
||||
modelsLoading.value = false
|
||||
if (data?.categories) {
|
||||
modelCategories.value = data.categories
|
||||
const flat = Object.values(data.categories).flat().map(m => m.name)
|
||||
allModels.value = flat
|
||||
selectedModels.value = new Set(flat)
|
||||
}
|
||||
}
|
||||
|
||||
// ── Derived ──────────────────────────────────────────────────────────────────
|
||||
const modelNames = computed(() => Object.keys(results.value?.models ?? {}))
|
||||
const modelCount = computed(() => modelNames.value.length)
|
||||
|
|
@ -355,7 +466,16 @@ function startBenchmark() {
|
|||
runError.value = ''
|
||||
runCancelled.value = false
|
||||
|
||||
const url = `/api/benchmark/run${includeSlow.value ? '?include_slow=true' : ''}`
|
||||
const params = new URLSearchParams()
|
||||
if (includeSlow.value) params.set('include_slow', 'true')
|
||||
// Only send model_names when a subset is selected (not all, not none)
|
||||
const total = allModels.value.length
|
||||
const selected = selectedModels.value.size
|
||||
if (total > 0 && selected > 0 && selected < total) {
|
||||
params.set('model_names', [...selectedModels.value].join(','))
|
||||
}
|
||||
const qs = params.toString()
|
||||
const url = `/api/benchmark/run${qs ? `?${qs}` : ''}`
|
||||
useApiSSE(
|
||||
url,
|
||||
async (event) => {
|
||||
|
|
@ -427,6 +547,7 @@ function startFinetune() {
|
|||
onMounted(() => {
|
||||
loadResults()
|
||||
loadFineTunedModels()
|
||||
loadModelCategories()
|
||||
})
|
||||
</script>
|
||||
|
||||
|
|
@ -762,6 +883,134 @@ onMounted(() => {
|
|||
font-weight: 700;
|
||||
}
|
||||
|
||||
/* ── Model Picker ───────────────────────────────────────── */
|
||||
.model-picker {
|
||||
border: 1px solid var(--color-border, #d0d7e8);
|
||||
border-radius: 0.5rem;
|
||||
overflow: hidden;
|
||||
}
|
||||
|
||||
.picker-summary {
|
||||
display: flex;
|
||||
align-items: center;
|
||||
gap: 0.6rem;
|
||||
padding: 0.65rem 0.9rem;
|
||||
cursor: pointer;
|
||||
user-select: none;
|
||||
list-style: none;
|
||||
background: var(--color-surface-raised, #e4ebf5);
|
||||
}
|
||||
.picker-summary::-webkit-details-marker { display: none; }
|
||||
.picker-summary::before { content: '▶ '; font-size: 0.65rem; color: var(--color-text-secondary, #6b7a99); }
|
||||
details[open] .picker-summary::before { content: '▼ '; }
|
||||
|
||||
.picker-title {
|
||||
font-size: 0.9rem;
|
||||
font-weight: 600;
|
||||
color: var(--color-text, #1a2338);
|
||||
}
|
||||
|
||||
.picker-badge {
|
||||
font-size: 0.75rem;
|
||||
color: var(--color-text-secondary, #6b7a99);
|
||||
background: var(--color-surface, #fff);
|
||||
border: 1px solid var(--color-border, #d0d7e8);
|
||||
padding: 0.15rem 0.5rem;
|
||||
border-radius: 1rem;
|
||||
font-family: var(--font-mono, monospace);
|
||||
margin-left: auto;
|
||||
}
|
||||
|
||||
.picker-body {
|
||||
padding: 0.75rem;
|
||||
border-top: 1px solid var(--color-border, #d0d7e8);
|
||||
display: flex;
|
||||
flex-direction: column;
|
||||
gap: 0.75rem;
|
||||
}
|
||||
|
||||
.picker-loading,
|
||||
.picker-empty {
|
||||
font-size: 0.85rem;
|
||||
color: var(--color-text-secondary, #6b7a99);
|
||||
padding: 0.5rem 0;
|
||||
}
|
||||
|
||||
.picker-category {
|
||||
display: flex;
|
||||
flex-direction: column;
|
||||
gap: 0.3rem;
|
||||
}
|
||||
|
||||
.picker-cat-header {
|
||||
display: flex;
|
||||
align-items: center;
|
||||
gap: 0.45rem;
|
||||
font-size: 0.82rem;
|
||||
font-weight: 700;
|
||||
color: var(--color-text, #1a2338);
|
||||
text-transform: uppercase;
|
||||
letter-spacing: 0.04em;
|
||||
cursor: pointer;
|
||||
}
|
||||
|
||||
.picker-cat-count {
|
||||
font-weight: 400;
|
||||
color: var(--color-text-secondary, #6b7a99);
|
||||
font-family: var(--font-mono, monospace);
|
||||
font-size: 0.75rem;
|
||||
text-transform: none;
|
||||
letter-spacing: 0;
|
||||
}
|
||||
|
||||
.picker-no-models {
|
||||
font-size: 0.78rem;
|
||||
color: var(--color-text-secondary, #6b7a99);
|
||||
opacity: 0.65;
|
||||
padding-left: 1.4rem;
|
||||
font-style: italic;
|
||||
}
|
||||
|
||||
.picker-model-list {
|
||||
display: flex;
|
||||
flex-wrap: wrap;
|
||||
gap: 0.35rem 0.75rem;
|
||||
padding-left: 1.4rem;
|
||||
}
|
||||
|
||||
.picker-model-row {
|
||||
display: flex;
|
||||
align-items: center;
|
||||
gap: 0.35rem;
|
||||
font-size: 0.82rem;
|
||||
cursor: pointer;
|
||||
color: var(--color-text, #1a2338);
|
||||
}
|
||||
|
||||
.picker-model-name {
|
||||
font-family: var(--font-mono, monospace);
|
||||
font-size: 0.78rem;
|
||||
white-space: nowrap;
|
||||
max-width: 18ch;
|
||||
overflow: hidden;
|
||||
text-overflow: ellipsis;
|
||||
}
|
||||
|
||||
.picker-adapter-type {
|
||||
font-size: 0.68rem;
|
||||
color: var(--color-text-secondary, #6b7a99);
|
||||
background: var(--color-surface-raised, #e4ebf5);
|
||||
border: 1px solid var(--color-border, #d0d7e8);
|
||||
border-radius: 0.25rem;
|
||||
padding: 0.05rem 0.3rem;
|
||||
font-family: var(--font-mono, monospace);
|
||||
}
|
||||
|
||||
@media (max-width: 600px) {
|
||||
.picker-model-list { padding-left: 0; }
|
||||
.picker-model-name { max-width: 14ch; }
|
||||
}
|
||||
|
||||
/* ── Fine-tune section ──────────────────────────────────── */
|
||||
.ft-section {
|
||||
border: 1px solid var(--color-border, #d0d7e8);
|
||||
|
|
|
|||
|
|
@ -35,6 +35,39 @@
|
|||
</div>
|
||||
</div>
|
||||
|
||||
<!-- Benchmark Results -->
|
||||
<template v-if="benchRows.length > 0">
|
||||
<h2 class="section-title">🏁 Benchmark Results</h2>
|
||||
<div class="bench-table-wrap">
|
||||
<table class="bench-table">
|
||||
<thead>
|
||||
<tr>
|
||||
<th class="bt-model-col">Model</th>
|
||||
<th
|
||||
v-for="m in BENCH_METRICS"
|
||||
:key="m.key as string"
|
||||
class="bt-metric-col"
|
||||
>{{ m.label }}</th>
|
||||
</tr>
|
||||
</thead>
|
||||
<tbody>
|
||||
<tr v-for="row in benchRows" :key="row.name">
|
||||
<td class="bt-model-cell" :title="row.name">{{ row.name }}</td>
|
||||
<td
|
||||
v-for="m in BENCH_METRICS"
|
||||
:key="m.key as string"
|
||||
class="bt-metric-cell"
|
||||
:class="{ 'bt-best': bestByMetric[m.key as string] === row.name }"
|
||||
>
|
||||
{{ formatMetric(row.result[m.key]) }}
|
||||
</td>
|
||||
</tr>
|
||||
</tbody>
|
||||
</table>
|
||||
</div>
|
||||
<p class="bench-hint">Highlighted cells are the best-scoring model per metric.</p>
|
||||
</template>
|
||||
|
||||
<div class="file-info">
|
||||
<span class="file-path">Score file: <code>data/email_score.jsonl</code></span>
|
||||
<span class="file-size">{{ fileSizeLabel }}</span>
|
||||
|
|
@ -54,10 +87,18 @@
|
|||
import { ref, computed, onMounted } from 'vue'
|
||||
import { useApiFetch } from '../composables/useApi'
|
||||
|
||||
interface BenchmarkModelResult {
|
||||
accuracy?: number
|
||||
macro_f1?: number
|
||||
weighted_f1?: number
|
||||
[key: string]: number | undefined
|
||||
}
|
||||
|
||||
interface StatsResponse {
|
||||
total: number
|
||||
counts: Record<string, number>
|
||||
score_file_bytes: number
|
||||
benchmark_results?: Record<string, BenchmarkModelResult>
|
||||
}
|
||||
|
||||
// Canonical label order + metadata
|
||||
|
|
@ -108,6 +149,42 @@ const fileSizeLabel = computed(() => {
|
|||
return `${(b / 1024 / 1024).toFixed(2)} MB`
|
||||
})
|
||||
|
||||
// Benchmark results helpers
|
||||
const BENCH_METRICS: Array<{ key: keyof BenchmarkModelResult; label: string }> = [
|
||||
{ key: 'accuracy', label: 'Accuracy' },
|
||||
{ key: 'macro_f1', label: 'Macro F1' },
|
||||
{ key: 'weighted_f1', label: 'Weighted F1' },
|
||||
]
|
||||
|
||||
const benchRows = computed(() => {
|
||||
const br = stats.value.benchmark_results
|
||||
if (!br || Object.keys(br).length === 0) return []
|
||||
return Object.entries(br).map(([name, result]) => ({ name, result }))
|
||||
})
|
||||
|
||||
// Find the best model name for each metric
|
||||
const bestByMetric = computed((): Record<string, string> => {
|
||||
const result: Record<string, string> = {}
|
||||
for (const { key } of BENCH_METRICS) {
|
||||
let bestName = ''
|
||||
let bestVal = -Infinity
|
||||
for (const { name, result: r } of benchRows.value) {
|
||||
const v = r[key]
|
||||
if (v != null && v > bestVal) { bestVal = v; bestName = name }
|
||||
}
|
||||
result[key as string] = bestName
|
||||
}
|
||||
return result
|
||||
})
|
||||
|
||||
function formatMetric(v: number | undefined): string {
|
||||
if (v == null) return '—'
|
||||
// Values in 0-1 range: format as percentage
|
||||
if (v <= 1) return `${(v * 100).toFixed(1)}%`
|
||||
// Already a percentage
|
||||
return `${v.toFixed(1)}%`
|
||||
}
|
||||
|
||||
async function load() {
|
||||
loading.value = true
|
||||
error.value = ''
|
||||
|
|
@ -234,6 +311,79 @@ onMounted(load)
|
|||
padding: 1rem;
|
||||
}
|
||||
|
||||
/* ── Benchmark Results ──────────────────────────── */
|
||||
.section-title {
|
||||
font-family: var(--font-display, var(--font-body, sans-serif));
|
||||
font-size: 1.05rem;
|
||||
font-weight: 700;
|
||||
color: var(--app-primary, #2A6080);
|
||||
margin: 0;
|
||||
}
|
||||
|
||||
.bench-table-wrap {
|
||||
overflow-x: auto;
|
||||
border: 1px solid var(--color-border, #d0d7e8);
|
||||
border-radius: 0.5rem;
|
||||
}
|
||||
|
||||
.bench-table {
|
||||
border-collapse: collapse;
|
||||
width: 100%;
|
||||
font-size: 0.82rem;
|
||||
}
|
||||
|
||||
.bt-model-col {
|
||||
text-align: left;
|
||||
padding: 0.45rem 0.75rem;
|
||||
background: var(--color-surface-raised, #e4ebf5);
|
||||
border-bottom: 1px solid var(--color-border, #d0d7e8);
|
||||
font-weight: 600;
|
||||
min-width: 12rem;
|
||||
}
|
||||
|
||||
.bt-metric-col {
|
||||
text-align: right;
|
||||
padding: 0.45rem 0.75rem;
|
||||
background: var(--color-surface-raised, #e4ebf5);
|
||||
border-bottom: 1px solid var(--color-border, #d0d7e8);
|
||||
font-weight: 600;
|
||||
white-space: nowrap;
|
||||
min-width: 6rem;
|
||||
}
|
||||
|
||||
.bt-model-cell {
|
||||
padding: 0.4rem 0.75rem;
|
||||
border-top: 1px solid var(--color-border, #d0d7e8);
|
||||
font-family: var(--font-mono, monospace);
|
||||
font-size: 0.76rem;
|
||||
white-space: nowrap;
|
||||
overflow: hidden;
|
||||
text-overflow: ellipsis;
|
||||
max-width: 16rem;
|
||||
color: var(--color-text, #1a2338);
|
||||
}
|
||||
|
||||
.bt-metric-cell {
|
||||
padding: 0.4rem 0.75rem;
|
||||
border-top: 1px solid var(--color-border, #d0d7e8);
|
||||
text-align: right;
|
||||
font-family: var(--font-mono, monospace);
|
||||
font-variant-numeric: tabular-nums;
|
||||
color: var(--color-text, #1a2338);
|
||||
}
|
||||
|
||||
.bt-metric-cell.bt-best {
|
||||
color: var(--color-success, #3a7a32);
|
||||
font-weight: 700;
|
||||
background: color-mix(in srgb, var(--color-success, #3a7a32) 8%, transparent);
|
||||
}
|
||||
|
||||
.bench-hint {
|
||||
font-size: 0.75rem;
|
||||
color: var(--color-text-secondary, #6b7a99);
|
||||
margin: 0;
|
||||
}
|
||||
|
||||
@media (max-width: 480px) {
|
||||
.bar-row {
|
||||
grid-template-columns: 1.5rem 1fr 1fr 3rem;
|
||||
|
|
|
|||
Loading…
Reference in a new issue