feat: benchmark_classifier — MODEL_REGISTRY, --list-models, --score, --compare modes
This commit is contained in:
parent
b0ab34dd17
commit
731e4d1aa2
2 changed files with 441 additions and 0 deletions
347
scripts/benchmark_classifier.py
Normal file
347
scripts/benchmark_classifier.py
Normal file
|
|
@ -0,0 +1,347 @@
|
|||
#!/usr/bin/env python
|
||||
"""
|
||||
Email classifier benchmark — compare HuggingFace models against our 6 labels.
|
||||
|
||||
Usage:
|
||||
# List available models
|
||||
conda run -n job-seeker-classifiers python scripts/benchmark_classifier.py --list-models
|
||||
|
||||
# Score against labeled JSONL
|
||||
conda run -n job-seeker-classifiers python scripts/benchmark_classifier.py --score
|
||||
|
||||
# Visual comparison on live IMAP emails
|
||||
conda run -n job-seeker-classifiers python scripts/benchmark_classifier.py --compare --limit 20
|
||||
|
||||
# Include slow/large models
|
||||
conda run -n job-seeker-classifiers python scripts/benchmark_classifier.py --score --include-slow
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import email as _email_lib
|
||||
import imaplib
|
||||
import json
|
||||
import sys
|
||||
import time
|
||||
from datetime import datetime, timedelta
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
sys.path.insert(0, str(Path(__file__).parent.parent))
|
||||
|
||||
from scripts.classifier_adapters import (
|
||||
LABELS,
|
||||
LABEL_DESCRIPTIONS,
|
||||
ClassifierAdapter,
|
||||
GLiClassAdapter,
|
||||
RerankerAdapter,
|
||||
ZeroShotAdapter,
|
||||
compute_metrics,
|
||||
)
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Model registry
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
MODEL_REGISTRY: dict[str, dict[str, Any]] = {
|
||||
"deberta-zeroshot": {
|
||||
"adapter": ZeroShotAdapter,
|
||||
"model_id": "MoritzLaurer/DeBERTa-v3-large-zeroshot-v2.0",
|
||||
"params": "400M",
|
||||
"default": True,
|
||||
},
|
||||
"deberta-small": {
|
||||
"adapter": ZeroShotAdapter,
|
||||
"model_id": "cross-encoder/nli-deberta-v3-small",
|
||||
"params": "100M",
|
||||
"default": True,
|
||||
},
|
||||
"gliclass-large": {
|
||||
"adapter": GLiClassAdapter,
|
||||
"model_id": "knowledgator/gliclass-instruct-large-v1.0",
|
||||
"params": "400M",
|
||||
"default": True,
|
||||
},
|
||||
"bart-mnli": {
|
||||
"adapter": ZeroShotAdapter,
|
||||
"model_id": "facebook/bart-large-mnli",
|
||||
"params": "400M",
|
||||
"default": True,
|
||||
},
|
||||
"bge-m3-zeroshot": {
|
||||
"adapter": ZeroShotAdapter,
|
||||
"model_id": "MoritzLaurer/bge-m3-zeroshot-v2.0",
|
||||
"params": "600M",
|
||||
"default": True,
|
||||
},
|
||||
"bge-reranker": {
|
||||
"adapter": RerankerAdapter,
|
||||
"model_id": "BAAI/bge-reranker-v2-m3",
|
||||
"params": "600M",
|
||||
"default": False,
|
||||
},
|
||||
"deberta-xlarge": {
|
||||
"adapter": ZeroShotAdapter,
|
||||
"model_id": "microsoft/deberta-xlarge-mnli",
|
||||
"params": "750M",
|
||||
"default": False,
|
||||
},
|
||||
"mdeberta-mnli": {
|
||||
"adapter": ZeroShotAdapter,
|
||||
"model_id": "MoritzLaurer/mDeBERTa-v3-base-mnli-xnli",
|
||||
"params": "300M",
|
||||
"default": False,
|
||||
},
|
||||
"xlm-roberta-anli": {
|
||||
"adapter": ZeroShotAdapter,
|
||||
"model_id": "vicgalle/xlm-roberta-large-xnli-anli",
|
||||
"params": "600M",
|
||||
"default": False,
|
||||
},
|
||||
}
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def load_scoring_jsonl(path: str) -> list[dict[str, str]]:
|
||||
"""Load labeled examples from a JSONL file for benchmark scoring."""
|
||||
p = Path(path)
|
||||
if not p.exists():
|
||||
raise FileNotFoundError(
|
||||
f"Scoring file not found: {path}\n"
|
||||
f"Copy data/email_score.jsonl.example → data/email_score.jsonl and label your emails."
|
||||
)
|
||||
rows = []
|
||||
with p.open() as f:
|
||||
for line in f:
|
||||
line = line.strip()
|
||||
if line:
|
||||
rows.append(json.loads(line))
|
||||
return rows
|
||||
|
||||
|
||||
def _active_models(include_slow: bool) -> dict[str, dict[str, Any]]:
|
||||
return {k: v for k, v in MODEL_REGISTRY.items() if v["default"] or include_slow}
|
||||
|
||||
|
||||
def run_scoring(
|
||||
adapters: list[ClassifierAdapter],
|
||||
score_file: str,
|
||||
) -> dict[str, Any]:
|
||||
"""Run all adapters against a labeled JSONL. Returns per-adapter metrics."""
|
||||
rows = load_scoring_jsonl(score_file)
|
||||
gold = [r["label"] for r in rows]
|
||||
results: dict[str, Any] = {}
|
||||
|
||||
for adapter in adapters:
|
||||
preds: list[str] = []
|
||||
t0 = time.monotonic()
|
||||
for row in rows:
|
||||
try:
|
||||
pred = adapter.classify(row["subject"], row["body"])
|
||||
except Exception as exc:
|
||||
print(f" [{adapter.name}] ERROR on '{row['subject'][:40]}': {exc}", flush=True)
|
||||
pred = "neutral"
|
||||
preds.append(pred)
|
||||
elapsed_ms = (time.monotonic() - t0) * 1000
|
||||
metrics = compute_metrics(preds, gold, LABELS)
|
||||
metrics["latency_ms"] = round(elapsed_ms / len(rows), 1)
|
||||
results[adapter.name] = metrics
|
||||
adapter.unload()
|
||||
|
||||
return results
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# IMAP helpers (stdlib only — no imap_sync dependency)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
_BROAD_TERMS = [
|
||||
"interview", "opportunity", "offer letter",
|
||||
"job offer", "application", "recruiting",
|
||||
]
|
||||
|
||||
|
||||
def _load_imap_config() -> dict[str, Any]:
|
||||
import yaml
|
||||
cfg_path = Path(__file__).parent.parent / "config" / "email.yaml"
|
||||
with cfg_path.open() as f:
|
||||
return yaml.safe_load(f)
|
||||
|
||||
|
||||
def _imap_connect(cfg: dict[str, Any]) -> imaplib.IMAP4_SSL:
|
||||
conn = imaplib.IMAP4_SSL(cfg["host"], cfg.get("port", 993))
|
||||
conn.login(cfg["username"], cfg["password"])
|
||||
return conn
|
||||
|
||||
|
||||
def _decode_part(part: Any) -> str:
|
||||
charset = part.get_content_charset() or "utf-8"
|
||||
try:
|
||||
return part.get_payload(decode=True).decode(charset, errors="replace")
|
||||
except Exception:
|
||||
return ""
|
||||
|
||||
|
||||
def _parse_uid(conn: imaplib.IMAP4_SSL, uid: bytes) -> dict[str, str] | None:
|
||||
try:
|
||||
_, data = conn.uid("fetch", uid, "(RFC822)")
|
||||
raw = data[0][1]
|
||||
msg = _email_lib.message_from_bytes(raw)
|
||||
subject = str(msg.get("subject", "")).strip()
|
||||
body = ""
|
||||
if msg.is_multipart():
|
||||
for part in msg.walk():
|
||||
if part.get_content_type() == "text/plain":
|
||||
body = _decode_part(part)
|
||||
break
|
||||
else:
|
||||
body = _decode_part(msg)
|
||||
return {"subject": subject, "body": body}
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
|
||||
def _fetch_imap_sample(limit: int, days: int) -> list[dict[str, str]]:
|
||||
cfg = _load_imap_config()
|
||||
conn = _imap_connect(cfg)
|
||||
since = (datetime.now() - timedelta(days=days)).strftime("%d-%b-%Y")
|
||||
conn.select("INBOX")
|
||||
|
||||
seen_uids: dict[bytes, None] = {}
|
||||
for term in _BROAD_TERMS:
|
||||
_, data = conn.uid("search", None, f'(SUBJECT "{term}" SINCE {since})')
|
||||
for uid in (data[0] or b"").split():
|
||||
seen_uids[uid] = None
|
||||
|
||||
sample = list(seen_uids.keys())[:limit]
|
||||
emails = []
|
||||
for uid in sample:
|
||||
parsed = _parse_uid(conn, uid)
|
||||
if parsed:
|
||||
emails.append(parsed)
|
||||
try:
|
||||
conn.logout()
|
||||
except Exception:
|
||||
pass
|
||||
return emails
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Subcommands
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def cmd_list_models(_args: argparse.Namespace) -> None:
|
||||
print(f"\n{'Name':<20} {'Params':<8} {'Default':<20} {'Adapter':<15} Model ID")
|
||||
print("-" * 100)
|
||||
for name, entry in MODEL_REGISTRY.items():
|
||||
adapter_name = entry["adapter"].__name__
|
||||
default_flag = "yes" if entry["default"] else "(--include-slow)"
|
||||
print(f"{name:<20} {entry['params']:<8} {default_flag:<20} {adapter_name:<15} {entry['model_id']}")
|
||||
print()
|
||||
|
||||
|
||||
def cmd_score(args: argparse.Namespace) -> None:
|
||||
active = _active_models(args.include_slow)
|
||||
if args.models:
|
||||
active = {k: v for k, v in active.items() if k in args.models}
|
||||
|
||||
adapters = [
|
||||
entry["adapter"](name, entry["model_id"])
|
||||
for name, entry in active.items()
|
||||
]
|
||||
|
||||
print(f"\nScoring {len(adapters)} model(s) against {args.score_file} …\n")
|
||||
results = run_scoring(adapters, args.score_file)
|
||||
|
||||
col = 12
|
||||
print(f"{'Model':<22}" + f"{'macro-F1':>{col}} {'Accuracy':>{col}} {'ms/email':>{col}}")
|
||||
print("-" * (22 + col * 3 + 2))
|
||||
for name, m in results.items():
|
||||
print(
|
||||
f"{name:<22}"
|
||||
f"{m['__macro_f1__']:>{col}.3f}"
|
||||
f"{m['__accuracy__']:>{col}.3f}"
|
||||
f"{m['latency_ms']:>{col}.1f}"
|
||||
)
|
||||
|
||||
print("\nPer-label F1:")
|
||||
names = list(results.keys())
|
||||
print(f"{'Label':<25}" + "".join(f"{n[:11]:>{col}}" for n in names))
|
||||
print("-" * (25 + col * len(names)))
|
||||
for label in LABELS:
|
||||
row_str = f"{label:<25}"
|
||||
for m in results.values():
|
||||
row_str += f"{m[label]['f1']:>{col}.3f}"
|
||||
print(row_str)
|
||||
print()
|
||||
|
||||
|
||||
def cmd_compare(args: argparse.Namespace) -> None:
|
||||
active = _active_models(args.include_slow)
|
||||
if args.models:
|
||||
active = {k: v for k, v in active.items() if k in args.models}
|
||||
|
||||
print(f"Fetching up to {args.limit} emails from IMAP …")
|
||||
emails = _fetch_imap_sample(args.limit, args.days)
|
||||
print(f"Fetched {len(emails)} emails. Loading {len(active)} model(s) …\n")
|
||||
|
||||
adapters = [
|
||||
entry["adapter"](name, entry["model_id"])
|
||||
for name, entry in active.items()
|
||||
]
|
||||
model_names = [a.name for a in adapters]
|
||||
|
||||
col = 22
|
||||
subj_w = 50
|
||||
print(f"{'Subject':<{subj_w}}" + "".join(f"{n:<{col}}" for n in model_names))
|
||||
print("-" * (subj_w + col * len(model_names)))
|
||||
|
||||
for row in emails:
|
||||
short_subj = row["subject"][:subj_w - 1] if len(row["subject"]) > subj_w else row["subject"]
|
||||
line = f"{short_subj:<{subj_w}}"
|
||||
for adapter in adapters:
|
||||
try:
|
||||
label = adapter.classify(row["subject"], row["body"])
|
||||
except Exception as exc:
|
||||
label = f"ERR:{str(exc)[:8]}"
|
||||
line += f"{label:<{col}}"
|
||||
print(line, flush=True)
|
||||
|
||||
for adapter in adapters:
|
||||
adapter.unload()
|
||||
print()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# CLI
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def main() -> None:
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Benchmark HuggingFace email classifiers against our 6 labels."
|
||||
)
|
||||
parser.add_argument("--list-models", action="store_true", help="Show model registry and exit")
|
||||
parser.add_argument("--score", action="store_true", help="Score against labeled JSONL")
|
||||
parser.add_argument("--compare", action="store_true", help="Visual table on live IMAP emails")
|
||||
parser.add_argument("--score-file", default="data/email_score.jsonl", help="Path to labeled JSONL")
|
||||
parser.add_argument("--limit", type=int, default=20, help="Max emails for --compare")
|
||||
parser.add_argument("--days", type=int, default=90, help="Days back for IMAP search")
|
||||
parser.add_argument("--include-slow", action="store_true", help="Include non-default heavy models")
|
||||
parser.add_argument("--models", nargs="+", help="Override: run only these model names")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
if args.list_models:
|
||||
cmd_list_models(args)
|
||||
elif args.score:
|
||||
cmd_score(args)
|
||||
elif args.compare:
|
||||
cmd_compare(args)
|
||||
else:
|
||||
parser.print_help()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
94
tests/test_benchmark_classifier.py
Normal file
94
tests/test_benchmark_classifier.py
Normal file
|
|
@ -0,0 +1,94 @@
|
|||
"""Tests for benchmark_classifier — no model downloads required."""
|
||||
import pytest
|
||||
|
||||
|
||||
def test_registry_has_nine_models():
|
||||
from scripts.benchmark_classifier import MODEL_REGISTRY
|
||||
assert len(MODEL_REGISTRY) == 9
|
||||
|
||||
|
||||
def test_registry_default_count():
|
||||
from scripts.benchmark_classifier import MODEL_REGISTRY
|
||||
defaults = [k for k, v in MODEL_REGISTRY.items() if v["default"]]
|
||||
assert len(defaults) == 5
|
||||
|
||||
|
||||
def test_registry_entries_have_required_keys():
|
||||
from scripts.benchmark_classifier import MODEL_REGISTRY
|
||||
from scripts.classifier_adapters import ClassifierAdapter
|
||||
for name, entry in MODEL_REGISTRY.items():
|
||||
assert "adapter" in entry, f"{name} missing 'adapter'"
|
||||
assert "model_id" in entry, f"{name} missing 'model_id'"
|
||||
assert "params" in entry, f"{name} missing 'params'"
|
||||
assert "default" in entry, f"{name} missing 'default'"
|
||||
assert issubclass(entry["adapter"], ClassifierAdapter), \
|
||||
f"{name} adapter must be a ClassifierAdapter subclass"
|
||||
|
||||
|
||||
def test_load_scoring_jsonl(tmp_path):
|
||||
from scripts.benchmark_classifier import load_scoring_jsonl
|
||||
import json
|
||||
f = tmp_path / "score.jsonl"
|
||||
rows = [
|
||||
{"subject": "Hi", "body": "Body text", "label": "neutral"},
|
||||
{"subject": "Interview", "body": "Schedule a call", "label": "interview_scheduled"},
|
||||
]
|
||||
f.write_text("\n".join(json.dumps(r) for r in rows))
|
||||
result = load_scoring_jsonl(str(f))
|
||||
assert len(result) == 2
|
||||
assert result[0]["label"] == "neutral"
|
||||
|
||||
|
||||
def test_load_scoring_jsonl_missing_file():
|
||||
from scripts.benchmark_classifier import load_scoring_jsonl
|
||||
with pytest.raises(FileNotFoundError):
|
||||
load_scoring_jsonl("/nonexistent/path.jsonl")
|
||||
|
||||
|
||||
def test_run_scoring_with_mock_adapters(tmp_path):
|
||||
"""run_scoring() returns per-model metrics using mock adapters."""
|
||||
import json
|
||||
from unittest.mock import MagicMock
|
||||
from scripts.benchmark_classifier import run_scoring
|
||||
|
||||
score_file = tmp_path / "score.jsonl"
|
||||
rows = [
|
||||
{"subject": "Interview", "body": "Let's schedule", "label": "interview_scheduled"},
|
||||
{"subject": "Sorry", "body": "We went with others", "label": "rejected"},
|
||||
{"subject": "Offer", "body": "We are pleased", "label": "offer_received"},
|
||||
]
|
||||
score_file.write_text("\n".join(json.dumps(r) for r in rows))
|
||||
|
||||
perfect = MagicMock()
|
||||
perfect.name = "perfect"
|
||||
perfect.classify.side_effect = lambda s, b: (
|
||||
"interview_scheduled" if "Interview" in s else
|
||||
"rejected" if "Sorry" in s else "offer_received"
|
||||
)
|
||||
|
||||
bad = MagicMock()
|
||||
bad.name = "bad"
|
||||
bad.classify.return_value = "neutral"
|
||||
|
||||
results = run_scoring([perfect, bad], str(score_file))
|
||||
|
||||
assert results["perfect"]["__accuracy__"] == pytest.approx(1.0)
|
||||
assert results["bad"]["__accuracy__"] == pytest.approx(0.0)
|
||||
assert "latency_ms" in results["perfect"]
|
||||
|
||||
|
||||
def test_run_scoring_handles_classify_error(tmp_path):
|
||||
"""run_scoring() falls back to 'neutral' on exception and continues."""
|
||||
import json
|
||||
from unittest.mock import MagicMock
|
||||
from scripts.benchmark_classifier import run_scoring
|
||||
|
||||
score_file = tmp_path / "score.jsonl"
|
||||
score_file.write_text(json.dumps({"subject": "Hi", "body": "Body", "label": "neutral"}))
|
||||
|
||||
broken = MagicMock()
|
||||
broken.name = "broken"
|
||||
broken.classify.side_effect = RuntimeError("model crashed")
|
||||
|
||||
results = run_scoring([broken], str(score_file))
|
||||
assert "broken" in results
|
||||
Loading…
Reference in a new issue