feat: initial avocet repo — email classifier training tool
Scrape → Store → Process pipeline for building email classifier
benchmark data across the CircuitForge menagerie.
- app/label_tool.py — Streamlit card-stack UI, multi-account IMAP fetch,
6-bucket labeling, undo/skip, keyboard shortcuts (1-6/S/U)
- scripts/classifier_adapters.py — ZeroShotAdapter (+ two_pass),
GLiClassAdapter, RerankerAdapter; ABC with lazy model loading
- scripts/benchmark_classifier.py — 13-model registry, --score,
--compare, --list-models, --export-db; uses label_tool.yaml for IMAP
- tests/ — 20 tests, all passing, zero model downloads required
- config/label_tool.yaml.example — multi-account IMAP template
- data/email_score.jsonl.example — sample labeled data for CI
Labels: interview_scheduled, offer_received, rejected,
positive_response, survey_received, neutral
This commit is contained in:
commit
d68754d432
13 changed files with 1623 additions and 0 deletions
16
.gitignore
vendored
Normal file
16
.gitignore
vendored
Normal file
|
|
@ -0,0 +1,16 @@
|
|||
__pycache__/
|
||||
*.pyc
|
||||
.pytest_cache/
|
||||
.coverage
|
||||
*.egg-info/
|
||||
|
||||
# Secrets and personal data
|
||||
config/label_tool.yaml
|
||||
|
||||
# Data files (user-generated, not for version control)
|
||||
data/email_score.jsonl
|
||||
data/email_label_queue.jsonl
|
||||
data/email_compare_sample.jsonl
|
||||
|
||||
# Conda/pip artifacts
|
||||
.env
|
||||
568
app/label_tool.py
Normal file
568
app/label_tool.py
Normal file
|
|
@ -0,0 +1,568 @@
|
|||
"""Email Label Tool — card-stack UI for building classifier benchmark data.
|
||||
|
||||
Philosophy: Scrape → Store → Process
|
||||
Fetch (IMAP, wide search, multi-account) → data/email_label_queue.jsonl
|
||||
Label (card stack) → data/email_score.jsonl
|
||||
|
||||
Run:
|
||||
conda run -n job-seeker streamlit run app/label_tool.py --server.port 8503
|
||||
|
||||
Config: config/label_tool.yaml (gitignored — see config/label_tool.yaml.example)
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import email as _email_lib
|
||||
import hashlib
|
||||
import imaplib
|
||||
import json
|
||||
import sys
|
||||
from datetime import datetime, timedelta
|
||||
from email.header import decode_header as _raw_decode
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import streamlit as st
|
||||
import yaml
|
||||
|
||||
# ── Path setup ─────────────────────────────────────────────────────────────
|
||||
_ROOT = Path(__file__).parent.parent
|
||||
sys.path.insert(0, str(_ROOT))
|
||||
|
||||
_QUEUE_FILE = _ROOT / "data" / "email_label_queue.jsonl"
|
||||
_SCORE_FILE = _ROOT / "data" / "email_score.jsonl"
|
||||
_CFG_FILE = _ROOT / "config" / "label_tool.yaml"
|
||||
|
||||
# ── Labels ─────────────────────────────────────────────────────────────────
|
||||
LABELS = [
|
||||
"interview_scheduled",
|
||||
"offer_received",
|
||||
"rejected",
|
||||
"positive_response",
|
||||
"survey_received",
|
||||
"neutral",
|
||||
]
|
||||
|
||||
_LABEL_META: dict[str, dict] = {
|
||||
"interview_scheduled": {"emoji": "🗓️", "color": "#4CAF50", "key": "1"},
|
||||
"offer_received": {"emoji": "🎉", "color": "#2196F3", "key": "2"},
|
||||
"rejected": {"emoji": "❌", "color": "#F44336", "key": "3"},
|
||||
"positive_response": {"emoji": "👍", "color": "#FF9800", "key": "4"},
|
||||
"survey_received": {"emoji": "📋", "color": "#9C27B0", "key": "5"},
|
||||
"neutral": {"emoji": "⬜", "color": "#607D8B", "key": "6"},
|
||||
}
|
||||
|
||||
# ── Wide IMAP search terms (cast a net across all 6 categories) ─────────────
|
||||
_WIDE_TERMS = [
|
||||
# interview_scheduled
|
||||
"interview", "phone screen", "video call", "zoom link", "schedule a call",
|
||||
# offer_received
|
||||
"offer letter", "job offer", "offer of employment", "pleased to offer",
|
||||
# rejected
|
||||
"unfortunately", "not moving forward", "other candidates", "regret to inform",
|
||||
"no longer", "decided not to", "decided to go with",
|
||||
# positive_response
|
||||
"opportunity", "interested in your background", "reached out", "great fit",
|
||||
"exciting role", "love to connect",
|
||||
# survey_received
|
||||
"assessment", "questionnaire", "culture fit", "culture-fit", "online assessment",
|
||||
# neutral / ATS confirms
|
||||
"application received", "thank you for applying", "application confirmation",
|
||||
"you applied", "your application for",
|
||||
# general recruitment
|
||||
"application", "recruiter", "recruiting", "hiring", "candidate",
|
||||
]
|
||||
|
||||
|
||||
# ── IMAP helpers ────────────────────────────────────────────────────────────
|
||||
|
||||
def _decode_str(value: str | None) -> str:
|
||||
if not value:
|
||||
return ""
|
||||
parts = _raw_decode(value)
|
||||
out = []
|
||||
for part, enc in parts:
|
||||
if isinstance(part, bytes):
|
||||
out.append(part.decode(enc or "utf-8", errors="replace"))
|
||||
else:
|
||||
out.append(str(part))
|
||||
return " ".join(out).strip()
|
||||
|
||||
|
||||
def _extract_body(msg: Any) -> str:
|
||||
if msg.is_multipart():
|
||||
for part in msg.walk():
|
||||
if part.get_content_type() == "text/plain":
|
||||
try:
|
||||
charset = part.get_content_charset() or "utf-8"
|
||||
return part.get_payload(decode=True).decode(charset, errors="replace")
|
||||
except Exception:
|
||||
pass
|
||||
else:
|
||||
try:
|
||||
charset = msg.get_content_charset() or "utf-8"
|
||||
return msg.get_payload(decode=True).decode(charset, errors="replace")
|
||||
except Exception:
|
||||
pass
|
||||
return ""
|
||||
|
||||
|
||||
def _fetch_account(cfg: dict, days: int, limit: int, known_keys: set[str],
|
||||
progress_cb=None) -> list[dict]:
|
||||
"""Fetch emails from one IMAP account using wide recruitment search terms."""
|
||||
since = (datetime.now() - timedelta(days=days)).strftime("%d-%b-%Y")
|
||||
host = cfg.get("host", "imap.gmail.com")
|
||||
port = int(cfg.get("port", 993))
|
||||
use_ssl = cfg.get("use_ssl", True)
|
||||
username = cfg["username"]
|
||||
password = cfg["password"]
|
||||
name = cfg.get("name", username)
|
||||
|
||||
conn = (imaplib.IMAP4_SSL if use_ssl else imaplib.IMAP4)(host, port)
|
||||
conn.login(username, password)
|
||||
|
||||
seen_uids: dict[bytes, None] = {}
|
||||
conn.select("INBOX", readonly=True)
|
||||
for term in _WIDE_TERMS:
|
||||
try:
|
||||
_, data = conn.search(None, f'(SUBJECT "{term}" SINCE "{since}")')
|
||||
for uid in (data[0] or b"").split():
|
||||
seen_uids[uid] = None
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
emails: list[dict] = []
|
||||
uids = list(seen_uids.keys())[:limit * 3] # overfetch; filter after dedup
|
||||
for i, uid in enumerate(uids):
|
||||
if len(emails) >= limit:
|
||||
break
|
||||
if progress_cb:
|
||||
progress_cb(i / len(uids), f"{name}: {len(emails)} fetched…")
|
||||
try:
|
||||
_, raw_data = conn.fetch(uid, "(RFC822)")
|
||||
if not raw_data or not raw_data[0]:
|
||||
continue
|
||||
msg = _email_lib.message_from_bytes(raw_data[0][1])
|
||||
subj = _decode_str(msg.get("Subject", ""))
|
||||
from_addr = _decode_str(msg.get("From", ""))
|
||||
date = _decode_str(msg.get("Date", ""))
|
||||
body = _extract_body(msg)[:800]
|
||||
entry = {
|
||||
"subject": subj,
|
||||
"body": body,
|
||||
"from_addr": from_addr,
|
||||
"date": date,
|
||||
"account": name,
|
||||
}
|
||||
key = _entry_key(entry)
|
||||
if key not in known_keys:
|
||||
known_keys.add(key)
|
||||
emails.append(entry)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
try:
|
||||
conn.logout()
|
||||
except Exception:
|
||||
pass
|
||||
return emails
|
||||
|
||||
|
||||
# ── Queue / score file helpers ───────────────────────────────────────────────
|
||||
|
||||
def _entry_key(e: dict) -> str:
|
||||
return hashlib.md5(
|
||||
(e.get("subject", "") + (e.get("body") or "")[:100]).encode()
|
||||
).hexdigest()
|
||||
|
||||
|
||||
def _load_jsonl(path: Path) -> list[dict]:
|
||||
if not path.exists():
|
||||
return []
|
||||
rows = []
|
||||
with path.open() as f:
|
||||
for line in f:
|
||||
line = line.strip()
|
||||
if line:
|
||||
try:
|
||||
rows.append(json.loads(line))
|
||||
except Exception:
|
||||
pass
|
||||
return rows
|
||||
|
||||
|
||||
def _save_jsonl(path: Path, rows: list[dict]) -> None:
|
||||
path.parent.mkdir(parents=True, exist_ok=True)
|
||||
with path.open("w") as f:
|
||||
for row in rows:
|
||||
f.write(json.dumps(row, ensure_ascii=False) + "\n")
|
||||
|
||||
|
||||
def _append_jsonl(path: Path, row: dict) -> None:
|
||||
path.parent.mkdir(parents=True, exist_ok=True)
|
||||
with path.open("a") as f:
|
||||
f.write(json.dumps(row, ensure_ascii=False) + "\n")
|
||||
|
||||
|
||||
# ── Config ──────────────────────────────────────────────────────────────────
|
||||
|
||||
def _load_config() -> list[dict]:
|
||||
if not _CFG_FILE.exists():
|
||||
return []
|
||||
cfg = yaml.safe_load(_CFG_FILE.read_text()) or {}
|
||||
return cfg.get("accounts", [])
|
||||
|
||||
|
||||
# ── Page setup ──────────────────────────────────────────────────────────────
|
||||
|
||||
st.set_page_config(
|
||||
page_title="Avocet — Email Labeler",
|
||||
page_icon="📬",
|
||||
layout="wide",
|
||||
)
|
||||
|
||||
st.markdown("""
|
||||
<style>
|
||||
/* Card stack */
|
||||
.email-card {
|
||||
border: 1px solid rgba(128,128,128,0.25);
|
||||
border-radius: 14px;
|
||||
padding: 28px 32px;
|
||||
box-shadow: 0 6px 24px rgba(0,0,0,0.18);
|
||||
margin-bottom: 4px;
|
||||
position: relative;
|
||||
}
|
||||
.card-stack-hint {
|
||||
height: 10px;
|
||||
border-radius: 0 0 12px 12px;
|
||||
border: 1px solid rgba(128,128,128,0.15);
|
||||
margin: 0 16px;
|
||||
box-shadow: 0 4px 12px rgba(0,0,0,0.10);
|
||||
}
|
||||
.card-stack-hint2 {
|
||||
height: 8px;
|
||||
border-radius: 0 0 10px 10px;
|
||||
border: 1px solid rgba(128,128,128,0.08);
|
||||
margin: 0 32px;
|
||||
}
|
||||
/* Subject line */
|
||||
.card-subject { font-size: 1.3rem; font-weight: 700; margin-bottom: 6px; }
|
||||
.card-meta { font-size: 0.82rem; opacity: 0.6; margin-bottom: 16px; }
|
||||
.card-body { font-size: 0.92rem; opacity: 0.85; white-space: pre-wrap; line-height: 1.5; }
|
||||
/* Bucket buttons */
|
||||
div[data-testid="stButton"] > button.bucket-btn {
|
||||
height: 70px;
|
||||
font-size: 1.05rem;
|
||||
font-weight: 600;
|
||||
border-radius: 12px;
|
||||
}
|
||||
</style>
|
||||
""", unsafe_allow_html=True)
|
||||
|
||||
st.title("📬 Avocet — Email Label Tool")
|
||||
st.caption("Scrape → Store → Process | card-stack edition")
|
||||
|
||||
# ── Session state init ───────────────────────────────────────────────────────
|
||||
|
||||
if "queue" not in st.session_state:
|
||||
st.session_state.queue: list[dict] = _load_jsonl(_QUEUE_FILE)
|
||||
|
||||
if "labeled" not in st.session_state:
|
||||
st.session_state.labeled: list[dict] = _load_jsonl(_SCORE_FILE)
|
||||
st.session_state.labeled_keys: set[str] = {
|
||||
_entry_key(r) for r in st.session_state.labeled
|
||||
}
|
||||
|
||||
if "idx" not in st.session_state:
|
||||
# Start past already-labeled entries in the queue
|
||||
labeled_keys = st.session_state.labeled_keys
|
||||
for i, entry in enumerate(st.session_state.queue):
|
||||
if _entry_key(entry) not in labeled_keys:
|
||||
st.session_state.idx = i
|
||||
break
|
||||
else:
|
||||
st.session_state.idx = len(st.session_state.queue)
|
||||
|
||||
if "history" not in st.session_state:
|
||||
st.session_state.history: list[tuple[int, str]] = [] # (queue_idx, label)
|
||||
|
||||
|
||||
# ── Sidebar stats ────────────────────────────────────────────────────────────
|
||||
|
||||
with st.sidebar:
|
||||
labeled = st.session_state.labeled
|
||||
queue = st.session_state.queue
|
||||
unlabeled = [e for e in queue if _entry_key(e) not in st.session_state.labeled_keys]
|
||||
|
||||
st.metric("✅ Labeled", len(labeled))
|
||||
st.metric("📥 Queue", len(unlabeled))
|
||||
|
||||
if labeled:
|
||||
st.caption("**Label distribution**")
|
||||
counts = {lbl: 0 for lbl in LABELS}
|
||||
for r in labeled:
|
||||
counts[r.get("label", "")] = counts.get(r.get("label", ""), 0) + 1
|
||||
for lbl in LABELS:
|
||||
m = _LABEL_META[lbl]
|
||||
st.caption(f"{m['emoji']} {lbl}: **{counts[lbl]}**")
|
||||
|
||||
|
||||
# ── Tabs ─────────────────────────────────────────────────────────────────────
|
||||
|
||||
tab_label, tab_fetch, tab_stats = st.tabs(["🃏 Label", "📥 Fetch", "📊 Stats"])
|
||||
|
||||
|
||||
# ══════════════════════════════════════════════════════════════════════════════
|
||||
# FETCH TAB
|
||||
# ══════════════════════════════════════════════════════════════════════════════
|
||||
|
||||
with tab_fetch:
|
||||
accounts = _load_config()
|
||||
|
||||
if not accounts:
|
||||
st.warning(
|
||||
f"No accounts configured. Copy `config/label_tool.yaml.example` → "
|
||||
f"`config/label_tool.yaml` and add your IMAP accounts.",
|
||||
icon="⚠️",
|
||||
)
|
||||
else:
|
||||
st.markdown(f"**{len(accounts)} account(s) configured:**")
|
||||
for acc in accounts:
|
||||
st.caption(f"• {acc.get('name', acc.get('username'))} ({acc.get('host')})")
|
||||
|
||||
col_days, col_limit = st.columns(2)
|
||||
days = col_days.number_input("Days back", min_value=7, max_value=730, value=180)
|
||||
limit = col_limit.number_input("Max emails per account", min_value=10, max_value=1000, value=150)
|
||||
|
||||
all_accs = [a.get("name", a.get("username")) for a in accounts]
|
||||
selected = st.multiselect("Accounts to fetch", all_accs, default=all_accs)
|
||||
|
||||
if st.button("📥 Fetch from IMAP", disabled=not accounts or not selected, type="primary"):
|
||||
existing_keys = {_entry_key(e) for e in st.session_state.queue}
|
||||
existing_keys.update(st.session_state.labeled_keys)
|
||||
|
||||
fetched_all: list[dict] = []
|
||||
status = st.status("Fetching…", expanded=True)
|
||||
|
||||
for acc in accounts:
|
||||
name = acc.get("name", acc.get("username"))
|
||||
if name not in selected:
|
||||
continue
|
||||
status.write(f"Connecting to **{name}**…")
|
||||
try:
|
||||
emails = _fetch_account(
|
||||
acc, days=int(days), limit=int(limit),
|
||||
known_keys=existing_keys,
|
||||
progress_cb=lambda p, msg: status.write(msg),
|
||||
)
|
||||
fetched_all.extend(emails)
|
||||
status.write(f"✓ {name}: {len(emails)} new emails")
|
||||
except Exception as e:
|
||||
status.write(f"✗ {name}: {e}")
|
||||
|
||||
if fetched_all:
|
||||
_save_jsonl(_QUEUE_FILE, st.session_state.queue + fetched_all)
|
||||
st.session_state.queue = _load_jsonl(_QUEUE_FILE)
|
||||
# Reset idx to first unlabeled
|
||||
labeled_keys = st.session_state.labeled_keys
|
||||
for i, entry in enumerate(st.session_state.queue):
|
||||
if _entry_key(entry) not in labeled_keys:
|
||||
st.session_state.idx = i
|
||||
break
|
||||
status.update(label=f"Done — {len(fetched_all)} new emails added to queue", state="complete")
|
||||
else:
|
||||
status.update(label="No new emails found (all already in queue or score file)", state="complete")
|
||||
|
||||
|
||||
# ══════════════════════════════════════════════════════════════════════════════
|
||||
# LABEL TAB
|
||||
# ══════════════════════════════════════════════════════════════════════════════
|
||||
|
||||
with tab_label:
|
||||
queue = st.session_state.queue
|
||||
labeled_keys = st.session_state.labeled_keys
|
||||
idx = st.session_state.idx
|
||||
|
||||
# Advance idx past already-labeled entries
|
||||
while idx < len(queue) and _entry_key(queue[idx]) in labeled_keys:
|
||||
idx += 1
|
||||
st.session_state.idx = idx
|
||||
|
||||
unlabeled = [e for e in queue if _entry_key(e) not in labeled_keys]
|
||||
total_in_queue = len(queue)
|
||||
n_labeled = len(st.session_state.labeled)
|
||||
|
||||
if not queue:
|
||||
st.info("Queue is empty — go to **Fetch** to pull emails from IMAP.", icon="📥")
|
||||
elif not unlabeled:
|
||||
st.success(
|
||||
f"🎉 All {n_labeled} emails labeled! Go to **Stats** to review and export.",
|
||||
icon="✅",
|
||||
)
|
||||
else:
|
||||
# Progress
|
||||
labeled_in_queue = total_in_queue - len(unlabeled)
|
||||
progress_pct = labeled_in_queue / total_in_queue if total_in_queue else 0
|
||||
st.progress(progress_pct, text=f"{labeled_in_queue} / {total_in_queue} labeled in queue")
|
||||
|
||||
# Current email
|
||||
entry = queue[idx]
|
||||
|
||||
# Card HTML
|
||||
subj = entry.get("subject", "(no subject)") or "(no subject)"
|
||||
from_ = entry.get("from_addr", "") or ""
|
||||
date_ = entry.get("date", "") or ""
|
||||
acct = entry.get("account", "") or ""
|
||||
body = (entry.get("body") or "").strip()
|
||||
|
||||
st.markdown(
|
||||
f"""<div class="email-card">
|
||||
<div class="card-meta">{from_} · {date_[:16]} · <em>{acct}</em></div>
|
||||
<div class="card-subject">{subj}</div>
|
||||
<div class="card-body">{body[:500].replace(chr(10), '<br>')}</div>
|
||||
</div>""",
|
||||
unsafe_allow_html=True,
|
||||
)
|
||||
if len(body) > 500:
|
||||
with st.expander("Show full body"):
|
||||
st.text(body)
|
||||
|
||||
# Stack hint (visual depth)
|
||||
st.markdown('<div class="card-stack-hint"></div>', unsafe_allow_html=True)
|
||||
st.markdown('<div class="card-stack-hint2"></div>', unsafe_allow_html=True)
|
||||
|
||||
st.markdown("") # spacer
|
||||
|
||||
# ── Bucket buttons ────────────────────────────────────────────────
|
||||
def _do_label(label: str) -> None:
|
||||
row = {"subject": entry.get("subject", ""), "body": body[:600], "label": label}
|
||||
st.session_state.labeled.append(row)
|
||||
st.session_state.labeled_keys.add(_entry_key(entry))
|
||||
_append_jsonl(_SCORE_FILE, row)
|
||||
st.session_state.history.append((idx, label))
|
||||
# Advance
|
||||
next_idx = idx + 1
|
||||
while next_idx < len(queue) and _entry_key(queue[next_idx]) in labeled_keys:
|
||||
next_idx += 1
|
||||
st.session_state.idx = next_idx
|
||||
|
||||
row1_cols = st.columns(3)
|
||||
row2_cols = st.columns(3)
|
||||
bucket_pairs = [
|
||||
(row1_cols[0], "interview_scheduled"),
|
||||
(row1_cols[1], "offer_received"),
|
||||
(row1_cols[2], "rejected"),
|
||||
(row2_cols[0], "positive_response"),
|
||||
(row2_cols[1], "survey_received"),
|
||||
(row2_cols[2], "neutral"),
|
||||
]
|
||||
for col, lbl in bucket_pairs:
|
||||
m = _LABEL_META[lbl]
|
||||
counts = {l: 0 for l in LABELS}
|
||||
for r in st.session_state.labeled:
|
||||
counts[r.get("label", "")] = counts.get(r.get("label", ""), 0) + 1
|
||||
label_display = f"{m['emoji']} **{lbl}** [{counts[lbl]}]\n`{m['key']}`"
|
||||
if col.button(label_display, key=f"lbl_{lbl}", use_container_width=True):
|
||||
_do_label(lbl)
|
||||
st.rerun()
|
||||
|
||||
# ── Navigation ────────────────────────────────────────────────────
|
||||
st.markdown("")
|
||||
nav_cols = st.columns([2, 1, 1])
|
||||
|
||||
remaining = len(unlabeled) - 1
|
||||
nav_cols[0].caption(f"**{remaining}** remaining · Keys: 1–6 = label, S = skip, U = undo")
|
||||
|
||||
if nav_cols[1].button("↩ Undo", disabled=not st.session_state.history, use_container_width=True):
|
||||
prev_idx, prev_label = st.session_state.history.pop()
|
||||
# Remove the last labeled entry
|
||||
if st.session_state.labeled:
|
||||
removed = st.session_state.labeled.pop()
|
||||
st.session_state.labeled_keys.discard(_entry_key(removed))
|
||||
_save_jsonl(_SCORE_FILE, st.session_state.labeled)
|
||||
st.session_state.idx = prev_idx
|
||||
st.rerun()
|
||||
|
||||
if nav_cols[2].button("→ Skip", use_container_width=True):
|
||||
next_idx = idx + 1
|
||||
while next_idx < len(queue) and _entry_key(queue[next_idx]) in labeled_keys:
|
||||
next_idx += 1
|
||||
st.session_state.idx = next_idx
|
||||
st.rerun()
|
||||
|
||||
# Keyboard shortcut capture (JS → hidden button click)
|
||||
st.components.v1.html(
|
||||
"""<script>
|
||||
document.addEventListener('keydown', function(e) {
|
||||
if (e.target.tagName === 'INPUT' || e.target.tagName === 'TEXTAREA') return;
|
||||
const keyToLabel = {
|
||||
'1':'interview_scheduled','2':'offer_received','3':'rejected',
|
||||
'4':'positive_response','5':'survey_received','6':'neutral'
|
||||
};
|
||||
const label = keyToLabel[e.key];
|
||||
if (label) {
|
||||
const btns = window.parent.document.querySelectorAll('button');
|
||||
for (const btn of btns) {
|
||||
if (btn.innerText.toLowerCase().includes(label.replace('_',' '))) {
|
||||
btn.click(); break;
|
||||
}
|
||||
}
|
||||
} else if (e.key.toLowerCase() === 's') {
|
||||
const btns = window.parent.document.querySelectorAll('button');
|
||||
for (const btn of btns) {
|
||||
if (btn.innerText.includes('Skip')) { btn.click(); break; }
|
||||
}
|
||||
} else if (e.key.toLowerCase() === 'u') {
|
||||
const btns = window.parent.document.querySelectorAll('button');
|
||||
for (const btn of btns) {
|
||||
if (btn.innerText.includes('Undo')) { btn.click(); break; }
|
||||
}
|
||||
}
|
||||
});
|
||||
</script>""",
|
||||
height=0,
|
||||
)
|
||||
|
||||
|
||||
# ══════════════════════════════════════════════════════════════════════════════
|
||||
# STATS TAB
|
||||
# ══════════════════════════════════════════════════════════════════════════════
|
||||
|
||||
with tab_stats:
|
||||
labeled = st.session_state.labeled
|
||||
|
||||
if not labeled:
|
||||
st.info("No labeled emails yet.")
|
||||
else:
|
||||
counts = {lbl: 0 for lbl in LABELS}
|
||||
for r in labeled:
|
||||
lbl = r.get("label", "")
|
||||
if lbl in counts:
|
||||
counts[lbl] += 1
|
||||
|
||||
st.markdown(f"**{len(labeled)} labeled emails total**")
|
||||
|
||||
for lbl in LABELS:
|
||||
m = _LABEL_META[lbl]
|
||||
col_name, col_bar, col_n = st.columns([3, 5, 1])
|
||||
col_name.markdown(f"{m['emoji']} {lbl}")
|
||||
col_bar.progress(counts[lbl] / max(counts.values()) if counts.values() else 0)
|
||||
col_n.markdown(f"**{counts[lbl]}**")
|
||||
|
||||
st.divider()
|
||||
|
||||
st.caption(
|
||||
f"Score file: `{_SCORE_FILE.relative_to(_ROOT)}` "
|
||||
f"({_SCORE_FILE.stat().st_size if _SCORE_FILE.exists() else 0:,} bytes)"
|
||||
)
|
||||
if st.button("🔄 Re-sync from disk"):
|
||||
st.session_state.labeled = _load_jsonl(_SCORE_FILE)
|
||||
st.session_state.labeled_keys = {_entry_key(r) for r in st.session_state.labeled}
|
||||
st.rerun()
|
||||
|
||||
if _SCORE_FILE.exists():
|
||||
st.download_button(
|
||||
"⬇️ Download email_score.jsonl",
|
||||
data=_SCORE_FILE.read_bytes(),
|
||||
file_name="email_score.jsonl",
|
||||
mime="application/jsonlines",
|
||||
)
|
||||
23
config/label_tool.yaml.example
Normal file
23
config/label_tool.yaml.example
Normal file
|
|
@ -0,0 +1,23 @@
|
|||
# config/label_tool.yaml — Multi-account IMAP config for the email label tool
|
||||
# Copy to config/label_tool.yaml and fill in your credentials.
|
||||
# This file is gitignored.
|
||||
|
||||
accounts:
|
||||
- name: "Gmail"
|
||||
host: "imap.gmail.com"
|
||||
port: 993
|
||||
username: "you@gmail.com"
|
||||
password: "your-app-password" # Use an App Password, not your login password
|
||||
folder: "INBOX"
|
||||
days_back: 90
|
||||
|
||||
- name: "Outlook"
|
||||
host: "outlook.office365.com"
|
||||
port: 993
|
||||
username: "you@outlook.com"
|
||||
password: "your-app-password"
|
||||
folder: "INBOX"
|
||||
days_back: 90
|
||||
|
||||
# Optional: limit emails fetched per account per run (0 = unlimited)
|
||||
max_per_account: 500
|
||||
0
data/.gitkeep
Normal file
0
data/.gitkeep
Normal file
8
data/email_score.jsonl.example
Normal file
8
data/email_score.jsonl.example
Normal file
|
|
@ -0,0 +1,8 @@
|
|||
{"subject": "Interview Invitation — Senior Engineer", "body": "Hi Meghan, we'd love to schedule a 30-min phone screen. Are you available Thursday at 2pm? Please reply to confirm.", "label": "interview_scheduled"}
|
||||
{"subject": "Your application to Acme Corp", "body": "Thank you for your interest in the Senior Engineer role. After careful consideration, we have decided to move forward with other candidates whose experience more closely matches our current needs.", "label": "rejected"}
|
||||
{"subject": "Offer Letter — Product Manager at Initech", "body": "Dear Meghan, we are thrilled to extend an offer of employment for the Product Manager position. Please find the attached offer letter outlining compensation and start date.", "label": "offer_received"}
|
||||
{"subject": "Quick question about your background", "body": "Hi Meghan, I came across your profile and would love to connect. We have a few roles that seem like a great match. Would you be open to a brief chat this week?", "label": "positive_response"}
|
||||
{"subject": "Company Culture Survey — Acme Corp", "body": "Meghan, as part of our evaluation process, we invite all candidates to complete our culture fit assessment. The survey takes approximately 15 minutes. Please click the link below.", "label": "survey_received"}
|
||||
{"subject": "Application Received — DataCo", "body": "Thank you for submitting your application for the Data Engineer role at DataCo. We have received your materials and will be in touch if your qualifications match our needs.", "label": "neutral"}
|
||||
{"subject": "Following up on your application", "body": "Hi Meghan, I wanted to follow up on your recent application. Your background looks interesting and we'd like to learn more. Can we set up a quick call?", "label": "positive_response"}
|
||||
{"subject": "We're moving forward with other candidates", "body": "Dear Meghan, thank you for taking the time to interview with us. After thoughtful consideration, we have decided not to move forward with your candidacy at this time.", "label": "rejected"}
|
||||
25
environment.yml
Normal file
25
environment.yml
Normal file
|
|
@ -0,0 +1,25 @@
|
|||
name: job-seeker-classifiers
|
||||
channels:
|
||||
- conda-forge
|
||||
- defaults
|
||||
dependencies:
|
||||
- python=3.11
|
||||
- pip
|
||||
- pip:
|
||||
# UI
|
||||
- streamlit>=1.32
|
||||
- pyyaml>=6.0
|
||||
|
||||
# Classifier backends (heavy — install selectively)
|
||||
- transformers>=4.40
|
||||
- torch>=2.2
|
||||
- accelerate>=0.27
|
||||
|
||||
# Optional: GLiClass adapter
|
||||
# - gliclass
|
||||
|
||||
# Optional: BGE reranker adapter
|
||||
# - FlagEmbedding
|
||||
|
||||
# Dev
|
||||
- pytest>=8.0
|
||||
5
pytest.ini
Normal file
5
pytest.ini
Normal file
|
|
@ -0,0 +1,5 @@
|
|||
[pytest]
|
||||
testpaths = tests
|
||||
python_files = test_*.py
|
||||
python_classes = Test*
|
||||
python_functions = test_*
|
||||
0
scripts/__init__.py
Normal file
0
scripts/__init__.py
Normal file
450
scripts/benchmark_classifier.py
Normal file
450
scripts/benchmark_classifier.py
Normal file
|
|
@ -0,0 +1,450 @@
|
|||
#!/usr/bin/env python
|
||||
"""
|
||||
Email classifier benchmark — compare HuggingFace models against our 6 labels.
|
||||
|
||||
Usage:
|
||||
# List available models
|
||||
conda run -n job-seeker-classifiers python scripts/benchmark_classifier.py --list-models
|
||||
|
||||
# Score against labeled JSONL
|
||||
conda run -n job-seeker-classifiers python scripts/benchmark_classifier.py --score
|
||||
|
||||
# Visual comparison on live IMAP emails (uses first account in label_tool.yaml)
|
||||
conda run -n job-seeker-classifiers python scripts/benchmark_classifier.py --compare --limit 20
|
||||
|
||||
# Include slow/large models
|
||||
conda run -n job-seeker-classifiers python scripts/benchmark_classifier.py --score --include-slow
|
||||
|
||||
# Export DB-labeled emails (⚠️ LLM-generated labels)
|
||||
conda run -n job-seeker-classifiers python scripts/benchmark_classifier.py --export-db --db /path/to/staging.db
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import email as _email_lib
|
||||
import imaplib
|
||||
import json
|
||||
import sys
|
||||
import time
|
||||
from datetime import datetime, timedelta
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
sys.path.insert(0, str(Path(__file__).parent.parent))
|
||||
|
||||
from scripts.classifier_adapters import (
|
||||
LABELS,
|
||||
LABEL_DESCRIPTIONS,
|
||||
ClassifierAdapter,
|
||||
GLiClassAdapter,
|
||||
RerankerAdapter,
|
||||
ZeroShotAdapter,
|
||||
compute_metrics,
|
||||
)
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Model registry
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
MODEL_REGISTRY: dict[str, dict[str, Any]] = {
|
||||
"deberta-zeroshot": {
|
||||
"adapter": ZeroShotAdapter,
|
||||
"model_id": "MoritzLaurer/DeBERTa-v3-large-zeroshot-v2.0",
|
||||
"params": "400M",
|
||||
"default": True,
|
||||
},
|
||||
"deberta-small": {
|
||||
"adapter": ZeroShotAdapter,
|
||||
"model_id": "cross-encoder/nli-deberta-v3-small",
|
||||
"params": "100M",
|
||||
"default": True,
|
||||
},
|
||||
"gliclass-large": {
|
||||
"adapter": GLiClassAdapter,
|
||||
"model_id": "knowledgator/gliclass-instruct-large-v1.0",
|
||||
"params": "400M",
|
||||
"default": True,
|
||||
},
|
||||
"bart-mnli": {
|
||||
"adapter": ZeroShotAdapter,
|
||||
"model_id": "facebook/bart-large-mnli",
|
||||
"params": "400M",
|
||||
"default": True,
|
||||
},
|
||||
"bge-m3-zeroshot": {
|
||||
"adapter": ZeroShotAdapter,
|
||||
"model_id": "MoritzLaurer/bge-m3-zeroshot-v2.0",
|
||||
"params": "600M",
|
||||
"default": True,
|
||||
},
|
||||
"deberta-small-2pass": {
|
||||
"adapter": ZeroShotAdapter,
|
||||
"model_id": "cross-encoder/nli-deberta-v3-small",
|
||||
"params": "100M",
|
||||
"default": True,
|
||||
"kwargs": {"two_pass": True},
|
||||
},
|
||||
"deberta-base-anli": {
|
||||
"adapter": ZeroShotAdapter,
|
||||
"model_id": "MoritzLaurer/DeBERTa-v3-base-mnli-fever-anli",
|
||||
"params": "200M",
|
||||
"default": True,
|
||||
},
|
||||
"deberta-large-ling": {
|
||||
"adapter": ZeroShotAdapter,
|
||||
"model_id": "MoritzLaurer/DeBERTa-v3-large-mnli-fever-anli-ling-wanli",
|
||||
"params": "400M",
|
||||
"default": False,
|
||||
},
|
||||
"mdeberta-xnli-2m": {
|
||||
"adapter": ZeroShotAdapter,
|
||||
"model_id": "MoritzLaurer/mDeBERTa-v3-base-xnli-multilingual-nli-2mil7",
|
||||
"params": "300M",
|
||||
"default": False,
|
||||
},
|
||||
"bge-reranker": {
|
||||
"adapter": RerankerAdapter,
|
||||
"model_id": "BAAI/bge-reranker-v2-m3",
|
||||
"params": "600M",
|
||||
"default": False,
|
||||
},
|
||||
"deberta-xlarge": {
|
||||
"adapter": ZeroShotAdapter,
|
||||
"model_id": "microsoft/deberta-xlarge-mnli",
|
||||
"params": "750M",
|
||||
"default": False,
|
||||
},
|
||||
"mdeberta-mnli": {
|
||||
"adapter": ZeroShotAdapter,
|
||||
"model_id": "MoritzLaurer/mDeBERTa-v3-base-mnli-xnli",
|
||||
"params": "300M",
|
||||
"default": False,
|
||||
},
|
||||
"xlm-roberta-anli": {
|
||||
"adapter": ZeroShotAdapter,
|
||||
"model_id": "vicgalle/xlm-roberta-large-xnli-anli",
|
||||
"params": "600M",
|
||||
"default": False,
|
||||
},
|
||||
}
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def load_scoring_jsonl(path: str) -> list[dict[str, str]]:
|
||||
"""Load labeled examples from a JSONL file for benchmark scoring."""
|
||||
p = Path(path)
|
||||
if not p.exists():
|
||||
raise FileNotFoundError(
|
||||
f"Scoring file not found: {path}\n"
|
||||
f"Copy data/email_score.jsonl.example → data/email_score.jsonl "
|
||||
f"or use the label tool (app/label_tool.py) to label your own emails."
|
||||
)
|
||||
rows = []
|
||||
with p.open() as f:
|
||||
for line in f:
|
||||
line = line.strip()
|
||||
if line:
|
||||
rows.append(json.loads(line))
|
||||
return rows
|
||||
|
||||
|
||||
def _active_models(include_slow: bool) -> dict[str, dict[str, Any]]:
|
||||
return {k: v for k, v in MODEL_REGISTRY.items() if v["default"] or include_slow}
|
||||
|
||||
|
||||
def run_scoring(
|
||||
adapters: list[ClassifierAdapter],
|
||||
score_file: str,
|
||||
) -> dict[str, Any]:
|
||||
"""Run all adapters against a labeled JSONL. Returns per-adapter metrics."""
|
||||
rows = load_scoring_jsonl(score_file)
|
||||
gold = [r["label"] for r in rows]
|
||||
results: dict[str, Any] = {}
|
||||
|
||||
for adapter in adapters:
|
||||
preds: list[str] = []
|
||||
t0 = time.monotonic()
|
||||
for row in rows:
|
||||
try:
|
||||
pred = adapter.classify(row["subject"], row["body"])
|
||||
except Exception as exc:
|
||||
print(f" [{adapter.name}] ERROR on '{row['subject'][:40]}': {exc}", flush=True)
|
||||
pred = "neutral"
|
||||
preds.append(pred)
|
||||
elapsed_ms = (time.monotonic() - t0) * 1000
|
||||
metrics = compute_metrics(preds, gold, LABELS)
|
||||
metrics["latency_ms"] = round(elapsed_ms / len(rows), 1)
|
||||
results[adapter.name] = metrics
|
||||
adapter.unload()
|
||||
|
||||
return results
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# IMAP helpers (stdlib only — reads label_tool.yaml, uses first account)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
_BROAD_TERMS = [
|
||||
"interview", "opportunity", "offer letter",
|
||||
"job offer", "application", "recruiting",
|
||||
]
|
||||
|
||||
|
||||
def _load_imap_config() -> dict[str, Any]:
|
||||
"""Load IMAP config from label_tool.yaml, returning first account as a flat dict."""
|
||||
import yaml
|
||||
cfg_path = Path(__file__).parent.parent / "config" / "label_tool.yaml"
|
||||
if not cfg_path.exists():
|
||||
raise FileNotFoundError(
|
||||
f"IMAP config not found: {cfg_path}\n"
|
||||
f"Copy config/label_tool.yaml.example → config/label_tool.yaml"
|
||||
)
|
||||
cfg = yaml.safe_load(cfg_path.read_text()) or {}
|
||||
accounts = cfg.get("accounts", [])
|
||||
if not accounts:
|
||||
raise ValueError("No accounts configured in config/label_tool.yaml")
|
||||
return accounts[0]
|
||||
|
||||
|
||||
def _imap_connect(cfg: dict[str, Any]) -> imaplib.IMAP4_SSL:
|
||||
conn = imaplib.IMAP4_SSL(cfg["host"], cfg.get("port", 993))
|
||||
conn.login(cfg["username"], cfg["password"])
|
||||
return conn
|
||||
|
||||
|
||||
def _decode_part(part: Any) -> str:
|
||||
charset = part.get_content_charset() or "utf-8"
|
||||
try:
|
||||
return part.get_payload(decode=True).decode(charset, errors="replace")
|
||||
except Exception:
|
||||
return ""
|
||||
|
||||
|
||||
def _parse_uid(conn: imaplib.IMAP4_SSL, uid: bytes) -> dict[str, str] | None:
|
||||
try:
|
||||
_, data = conn.uid("fetch", uid, "(RFC822)")
|
||||
raw = data[0][1]
|
||||
msg = _email_lib.message_from_bytes(raw)
|
||||
subject = str(msg.get("subject", "")).strip()
|
||||
body = ""
|
||||
if msg.is_multipart():
|
||||
for part in msg.walk():
|
||||
if part.get_content_type() == "text/plain":
|
||||
body = _decode_part(part)
|
||||
break
|
||||
else:
|
||||
body = _decode_part(msg)
|
||||
return {"subject": subject, "body": body}
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
|
||||
def _fetch_imap_sample(limit: int, days: int) -> list[dict[str, str]]:
|
||||
cfg = _load_imap_config()
|
||||
conn = _imap_connect(cfg)
|
||||
since = (datetime.now() - timedelta(days=days)).strftime("%d-%b-%Y")
|
||||
conn.select("INBOX")
|
||||
|
||||
seen_uids: dict[bytes, None] = {}
|
||||
for term in _BROAD_TERMS:
|
||||
_, data = conn.uid("search", None, f'(SUBJECT "{term}" SINCE {since})')
|
||||
for uid in (data[0] or b"").split():
|
||||
seen_uids[uid] = None
|
||||
|
||||
sample = list(seen_uids.keys())[:limit]
|
||||
emails = []
|
||||
for uid in sample:
|
||||
parsed = _parse_uid(conn, uid)
|
||||
if parsed:
|
||||
emails.append(parsed)
|
||||
try:
|
||||
conn.logout()
|
||||
except Exception:
|
||||
pass
|
||||
return emails
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# DB export
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def cmd_export_db(args: argparse.Namespace) -> None:
|
||||
"""Export LLM-labeled emails from a peregrine-style job_contacts table → scoring JSONL."""
|
||||
import sqlite3
|
||||
|
||||
db_path = Path(args.db)
|
||||
if not db_path.exists():
|
||||
print(f"ERROR: Database not found: {args.db}", file=sys.stderr)
|
||||
sys.exit(1)
|
||||
|
||||
conn = sqlite3.connect(db_path)
|
||||
cur = conn.cursor()
|
||||
cur.execute("""
|
||||
SELECT subject, body, stage_signal
|
||||
FROM job_contacts
|
||||
WHERE stage_signal IS NOT NULL
|
||||
AND stage_signal != ''
|
||||
AND direction = 'inbound'
|
||||
ORDER BY received_at
|
||||
""")
|
||||
rows = cur.fetchall()
|
||||
conn.close()
|
||||
|
||||
if not rows:
|
||||
print("No labeled emails in job_contacts. Run imap_sync first to populate.")
|
||||
return
|
||||
|
||||
out_path = Path(args.score_file)
|
||||
out_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
written = 0
|
||||
skipped = 0
|
||||
label_counts: dict[str, int] = {}
|
||||
with out_path.open("w") as f:
|
||||
for subject, body, label in rows:
|
||||
if label not in LABELS:
|
||||
print(f" SKIP unknown label '{label}': {subject[:50]}")
|
||||
skipped += 1
|
||||
continue
|
||||
json.dump({"subject": subject or "", "body": (body or "")[:600], "label": label}, f)
|
||||
f.write("\n")
|
||||
label_counts[label] = label_counts.get(label, 0) + 1
|
||||
written += 1
|
||||
|
||||
print(f"\nExported {written} emails → {out_path}" + (f" ({skipped} skipped)" if skipped else ""))
|
||||
print("\nLabel distribution:")
|
||||
for label in LABELS:
|
||||
count = label_counts.get(label, 0)
|
||||
bar = "█" * count
|
||||
print(f" {label:<25} {count:>3} {bar}")
|
||||
print(
|
||||
"\nNOTE: Labels are LLM predictions from imap_sync — review before treating as ground truth."
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Subcommands
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def cmd_list_models(_args: argparse.Namespace) -> None:
|
||||
print(f"\n{'Name':<24} {'Params':<8} {'Default':<20} {'Adapter':<15} Model ID")
|
||||
print("-" * 104)
|
||||
for name, entry in MODEL_REGISTRY.items():
|
||||
adapter_name = entry["adapter"].__name__
|
||||
if entry.get("kwargs", {}).get("two_pass"):
|
||||
adapter_name += " (2-pass)"
|
||||
default_flag = "yes" if entry["default"] else "(--include-slow)"
|
||||
print(f"{name:<24} {entry['params']:<8} {default_flag:<20} {adapter_name:<15} {entry['model_id']}")
|
||||
print()
|
||||
|
||||
|
||||
def cmd_score(args: argparse.Namespace) -> None:
|
||||
active = _active_models(args.include_slow)
|
||||
if args.models:
|
||||
active = {k: v for k, v in active.items() if k in args.models}
|
||||
|
||||
adapters = [
|
||||
entry["adapter"](name, entry["model_id"], **entry.get("kwargs", {}))
|
||||
for name, entry in active.items()
|
||||
]
|
||||
|
||||
print(f"\nScoring {len(adapters)} model(s) against {args.score_file} …\n")
|
||||
results = run_scoring(adapters, args.score_file)
|
||||
|
||||
col = 12
|
||||
print(f"{'Model':<22}" + f"{'macro-F1':>{col}} {'Accuracy':>{col}} {'ms/email':>{col}}")
|
||||
print("-" * (22 + col * 3 + 2))
|
||||
for name, m in results.items():
|
||||
print(
|
||||
f"{name:<22}"
|
||||
f"{m['__macro_f1__']:>{col}.3f}"
|
||||
f"{m['__accuracy__']:>{col}.3f}"
|
||||
f"{m['latency_ms']:>{col}.1f}"
|
||||
)
|
||||
|
||||
print("\nPer-label F1:")
|
||||
names = list(results.keys())
|
||||
print(f"{'Label':<25}" + "".join(f"{n[:11]:>{col}}" for n in names))
|
||||
print("-" * (25 + col * len(names)))
|
||||
for label in LABELS:
|
||||
row_str = f"{label:<25}"
|
||||
for m in results.values():
|
||||
row_str += f"{m[label]['f1']:>{col}.3f}"
|
||||
print(row_str)
|
||||
print()
|
||||
|
||||
|
||||
def cmd_compare(args: argparse.Namespace) -> None:
|
||||
active = _active_models(args.include_slow)
|
||||
if args.models:
|
||||
active = {k: v for k, v in active.items() if k in args.models}
|
||||
|
||||
print(f"Fetching up to {args.limit} emails from IMAP …")
|
||||
emails = _fetch_imap_sample(args.limit, args.days)
|
||||
print(f"Fetched {len(emails)} emails. Loading {len(active)} model(s) …\n")
|
||||
|
||||
adapters = [
|
||||
entry["adapter"](name, entry["model_id"], **entry.get("kwargs", {}))
|
||||
for name, entry in active.items()
|
||||
]
|
||||
model_names = [a.name for a in adapters]
|
||||
|
||||
col = 22
|
||||
subj_w = 50
|
||||
print(f"{'Subject':<{subj_w}}" + "".join(f"{n:<{col}}" for n in model_names))
|
||||
print("-" * (subj_w + col * len(model_names)))
|
||||
|
||||
for row in emails:
|
||||
short_subj = row["subject"][:subj_w - 1] if len(row["subject"]) > subj_w else row["subject"]
|
||||
line = f"{short_subj:<{subj_w}}"
|
||||
for adapter in adapters:
|
||||
try:
|
||||
label = adapter.classify(row["subject"], row["body"])
|
||||
except Exception as exc:
|
||||
label = f"ERR:{str(exc)[:8]}"
|
||||
line += f"{label:<{col}}"
|
||||
print(line, flush=True)
|
||||
|
||||
for adapter in adapters:
|
||||
adapter.unload()
|
||||
print()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# CLI
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def main() -> None:
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Benchmark HuggingFace email classifiers against our 6 labels."
|
||||
)
|
||||
parser.add_argument("--list-models", action="store_true", help="Show model registry and exit")
|
||||
parser.add_argument("--score", action="store_true", help="Score against labeled JSONL")
|
||||
parser.add_argument("--compare", action="store_true", help="Visual table on live IMAP emails")
|
||||
parser.add_argument("--export-db", action="store_true",
|
||||
help="Export labeled emails from a staging.db → score JSONL")
|
||||
parser.add_argument("--score-file", default="data/email_score.jsonl", help="Path to labeled JSONL")
|
||||
parser.add_argument("--db", default="data/staging.db", help="Path to staging.db for --export-db")
|
||||
parser.add_argument("--limit", type=int, default=20, help="Max emails for --compare")
|
||||
parser.add_argument("--days", type=int, default=90, help="Days back for IMAP search")
|
||||
parser.add_argument("--include-slow", action="store_true", help="Include non-default heavy models")
|
||||
parser.add_argument("--models", nargs="+", help="Override: run only these model names")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
if args.list_models:
|
||||
cmd_list_models(args)
|
||||
elif args.score:
|
||||
cmd_score(args)
|
||||
elif args.compare:
|
||||
cmd_compare(args)
|
||||
elif args.export_db:
|
||||
cmd_export_db(args)
|
||||
else:
|
||||
parser.print_help()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
257
scripts/classifier_adapters.py
Normal file
257
scripts/classifier_adapters.py
Normal file
|
|
@ -0,0 +1,257 @@
|
|||
"""Classifier adapters for email classification benchmark.
|
||||
|
||||
Each adapter wraps a HuggingFace model and normalizes output to LABELS.
|
||||
Models load lazily on first classify() call; call unload() to free VRAM.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import abc
|
||||
from collections import defaultdict
|
||||
from typing import Any
|
||||
|
||||
__all__ = [
|
||||
"LABELS",
|
||||
"LABEL_DESCRIPTIONS",
|
||||
"compute_metrics",
|
||||
"ClassifierAdapter",
|
||||
"ZeroShotAdapter",
|
||||
"GLiClassAdapter",
|
||||
"RerankerAdapter",
|
||||
]
|
||||
|
||||
LABELS: list[str] = [
|
||||
"interview_scheduled",
|
||||
"offer_received",
|
||||
"rejected",
|
||||
"positive_response",
|
||||
"survey_received",
|
||||
"neutral",
|
||||
]
|
||||
|
||||
# Natural-language descriptions used by the RerankerAdapter.
|
||||
LABEL_DESCRIPTIONS: dict[str, str] = {
|
||||
"interview_scheduled": "scheduling an interview, phone screen, or video call",
|
||||
"offer_received": "a formal job offer or employment offer letter",
|
||||
"rejected": "application rejected or not moving forward with candidacy",
|
||||
"positive_response": "positive recruiter interest or request to connect",
|
||||
"survey_received": "invitation to complete a culture-fit survey or assessment",
|
||||
"neutral": "automated ATS confirmation or unrelated email",
|
||||
}
|
||||
|
||||
# Lazy import shims — allow tests to patch without requiring the libs installed.
|
||||
try:
|
||||
from transformers import pipeline # type: ignore[assignment]
|
||||
except ImportError:
|
||||
pipeline = None # type: ignore[assignment]
|
||||
|
||||
try:
|
||||
from gliclass import GLiClassModel, ZeroShotClassificationPipeline # type: ignore
|
||||
from transformers import AutoTokenizer
|
||||
except ImportError:
|
||||
GLiClassModel = None # type: ignore
|
||||
ZeroShotClassificationPipeline = None # type: ignore
|
||||
AutoTokenizer = None # type: ignore
|
||||
|
||||
try:
|
||||
from FlagEmbedding import FlagReranker # type: ignore
|
||||
except ImportError:
|
||||
FlagReranker = None # type: ignore
|
||||
|
||||
|
||||
def _cuda_available() -> bool:
|
||||
try:
|
||||
import torch
|
||||
return torch.cuda.is_available()
|
||||
except ImportError:
|
||||
return False
|
||||
|
||||
|
||||
def compute_metrics(
|
||||
predictions: list[str],
|
||||
gold: list[str],
|
||||
labels: list[str],
|
||||
) -> dict[str, Any]:
|
||||
"""Return per-label precision/recall/F1 + macro_f1 + accuracy."""
|
||||
tp: dict[str, int] = defaultdict(int)
|
||||
fp: dict[str, int] = defaultdict(int)
|
||||
fn: dict[str, int] = defaultdict(int)
|
||||
|
||||
for pred, true in zip(predictions, gold):
|
||||
if pred == true:
|
||||
tp[pred] += 1
|
||||
else:
|
||||
fp[pred] += 1
|
||||
fn[true] += 1
|
||||
|
||||
result: dict[str, Any] = {}
|
||||
for label in labels:
|
||||
denom_p = tp[label] + fp[label]
|
||||
denom_r = tp[label] + fn[label]
|
||||
p = tp[label] / denom_p if denom_p else 0.0
|
||||
r = tp[label] / denom_r if denom_r else 0.0
|
||||
f1 = 2 * p * r / (p + r) if (p + r) else 0.0
|
||||
result[label] = {
|
||||
"precision": p,
|
||||
"recall": r,
|
||||
"f1": f1,
|
||||
"support": denom_r,
|
||||
}
|
||||
|
||||
labels_with_support = [label for label in labels if result[label]["support"] > 0]
|
||||
if labels_with_support:
|
||||
result["__macro_f1__"] = (
|
||||
sum(result[label]["f1"] for label in labels_with_support) / len(labels_with_support)
|
||||
)
|
||||
else:
|
||||
result["__macro_f1__"] = 0.0
|
||||
result["__accuracy__"] = sum(tp.values()) / len(predictions) if predictions else 0.0
|
||||
return result
|
||||
|
||||
|
||||
class ClassifierAdapter(abc.ABC):
|
||||
"""Abstract base for all email classifier adapters."""
|
||||
|
||||
@property
|
||||
@abc.abstractmethod
|
||||
def name(self) -> str: ...
|
||||
|
||||
@property
|
||||
@abc.abstractmethod
|
||||
def model_id(self) -> str: ...
|
||||
|
||||
@abc.abstractmethod
|
||||
def load(self) -> None:
|
||||
"""Download/load the model into memory."""
|
||||
|
||||
@abc.abstractmethod
|
||||
def unload(self) -> None:
|
||||
"""Release model from memory."""
|
||||
|
||||
@abc.abstractmethod
|
||||
def classify(self, subject: str, body: str) -> str:
|
||||
"""Return one of LABELS for the given email."""
|
||||
|
||||
|
||||
class ZeroShotAdapter(ClassifierAdapter):
|
||||
"""Wraps any transformers zero-shot-classification pipeline.
|
||||
|
||||
load() calls pipeline("zero-shot-classification", model=..., device=...) to get
|
||||
an inference callable, stored as self._pipeline. classify() then calls
|
||||
self._pipeline(text, LABELS, multi_label=False). In tests, patch
|
||||
'scripts.classifier_adapters.pipeline' with a MagicMock whose .return_value is
|
||||
itself a MagicMock(return_value={...}) to simulate both the factory call and the
|
||||
inference call.
|
||||
|
||||
two_pass: if True, classify() runs a second pass restricted to the top-2 labels
|
||||
from the first pass, forcing a binary choice. This typically improves confidence
|
||||
without the accuracy cost of a full 6-label second run.
|
||||
"""
|
||||
|
||||
def __init__(self, name: str, model_id: str, two_pass: bool = False) -> None:
|
||||
self._name = name
|
||||
self._model_id = model_id
|
||||
self._pipeline: Any = None
|
||||
self._two_pass = two_pass
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return self._name
|
||||
|
||||
@property
|
||||
def model_id(self) -> str:
|
||||
return self._model_id
|
||||
|
||||
def load(self) -> None:
|
||||
import scripts.classifier_adapters as _mod # noqa: PLC0415
|
||||
_pipe_fn = _mod.pipeline
|
||||
if _pipe_fn is None:
|
||||
raise ImportError("transformers not installed — run: pip install transformers")
|
||||
device = 0 if _cuda_available() else -1
|
||||
# Instantiate the pipeline once; classify() calls the resulting object on each text.
|
||||
self._pipeline = _pipe_fn("zero-shot-classification", model=self._model_id, device=device)
|
||||
|
||||
def unload(self) -> None:
|
||||
self._pipeline = None
|
||||
|
||||
def classify(self, subject: str, body: str) -> str:
|
||||
if self._pipeline is None:
|
||||
self.load()
|
||||
text = f"Subject: {subject}\n\n{body[:600]}"
|
||||
result = self._pipeline(text, LABELS, multi_label=False)
|
||||
if self._two_pass and len(result["labels"]) >= 2:
|
||||
top2 = result["labels"][:2]
|
||||
result = self._pipeline(text, top2, multi_label=False)
|
||||
return result["labels"][0]
|
||||
|
||||
|
||||
class GLiClassAdapter(ClassifierAdapter):
|
||||
"""Wraps knowledgator GLiClass models via the gliclass library."""
|
||||
|
||||
def __init__(self, name: str, model_id: str) -> None:
|
||||
self._name = name
|
||||
self._model_id = model_id
|
||||
self._pipeline: Any = None
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return self._name
|
||||
|
||||
@property
|
||||
def model_id(self) -> str:
|
||||
return self._model_id
|
||||
|
||||
def load(self) -> None:
|
||||
if GLiClassModel is None:
|
||||
raise ImportError("gliclass not installed — run: pip install gliclass")
|
||||
device = "cuda:0" if _cuda_available() else "cpu"
|
||||
model = GLiClassModel.from_pretrained(self._model_id)
|
||||
tokenizer = AutoTokenizer.from_pretrained(self._model_id)
|
||||
self._pipeline = ZeroShotClassificationPipeline(
|
||||
model,
|
||||
tokenizer,
|
||||
classification_type="single-label",
|
||||
device=device,
|
||||
)
|
||||
|
||||
def unload(self) -> None:
|
||||
self._pipeline = None
|
||||
|
||||
def classify(self, subject: str, body: str) -> str:
|
||||
if self._pipeline is None:
|
||||
self.load()
|
||||
text = f"Subject: {subject}\n\n{body[:600]}"
|
||||
results = self._pipeline(text, LABELS, threshold=0.0)[0]
|
||||
return max(results, key=lambda r: r["score"])["label"]
|
||||
|
||||
|
||||
class RerankerAdapter(ClassifierAdapter):
|
||||
"""Uses a BGE reranker to score (email, label_description) pairs."""
|
||||
|
||||
def __init__(self, name: str, model_id: str) -> None:
|
||||
self._name = name
|
||||
self._model_id = model_id
|
||||
self._reranker: Any = None
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return self._name
|
||||
|
||||
@property
|
||||
def model_id(self) -> str:
|
||||
return self._model_id
|
||||
|
||||
def load(self) -> None:
|
||||
if FlagReranker is None:
|
||||
raise ImportError("FlagEmbedding not installed — run: pip install FlagEmbedding")
|
||||
self._reranker = FlagReranker(self._model_id, use_fp16=_cuda_available())
|
||||
|
||||
def unload(self) -> None:
|
||||
self._reranker = None
|
||||
|
||||
def classify(self, subject: str, body: str) -> str:
|
||||
if self._reranker is None:
|
||||
self.load()
|
||||
text = f"Subject: {subject}\n\n{body[:600]}"
|
||||
pairs = [[text, LABEL_DESCRIPTIONS[label]] for label in LABELS]
|
||||
scores: list[float] = self._reranker.compute_score(pairs, normalize=True)
|
||||
return LABELS[scores.index(max(scores))]
|
||||
0
tests/__init__.py
Normal file
0
tests/__init__.py
Normal file
94
tests/test_benchmark_classifier.py
Normal file
94
tests/test_benchmark_classifier.py
Normal file
|
|
@ -0,0 +1,94 @@
|
|||
"""Tests for benchmark_classifier — no model downloads required."""
|
||||
import pytest
|
||||
|
||||
|
||||
def test_registry_has_thirteen_models():
|
||||
from scripts.benchmark_classifier import MODEL_REGISTRY
|
||||
assert len(MODEL_REGISTRY) == 13
|
||||
|
||||
|
||||
def test_registry_default_count():
|
||||
from scripts.benchmark_classifier import MODEL_REGISTRY
|
||||
defaults = [k for k, v in MODEL_REGISTRY.items() if v["default"]]
|
||||
assert len(defaults) == 7
|
||||
|
||||
|
||||
def test_registry_entries_have_required_keys():
|
||||
from scripts.benchmark_classifier import MODEL_REGISTRY
|
||||
from scripts.classifier_adapters import ClassifierAdapter
|
||||
for name, entry in MODEL_REGISTRY.items():
|
||||
assert "adapter" in entry, f"{name} missing 'adapter'"
|
||||
assert "model_id" in entry, f"{name} missing 'model_id'"
|
||||
assert "params" in entry, f"{name} missing 'params'"
|
||||
assert "default" in entry, f"{name} missing 'default'"
|
||||
assert issubclass(entry["adapter"], ClassifierAdapter), \
|
||||
f"{name} adapter must be a ClassifierAdapter subclass"
|
||||
|
||||
|
||||
def test_load_scoring_jsonl(tmp_path):
|
||||
from scripts.benchmark_classifier import load_scoring_jsonl
|
||||
import json
|
||||
f = tmp_path / "score.jsonl"
|
||||
rows = [
|
||||
{"subject": "Hi", "body": "Body text", "label": "neutral"},
|
||||
{"subject": "Interview", "body": "Schedule a call", "label": "interview_scheduled"},
|
||||
]
|
||||
f.write_text("\n".join(json.dumps(r) for r in rows))
|
||||
result = load_scoring_jsonl(str(f))
|
||||
assert len(result) == 2
|
||||
assert result[0]["label"] == "neutral"
|
||||
|
||||
|
||||
def test_load_scoring_jsonl_missing_file():
|
||||
from scripts.benchmark_classifier import load_scoring_jsonl
|
||||
with pytest.raises(FileNotFoundError):
|
||||
load_scoring_jsonl("/nonexistent/path.jsonl")
|
||||
|
||||
|
||||
def test_run_scoring_with_mock_adapters(tmp_path):
|
||||
"""run_scoring() returns per-model metrics using mock adapters."""
|
||||
import json
|
||||
from unittest.mock import MagicMock
|
||||
from scripts.benchmark_classifier import run_scoring
|
||||
|
||||
score_file = tmp_path / "score.jsonl"
|
||||
rows = [
|
||||
{"subject": "Interview", "body": "Let's schedule", "label": "interview_scheduled"},
|
||||
{"subject": "Sorry", "body": "We went with others", "label": "rejected"},
|
||||
{"subject": "Offer", "body": "We are pleased", "label": "offer_received"},
|
||||
]
|
||||
score_file.write_text("\n".join(json.dumps(r) for r in rows))
|
||||
|
||||
perfect = MagicMock()
|
||||
perfect.name = "perfect"
|
||||
perfect.classify.side_effect = lambda s, b: (
|
||||
"interview_scheduled" if "Interview" in s else
|
||||
"rejected" if "Sorry" in s else "offer_received"
|
||||
)
|
||||
|
||||
bad = MagicMock()
|
||||
bad.name = "bad"
|
||||
bad.classify.return_value = "neutral"
|
||||
|
||||
results = run_scoring([perfect, bad], str(score_file))
|
||||
|
||||
assert results["perfect"]["__accuracy__"] == pytest.approx(1.0)
|
||||
assert results["bad"]["__accuracy__"] == pytest.approx(0.0)
|
||||
assert "latency_ms" in results["perfect"]
|
||||
|
||||
|
||||
def test_run_scoring_handles_classify_error(tmp_path):
|
||||
"""run_scoring() falls back to 'neutral' on exception and continues."""
|
||||
import json
|
||||
from unittest.mock import MagicMock
|
||||
from scripts.benchmark_classifier import run_scoring
|
||||
|
||||
score_file = tmp_path / "score.jsonl"
|
||||
score_file.write_text(json.dumps({"subject": "Hi", "body": "Body", "label": "neutral"}))
|
||||
|
||||
broken = MagicMock()
|
||||
broken.name = "broken"
|
||||
broken.classify.side_effect = RuntimeError("model crashed")
|
||||
|
||||
results = run_scoring([broken], str(score_file))
|
||||
assert "broken" in results
|
||||
177
tests/test_classifier_adapters.py
Normal file
177
tests/test_classifier_adapters.py
Normal file
|
|
@ -0,0 +1,177 @@
|
|||
"""Tests for classifier_adapters — no model downloads required."""
|
||||
import pytest
|
||||
|
||||
|
||||
def test_labels_constant_has_six_items():
|
||||
from scripts.classifier_adapters import LABELS
|
||||
assert len(LABELS) == 6
|
||||
assert "interview_scheduled" in LABELS
|
||||
assert "neutral" in LABELS
|
||||
|
||||
|
||||
def test_compute_metrics_perfect_predictions():
|
||||
from scripts.classifier_adapters import compute_metrics, LABELS
|
||||
gold = ["rejected", "interview_scheduled", "neutral"]
|
||||
preds = ["rejected", "interview_scheduled", "neutral"]
|
||||
m = compute_metrics(preds, gold, LABELS)
|
||||
assert m["rejected"]["f1"] == pytest.approx(1.0)
|
||||
assert m["__accuracy__"] == pytest.approx(1.0)
|
||||
assert m["__macro_f1__"] == pytest.approx(1.0)
|
||||
|
||||
|
||||
def test_compute_metrics_all_wrong():
|
||||
from scripts.classifier_adapters import compute_metrics, LABELS
|
||||
gold = ["rejected", "rejected"]
|
||||
preds = ["neutral", "interview_scheduled"]
|
||||
m = compute_metrics(preds, gold, LABELS)
|
||||
assert m["rejected"]["recall"] == pytest.approx(0.0)
|
||||
assert m["__accuracy__"] == pytest.approx(0.0)
|
||||
|
||||
|
||||
def test_compute_metrics_partial():
|
||||
from scripts.classifier_adapters import compute_metrics, LABELS
|
||||
gold = ["rejected", "neutral", "rejected"]
|
||||
preds = ["rejected", "neutral", "interview_scheduled"]
|
||||
m = compute_metrics(preds, gold, LABELS)
|
||||
assert m["rejected"]["precision"] == pytest.approx(1.0)
|
||||
assert m["rejected"]["recall"] == pytest.approx(0.5)
|
||||
assert m["neutral"]["f1"] == pytest.approx(1.0)
|
||||
assert m["__accuracy__"] == pytest.approx(2 / 3)
|
||||
|
||||
|
||||
def test_compute_metrics_empty():
|
||||
from scripts.classifier_adapters import compute_metrics, LABELS
|
||||
m = compute_metrics([], [], LABELS)
|
||||
assert m["__accuracy__"] == pytest.approx(0.0)
|
||||
|
||||
|
||||
def test_classifier_adapter_is_abstract():
|
||||
from scripts.classifier_adapters import ClassifierAdapter
|
||||
with pytest.raises(TypeError):
|
||||
ClassifierAdapter()
|
||||
|
||||
|
||||
# ---- ZeroShotAdapter tests ----
|
||||
|
||||
def test_zeroshot_adapter_classify_mocked():
|
||||
from unittest.mock import MagicMock, patch
|
||||
from scripts.classifier_adapters import ZeroShotAdapter
|
||||
|
||||
# Two-level mock: factory call returns pipeline instance; instance call returns inference result.
|
||||
mock_pipe_factory = MagicMock()
|
||||
mock_pipe_factory.return_value = MagicMock(return_value={
|
||||
"labels": ["rejected", "neutral", "interview_scheduled"],
|
||||
"scores": [0.85, 0.10, 0.05],
|
||||
})
|
||||
|
||||
with patch("scripts.classifier_adapters.pipeline", mock_pipe_factory):
|
||||
adapter = ZeroShotAdapter("test-zs", "some/model")
|
||||
adapter.load()
|
||||
result = adapter.classify("We went with another candidate", "Thank you for applying.")
|
||||
|
||||
assert result == "rejected"
|
||||
# Factory was called with the correct task type
|
||||
assert mock_pipe_factory.call_args[0][0] == "zero-shot-classification"
|
||||
# Pipeline instance was called with the email text
|
||||
assert "We went with another candidate" in mock_pipe_factory.return_value.call_args[0][0]
|
||||
|
||||
|
||||
def test_zeroshot_adapter_unload_clears_pipeline():
|
||||
from unittest.mock import MagicMock, patch
|
||||
from scripts.classifier_adapters import ZeroShotAdapter
|
||||
|
||||
with patch("scripts.classifier_adapters.pipeline", MagicMock()):
|
||||
adapter = ZeroShotAdapter("test-zs", "some/model")
|
||||
adapter.load()
|
||||
assert adapter._pipeline is not None
|
||||
adapter.unload()
|
||||
assert adapter._pipeline is None
|
||||
|
||||
|
||||
def test_zeroshot_adapter_lazy_loads():
|
||||
from unittest.mock import MagicMock, patch
|
||||
from scripts.classifier_adapters import ZeroShotAdapter
|
||||
|
||||
mock_pipe_factory = MagicMock()
|
||||
mock_pipe_factory.return_value = MagicMock(return_value={
|
||||
"labels": ["neutral"], "scores": [1.0]
|
||||
})
|
||||
|
||||
with patch("scripts.classifier_adapters.pipeline", mock_pipe_factory):
|
||||
adapter = ZeroShotAdapter("test-zs", "some/model")
|
||||
adapter.classify("subject", "body")
|
||||
|
||||
mock_pipe_factory.assert_called_once()
|
||||
|
||||
|
||||
# ---- GLiClassAdapter tests ----
|
||||
|
||||
def test_gliclass_adapter_classify_mocked():
|
||||
from unittest.mock import MagicMock, patch
|
||||
from scripts.classifier_adapters import GLiClassAdapter
|
||||
|
||||
mock_pipeline_instance = MagicMock()
|
||||
mock_pipeline_instance.return_value = [[
|
||||
{"label": "interview_scheduled", "score": 0.91},
|
||||
{"label": "neutral", "score": 0.05},
|
||||
{"label": "rejected", "score": 0.04},
|
||||
]]
|
||||
|
||||
with patch("scripts.classifier_adapters.GLiClassModel") as _mc, \
|
||||
patch("scripts.classifier_adapters.AutoTokenizer") as _mt, \
|
||||
patch("scripts.classifier_adapters.ZeroShotClassificationPipeline",
|
||||
return_value=mock_pipeline_instance):
|
||||
adapter = GLiClassAdapter("test-gli", "some/gliclass-model")
|
||||
adapter.load()
|
||||
result = adapter.classify("Interview invitation", "Let's schedule a call.")
|
||||
|
||||
assert result == "interview_scheduled"
|
||||
|
||||
|
||||
def test_gliclass_adapter_returns_highest_score():
|
||||
from unittest.mock import MagicMock, patch
|
||||
from scripts.classifier_adapters import GLiClassAdapter
|
||||
|
||||
mock_pipeline_instance = MagicMock()
|
||||
mock_pipeline_instance.return_value = [[
|
||||
{"label": "neutral", "score": 0.02},
|
||||
{"label": "offer_received", "score": 0.88},
|
||||
{"label": "rejected", "score": 0.10},
|
||||
]]
|
||||
|
||||
with patch("scripts.classifier_adapters.GLiClassModel"), \
|
||||
patch("scripts.classifier_adapters.AutoTokenizer"), \
|
||||
patch("scripts.classifier_adapters.ZeroShotClassificationPipeline",
|
||||
return_value=mock_pipeline_instance):
|
||||
adapter = GLiClassAdapter("test-gli", "some/model")
|
||||
adapter.load()
|
||||
result = adapter.classify("Offer letter enclosed", "Dear Meghan, we are pleased to offer...")
|
||||
|
||||
assert result == "offer_received"
|
||||
|
||||
|
||||
# ---- RerankerAdapter tests ----
|
||||
|
||||
def test_reranker_adapter_picks_highest_score():
|
||||
from unittest.mock import MagicMock, patch
|
||||
from scripts.classifier_adapters import RerankerAdapter, LABELS
|
||||
|
||||
mock_reranker = MagicMock()
|
||||
mock_reranker.compute_score.return_value = [0.1, 0.05, 0.85, 0.05, 0.02, 0.03]
|
||||
|
||||
with patch("scripts.classifier_adapters.FlagReranker", return_value=mock_reranker):
|
||||
adapter = RerankerAdapter("test-rr", "BAAI/bge-reranker-v2-m3")
|
||||
adapter.load()
|
||||
result = adapter.classify(
|
||||
"We regret to inform you",
|
||||
"After careful consideration we are moving forward with other candidates.",
|
||||
)
|
||||
|
||||
assert result == "rejected"
|
||||
pairs = mock_reranker.compute_score.call_args[0][0]
|
||||
assert len(pairs) == len(LABELS)
|
||||
|
||||
|
||||
def test_reranker_adapter_descriptions_cover_all_labels():
|
||||
from scripts.classifier_adapters import LABEL_DESCRIPTIONS, LABELS
|
||||
assert set(LABEL_DESCRIPTIONS.keys()) == set(LABELS)
|
||||
Loading…
Reference in a new issue