feat: initial avocet repo — email classifier training tool

Scrape → Store → Process pipeline for building email classifier
benchmark data across the CircuitForge menagerie.

- app/label_tool.py — Streamlit card-stack UI, multi-account IMAP fetch,
  6-bucket labeling, undo/skip, keyboard shortcuts (1-6/S/U)
- scripts/classifier_adapters.py — ZeroShotAdapter (+ two_pass),
  GLiClassAdapter, RerankerAdapter; ABC with lazy model loading
- scripts/benchmark_classifier.py — 13-model registry, --score,
  --compare, --list-models, --export-db; uses label_tool.yaml for IMAP
- tests/ — 20 tests, all passing, zero model downloads required
- config/label_tool.yaml.example — multi-account IMAP template
- data/email_score.jsonl.example — sample labeled data for CI

Labels: interview_scheduled, offer_received, rejected,
        positive_response, survey_received, neutral
This commit is contained in:
pyr0ball 2026-02-27 14:07:38 -08:00
commit 0e238a9e37
14 changed files with 1723 additions and 0 deletions

16
.gitignore vendored Normal file
View file

@ -0,0 +1,16 @@
__pycache__/
*.pyc
.pytest_cache/
.coverage
*.egg-info/
# Secrets and personal data
config/label_tool.yaml
# Data files (user-generated, not for version control)
data/email_score.jsonl
data/email_label_queue.jsonl
data/email_compare_sample.jsonl
# Conda/pip artifacts
.env

100
CLAUDE.md Normal file
View file

@ -0,0 +1,100 @@
# Avocet — Email Classifier Training Tool
## What it is
Shared infrastructure for building and benchmarking email classifiers across the CircuitForge menagerie.
Named for the avocet's sweeping-bill technique — it sweeps through email streams and filters out categories.
**Pipeline:**
```
Scrape (IMAP, wide search, multi-account) → data/email_label_queue.jsonl
Label (card-stack UI) → data/email_score.jsonl
Benchmark (HuggingFace NLI/reranker) → per-model macro-F1 + latency
```
## Environment
- Python env: `conda run -n job-seeker <cmd>` for basic use (streamlit, yaml, stdlib only)
- Classifier env: `conda run -n job-seeker-classifiers <cmd>` for benchmark (transformers, FlagEmbedding, gliclass)
- Run tests: `/devl/miniconda3/envs/job-seeker/bin/pytest tests/ -v`
(direct binary — `conda run pytest` can spawn runaway processes)
- Create classifier env: `conda env create -f environment.yml`
## Label Tool (app/label_tool.py)
Card-stack Streamlit UI for manually labeling recruitment emails.
```
conda run -n job-seeker streamlit run app/label_tool.py --server.port 8503
```
- Config: `config/label_tool.yaml` (gitignored — copy from `.example`)
- Queue: `data/email_label_queue.jsonl` (gitignored)
- Output: `data/email_score.jsonl` (gitignored)
- Three tabs: 🃏 Label, 📥 Fetch, 📊 Stats
- Keyboard shortcuts: 16 = label, S = skip, U = undo
- Dedup: MD5 of `(subject + body[:100])` — cross-account safe
## Benchmark (scripts/benchmark_classifier.py)
```
# List available models
conda run -n job-seeker-classifiers python scripts/benchmark_classifier.py --list-models
# Score against labeled JSONL
conda run -n job-seeker-classifiers python scripts/benchmark_classifier.py --score
# Visual comparison on live IMAP emails
conda run -n job-seeker-classifiers python scripts/benchmark_classifier.py --compare --limit 20
# Include slow/large models
conda run -n job-seeker-classifiers python scripts/benchmark_classifier.py --score --include-slow
# Export DB-labeled emails (⚠️ LLM-generated labels — review first)
conda run -n job-seeker-classifiers python scripts/benchmark_classifier.py --export-db --db /path/to/staging.db
```
## Labels (peregrine defaults — configurable per product)
| Label | Meaning |
|-------|---------|
| `interview_scheduled` | Phone screen, video call, or on-site invitation |
| `offer_received` | Formal job offer or offer letter |
| `rejected` | Application declined or not moving forward |
| `positive_response` | Recruiter interest or request to connect |
| `survey_received` | Culture-fit survey or assessment invitation |
| `neutral` | ATS confirmation or unrelated email |
## Model Registry (13 models, 7 defaults)
See `scripts/benchmark_classifier.py:MODEL_REGISTRY`.
Default models run without `--include-slow`.
Add `--models deberta-small deberta-small-2pass` to test a specific subset.
## Config Files
- `config/label_tool.yaml` — gitignored; multi-account IMAP config
- `config/label_tool.yaml.example` — committed template
## Data Files
- `data/email_score.jsonl` — gitignored; manually-labeled ground truth
- `data/email_score.jsonl.example` — committed sample for CI
- `data/email_label_queue.jsonl` — gitignored; IMAP fetch queue
## Key Design Notes
- `ZeroShotAdapter.load()` instantiates the pipeline object; `classify()` calls the object.
Tests patch `scripts.classifier_adapters.pipeline` (the module-level factory) with a
two-level mock: `mock_factory.return_value = MagicMock(return_value={...})`.
- `two_pass=True` on ZeroShotAdapter: first pass ranks all 6 labels; second pass re-runs
with only top-2, forcing a binary choice. 2× cost, better confidence.
- `--compare` uses the first account in `label_tool.yaml` for live IMAP emails.
- DB export labels are llama3.1:8b-generated — treat as noisy, not gold truth.
## Relationship to Peregrine
Avocet started as `peregrine/tools/label_tool.py` + `peregrine/scripts/classifier_adapters.py`.
Peregrine retains copies during stabilization; once avocet is proven, peregrine will import from here.

568
app/label_tool.py Normal file
View file

