Scrape → Store → Process pipeline for building email classifier
benchmark data across the CircuitForge menagerie.
- app/label_tool.py — Streamlit card-stack UI, multi-account IMAP fetch,
6-bucket labeling, undo/skip, keyboard shortcuts (1-6/S/U)
- scripts/classifier_adapters.py — ZeroShotAdapter (+ two_pass),
GLiClassAdapter, RerankerAdapter; ABC with lazy model loading
- scripts/benchmark_classifier.py — 13-model registry, --score,
--compare, --list-models, --export-db; uses label_tool.yaml for IMAP
- tests/ — 20 tests, all passing, zero model downloads required
- config/label_tool.yaml.example — multi-account IMAP template
- data/email_score.jsonl.example — sample labeled data for CI
Labels: interview_scheduled, offer_received, rejected,
positive_response, survey_received, neutral
450 lines
15 KiB
Python
450 lines
15 KiB
Python
#!/usr/bin/env python
|
|
"""
|
|
Email classifier benchmark — compare HuggingFace models against our 6 labels.
|
|
|
|
Usage:
|
|
# List available models
|
|
conda run -n job-seeker-classifiers python scripts/benchmark_classifier.py --list-models
|
|
|
|
# Score against labeled JSONL
|
|
conda run -n job-seeker-classifiers python scripts/benchmark_classifier.py --score
|
|
|
|
# Visual comparison on live IMAP emails (uses first account in label_tool.yaml)
|
|
conda run -n job-seeker-classifiers python scripts/benchmark_classifier.py --compare --limit 20
|
|
|
|
# Include slow/large models
|
|
conda run -n job-seeker-classifiers python scripts/benchmark_classifier.py --score --include-slow
|
|
|
|
# Export DB-labeled emails (⚠️ LLM-generated labels)
|
|
conda run -n job-seeker-classifiers python scripts/benchmark_classifier.py --export-db --db /path/to/staging.db
|
|
"""
|
|
from __future__ import annotations
|
|
|
|
import argparse
|
|
import email as _email_lib
|
|
import imaplib
|
|
import json
|
|
import sys
|
|
import time
|
|
from datetime import datetime, timedelta
|
|
from pathlib import Path
|
|
from typing import Any
|
|
|
|
sys.path.insert(0, str(Path(__file__).parent.parent))
|
|
|
|
from scripts.classifier_adapters import (
|
|
LABELS,
|
|
LABEL_DESCRIPTIONS,
|
|
ClassifierAdapter,
|
|
GLiClassAdapter,
|
|
RerankerAdapter,
|
|
ZeroShotAdapter,
|
|
compute_metrics,
|
|
)
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Model registry
|
|
# ---------------------------------------------------------------------------
|
|
|
|
MODEL_REGISTRY: dict[str, dict[str, Any]] = {
|
|
"deberta-zeroshot": {
|
|
"adapter": ZeroShotAdapter,
|
|
"model_id": "MoritzLaurer/DeBERTa-v3-large-zeroshot-v2.0",
|
|
"params": "400M",
|
|
"default": True,
|
|
},
|
|
"deberta-small": {
|
|
"adapter": ZeroShotAdapter,
|
|
"model_id": "cross-encoder/nli-deberta-v3-small",
|
|
"params": "100M",
|
|
"default": True,
|
|
},
|
|
"gliclass-large": {
|
|
"adapter": GLiClassAdapter,
|
|
"model_id": "knowledgator/gliclass-instruct-large-v1.0",
|
|
"params": "400M",
|
|
"default": True,
|
|
},
|
|
"bart-mnli": {
|
|
"adapter": ZeroShotAdapter,
|
|
"model_id": "facebook/bart-large-mnli",
|
|
"params": "400M",
|
|
"default": True,
|
|
},
|
|
"bge-m3-zeroshot": {
|
|
"adapter": ZeroShotAdapter,
|
|
"model_id": "MoritzLaurer/bge-m3-zeroshot-v2.0",
|
|
"params": "600M",
|
|
"default": True,
|
|
},
|
|
"deberta-small-2pass": {
|
|
"adapter": ZeroShotAdapter,
|
|
"model_id": "cross-encoder/nli-deberta-v3-small",
|
|
"params": "100M",
|
|
"default": True,
|
|
"kwargs": {"two_pass": True},
|
|
},
|
|
"deberta-base-anli": {
|
|
"adapter": ZeroShotAdapter,
|
|
"model_id": "MoritzLaurer/DeBERTa-v3-base-mnli-fever-anli",
|
|
"params": "200M",
|
|
"default": True,
|
|
},
|
|
"deberta-large-ling": {
|
|
"adapter": ZeroShotAdapter,
|
|
"model_id": "MoritzLaurer/DeBERTa-v3-large-mnli-fever-anli-ling-wanli",
|
|
"params": "400M",
|
|
"default": False,
|
|
},
|
|
"mdeberta-xnli-2m": {
|
|
"adapter": ZeroShotAdapter,
|
|
"model_id": "MoritzLaurer/mDeBERTa-v3-base-xnli-multilingual-nli-2mil7",
|
|
"params": "300M",
|
|
"default": False,
|
|
},
|
|
"bge-reranker": {
|
|
"adapter": RerankerAdapter,
|
|
"model_id": "BAAI/bge-reranker-v2-m3",
|
|
"params": "600M",
|
|
"default": False,
|
|
},
|
|
"deberta-xlarge": {
|
|
"adapter": ZeroShotAdapter,
|
|
"model_id": "microsoft/deberta-xlarge-mnli",
|
|
"params": "750M",
|
|
"default": False,
|
|
},
|
|
"mdeberta-mnli": {
|
|
"adapter": ZeroShotAdapter,
|
|
"model_id": "MoritzLaurer/mDeBERTa-v3-base-mnli-xnli",
|
|
"params": "300M",
|
|
"default": False,
|
|
},
|
|
"xlm-roberta-anli": {
|
|
"adapter": ZeroShotAdapter,
|
|
"model_id": "vicgalle/xlm-roberta-large-xnli-anli",
|
|
"params": "600M",
|
|
"default": False,
|
|
},
|
|
}
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Helpers
|
|
# ---------------------------------------------------------------------------
|
|
|
|
def load_scoring_jsonl(path: str) -> list[dict[str, str]]:
|
|
"""Load labeled examples from a JSONL file for benchmark scoring."""
|
|
p = Path(path)
|
|
if not p.exists():
|
|
raise FileNotFoundError(
|
|
f"Scoring file not found: {path}\n"
|
|
f"Copy data/email_score.jsonl.example → data/email_score.jsonl "
|
|
f"or use the label tool (app/label_tool.py) to label your own emails."
|
|
)
|
|
rows = []
|
|
with p.open() as f:
|
|
for line in f:
|
|
line = line.strip()
|
|
if line:
|
|
rows.append(json.loads(line))
|
|
return rows
|
|
|
|
|
|
def _active_models(include_slow: bool) -> dict[str, dict[str, Any]]:
|
|
return {k: v for k, v in MODEL_REGISTRY.items() if v["default"] or include_slow}
|
|
|
|
|
|
def run_scoring(
|
|
adapters: list[ClassifierAdapter],
|
|
score_file: str,
|
|
) -> dict[str, Any]:
|
|
"""Run all adapters against a labeled JSONL. Returns per-adapter metrics."""
|
|
rows = load_scoring_jsonl(score_file)
|
|
gold = [r["label"] for r in rows]
|
|
results: dict[str, Any] = {}
|
|
|
|
for adapter in adapters:
|
|
preds: list[str] = []
|
|
t0 = time.monotonic()
|
|
for row in rows:
|
|
try:
|
|
pred = adapter.classify(row["subject"], row["body"])
|
|
except Exception as exc:
|
|
print(f" [{adapter.name}] ERROR on '{row['subject'][:40]}': {exc}", flush=True)
|
|
pred = "neutral"
|
|
preds.append(pred)
|
|
elapsed_ms = (time.monotonic() - t0) * 1000
|
|
metrics = compute_metrics(preds, gold, LABELS)
|
|
metrics["latency_ms"] = round(elapsed_ms / len(rows), 1)
|
|
results[adapter.name] = metrics
|
|
adapter.unload()
|
|
|
|
return results
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# IMAP helpers (stdlib only — reads label_tool.yaml, uses first account)
|
|
# ---------------------------------------------------------------------------
|
|
|
|
_BROAD_TERMS = [
|
|
"interview", "opportunity", "offer letter",
|
|
"job offer", "application", "recruiting",
|
|
]
|
|
|
|
|
|
def _load_imap_config() -> dict[str, Any]:
|
|
"""Load IMAP config from label_tool.yaml, returning first account as a flat dict."""
|
|
import yaml
|
|
cfg_path = Path(__file__).parent.parent / "config" / "label_tool.yaml"
|
|
if not cfg_path.exists():
|
|
raise FileNotFoundError(
|
|
f"IMAP config not found: {cfg_path}\n"
|
|
f"Copy config/label_tool.yaml.example → config/label_tool.yaml"
|
|
)
|
|
cfg = yaml.safe_load(cfg_path.read_text()) or {}
|
|
accounts = cfg.get("accounts", [])
|
|
if not accounts:
|
|
raise ValueError("No accounts configured in config/label_tool.yaml")
|
|
return accounts[0]
|
|
|
|
|
|
def _imap_connect(cfg: dict[str, Any]) -> imaplib.IMAP4_SSL:
|
|
conn = imaplib.IMAP4_SSL(cfg["host"], cfg.get("port", 993))
|
|
conn.login(cfg["username"], cfg["password"])
|
|
return conn
|
|
|
|
|
|
def _decode_part(part: Any) -> str:
|
|
charset = part.get_content_charset() or "utf-8"
|
|
try:
|
|
return part.get_payload(decode=True).decode(charset, errors="replace")
|
|
except Exception:
|
|
return ""
|
|
|
|
|
|
def _parse_uid(conn: imaplib.IMAP4_SSL, uid: bytes) -> dict[str, str] | None:
|
|
try:
|
|
_, data = conn.uid("fetch", uid, "(RFC822)")
|
|
raw = data[0][1]
|
|
msg = _email_lib.message_from_bytes(raw)
|
|
subject = str(msg.get("subject", "")).strip()
|
|
body = ""
|
|
if msg.is_multipart():
|
|
for part in msg.walk():
|
|
if part.get_content_type() == "text/plain":
|
|
body = _decode_part(part)
|
|
break
|
|
else:
|
|
body = _decode_part(msg)
|
|
return {"subject": subject, "body": body}
|
|
except Exception:
|
|
return None
|
|
|
|
|
|
def _fetch_imap_sample(limit: int, days: int) -> list[dict[str, str]]:
|
|
cfg = _load_imap_config()
|
|
conn = _imap_connect(cfg)
|
|
since = (datetime.now() - timedelta(days=days)).strftime("%d-%b-%Y")
|
|
conn.select("INBOX")
|
|
|
|
seen_uids: dict[bytes, None] = {}
|
|
for term in _BROAD_TERMS:
|
|
_, data = conn.uid("search", None, f'(SUBJECT "{term}" SINCE {since})')
|
|
for uid in (data[0] or b"").split():
|
|
seen_uids[uid] = None
|
|
|
|
sample = list(seen_uids.keys())[:limit]
|
|
emails = []
|
|
for uid in sample:
|
|
parsed = _parse_uid(conn, uid)
|
|
if parsed:
|
|
emails.append(parsed)
|
|
try:
|
|
conn.logout()
|
|
except Exception:
|
|
pass
|
|
return emails
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# DB export
|
|
# ---------------------------------------------------------------------------
|
|
|
|
def cmd_export_db(args: argparse.Namespace) -> None:
|
|
"""Export LLM-labeled emails from a peregrine-style job_contacts table → scoring JSONL."""
|
|
import sqlite3
|
|
|
|
db_path = Path(args.db)
|
|
if not db_path.exists():
|
|
print(f"ERROR: Database not found: {args.db}", file=sys.stderr)
|
|
sys.exit(1)
|
|
|
|
conn = sqlite3.connect(db_path)
|
|
cur = conn.cursor()
|
|
cur.execute("""
|
|
SELECT subject, body, stage_signal
|
|
FROM job_contacts
|
|
WHERE stage_signal IS NOT NULL
|
|
AND stage_signal != ''
|
|
AND direction = 'inbound'
|
|
ORDER BY received_at
|
|
""")
|
|
rows = cur.fetchall()
|
|
conn.close()
|
|
|
|
if not rows:
|
|
print("No labeled emails in job_contacts. Run imap_sync first to populate.")
|
|
return
|
|
|
|
out_path = Path(args.score_file)
|
|
out_path.parent.mkdir(parents=True, exist_ok=True)
|
|
|
|
written = 0
|
|
skipped = 0
|
|
label_counts: dict[str, int] = {}
|
|
with out_path.open("w") as f:
|
|
for subject, body, label in rows:
|
|
if label not in LABELS:
|
|
print(f" SKIP unknown label '{label}': {subject[:50]}")
|
|
skipped += 1
|
|
continue
|
|
json.dump({"subject": subject or "", "body": (body or "")[:600], "label": label}, f)
|
|
f.write("\n")
|
|
label_counts[label] = label_counts.get(label, 0) + 1
|
|
written += 1
|
|
|
|
print(f"\nExported {written} emails → {out_path}" + (f" ({skipped} skipped)" if skipped else ""))
|
|
print("\nLabel distribution:")
|
|
for label in LABELS:
|
|
count = label_counts.get(label, 0)
|
|
bar = "█" * count
|
|
print(f" {label:<25} {count:>3} {bar}")
|
|
print(
|
|
"\nNOTE: Labels are LLM predictions from imap_sync — review before treating as ground truth."
|
|
)
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Subcommands
|
|
# ---------------------------------------------------------------------------
|
|
|
|
def cmd_list_models(_args: argparse.Namespace) -> None:
|
|
print(f"\n{'Name':<24} {'Params':<8} {'Default':<20} {'Adapter':<15} Model ID")
|
|
print("-" * 104)
|
|
for name, entry in MODEL_REGISTRY.items():
|
|
adapter_name = entry["adapter"].__name__
|
|
if entry.get("kwargs", {}).get("two_pass"):
|
|
adapter_name += " (2-pass)"
|
|
default_flag = "yes" if entry["default"] else "(--include-slow)"
|
|
print(f"{name:<24} {entry['params']:<8} {default_flag:<20} {adapter_name:<15} {entry['model_id']}")
|
|
print()
|
|
|
|
|
|
def cmd_score(args: argparse.Namespace) -> None:
|
|
active = _active_models(args.include_slow)
|
|
if args.models:
|
|
active = {k: v for k, v in active.items() if k in args.models}
|
|
|
|
adapters = [
|
|
entry["adapter"](name, entry["model_id"], **entry.get("kwargs", {}))
|
|
for name, entry in active.items()
|
|
]
|
|
|
|
print(f"\nScoring {len(adapters)} model(s) against {args.score_file} …\n")
|
|
results = run_scoring(adapters, args.score_file)
|
|
|
|
col = 12
|
|
print(f"{'Model':<22}" + f"{'macro-F1':>{col}} {'Accuracy':>{col}} {'ms/email':>{col}}")
|
|
print("-" * (22 + col * 3 + 2))
|
|
for name, m in results.items():
|
|
print(
|
|
f"{name:<22}"
|
|
f"{m['__macro_f1__']:>{col}.3f}"
|
|
f"{m['__accuracy__']:>{col}.3f}"
|
|
f"{m['latency_ms']:>{col}.1f}"
|
|
)
|
|
|
|
print("\nPer-label F1:")
|
|
names = list(results.keys())
|
|
print(f"{'Label':<25}" + "".join(f"{n[:11]:>{col}}" for n in names))
|
|
print("-" * (25 + col * len(names)))
|
|
for label in LABELS:
|
|
row_str = f"{label:<25}"
|
|
for m in results.values():
|
|
row_str += f"{m[label]['f1']:>{col}.3f}"
|
|
print(row_str)
|
|
print()
|
|
|
|
|
|
def cmd_compare(args: argparse.Namespace) -> None:
|
|
active = _active_models(args.include_slow)
|
|
if args.models:
|
|
active = {k: v for k, v in active.items() if k in args.models}
|
|
|
|
print(f"Fetching up to {args.limit} emails from IMAP …")
|
|
emails = _fetch_imap_sample(args.limit, args.days)
|
|
print(f"Fetched {len(emails)} emails. Loading {len(active)} model(s) …\n")
|
|
|
|
adapters = [
|
|
entry["adapter"](name, entry["model_id"], **entry.get("kwargs", {}))
|
|
for name, entry in active.items()
|
|
]
|
|
model_names = [a.name for a in adapters]
|
|
|
|
col = 22
|
|
subj_w = 50
|
|
print(f"{'Subject':<{subj_w}}" + "".join(f"{n:<{col}}" for n in model_names))
|
|
print("-" * (subj_w + col * len(model_names)))
|
|
|
|
for row in emails:
|
|
short_subj = row["subject"][:subj_w - 1] if len(row["subject"]) > subj_w else row["subject"]
|
|
line = f"{short_subj:<{subj_w}}"
|
|
for adapter in adapters:
|
|
try:
|
|
label = adapter.classify(row["subject"], row["body"])
|
|
except Exception as exc:
|
|
label = f"ERR:{str(exc)[:8]}"
|
|
line += f"{label:<{col}}"
|
|
print(line, flush=True)
|
|
|
|
for adapter in adapters:
|
|
adapter.unload()
|
|
print()
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# CLI
|
|
# ---------------------------------------------------------------------------
|
|
|
|
def main() -> None:
|
|
parser = argparse.ArgumentParser(
|
|
description="Benchmark HuggingFace email classifiers against our 6 labels."
|
|
)
|
|
parser.add_argument("--list-models", action="store_true", help="Show model registry and exit")
|
|
parser.add_argument("--score", action="store_true", help="Score against labeled JSONL")
|
|
parser.add_argument("--compare", action="store_true", help="Visual table on live IMAP emails")
|
|
parser.add_argument("--export-db", action="store_true",
|
|
help="Export labeled emails from a staging.db → score JSONL")
|
|
parser.add_argument("--score-file", default="data/email_score.jsonl", help="Path to labeled JSONL")
|
|
parser.add_argument("--db", default="data/staging.db", help="Path to staging.db for --export-db")
|
|
parser.add_argument("--limit", type=int, default=20, help="Max emails for --compare")
|
|
parser.add_argument("--days", type=int, default=90, help="Days back for IMAP search")
|
|
parser.add_argument("--include-slow", action="store_true", help="Include non-default heavy models")
|
|
parser.add_argument("--models", nargs="+", help="Override: run only these model names")
|
|
|
|
args = parser.parse_args()
|
|
|
|
if args.list_models:
|
|
cmd_list_models(args)
|
|
elif args.score:
|
|
cmd_score(args)
|
|
elif args.compare:
|
|
cmd_compare(args)
|
|
elif args.export_db:
|
|
cmd_export_db(args)
|
|
else:
|
|
parser.print_help()
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|