diff --git a/scripts/benchmark_classifier.py b/scripts/benchmark_classifier.py new file mode 100644 index 0000000..2eec77d --- /dev/null +++ b/scripts/benchmark_classifier.py @@ -0,0 +1,347 @@ +#!/usr/bin/env python +""" +Email classifier benchmark — compare HuggingFace models against our 6 labels. + +Usage: + # List available models + conda run -n job-seeker-classifiers python scripts/benchmark_classifier.py --list-models + + # Score against labeled JSONL + conda run -n job-seeker-classifiers python scripts/benchmark_classifier.py --score + + # Visual comparison on live IMAP emails + conda run -n job-seeker-classifiers python scripts/benchmark_classifier.py --compare --limit 20 + + # Include slow/large models + conda run -n job-seeker-classifiers python scripts/benchmark_classifier.py --score --include-slow +""" +from __future__ import annotations + +import argparse +import email as _email_lib +import imaplib +import json +import sys +import time +from datetime import datetime, timedelta +from pathlib import Path +from typing import Any + +sys.path.insert(0, str(Path(__file__).parent.parent)) + +from scripts.classifier_adapters import ( + LABELS, + LABEL_DESCRIPTIONS, + ClassifierAdapter, + GLiClassAdapter, + RerankerAdapter, + ZeroShotAdapter, + compute_metrics, +) + +# --------------------------------------------------------------------------- +# Model registry +# --------------------------------------------------------------------------- + +MODEL_REGISTRY: dict[str, dict[str, Any]] = { + "deberta-zeroshot": { + "adapter": ZeroShotAdapter, + "model_id": "MoritzLaurer/DeBERTa-v3-large-zeroshot-v2.0", + "params": "400M", + "default": True, + }, + "deberta-small": { + "adapter": ZeroShotAdapter, + "model_id": "cross-encoder/nli-deberta-v3-small", + "params": "100M", + "default": True, + }, + "gliclass-large": { + "adapter": GLiClassAdapter, + "model_id": "knowledgator/gliclass-instruct-large-v1.0", + "params": "400M", + "default": True, + }, + "bart-mnli": { + "adapter": ZeroShotAdapter, + "model_id": "facebook/bart-large-mnli", + "params": "400M", + "default": True, + }, + "bge-m3-zeroshot": { + "adapter": ZeroShotAdapter, + "model_id": "MoritzLaurer/bge-m3-zeroshot-v2.0", + "params": "600M", + "default": True, + }, + "bge-reranker": { + "adapter": RerankerAdapter, + "model_id": "BAAI/bge-reranker-v2-m3", + "params": "600M", + "default": False, + }, + "deberta-xlarge": { + "adapter": ZeroShotAdapter, + "model_id": "microsoft/deberta-xlarge-mnli", + "params": "750M", + "default": False, + }, + "mdeberta-mnli": { + "adapter": ZeroShotAdapter, + "model_id": "MoritzLaurer/mDeBERTa-v3-base-mnli-xnli", + "params": "300M", + "default": False, + }, + "xlm-roberta-anli": { + "adapter": ZeroShotAdapter, + "model_id": "vicgalle/xlm-roberta-large-xnli-anli", + "params": "600M", + "default": False, + }, +} + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + +def load_scoring_jsonl(path: str) -> list[dict[str, str]]: + """Load labeled examples from a JSONL file for benchmark scoring.""" + p = Path(path) + if not p.exists(): + raise FileNotFoundError( + f"Scoring file not found: {path}\n" + f"Copy data/email_score.jsonl.example → data/email_score.jsonl and label your emails." + ) + rows = [] + with p.open() as f: + for line in f: + line = line.strip() + if line: + rows.append(json.loads(line)) + return rows + + +def _active_models(include_slow: bool) -> dict[str, dict[str, Any]]: + return {k: v for k, v in MODEL_REGISTRY.items() if v["default"] or include_slow} + + +def run_scoring( + adapters: list[ClassifierAdapter], + score_file: str, +) -> dict[str, Any]: + """Run all adapters against a labeled JSONL. Returns per-adapter metrics.""" + rows = load_scoring_jsonl(score_file) + gold = [r["label"] for r in rows] + results: dict[str, Any] = {} + + for adapter in adapters: + preds: list[str] = [] + t0 = time.monotonic() + for row in rows: + try: + pred = adapter.classify(row["subject"], row["body"]) + except Exception as exc: + print(f" [{adapter.name}] ERROR on '{row['subject'][:40]}': {exc}", flush=True) + pred = "neutral" + preds.append(pred) + elapsed_ms = (time.monotonic() - t0) * 1000 + metrics = compute_metrics(preds, gold, LABELS) + metrics["latency_ms"] = round(elapsed_ms / len(rows), 1) + results[adapter.name] = metrics + adapter.unload() + + return results + + +# --------------------------------------------------------------------------- +# IMAP helpers (stdlib only — no imap_sync dependency) +# --------------------------------------------------------------------------- + +_BROAD_TERMS = [ + "interview", "opportunity", "offer letter", + "job offer", "application", "recruiting", +] + + +def _load_imap_config() -> dict[str, Any]: + import yaml + cfg_path = Path(__file__).parent.parent / "config" / "email.yaml" + with cfg_path.open() as f: + return yaml.safe_load(f) + + +def _imap_connect(cfg: dict[str, Any]) -> imaplib.IMAP4_SSL: + conn = imaplib.IMAP4_SSL(cfg["host"], cfg.get("port", 993)) + conn.login(cfg["username"], cfg["password"]) + return conn + + +def _decode_part(part: Any) -> str: + charset = part.get_content_charset() or "utf-8" + try: + return part.get_payload(decode=True).decode(charset, errors="replace") + except Exception: + return "" + + +def _parse_uid(conn: imaplib.IMAP4_SSL, uid: bytes) -> dict[str, str] | None: + try: + _, data = conn.uid("fetch", uid, "(RFC822)") + raw = data[0][1] + msg = _email_lib.message_from_bytes(raw) + subject = str(msg.get("subject", "")).strip() + body = "" + if msg.is_multipart(): + for part in msg.walk(): + if part.get_content_type() == "text/plain": + body = _decode_part(part) + break + else: + body = _decode_part(msg) + return {"subject": subject, "body": body} + except Exception: + return None + + +def _fetch_imap_sample(limit: int, days: int) -> list[dict[str, str]]: + cfg = _load_imap_config() + conn = _imap_connect(cfg) + since = (datetime.now() - timedelta(days=days)).strftime("%d-%b-%Y") + conn.select("INBOX") + + seen_uids: dict[bytes, None] = {} + for term in _BROAD_TERMS: + _, data = conn.uid("search", None, f'(SUBJECT "{term}" SINCE {since})') + for uid in (data[0] or b"").split(): + seen_uids[uid] = None + + sample = list(seen_uids.keys())[:limit] + emails = [] + for uid in sample: + parsed = _parse_uid(conn, uid) + if parsed: + emails.append(parsed) + try: + conn.logout() + except Exception: + pass + return emails + + +# --------------------------------------------------------------------------- +# Subcommands +# --------------------------------------------------------------------------- + +def cmd_list_models(_args: argparse.Namespace) -> None: + print(f"\n{'Name':<20} {'Params':<8} {'Default':<20} {'Adapter':<15} Model ID") + print("-" * 100) + for name, entry in MODEL_REGISTRY.items(): + adapter_name = entry["adapter"].__name__ + default_flag = "yes" if entry["default"] else "(--include-slow)" + print(f"{name:<20} {entry['params']:<8} {default_flag:<20} {adapter_name:<15} {entry['model_id']}") + print() + + +def cmd_score(args: argparse.Namespace) -> None: + active = _active_models(args.include_slow) + if args.models: + active = {k: v for k, v in active.items() if k in args.models} + + adapters = [ + entry["adapter"](name, entry["model_id"]) + for name, entry in active.items() + ] + + print(f"\nScoring {len(adapters)} model(s) against {args.score_file} …\n") + results = run_scoring(adapters, args.score_file) + + col = 12 + print(f"{'Model':<22}" + f"{'macro-F1':>{col}} {'Accuracy':>{col}} {'ms/email':>{col}}") + print("-" * (22 + col * 3 + 2)) + for name, m in results.items(): + print( + f"{name:<22}" + f"{m['__macro_f1__']:>{col}.3f}" + f"{m['__accuracy__']:>{col}.3f}" + f"{m['latency_ms']:>{col}.1f}" + ) + + print("\nPer-label F1:") + names = list(results.keys()) + print(f"{'Label':<25}" + "".join(f"{n[:11]:>{col}}" for n in names)) + print("-" * (25 + col * len(names))) + for label in LABELS: + row_str = f"{label:<25}" + for m in results.values(): + row_str += f"{m[label]['f1']:>{col}.3f}" + print(row_str) + print() + + +def cmd_compare(args: argparse.Namespace) -> None: + active = _active_models(args.include_slow) + if args.models: + active = {k: v for k, v in active.items() if k in args.models} + + print(f"Fetching up to {args.limit} emails from IMAP …") + emails = _fetch_imap_sample(args.limit, args.days) + print(f"Fetched {len(emails)} emails. Loading {len(active)} model(s) …\n") + + adapters = [ + entry["adapter"](name, entry["model_id"]) + for name, entry in active.items() + ] + model_names = [a.name for a in adapters] + + col = 22 + subj_w = 50 + print(f"{'Subject':<{subj_w}}" + "".join(f"{n:<{col}}" for n in model_names)) + print("-" * (subj_w + col * len(model_names))) + + for row in emails: + short_subj = row["subject"][:subj_w - 1] if len(row["subject"]) > subj_w else row["subject"] + line = f"{short_subj:<{subj_w}}" + for adapter in adapters: + try: + label = adapter.classify(row["subject"], row["body"]) + except Exception as exc: + label = f"ERR:{str(exc)[:8]}" + line += f"{label:<{col}}" + print(line, flush=True) + + for adapter in adapters: + adapter.unload() + print() + + +# --------------------------------------------------------------------------- +# CLI +# --------------------------------------------------------------------------- + +def main() -> None: + parser = argparse.ArgumentParser( + description="Benchmark HuggingFace email classifiers against our 6 labels." + ) + parser.add_argument("--list-models", action="store_true", help="Show model registry and exit") + parser.add_argument("--score", action="store_true", help="Score against labeled JSONL") + parser.add_argument("--compare", action="store_true", help="Visual table on live IMAP emails") + parser.add_argument("--score-file", default="data/email_score.jsonl", help="Path to labeled JSONL") + parser.add_argument("--limit", type=int, default=20, help="Max emails for --compare") + parser.add_argument("--days", type=int, default=90, help="Days back for IMAP search") + parser.add_argument("--include-slow", action="store_true", help="Include non-default heavy models") + parser.add_argument("--models", nargs="+", help="Override: run only these model names") + + args = parser.parse_args() + + if args.list_models: + cmd_list_models(args) + elif args.score: + cmd_score(args) + elif args.compare: + cmd_compare(args) + else: + parser.print_help() + + +if __name__ == "__main__": + main() diff --git a/tests/test_benchmark_classifier.py b/tests/test_benchmark_classifier.py new file mode 100644 index 0000000..d218c4a --- /dev/null +++ b/tests/test_benchmark_classifier.py @@ -0,0 +1,94 @@ +"""Tests for benchmark_classifier — no model downloads required.""" +import pytest + + +def test_registry_has_nine_models(): + from scripts.benchmark_classifier import MODEL_REGISTRY + assert len(MODEL_REGISTRY) == 9 + + +def test_registry_default_count(): + from scripts.benchmark_classifier import MODEL_REGISTRY + defaults = [k for k, v in MODEL_REGISTRY.items() if v["default"]] + assert len(defaults) == 5 + + +def test_registry_entries_have_required_keys(): + from scripts.benchmark_classifier import MODEL_REGISTRY + from scripts.classifier_adapters import ClassifierAdapter + for name, entry in MODEL_REGISTRY.items(): + assert "adapter" in entry, f"{name} missing 'adapter'" + assert "model_id" in entry, f"{name} missing 'model_id'" + assert "params" in entry, f"{name} missing 'params'" + assert "default" in entry, f"{name} missing 'default'" + assert issubclass(entry["adapter"], ClassifierAdapter), \ + f"{name} adapter must be a ClassifierAdapter subclass" + + +def test_load_scoring_jsonl(tmp_path): + from scripts.benchmark_classifier import load_scoring_jsonl + import json + f = tmp_path / "score.jsonl" + rows = [ + {"subject": "Hi", "body": "Body text", "label": "neutral"}, + {"subject": "Interview", "body": "Schedule a call", "label": "interview_scheduled"}, + ] + f.write_text("\n".join(json.dumps(r) for r in rows)) + result = load_scoring_jsonl(str(f)) + assert len(result) == 2 + assert result[0]["label"] == "neutral" + + +def test_load_scoring_jsonl_missing_file(): + from scripts.benchmark_classifier import load_scoring_jsonl + with pytest.raises(FileNotFoundError): + load_scoring_jsonl("/nonexistent/path.jsonl") + + +def test_run_scoring_with_mock_adapters(tmp_path): + """run_scoring() returns per-model metrics using mock adapters.""" + import json + from unittest.mock import MagicMock + from scripts.benchmark_classifier import run_scoring + + score_file = tmp_path / "score.jsonl" + rows = [ + {"subject": "Interview", "body": "Let's schedule", "label": "interview_scheduled"}, + {"subject": "Sorry", "body": "We went with others", "label": "rejected"}, + {"subject": "Offer", "body": "We are pleased", "label": "offer_received"}, + ] + score_file.write_text("\n".join(json.dumps(r) for r in rows)) + + perfect = MagicMock() + perfect.name = "perfect" + perfect.classify.side_effect = lambda s, b: ( + "interview_scheduled" if "Interview" in s else + "rejected" if "Sorry" in s else "offer_received" + ) + + bad = MagicMock() + bad.name = "bad" + bad.classify.return_value = "neutral" + + results = run_scoring([perfect, bad], str(score_file)) + + assert results["perfect"]["__accuracy__"] == pytest.approx(1.0) + assert results["bad"]["__accuracy__"] == pytest.approx(0.0) + assert "latency_ms" in results["perfect"] + + +def test_run_scoring_handles_classify_error(tmp_path): + """run_scoring() falls back to 'neutral' on exception and continues.""" + import json + from unittest.mock import MagicMock + from scripts.benchmark_classifier import run_scoring + + score_file = tmp_path / "score.jsonl" + score_file.write_text(json.dumps({"subject": "Hi", "body": "Body", "label": "neutral"})) + + broken = MagicMock() + broken.name = "broken" + broken.classify.side_effect = RuntimeError("model crashed") + + results = run_scoring([broken], str(score_file)) + assert "broken" in results