@ -0,0 +1,568 @@
"""Email Label Tool — card-stack UI for building classifier benchmark data.
Philosophy: Scrape Store Process
Fetch (IMAP, wide search, multi-account) data/email_label_queue.jsonl
Label (card stack) data/email_score.jsonl
Run:
conda run -n job-seeker streamlit run app/label_tool.py --server.port 8503
Config: config/label_tool.yaml (gitignored see config/label_tool.yaml.example)
"""
from __future__ import annotations
import email as _email_lib
import hashlib
import imaplib
import json
import sys
from datetime import datetime, timedelta
from email.header import decode_header as _raw_decode
from pathlib import Path
from typing import Any
import streamlit as st
import yaml
# ── Path setup ─────────────────────────────────────────────────────────────
_ROOT = Path(__file__).parent.parent
sys.path.insert(0, str(_ROOT))
_QUEUE_FILE = _ROOT / "data" / "email_label_queue.jsonl"
_SCORE_FILE = _ROOT / "data" / "email_score.jsonl"
_CFG_FILE = _ROOT / "config" / "label_tool.yaml"
# ── Labels ─────────────────────────────────────────────────────────────────
LABELS = [
"interview_scheduled",
"offer_received",
"rejected",
"positive_response",
"survey_received",
"neutral",
]
_LABEL_META: dict[str, dict] = {
"interview_scheduled": {"emoji": "🗓️", "color": "#4CAF50", "key": "1"},
"offer_received": {"emoji": "🎉", "color": "#2196F3", "key": "2"},
"rejected": {"emoji": "", "color": "#F44336", "key": "3"},
"positive_response": {"emoji": "👍", "color": "#FF9800", "key": "4"},
"survey_received": {"emoji": "📋", "color": "#9C27B0", "key": "5"},
"neutral": {"emoji": "", "color": "#607D8B", "key": "6"},
}
# ── Wide IMAP search terms (cast a net across all 6 categories) ─────────────
_WIDE_TERMS = [
# interview_scheduled
"interview", "phone screen", "video call", "zoom link", "schedule a call",
# offer_received
"offer letter", "job offer", "offer of employment", "pleased to offer",
# rejected
"unfortunately", "not moving forward", "other candidates", "regret to inform",
"no longer", "decided not to", "decided to go with",
# positive_response
"opportunity", "interested in your background", "reached out", "great fit",
"exciting role", "love to connect",
# survey_received
"assessment", "questionnaire", "culture fit", "culture-fit", "online assessment",
# neutral / ATS confirms
"application received", "thank you for applying", "application confirmation",
"you applied", "your application for",
# general recruitment
"application", "recruiter", "recruiting", "hiring", "candidate",
]
# ── IMAP helpers ────────────────────────────────────────────────────────────
def _decode_str(value: str | None) -> str:
if not value:
return ""
parts = _raw_decode(value)
out = []
for part, enc in parts:
if isinstance(part, bytes):
out.append(part.decode(enc or "utf-8", errors="replace"))
else:
out.append(str(part))
return " ".join(out).strip()
def _extract_body(msg: Any) -> str:
if msg.is_multipart():
for part in msg.walk():
if part.get_content_type() == "text/plain":
try:
charset = part.get_content_charset() or "utf-8"
return part.get_payload(decode=True).decode(charset, errors="replace")
except Exception:
pass
else:
try:
charset = msg.get_content_charset() or "utf-8"
return msg.get_payload(decode=True).decode(charset, errors="replace")
except Exception:
pass
return ""
def _fetch_account(cfg: dict, days: int, limit: int, known_keys: set[str],
progress_cb=None) -> list[dict]:
"""Fetch emails from one IMAP account using wide recruitment search terms."""
since = (datetime.now() - timedelta(days=days)).strftime("%d-%b-%Y")
host = cfg.get("host", "imap.gmail.com")
port = int(cfg.get("port", 993))
use_ssl = cfg.get("use_ssl", True)
username = cfg["username"]
password = cfg["password"]
name = cfg.get("name", username)
conn = (imaplib.IMAP4_SSL if use_ssl else imaplib.IMAP4)(host, port)
conn.login(username, password)
seen_uids: dict[bytes, None] = {}
conn.select("INBOX", readonly=True)
for term in _WIDE_TERMS:
try:
_, data = conn.search(None, f'(SUBJECT "{term}" SINCE "{since}")')
for uid in (data[0] or b"").split():
seen_uids[uid] = None
except Exception:
pass
emails: list[dict] = []
uids = list(seen_uids.keys())[:limit * 3] # overfetch; filter after dedup
for i, uid in enumerate(uids):
if len(emails) >= limit:
break
if progress_cb:
progress_cb(i / len(uids), f"{name}: {len(emails)} fetched…")
try:
_, raw_data = conn.fetch(uid, "(RFC822)")
if not raw_data or not raw_data[0]:
continue
msg = _email_lib.message_from_bytes(raw_data[0][1])
subj = _decode_str(msg.get("Subject", ""))
from_addr = _decode_str(msg.get("From", ""))
date = _decode_str(msg.get("Date", ""))
body = _extract_body(msg)[:800]
entry = {
"subject": subj,
"body": body,
"from_addr": from_addr,
"date": date,
"account": name,
}
key = _entry_key(entry)
if key not in known_keys:
known_keys.add(key)
emails.append(entry)
except Exception:
pass
try:
conn.logout()
except Exception:
pass
return emails
# ── Queue / score file helpers ───────────────────────────────────────────────
def _entry_key(e: dict) -> str:
return hashlib.md5(
(e.get("subject", "") + (e.get("body") or "")[:100]).encode()
).hexdigest()
def _load_jsonl(path: Path) -> list[dict]:
if not path.exists():
return []
rows = []
with path.open() as f:
for line in f:
line = line.strip()
if line:
try:
rows.append(json.loads(line))
except Exception:
pass
return rows
def _save_jsonl(path: Path, rows: list[dict]) -> None:
path.parent.mkdir(parents=True, exist_ok=True)
with path.open("w") as f:
for row in rows:
f.write(json.dumps(row, ensure_ascii=False) + "\n")
def _append_jsonl(path: Path, row: dict) -> None:
path.parent.mkdir(parents=True, exist_ok=True)
with path.open("a") as f:
f.write(json.dumps(row, ensure_ascii=False) + "\n")
# ── Config ──────────────────────────────────────────────────────────────────
def _load_config() -> list[dict]:
if not _CFG_FILE.exists():
return []
cfg = yaml.safe_load(_CFG_FILE.read_text()) or {}
return cfg.get("accounts", [])
# ── Page setup ──────────────────────────────────────────────────────────────
st.set_page_config(
page_title="Avocet — Email Labeler",
page_icon="📬",
layout="wide",
)
st.markdown("""
<style>
/* Card stack */
.email-card {
border: 1px solid rgba(128,128,128,0.25);
border-radius: 14px;
padding: 28px 32px;
box-shadow: 0 6px 24px rgba(0,0,0,0.18);
margin-bottom: 4px;
position: relative;
}
.card-stack-hint {
height: 10px;
border-radius: 0 0 12px 12px;
border: 1px solid rgba(128,128,128,0.15);
margin: 0 16px;
box-shadow: 0 4px 12px rgba(0,0,0,0.10);
}
.card-stack-hint2 {
height: 8px;
border-radius: 0 0 10px 10px;
border: 1px solid rgba(128,128,128,0.08);
margin: 0 32px;
}
/* Subject line */
.card-subject { font-size: 1.3rem; font-weight: 700; margin-bottom: 6px; }
.card-meta { font-size: 0.82rem; opacity: 0.6; margin-bottom: 16px; }
.card-body { font-size: 0.92rem; opacity: 0.85; white-space: pre-wrap; line-height: 1.5; }
/* Bucket buttons */
div[data-testid="stButton"] > button.bucket-btn {
height: 70px;
font-size: 1.05rem;
font-weight: 600;
border-radius: 12px;
}
</style>
""", unsafe_allow_html=True)
st.title("📬 Avocet — Email Label Tool")
st.caption("Scrape → Store → Process | card-stack edition")
# ── Session state init ───────────────────────────────────────────────────────
if "queue" not in st.session_state:
st.session_state.queue: list[dict] = _load_jsonl(_QUEUE_FILE)
if "labeled" not in st.session_state:
st.session_state.labeled: list[dict] = _load_jsonl(_SCORE_FILE)
st.session_state.labeled_keys: set[str] = {
_entry_key(r) for r in st.session_state.labeled
}
if "idx" not in st.session_state:
# Start past already-labeled entries in the queue
labeled_keys = st.session_state.labeled_keys
for i, entry in enumerate(st.session_state.queue):
if _entry_key(entry) not in labeled_keys:
st.session_state.idx = i
break
else:
st.session_state.idx = len(st.session_state.queue)
if "history" not in st.session_state:
st.session_state.history: list[tuple[int, str]] = [] # (queue_idx, label)
# ── Sidebar stats ────────────────────────────────────────────────────────────
with st.sidebar:
labeled = st.session_state.labeled
queue = st.session_state.queue
unlabeled = [e for e in queue if _entry_key(e) not in st.session_state.labeled_keys]
st.metric("✅ Labeled", len(labeled))
st.metric("📥 Queue", len(unlabeled))
if labeled:
st.caption("**Label distribution**")
counts = {lbl: 0 for lbl in LABELS}
for r in labeled:
counts[r.get("label", "")] = counts.get(r.get("label", ""), 0) + 1
for lbl in LABELS:
m = _LABEL_META[lbl]
st.caption(f"{m['emoji']} {lbl}: **{counts[lbl]}**")
# ── Tabs ─────────────────────────────────────────────────────────────────────
tab_label, tab_fetch, tab_stats = st.tabs(["🃏 Label", "📥 Fetch", "📊 Stats"])
# ══════════════════════════════════════════════════════════════════════════════
# FETCH TAB
# ══════════════════════════════════════════════════════════════════════════════
with tab_fetch:
accounts = _load_config()
if not accounts:
st.warning(
f"No accounts configured. Copy `config/label_tool.yaml.example` → "
f"`config/label_tool.yaml` and add your IMAP accounts.",
icon="⚠️",
)
else:
st.markdown(f"**{len(accounts)} account(s) configured:**")
for acc in accounts:
st.caption(f"{acc.get('name', acc.get('username'))} ({acc.get('host')})")
col_days, col_limit = st.columns(2)
days = col_days.number_input("Days back", min_value=7, max_value=730, value=180)
limit = col_limit.number_input("Max emails per account", min_value=10, max_value=1000, value=150)
all_accs = [a.get("name", a.get("username")) for a in accounts]
selected = st.multiselect("Accounts to fetch", all_accs, default=all_accs)
if st.button("📥 Fetch from IMAP", disabled=not accounts or not selected, type="primary"):
existing_keys = {_entry_key(e) for e in st.session_state.queue}
existing_keys.update(st.session_state.labeled_keys)
fetched_all: list[dict] = []
status = st.status("Fetching…", expanded=True)
for acc in accounts:
name = acc.get("name", acc.get("username"))
if name not in selected:
continue
status.write(f"Connecting to **{name}**…")
try:
emails = _fetch_account(
acc, days=int(days), limit=int(limit),
known_keys=existing_keys,
progress_cb=lambda p, msg: status.write(msg),
)
fetched_all.extend(emails)
status.write(f"{name}: {len(emails)} new emails")
except Exception as e:
status.write(f"{name}: {e}")
if fetched_all:
_save_jsonl(_QUEUE_FILE, st.session_state.queue + fetched_all)
st.session_state.queue = _load_jsonl(_QUEUE_FILE)
# Reset idx to first unlabeled
labeled_keys = st.session_state.labeled_keys
for i, entry in enumerate(st.session_state.queue):
if _entry_key(entry) not in labeled_keys:
st.session_state.idx = i
break
status.update(label=f"Done — {len(fetched_all)} new emails added to queue", state="complete")
else:
status.update(label="No new emails found (all already in queue or score file)", state="complete")
# ══════════════════════════════════════════════════════════════════════════════
# LABEL TAB
# ══════════════════════════════════════════════════════════════════════════════
with tab_label:
queue = st.session_state.queue
labeled_keys = st.session_state.labeled_keys
idx = st.session_state.idx
# Advance idx past already-labeled entries
while idx < len(queue) and _entry_key(queue[idx]) in labeled_keys:
idx += 1
st.session_state.idx = idx
unlabeled = [e for e in queue if _entry_key(e) not in labeled_keys]
total_in_queue = len(queue)
n_labeled = len(st.session_state.labeled)
if not queue:
st.info("Queue is empty — go to **Fetch** to pull emails from IMAP.", icon="📥")
elif not unlabeled:
st.success(
f"🎉 All {n_labeled} emails labeled! Go to **Stats** to review and export.",
icon="",
)
else:
# Progress
labeled_in_queue = total_in_queue - len(unlabeled)
progress_pct = labeled_in_queue / total_in_queue if total_in_queue else 0
st.progress(progress_pct, text=f"{labeled_in_queue} / {total_in_queue} labeled in queue")
# Current email
entry = queue[idx]
# Card HTML
subj = entry.get("subject", "(no subject)") or "(no subject)"
from_ = entry.get("from_addr", "") or ""
date_ = entry.get("date", "") or ""
acct = entry.get("account", "") or ""
body = (entry.get("body") or "").strip()
st.markdown(
f"""<div class="email-card">
<div class="card-meta">{from_} &nbsp;·&nbsp; {date_[:16]} &nbsp;·&nbsp; <em>{acct}</em></div>
<div class="card-subject">{subj}</div>
<div class="card-body">{body[:500].replace(chr(10), '<br>')}</div>
</div>""",
unsafe_allow_html=True,
)
if len(body) > 500:
with st.expander("Show full body"):
st.text(body)
# Stack hint (visual depth)
st.markdown('<div class="card-stack-hint"></div>', unsafe_allow_html=True)
st.markdown('<div class="card-stack-hint2"></div>', unsafe_allow_html=True)
st.markdown("") # spacer
# ── Bucket buttons ────────────────────────────────────────────────
def _do_label(label: str) -> None:
row = {"subject": entry.get("subject", ""), "body": body[:600], "label": label}
st.session_state.labeled.append(row)
st.session_state.labeled_keys.add(_entry_key(entry))
_append_jsonl(_SCORE_FILE, row)
st.session_state.history.append((idx, label))
# Advance
next_idx = idx + 1
while next_idx < len(queue) and _entry_key(queue[next_idx]) in labeled_keys:
next_idx += 1
st.session_state.idx = next_idx
row1_cols = st.columns(3)
row2_cols = st.columns(3)
bucket_pairs = [
(row1_cols[0], "interview_scheduled"),
(row1_cols[1], "offer_received"),
(row1_cols[2], "rejected"),
(row2_cols[0], "positive_response"),
(row2_cols[1], "survey_received"),
(row2_cols[2], "neutral"),
]
for col, lbl in bucket_pairs:
m = _LABEL_META[lbl]
counts = {l: 0 for l in LABELS}
for r in st.session_state.labeled:
counts[r.get("label", "")] = counts.get(r.get("label", ""), 0) + 1
label_display = f"{m['emoji']} **{lbl}** [{counts[lbl]}]\n`{m['key']}`"
if col.button(label_display, key=f"lbl_{lbl}", use_container_width=True):
_do_label(lbl)
st.rerun()
# ── Navigation ────────────────────────────────────────────────────
st.markdown("")
nav_cols = st.columns([2, 1, 1])
remaining = len(unlabeled) - 1
nav_cols[0].caption(f"**{remaining}** remaining · Keys: 16 = label, S = skip, U = undo")
if nav_cols[1].button("↩ Undo", disabled=not st.session_state.history, use_container_width=True):
prev_idx, prev_label = st.session_state.history.pop()
# Remove the last labeled entry
if st.session_state.labeled:
removed = st.session_state.labeled.pop()
st.session_state.labeled_keys.discard(_entry_key(removed))
_save_jsonl(_SCORE_FILE, st.session_state.labeled)
st.session_state.idx = prev_idx
st.rerun()
if nav_cols[2].button("→ Skip", use_container_width=True):
next_idx = idx + 1
while next_idx < len(queue) and _entry_key(queue[next_idx]) in labeled_keys:
next_idx += 1
st.session_state.idx = next_idx
st.rerun()
# Keyboard shortcut capture (JS → hidden button click)
st.components.v1.html(
"""<script>
document.addEventListener('keydown', function(e) {
if (e.target.tagName === 'INPUT' || e.target.tagName === 'TEXTAREA') return;
const keyToLabel = {
'1':'interview_scheduled','2':'offer_received','3':'rejected',
'4':'positive_response','5':'survey_received','6':'neutral'
};
const label = keyToLabel[e.key];
if (label) {
const btns = window.parent.document.querySelectorAll('button');
for (const btn of btns) {
if (btn.innerText.toLowerCase().includes(label.replace('_',' '))) {
btn.click(); break;
}
}
} else if (e.key.toLowerCase() === 's') {
const btns = window.parent.document.querySelectorAll('button');
for (const btn of btns) {
if (btn.innerText.includes('Skip')) { btn.click(); break; }
}
} else if (e.key.toLowerCase() === 'u') {
const btns = window.parent.document.querySelectorAll('button');
for (const btn of btns) {
if (btn.innerText.includes('Undo')) { btn.click(); break; }
}
}
});
</script>""",
height=0,
)
# ══════════════════════════════════════════════════════════════════════════════
# STATS TAB
# ══════════════════════════════════════════════════════════════════════════════
with tab_stats:
labeled = st.session_state.labeled
if not labeled:
st.info("No labeled emails yet.")
else:
counts = {lbl: 0 for lbl in LABELS}
for r in labeled:
lbl = r.get("label", "")
if lbl in counts:
counts[lbl] += 1
st.markdown(f"**{len(labeled)} labeled emails total**")
for lbl in LABELS:
m = _LABEL_META[lbl]
col_name, col_bar, col_n = st.columns([3, 5, 1])
col_name.markdown(f"{m['emoji']} {lbl}")
col_bar.progress(counts[lbl] / max(counts.values()) if counts.values() else 0)
col_n.markdown(f"**{counts[lbl]}**")
st.divider()
st.caption(
f"Score file: `{_SCORE_FILE.relative_to(_ROOT)}` "
f"({_SCORE_FILE.stat().st_size if _SCORE_FILE.exists() else 0:,} bytes)"
)
if st.button("🔄 Re-sync from disk"):
st.session_state.labeled = _load_jsonl(_SCORE_FILE)
st.session_state.labeled_keys = {_entry_key(r) for r in st.session_state.labeled}
st.rerun()
if _SCORE_FILE.exists():
st.download_button(
"⬇️ Download email_score.jsonl",
data=_SCORE_FILE.read_bytes(),
file_name="email_score.jsonl",
mime="application/jsonlines",
)

View file

@ -0,0 +1,23 @@
# config/label_tool.yaml — Multi-account IMAP config for the email label tool
# Copy to config/label_tool.yaml and fill in your credentials.
# This file is gitignored.
accounts:
- name: "Gmail"
host: "imap.gmail.com"
port: 993
username: "you@gmail.com"
password: "your-app-password" # Use an App Password, not your login password
folder: "INBOX"
days_back: 90
- name: "Outlook"
host: "outlook.office365.com"
port: 993
username: "you@outlook.com"
password: "your-app-password"
folder: "INBOX"
days_back: 90
# Optional: limit emails fetched per account per run (0 = unlimited)
max_per_account: 500

0
data/.gitkeep Normal file
View file

View file

@ -0,0 +1,8 @@
{"subject": "Interview Invitation — Senior Engineer", "body": "Hi Meghan, we'd love to schedule a 30-min phone screen. Are you available Thursday at 2pm? Please reply to confirm.", "label": "interview_scheduled"}
{"subject": "Your application to Acme Corp", "body": "Thank you for your interest in the Senior Engineer role. After careful consideration, we have decided to move forward with other candidates whose experience more closely matches our current needs.", "label": "rejected"}
{"subject": "Offer Letter — Product Manager at Initech", "body": "Dear Meghan, we are thrilled to extend an offer of employment for the Product Manager position. Please find the attached offer letter outlining compensation and start date.", "label": "offer_received"}
{"subject": "Quick question about your background", "body": "Hi Meghan, I came across your profile and would love to connect. We have a few roles that seem like a great match. Would you be open to a brief chat this week?", "label": "positive_response"}
{"subject": "Company Culture Survey — Acme Corp", "body": "Meghan, as part of our evaluation process, we invite all candidates to complete our culture fit assessment. The survey takes approximately 15 minutes. Please click the link below.", "label": "survey_received"}
{"subject": "Application Received — DataCo", "body": "Thank you for submitting your application for the Data Engineer role at DataCo. We have received your materials and will be in touch if your qualifications match our needs.", "label": "neutral"}
{"subject": "Following up on your application", "body": "Hi Meghan, I wanted to follow up on your recent application. Your background looks interesting and we'd like to learn more. Can we set up a quick call?", "label": "positive_response"}
{"subject": "We're moving forward with other candidates", "body": "Dear Meghan, thank you for taking the time to interview with us. After thoughtful consideration, we have decided not to move forward with your candidacy at this time.", "label": "rejected"}

25
environment.yml Normal file
View file

@ -0,0 +1,25 @@
name: job-seeker-classifiers
channels:
- conda-forge
- defaults
dependencies:
- python=3.11
- pip
- pip:
# UI
- streamlit>=1.32
- pyyaml>=6.0
# Classifier backends (heavy — install selectively)
- transformers>=4.40
- torch>=2.2
- accelerate>=0.27
# Optional: GLiClass adapter
# - gliclass
# Optional: BGE reranker adapter
# - FlagEmbedding
# Dev
- pytest>=8.0

5
pytest.ini Normal file
View file

@ -0,0 +1,5 @@
[pytest]
testpaths = tests
python_files = test_*.py
python_classes = Test*
python_functions = test_*

0
scripts/__init__.py Normal file
View file

View file

@ -0,0 +1,450 @@
#!/usr/bin/env python
"""
Email classifier benchmark compare HuggingFace models against our 6 labels.
Usage:
# List available models
conda run -n job-seeker-classifiers python scripts/benchmark_classifier.py --list-models
# Score against labeled JSONL
conda run -n job-seeker-classifiers python scripts/benchmark_classifier.py --score
# Visual comparison on live IMAP emails (uses first account in label_tool.yaml)
conda run -n job-seeker-classifiers python scripts/benchmark_classifier.py --compare --limit 20
# Include slow/large models
conda run -n job-seeker-classifiers python scripts/benchmark_classifier.py --score --include-slow
# Export DB-labeled emails (⚠️ LLM-generated labels)
conda run -n job-seeker-classifiers python scripts/benchmark_classifier.py --export-db --db /path/to/staging.db
"""
from __future__ import annotations
import argparse
import email as _email_lib
import imaplib
import json
import sys
import time
from datetime import datetime, timedelta
from pathlib import Path
from typing import Any
sys.path.insert(0, str(Path(__file__).parent.parent))
from scripts.classifier_adapters import (
LABELS,
LABEL_DESCRIPTIONS,
ClassifierAdapter,
GLiClassAdapter,
RerankerAdapter,
ZeroShotAdapter,
compute_metrics,
)
# ---------------------------------------------------------------------------
# Model registry
# ---------------------------------------------------------------------------
MODEL_REGISTRY: dict[str, dict[str, Any]] = {
"deberta-zeroshot": {
"adapter": ZeroShotAdapter,
"model_id": "MoritzLaurer/DeBERTa-v3-large-zeroshot-v2.0",
"params": "400M",
"default": True,
},
"deberta-small": {
"adapter": ZeroShotAdapter,
"model_id": "cross-encoder/nli-deberta-v3-small",
"params": "100M",
"default": True,
},
"gliclass-large": {
"adapter": GLiClassAdapter,
"model_id": "knowledgator/gliclass-instruct-large-v1.0",
"params": "400M",
"default": True,
},
"bart-mnli": {
"adapter": ZeroShotAdapter,
"model_id": "facebook/bart-large-mnli",
"params": "400M",
"default": True,
},
"bge-m3-zeroshot": {
"adapter": ZeroShotAdapter,
"model_id": "MoritzLaurer/bge-m3-zeroshot-v2.0",
"params": "600M",
"default": True,
},
"deberta-small-2pass": {
"adapter": ZeroShotAdapter,
"model_id": "cross-encoder/nli-deberta-v3-small",
"params": "100M",
"default": True,
"kwargs": {"two_pass": True},
},
"deberta-base-anli": {
"adapter": ZeroShotAdapter,
"model_id": "MoritzLaurer/DeBERTa-v3-base-mnli-fever-anli",
"params": "200M",
"default": True,
},
"deberta-large-ling": {
"adapter": ZeroShotAdapter,
"model_id": "MoritzLaurer/DeBERTa-v3-large-mnli-fever-anli-ling-wanli",
"params": "400M",
"default": False,
},
"mdeberta-xnli-2m": {
"adapter": ZeroShotAdapter,
"model_id": "MoritzLaurer/mDeBERTa-v3-base-xnli-multilingual-nli-2mil7",
"params": "300M",
"default": False,
},
"bge-reranker": {
"adapter": RerankerAdapter,
"model_id": "BAAI/bge-reranker-v2-m3",
"params": "600M",
"default": False,
},
"deberta-xlarge": {
"adapter": ZeroShotAdapter,
"model_id": "microsoft/deberta-xlarge-mnli",
"params": "750M",
"default": False,
},
"mdeberta-mnli": {
"adapter": ZeroShotAdapter,
"model_id": "MoritzLaurer/mDeBERTa-v3-base-mnli-xnli",
"params": "300M",
"default": False,
},
"xlm-roberta-anli": {
"adapter": ZeroShotAdapter,
"model_id": "vicgalle/xlm-roberta-large-xnli-anli",
"params": "600M",
"default": False,
},
}
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
def load_scoring_jsonl(path: str) -> list[dict[str, str]]:
"""Load labeled examples from a JSONL file for benchmark scoring."""
p = Path(path)
if not p.exists():
raise FileNotFoundError(
f"Scoring file not found: {path}\n"
f"Copy data/email_score.jsonl.example → data/email_score.jsonl "
f"or use the label tool (app/label_tool.py) to label your own emails."
)
rows = []
with p.open() as f:
for line in f:
line = line.strip()
if line:
rows.append(json.loads(line))
return rows
def _active_models(include_slow: bool) -> dict[str, dict[str, Any]]:
return {k: v for k, v in MODEL_REGISTRY.items() if v["default"] or include_slow}
def run_scoring(
adapters: list[ClassifierAdapter],
score_file: str,
) -> dict[str, Any]:
"""Run all adapters against a labeled JSONL. Returns per-adapter metrics."""
rows = load_scoring_jsonl(score_file)
gold = [r["label"] for r in rows]
results: dict[str, Any] = {}
for adapter in adapters:
preds: list[str] = []
t0 = time.monotonic()
for row in rows:
try:
pred = adapter.classify(row["subject"], row["body"])
except Exception as exc:
print(f" [{adapter.name}] ERROR on '{row['subject'][:40]}': {exc}", flush=True)
pred = "neutral"
preds.append(pred)
elapsed_ms = (time.monotonic() - t0) * 1000
metrics = compute_metrics(preds, gold, LABELS)
metrics["latency_ms"] = round(elapsed_ms / len(rows), 1)
results[adapter.name] = metrics
adapter.unload()
return results
# ---------------------------------------------------------------------------
# IMAP helpers (stdlib only — reads label_tool.yaml, uses first account)
# ---------------------------------------------------------------------------
_BROAD_TERMS = [
"interview", "opportunity", "offer letter",
"job offer", "application", "recruiting",
]
def _load_imap_config() -> dict[str, Any]:
"""Load IMAP config from label_tool.yaml, returning first account as a flat dict."""
import yaml
cfg_path = Path(__file__).parent.parent / "config" / "label_tool.yaml"
if not cfg_path.exists():
raise FileNotFoundError(
f"IMAP config not found: {cfg_path}\n"
f"Copy config/label_tool.yaml.example → config/label_tool.yaml"
)
cfg = yaml.safe_load(cfg_path.read_text()) or {}
accounts = cfg.get("accounts", [])
if not accounts:
raise ValueError("No accounts configured in config/label_tool.yaml")
return accounts[0]
def _imap_connect(cfg: dict[str, Any]) -> imaplib.IMAP4_SSL:
conn = imaplib.IMAP4_SSL(cfg["host"], cfg.get("port", 993))
conn.login(cfg["username"], cfg["password"])
return conn
def _decode_part(part: Any) -> str:
charset = part.get_content_charset() or "utf-8"
try:
return part.get_payload(decode=True).decode(charset, errors="replace")
except Exception:
return ""
def _parse_uid(conn: imaplib.IMAP4_SSL, uid: bytes) -> dict[str, str] | None:
try:
_, data = conn.uid("fetch", uid, "(RFC822)")
raw = data[0][1]
msg = _email_lib.message_from_bytes(raw)
subject = str(msg.get("subject", "")).strip()
body = ""
if msg.is_multipart():
for part in msg.walk():
if part.get_content_type() == "text/plain":
body = _decode_part(part)
break
else:
body = _decode_part(msg)
return {"subject": subject, "body": body}
except Exception:
return None
def _fetch_imap_sample(limit: int, days: int) -> list[dict[str, str]]:
cfg = _load_imap_config()
conn = _imap_connect(cfg)
since = (datetime.now() - timedelta(days=days)).strftime("%d-%b-%Y")
conn.select("INBOX")
seen_uids: dict[bytes, None] = {}
for term in _BROAD_TERMS:
_, data = conn.uid("search", None, f'(SUBJECT "{term}" SINCE {since})')
for uid in (data[0] or b"").split():
seen_uids[uid] = None
sample = list(seen_uids.keys())[:limit]
emails = []
for uid in sample:
parsed = _parse_uid(conn, uid)
if parsed:
emails.append(parsed)
try:
conn.logout()
except Exception:
pass
return emails
# ---------------------------------------------------------------------------
# DB export
# ---------------------------------------------------------------------------
def cmd_export_db(args: argparse.Namespace) -> None:
"""Export LLM-labeled emails from a peregrine-style job_contacts table → scoring JSONL."""
import sqlite3
db_path = Path(args.db)
if not db_path.exists():
print(f"ERROR: Database not found: {args.db}", file=sys.stderr)
sys.exit(1)
conn = sqlite3.connect(db_path)
cur = conn.cursor()
cur.execute("""
SELECT subject, body, stage_signal
FROM job_contacts
WHERE stage_signal IS NOT NULL
AND stage_signal != ''
AND direction = 'inbound'
ORDER BY received_at
""")
rows = cur.fetchall()
conn.close()
if not rows:
print("No labeled emails in job_contacts. Run imap_sync first to populate.")
return
out_path = Path(args.score_file)
out_path.parent.mkdir(parents=True, exist_ok=True)
written = 0
skipped = 0
label_counts: dict[str, int] = {}
with out_path.open("w") as f:
for subject, body, label in rows:
if label not in LABELS:
print(f" SKIP unknown label '{label}': {subject[:50]}")
skipped += 1
continue
json.dump({"subject": subject or "", "body": (body or "")[:600], "label": label}, f)
f.write("\n")
label_counts[label] = label_counts.get(label, 0) + 1
written += 1
print(f"\nExported {written} emails → {out_path}" + (f" ({skipped} skipped)" if skipped else ""))
print("\nLabel distribution:")
for label in LABELS:
count = label_counts.get(label, 0)
bar = "" * count
print(f" {label:<25} {count:>3} {bar}")
print(
"\nNOTE: Labels are LLM predictions from imap_sync — review before treating as ground truth."
)
# ---------------------------------------------------------------------------
# Subcommands
# ---------------------------------------------------------------------------
def cmd_list_models(_args: argparse.Namespace) -> None:
print(f"\n{'Name':<24} {'Params':<8} {'Default':<20} {'Adapter':<15} Model ID")
print("-" * 104)
for name, entry in MODEL_REGISTRY.items():
adapter_name = entry["adapter"].__name__
if entry.get("kwargs", {}).get("two_pass"):
adapter_name += " (2-pass)"
default_flag = "yes" if entry["default"] else "(--include-slow)"
print(f"{name:<24} {entry['params']:<8} {default_flag:<20} {adapter_name:<15} {entry['model_id']}")
print()
def cmd_score(args: argparse.Namespace) -> None:
active = _active_models(args.include_slow)
if args.models:
active = {k: v for k, v in active.items() if k in args.models}
adapters = [
entry["adapter"](name, entry["model_id"], **entry.get("kwargs", {}))
for name, entry in active.items()
]
print(f"\nScoring {len(adapters)} model(s) against {args.score_file}\n")
results = run_scoring(adapters, args.score_file)
col = 12
print(f"{'Model':<22}" + f"{'macro-F1':>{col}} {'Accuracy':>{col}} {'ms/email':>{col}}")
print("-" * (22 + col * 3 + 2))
for name, m in results.items():
print(
f"{name:<22}"
f"{m['__macro_f1__']:>{col}.3f}"
f"{m['__accuracy__']:>{col}.3f}"
f"{m['latency_ms']:>{col}.1f}"
)
print("\nPer-label F1:")
names = list(results.keys())
print(f"{'Label':<25}" + "".join(f"{n[:11]:>{col}}" for n in names))
print("-" * (25 + col * len(names)))
for label in LABELS:
row_str = f"{label:<25}"
for m in results.values():
row_str += f"{m[label]['f1']:>{col}.3f}"
print(row_str)
print()
def cmd_compare(args: argparse.Namespace) -> None:
active = _active_models(args.include_slow)
if args.models:
active = {k: v for k, v in active.items() if k in args.models}
print(f"Fetching up to {args.limit} emails from IMAP …")
emails = _fetch_imap_sample(args.limit, args.days)
print(f"Fetched {len(emails)} emails. Loading {len(active)} model(s) …\n")
adapters = [
entry["adapter"](name, entry["model_id"], **entry.get("kwargs", {}))
for name, entry in active.items()
]
model_names = [a.name for a in adapters]
col = 22
subj_w = 50
print(f"{'Subject':<{subj_w}}" + "".join(f"{n:<{col}}" for n in model_names))
print("-" * (subj_w + col * len(model_names)))
for row in emails:
short_subj = row["subject"][:subj_w - 1] if len(row["subject"]) > subj_w else row["subject"]
line = f"{short_subj:<{subj_w}}"
for adapter in adapters:
try:
label = adapter.classify(row["subject"], row["body"])
except Exception as exc:
label = f"ERR:{str(exc)[:8]}"
line += f"{label:<{col}}"
print(line, flush=True)
for adapter in adapters:
adapter.unload()
print()
# ---------------------------------------------------------------------------
# CLI
# ---------------------------------------------------------------------------
def main() -> None:
parser = argparse.ArgumentParser(
description="Benchmark HuggingFace email classifiers against our 6 labels."
)
parser.add_argument("--list-models", action="store_true", help="Show model registry and exit")
parser.add_argument("--score", action="store_true", help="Score against labeled JSONL")
parser.add_argument("--compare", action="store_true", help="Visual table on live IMAP emails")
parser.add_argument("--export-db", action="store_true",
help="Export labeled emails from a staging.db → score JSONL")
parser.add_argument("--score-file", default="data/email_score.jsonl", help="Path to labeled JSONL")
parser.add_argument("--db", default="data/staging.db", help="Path to staging.db for --export-db")
parser.add_argument("--limit", type=int, default=20, help="Max emails for --compare")
parser.add_argument("--days", type=int, default=90, help="Days back for IMAP search")
parser.add_argument("--include-slow", action="store_true", help="Include non-default heavy models")
parser.add_argument("--models", nargs="+", help="Override: run only these model names")
args = parser.parse_args()
if args.list_models:
cmd_list_models(args)
elif args.score:
cmd_score(args)
elif args.compare:
cmd_compare(args)
elif args.export_db:
cmd_export_db(args)
else:
parser.print_help()
if __name__ == "__main__":
main()

View file

@ -0,0 +1,257 @@
"""Classifier adapters for email classification benchmark.
Each adapter wraps a HuggingFace model and normalizes output to LABELS.
Models load lazily on first classify() call; call unload() to free VRAM.
"""
from __future__ import annotations
import abc
from collections import defaultdict
from typing import Any
__all__ = [
"LABELS",
"LABEL_DESCRIPTIONS",
"compute_metrics",
"ClassifierAdapter",
"ZeroShotAdapter",
"GLiClassAdapter",
"RerankerAdapter",
]
LABELS: list[str] = [
"interview_scheduled",
"offer_received",
"rejected",
"positive_response",
"survey_received",
"neutral",
]
# Natural-language descriptions used by the RerankerAdapter.
LABEL_DESCRIPTIONS: dict[str, str] = {
"interview_scheduled": "scheduling an interview, phone screen, or video call",
"offer_received": "a formal job offer or employment offer letter",
"rejected": "application rejected or not moving forward with candidacy",
"positive_response": "positive recruiter interest or request to connect",
"survey_received": "invitation to complete a culture-fit survey or assessment",
"neutral": "automated ATS confirmation or unrelated email",
}
# Lazy import shims — allow tests to patch without requiring the libs installed.
try:
from transformers import pipeline # type: ignore[assignment]
except ImportError:
pipeline = None # type: ignore[assignment]
try:
from gliclass import GLiClassModel, ZeroShotClassificationPipeline # type: ignore
from transformers import AutoTokenizer
except ImportError:
GLiClassModel = None # type: ignore
ZeroShotClassificationPipeline = None # type: ignore
AutoTokenizer = None # type: ignore
try:
from FlagEmbedding import FlagReranker # type: ignore
except ImportError:
FlagReranker = None # type: ignore
def _cuda_available() -> bool:
try:
import torch
return torch.cuda.is_available()
except ImportError:
return False
def compute_metrics(
predictions: list[str],
gold: list[str],
labels: list[str],
) -> dict[str, Any]:
"""Return per-label precision/recall/F1 + macro_f1 + accuracy."""
tp: dict[str, int] = defaultdict(int)
fp: dict[str, int] = defaultdict(int)
fn: dict[str, int] = defaultdict(int)
for pred, true in zip(predictions, gold):
if pred == true:
tp[pred] += 1
else:
fp[pred] += 1
fn[true] += 1
result: dict[str, Any] = {}
for label in labels:
denom_p = tp[label] + fp[label]
denom_r = tp[label] + fn[label]
p = tp[label] / denom_p if denom_p else 0.0
r = tp[label] / denom_r if denom_r else 0.0
f1 = 2 * p * r / (p + r) if (p + r) else 0.0
result[label] = {
"precision": p,
"recall": r,
"f1": f1,
"support": denom_r,
}
labels_with_support = [label for label in labels if result[label]["support"] > 0]
if labels_with_support:
result["__macro_f1__"] = (
sum(result[label]["f1"] for label in labels_with_support) / len(labels_with_support)
)
else:
result["__macro_f1__"] = 0.0
result["__accuracy__"] = sum(tp.values()) / len(predictions) if predictions else 0.0
return result
class ClassifierAdapter(abc.ABC):
"""Abstract base for all email classifier adapters."""
@property
@abc.abstractmethod
def name(self) -> str: ...
@property
@abc.abstractmethod
def model_id(self) -> str: ...
@abc.abstractmethod
def load(self) -> None:
"""Download/load the model into memory."""
@abc.abstractmethod
def unload(self) -> None:
"""Release model from memory."""
@abc.abstractmethod
def classify(self, subject: str, body: str) -> str:
"""Return one of LABELS for the given email."""
class ZeroShotAdapter(ClassifierAdapter):
"""Wraps any transformers zero-shot-classification pipeline.
load() calls pipeline("zero-shot-classification", model=..., device=...) to get
an inference callable, stored as self._pipeline. classify() then calls
self._pipeline(text, LABELS, multi_label=False). In tests, patch
'scripts.classifier_adapters.pipeline' with a MagicMock whose .return_value is
itself a MagicMock(return_value={...}) to simulate both the factory call and the
inference call.
two_pass: if True, classify() runs a second pass restricted to the top-2 labels
from the first pass, forcing a binary choice. This typically improves confidence
without the accuracy cost of a full 6-label second run.
"""
def __init__(self, name: str, model_id: str, two_pass: bool = False) -> None:
self._name = name
self._model_id = model_id
self._pipeline: Any = None
self._two_pass = two_pass
@property
def name(self) -> str:
return self._name
@property
def model_id(self) -> str:
return self._model_id
def load(self) -> None:
import scripts.classifier_adapters as _mod # noqa: PLC0415
_pipe_fn = _mod.pipeline
if _pipe_fn is None:
raise ImportError("transformers not installed — run: pip install transformers")
device = 0 if _cuda_available() else -1
# Instantiate the pipeline once; classify() calls the resulting object on each text.
self._pipeline = _pipe_fn("zero-shot-classification", model=self._model_id, device=device)
def unload(self) -> None:
self._pipeline = None
def classify(self, subject: str, body: str) -> str:
if self._pipeline is None:
self.load()
text = f"Subject: {subject}\n\n{body[:600]}"
result = self._pipeline(text, LABELS, multi_label=False)
if self._two_pass and len(result["labels"]) >= 2:
top2 = result["labels"][:2]
result = self._pipeline(text, top2, multi_label=False)
return result["labels"][0]
class GLiClassAdapter(ClassifierAdapter):
"""Wraps knowledgator GLiClass models via the gliclass library."""
def __init__(self, name: str, model_id: str) -> None:
self._name = name
self._model_id = model_id
self._pipeline: Any = None
@property
def name(self) -> str:
return self._name
@property
def model_id(self) -> str:
return self._model_id
def load(self) -> None:
if GLiClassModel is None:
raise ImportError("gliclass not installed — run: pip install gliclass")
device = "cuda:0" if _cuda_available() else "cpu"
model = GLiClassModel.from_pretrained(self._model_id)
tokenizer = AutoTokenizer.from_pretrained(self._model_id)
self._pipeline = ZeroShotClassificationPipeline(
model,
tokenizer,
classification_type="single-label",
device=device,
)
def unload(self) -> None:
self._pipeline = None
def classify(self, subject: str, body: str) -> str:
if self._pipeline is None:
self.load()
text = f"Subject: {subject}\n\n{body[:600]}"
results = self._pipeline(text, LABELS, threshold=0.0)[0]
return max(results, key=lambda r: r["score"])["label"]
class RerankerAdapter(ClassifierAdapter):
"""Uses a BGE reranker to score (email, label_description) pairs."""
def __init__(self, name: str, model_id: str) -> None:
self._name = name
self._model_id = model_id
self._reranker: Any = None
@property
def name(self) -> str:
return self._name
@property
def model_id(self) -> str:
return self._model_id
def load(self) -> None:
if FlagReranker is None:
raise ImportError("FlagEmbedding not installed — run: pip install FlagEmbedding")
self._reranker = FlagReranker(self._model_id, use_fp16=_cuda_available())
def unload(self) -> None:
self._reranker = None
def classify(self, subject: str, body: str) -> str:
if self._reranker is None:
self.load()
text = f"Subject: {subject}\n\n{body[:600]}"
pairs = [[text, LABEL_DESCRIPTIONS[label]] for label in LABELS]
scores: list[float] = self._reranker.compute_score(pairs, normalize=True)
return LABELS[scores.index(max(scores))]

0
tests/__init__.py Normal file
View file

View file

@ -0,0 +1,94 @@
"""Tests for benchmark_classifier — no model downloads required."""
import pytest
def test_registry_has_thirteen_models():
from scripts.benchmark_classifier import MODEL_REGISTRY
assert len(MODEL_REGISTRY) == 13
def test_registry_default_count():
from scripts.benchmark_classifier import MODEL_REGISTRY
defaults = [k for k, v in MODEL_REGISTRY.items() if v["default"]]
assert len(defaults) == 7
def test_registry_entries_have_required_keys():
from scripts.benchmark_classifier import MODEL_REGISTRY
from scripts.classifier_adapters import ClassifierAdapter
for name, entry in MODEL_REGISTRY.items():
assert "adapter" in entry, f"{name} missing 'adapter'"
assert "model_id" in entry, f"{name} missing 'model_id'"
assert "params" in entry, f"{name} missing 'params'"
assert "default" in entry, f"{name} missing 'default'"
assert issubclass(entry["adapter"], ClassifierAdapter), \
f"{name} adapter must be a ClassifierAdapter subclass"
def test_load_scoring_jsonl(tmp_path):
from scripts.benchmark_classifier import load_scoring_jsonl
import json
f = tmp_path / "score.jsonl"
rows = [
{"subject": "Hi", "body": "Body text", "label": "neutral"},
{"subject": "Interview", "body": "Schedule a call", "label": "interview_scheduled"},
]
f.write_text("\n".join(json.dumps(r) for r in rows))
result = load_scoring_jsonl(str(f))
assert len(result) == 2
assert result[0]["label"] == "neutral"
def test_load_scoring_jsonl_missing_file():
from scripts.benchmark_classifier import load_scoring_jsonl
with pytest.raises(FileNotFoundError):
load_scoring_jsonl("/nonexistent/path.jsonl")
def test_run_scoring_with_mock_adapters(tmp_path):
"""run_scoring() returns per-model metrics using mock adapters."""
import json
from unittest.mock import MagicMock
from scripts.benchmark_classifier import run_scoring
score_file = tmp_path / "score.jsonl"
rows = [
{"subject": "Interview", "body": "Let's schedule", "label": "interview_scheduled"},
{"subject": "Sorry", "body": "We went with others", "label": "rejected"},
{"subject": "Offer", "body": "We are pleased", "label": "offer_received"},
]
score_file.write_text("\n".join(json.dumps(r) for r in rows))
perfect = MagicMock()
perfect.name = "perfect"
perfect.classify.side_effect = lambda s, b: (
"interview_scheduled" if "Interview" in s else
"rejected" if "Sorry" in s else "offer_received"
)
bad = MagicMock()
bad.name = "bad"
bad.classify.return_value = "neutral"
results = run_scoring([perfect, bad], str(score_file))
assert results["perfect"]["__accuracy__"] == pytest.approx(1.0)
assert results["bad"]["__accuracy__"] == pytest.approx(0.0)
assert "latency_ms" in results["perfect"]
def test_run_scoring_handles_classify_error(tmp_path):
"""run_scoring() falls back to 'neutral' on exception and continues."""
import json
from unittest.mock import MagicMock
from scripts.benchmark_classifier import run_scoring
score_file = tmp_path / "score.jsonl"
score_file.write_text(json.dumps({"subject": "Hi", "body": "Body", "label": "neutral"}))
broken = MagicMock()
broken.name = "broken"
broken.classify.side_effect = RuntimeError("model crashed")
results = run_scoring([broken], str(score_file))
assert "broken" in results

View file

@ -0,0 +1,177 @@
"""Tests for classifier_adapters — no model downloads required."""
import pytest
def test_labels_constant_has_six_items():
from scripts.classifier_adapters import LABELS
assert len(LABELS) == 6
assert "interview_scheduled" in LABELS
assert "neutral" in LABELS
def test_compute_metrics_perfect_predictions():
from scripts.classifier_adapters import compute_metrics, LABELS
gold = ["rejected", "interview_scheduled", "neutral"]
preds = ["rejected", "interview_scheduled", "neutral"]
m = compute_metrics(preds, gold, LABELS)
assert m["rejected"]["f1"] == pytest.approx(1.0)
assert m["__accuracy__"] == pytest.approx(1.0)
assert m["__macro_f1__"] == pytest.approx(1.0)
def test_compute_metrics_all_wrong():
from scripts.classifier_adapters import compute_metrics, LABELS
gold = ["rejected", "rejected"]
preds = ["neutral", "interview_scheduled"]
m = compute_metrics(preds, gold, LABELS)
assert m["rejected"]["recall"] == pytest.approx(0.0)
assert m["__accuracy__"] == pytest.approx(0.0)
def test_compute_metrics_partial():
from scripts.classifier_adapters import compute_metrics, LABELS
gold = ["rejected", "neutral", "rejected"]
preds = ["rejected", "neutral", "interview_scheduled"]
m = compute_metrics(preds, gold, LABELS)
assert m["rejected"]["precision"] == pytest.approx(1.0)
assert m["rejected"]["recall"] == pytest.approx(0.5)
assert m["neutral"]["f1"] == pytest.approx(1.0)
assert m["__accuracy__"] == pytest.approx(2 / 3)
def test_compute_metrics_empty():
from scripts.classifier_adapters import compute_metrics, LABELS
m = compute_metrics([], [], LABELS)
assert m["__accuracy__"] == pytest.approx(0.0)
def test_classifier_adapter_is_abstract():
from scripts.classifier_adapters import ClassifierAdapter
with pytest.raises(TypeError):
ClassifierAdapter()
# ---- ZeroShotAdapter tests ----
def test_zeroshot_adapter_classify_mocked():
from unittest.mock import MagicMock, patch
from scripts.classifier_adapters import ZeroShotAdapter
# Two-level mock: factory call returns pipeline instance; instance call returns inference result.
mock_pipe_factory = MagicMock()
mock_pipe_factory.return_value = MagicMock(return_value={
"labels": ["rejected", "neutral", "interview_scheduled"],
"scores": [0.85, 0.10, 0.05],
})
with patch("scripts.classifier_adapters.pipeline", mock_pipe_factory):
adapter = ZeroShotAdapter("test-zs", "some/model")
adapter.load()
result = adapter.classify("We went with another candidate", "Thank you for applying.")
assert result == "rejected"
# Factory was called with the correct task type
assert mock_pipe_factory.call_args[0][0] == "zero-shot-classification"
# Pipeline instance was called with the email text
assert "We went with another candidate" in mock_pipe_factory.return_value.call_args[0][0]
def test_zeroshot_adapter_unload_clears_pipeline():
from unittest.mock import MagicMock, patch
from scripts.classifier_adapters import ZeroShotAdapter
with patch("scripts.classifier_adapters.pipeline", MagicMock()):
adapter = ZeroShotAdapter("test-zs", "some/model")
adapter.load()
assert adapter._pipeline is not None
adapter.unload()
assert adapter._pipeline is None
def test_zeroshot_adapter_lazy_loads():
from unittest.mock import MagicMock, patch
from scripts.classifier_adapters import ZeroShotAdapter
mock_pipe_factory = MagicMock()
mock_pipe_factory.return_value = MagicMock(return_value={
"labels": ["neutral"], "scores": [1.0]
})
with patch("scripts.classifier_adapters.pipeline", mock_pipe_factory):
adapter = ZeroShotAdapter("test-zs", "some/model")
adapter.classify("subject", "body")
mock_pipe_factory.assert_called_once()
# ---- GLiClassAdapter tests ----
def test_gliclass_adapter_classify_mocked():
from unittest.mock import MagicMock, patch
from scripts.classifier_adapters import GLiClassAdapter
mock_pipeline_instance = MagicMock()
mock_pipeline_instance.return_value = [[
{"label": "interview_scheduled", "score": 0.91},
{"label": "neutral", "score": 0.05},
{"label": "rejected", "score": 0.04},
]]
with patch("scripts.classifier_adapters.GLiClassModel") as _mc, \
patch("scripts.classifier_adapters.AutoTokenizer") as _mt, \
patch("scripts.classifier_adapters.ZeroShotClassificationPipeline",
return_value=mock_pipeline_instance):
adapter = GLiClassAdapter("test-gli", "some/gliclass-model")
adapter.load()
result = adapter.classify("Interview invitation", "Let's schedule a call.")
assert result == "interview_scheduled"
def test_gliclass_adapter_returns_highest_score():
from unittest.mock import MagicMock, patch
from scripts.classifier_adapters import GLiClassAdapter
mock_pipeline_instance = MagicMock()
mock_pipeline_instance.return_value = [[
{"label": "neutral", "score": 0.02},
{"label": "offer_received", "score": 0.88},
{"label": "rejected", "score": 0.10},
]]
with patch("scripts.classifier_adapters.GLiClassModel"), \
patch("scripts.classifier_adapters.AutoTokenizer"), \
patch("scripts.classifier_adapters.ZeroShotClassificationPipeline",
return_value=mock_pipeline_instance):
adapter = GLiClassAdapter("test-gli", "some/model")
adapter.load()
result = adapter.classify("Offer letter enclosed", "Dear Meghan, we are pleased to offer...")
assert result == "offer_received"
# ---- RerankerAdapter tests ----
def test_reranker_adapter_picks_highest_score():
from unittest.mock import MagicMock, patch
from scripts.classifier_adapters import RerankerAdapter, LABELS
mock_reranker = MagicMock()
mock_reranker.compute_score.return_value = [0.1, 0.05, 0.85, 0.05, 0.02, 0.03]
with patch("scripts.classifier_adapters.FlagReranker", return_value=mock_reranker):
adapter = RerankerAdapter("test-rr", "BAAI/bge-reranker-v2-m3")
adapter.load()
result = adapter.classify(
"We regret to inform you",
"After careful consideration we are moving forward with other candidates.",
)
assert result == "rejected"
pairs = mock_reranker.compute_score.call_args[0][0]
assert len(pairs) == len(LABELS)
def test_reranker_adapter_descriptions_cover_all_labels():
from scripts.classifier_adapters import LABEL_DESCRIPTIONS, LABELS
assert set(LABEL_DESCRIPTIONS.keys()) == set(LABELS)