diff --git a/scripts/classifier_adapters.py b/scripts/classifier_adapters.py
index 778e1d4..e6020e2 100644
--- a/scripts/classifier_adapters.py
+++ b/scripts/classifier_adapters.py
@@ -26,6 +26,9 @@ LABELS: list[str] = [
"positive_response",
"survey_received",
"neutral",
+ "event_rescheduled",
+ "unrelated",
+ "digest",
]
# Natural-language descriptions used by the RerankerAdapter.
@@ -35,7 +38,10 @@ LABEL_DESCRIPTIONS: dict[str, str] = {
"rejected": "application rejected or not moving forward with candidacy",
"positive_response": "positive recruiter interest or request to connect",
"survey_received": "invitation to complete a culture-fit survey or assessment",
- "neutral": "automated ATS confirmation or unrelated email",
+ "neutral": "automated ATS confirmation such as application received",
+ "event_rescheduled": "an interview or scheduled event moved to a new time",
+ "unrelated": "non-job-search email unrelated to any application or recruiter",
+ "digest": "job digest or multi-listing email with multiple job postings",
}
# Lazy import shims — allow tests to patch without requiring the libs installed.
@@ -135,23 +141,23 @@ class ClassifierAdapter(abc.ABC):
class ZeroShotAdapter(ClassifierAdapter):
"""Wraps any transformers zero-shot-classification pipeline.
- Design note: the module-level ``pipeline`` shim is resolved once in load()
- and stored as ``self._pipeline``. classify() calls ``self._pipeline`` directly
- with (text, candidate_labels, multi_label=False). This makes the adapter
- patchable in tests via ``patch('scripts.classifier_adapters.pipeline', mock)``:
- ``mock`` is stored in ``self._pipeline`` and called with the text during
- classify(), so ``mock.call_args`` captures the arguments.
+ load() calls pipeline("zero-shot-classification", model=..., device=...) to get
+ an inference callable, stored as self._pipeline. classify() then calls
+ self._pipeline(text, LABELS, multi_label=False). In tests, patch
+ 'scripts.classifier_adapters.pipeline' with a MagicMock whose .return_value is
+ itself a MagicMock(return_value={...}) to simulate both the factory call and the
+ inference call.
- For real transformers use, ``pipeline`` is the factory function and the call
- in classify() initialises the pipeline on first use (lazy loading without
- pre-caching a model object). Subclasses that need a pre-warmed model object
- should override load() to call the factory and store the result.
+ two_pass: if True, classify() runs a second pass restricted to the top-2 labels
+ from the first pass, forcing a binary choice. This typically improves confidence
+ without the accuracy cost of a full 6-label second run.
"""
- def __init__(self, name: str, model_id: str) -> None:
+ def __init__(self, name: str, model_id: str, two_pass: bool = False) -> None:
self._name = name
self._model_id = model_id
self._pipeline: Any = None
+ self._two_pass = two_pass
@property
def name(self) -> str:
@@ -166,9 +172,9 @@ class ZeroShotAdapter(ClassifierAdapter):
_pipe_fn = _mod.pipeline
if _pipe_fn is None:
raise ImportError("transformers not installed — run: pip install transformers")
- # Store the pipeline factory/callable so that test patches are honoured.
- # classify() will call self._pipeline(text, labels, multi_label=False).
- self._pipeline = _pipe_fn
+ device = 0 if _cuda_available() else -1
+ # Instantiate the pipeline once; classify() calls the resulting object on each text.
+ self._pipeline = _pipe_fn("zero-shot-classification", model=self._model_id, device=device)
def unload(self) -> None:
self._pipeline = None
@@ -178,6 +184,9 @@ class ZeroShotAdapter(ClassifierAdapter):
self.load()
text = f"Subject: {subject}\n\n{body[:600]}"
result = self._pipeline(text, LABELS, multi_label=False)
+ if self._two_pass and len(result["labels"]) >= 2:
+ top2 = result["labels"][:2]
+ result = self._pipeline(text, top2, multi_label=False)
return result["labels"][0]
diff --git a/tests/test_classifier_adapters.py b/tests/test_classifier_adapters.py
index 26da0ce..feb2f6a 100644
--- a/tests/test_classifier_adapters.py
+++ b/tests/test_classifier_adapters.py
@@ -2,11 +2,14 @@
import pytest
-def test_labels_constant_has_six_items():
+def test_labels_constant_has_nine_items():
from scripts.classifier_adapters import LABELS
- assert len(LABELS) == 6
+ assert len(LABELS) == 9
assert "interview_scheduled" in LABELS
assert "neutral" in LABELS
+ assert "event_rescheduled" in LABELS
+ assert "unrelated" in LABELS
+ assert "digest" in LABELS
def test_compute_metrics_perfect_predictions():
@@ -57,20 +60,23 @@ def test_zeroshot_adapter_classify_mocked():
from unittest.mock import MagicMock, patch
from scripts.classifier_adapters import ZeroShotAdapter
- mock_pipeline = MagicMock()
- mock_pipeline.return_value = {
+ # Two-level mock: factory call returns pipeline instance; instance call returns inference result.
+ mock_pipe_factory = MagicMock()
+ mock_pipe_factory.return_value = MagicMock(return_value={
"labels": ["rejected", "neutral", "interview_scheduled"],
"scores": [0.85, 0.10, 0.05],
- }
+ })
- with patch("scripts.classifier_adapters.pipeline", mock_pipeline):
+ with patch("scripts.classifier_adapters.pipeline", mock_pipe_factory):
adapter = ZeroShotAdapter("test-zs", "some/model")
adapter.load()
result = adapter.classify("We went with another candidate", "Thank you for applying.")
assert result == "rejected"
- call_args = mock_pipeline.call_args
- assert "We went with another candidate" in call_args[0][0]
+ # Factory was called with the correct task type
+ assert mock_pipe_factory.call_args[0][0] == "zero-shot-classification"
+ # Pipeline instance was called with the email text
+ assert "We went with another candidate" in mock_pipe_factory.return_value.call_args[0][0]
def test_zeroshot_adapter_unload_clears_pipeline():
diff --git a/tools/label_tool.py b/tools/label_tool.py
new file mode 100644
index 0000000..74d1857
--- /dev/null
+++ b/tools/label_tool.py
@@ -0,0 +1,648 @@
+"""Email Label Tool — card-stack UI for building classifier benchmark data.
+
+Philosophy: Scrape → Store → Process
+ Fetch (IMAP, wide search, multi-account) → data/email_label_queue.jsonl
+ Label (card stack) → data/email_score.jsonl
+
+Run:
+ conda run -n job-seeker streamlit run tools/label_tool.py --server.port 8503
+
+Config: config/label_tool.yaml (gitignored — see config/label_tool.yaml.example)
+"""
+from __future__ import annotations
+
+import email as _email_lib
+import hashlib
+import html as _html
+import imaplib
+import json
+import re
+import sys
+from datetime import datetime, timedelta
+from email.header import decode_header as _raw_decode
+from pathlib import Path
+from typing import Any
+
+import streamlit as st
+import yaml
+
+# ── Path setup ─────────────────────────────────────────────────────────────
+_ROOT = Path(__file__).parent.parent
+sys.path.insert(0, str(_ROOT))
+
+_QUEUE_FILE = _ROOT / "data" / "email_label_queue.jsonl"
+_SCORE_FILE = _ROOT / "data" / "email_score.jsonl"
+_CFG_FILE = _ROOT / "config" / "label_tool.yaml"
+
+# ── Labels ─────────────────────────────────────────────────────────────────
+LABELS = [
+ "interview_scheduled",
+ "offer_received",
+ "rejected",
+ "positive_response",
+ "survey_received",
+ "neutral",
+ "event_rescheduled",
+ "unrelated",
+ "digest",
+]
+
+_LABEL_META: dict[str, dict] = {
+ "interview_scheduled": {"emoji": "🗓️", "color": "#4CAF50", "key": "1"},
+ "offer_received": {"emoji": "🎉", "color": "#2196F3", "key": "2"},
+ "rejected": {"emoji": "❌", "color": "#F44336", "key": "3"},
+ "positive_response": {"emoji": "👍", "color": "#FF9800", "key": "4"},
+ "survey_received": {"emoji": "📋", "color": "#9C27B0", "key": "5"},
+ "neutral": {"emoji": "⬜", "color": "#607D8B", "key": "6"},
+ "event_rescheduled": {"emoji": "🔄", "color": "#FF5722", "key": "7"},
+ "unrelated": {"emoji": "🗑️", "color": "#757575", "key": "8"},
+ "digest": {"emoji": "📰", "color": "#00BCD4", "key": "9"},
+}
+
+# ── HTML sanitiser ───────────────────────────────────────────────────────────
+# Valid chars per XML 1.0 §2.2 (same set HTML5 innerHTML enforces):
+# #x9 | #xA | #xD | [#x20–#xD7FF] | [#xE000–#xFFFD] | [#x10000–#x10FFFF]
+# Anything outside this range causes InvalidCharacterError in the browser.
+_INVALID_XML_CHARS = re.compile(
+ r"[^\x09\x0A\x0D\x20-\uD7FF\uE000-\uFFFD\U00010000-\U0010FFFF]"
+)
+
+def _to_html(text: str, newlines_to_br: bool = False) -> str:
+ """Strip invalid XML chars, HTML-escape the result, optionally convert \\n →
."""
+ if not text:
+ return ""
+ cleaned = _INVALID_XML_CHARS.sub("", text)
+ escaped = _html.escape(cleaned)
+ if newlines_to_br:
+ escaped = escaped.replace("\n", "
")
+ return escaped
+
+
+# ── Wide IMAP search terms (cast a net across all 9 categories) ─────────────
+_WIDE_TERMS = [
+ # interview_scheduled
+ "interview", "phone screen", "video call", "zoom link", "schedule a call",
+ # offer_received
+ "offer letter", "job offer", "offer of employment", "pleased to offer",
+ # rejected
+ "unfortunately", "not moving forward", "other candidates", "regret to inform",
+ "no longer", "decided not to", "decided to go with",
+ # positive_response
+ "opportunity", "interested in your background", "reached out", "great fit",
+ "exciting role", "love to connect",
+ # survey_received
+ "assessment", "questionnaire", "culture fit", "culture-fit", "online assessment",
+ # neutral / ATS confirms
+ "application received", "thank you for applying", "application confirmation",
+ "you applied", "your application for",
+ # event_rescheduled
+ "reschedule", "rescheduled", "new time", "moved to", "postponed", "new date",
+ # digest
+ "job digest", "jobs you may like", "recommended jobs", "jobs for you",
+ "new jobs", "job alert",
+ # general recruitment
+ "application", "recruiter", "recruiting", "hiring", "candidate",
+]
+
+
+# ── IMAP helpers ────────────────────────────────────────────────────────────
+
+def _decode_str(value: str | None) -> str:
+ if not value:
+ return ""
+ parts = _raw_decode(value)
+ out = []
+ for part, enc in parts:
+ if isinstance(part, bytes):
+ out.append(part.decode(enc or "utf-8", errors="replace"))
+ else:
+ out.append(str(part))
+ return " ".join(out).strip()
+
+
+def _extract_body(msg: Any) -> str:
+ if msg.is_multipart():
+ for part in msg.walk():
+ if part.get_content_type() == "text/plain":
+ try:
+ charset = part.get_content_charset() or "utf-8"
+ return part.get_payload(decode=True).decode(charset, errors="replace")
+ except Exception:
+ pass
+ else:
+ try:
+ charset = msg.get_content_charset() or "utf-8"
+ return msg.get_payload(decode=True).decode(charset, errors="replace")
+ except Exception:
+ pass
+ return ""
+
+
+def _fetch_account(cfg: dict, days: int, limit: int, known_keys: set[str],
+ progress_cb=None) -> list[dict]:
+ """Fetch emails from one IMAP account using wide recruitment search terms."""
+ since = (datetime.now() - timedelta(days=days)).strftime("%d-%b-%Y")
+ host = cfg.get("host", "imap.gmail.com")
+ port = int(cfg.get("port", 993))
+ use_ssl = cfg.get("use_ssl", True)
+ username = cfg["username"]
+ password = cfg["password"]
+ name = cfg.get("name", username)
+
+ conn = (imaplib.IMAP4_SSL if use_ssl else imaplib.IMAP4)(host, port)
+ conn.login(username, password)
+
+ seen_uids: dict[bytes, None] = {}
+ conn.select("INBOX", readonly=True)
+ for term in _WIDE_TERMS:
+ try:
+ _, data = conn.search(None, f'(SUBJECT "{term}" SINCE "{since}")')
+ for uid in (data[0] or b"").split():
+ seen_uids[uid] = None
+ except Exception:
+ pass
+
+ emails: list[dict] = []
+ uids = list(seen_uids.keys())[:limit * 3] # overfetch; filter after dedup
+ for i, uid in enumerate(uids):
+ if len(emails) >= limit:
+ break
+ if progress_cb:
+ progress_cb(i / len(uids), f"{name}: {len(emails)} fetched…")
+ try:
+ _, raw_data = conn.fetch(uid, "(RFC822)")
+ if not raw_data or not raw_data[0]:
+ continue
+ msg = _email_lib.message_from_bytes(raw_data[0][1])
+ subj = _decode_str(msg.get("Subject", ""))
+ from_addr = _decode_str(msg.get("From", ""))
+ date = _decode_str(msg.get("Date", ""))
+ body = _extract_body(msg)[:800]
+ entry = {
+ "subject": subj,
+ "body": body,
+ "from_addr": from_addr,
+ "date": date,
+ "account": name,
+ }
+ key = _entry_key(entry)
+ if key not in known_keys:
+ known_keys.add(key)
+ emails.append(entry)
+ except Exception:
+ pass
+
+ try:
+ conn.logout()
+ except Exception:
+ pass
+ return emails
+
+
+# ── Queue / score file helpers ───────────────────────────────────────────────
+
+def _entry_key(e: dict) -> str:
+ return hashlib.md5(
+ (e.get("subject", "") + (e.get("body") or "")[:100]).encode()
+ ).hexdigest()
+
+
+def _load_jsonl(path: Path) -> list[dict]:
+ if not path.exists():
+ return []
+ rows = []
+ with path.open() as f:
+ for line in f:
+ line = line.strip()
+ if line:
+ try:
+ rows.append(json.loads(line))
+ except Exception:
+ pass
+ return rows
+
+
+def _save_jsonl(path: Path, rows: list[dict]) -> None:
+ path.parent.mkdir(parents=True, exist_ok=True)
+ with path.open("w") as f:
+ for row in rows:
+ f.write(json.dumps(row, ensure_ascii=False) + "\n")
+
+
+def _append_jsonl(path: Path, row: dict) -> None:
+ path.parent.mkdir(parents=True, exist_ok=True)
+ with path.open("a") as f:
+ f.write(json.dumps(row, ensure_ascii=False) + "\n")
+
+
+# ── Config ──────────────────────────────────────────────────────────────────
+
+def _load_config() -> list[dict]:
+ if not _CFG_FILE.exists():
+ return []
+ cfg = yaml.safe_load(_CFG_FILE.read_text()) or {}
+ return cfg.get("accounts", [])
+
+
+# ── Page setup ──────────────────────────────────────────────────────────────
+
+st.set_page_config(
+ page_title="Email Labeler",
+ page_icon="📬",
+ layout="wide",
+)
+
+st.markdown("""
+
+""", unsafe_allow_html=True)
+
+st.title("📬 Email Label Tool")
+st.caption("Scrape → Store → Process | card-stack edition")
+
+# ── Session state init ───────────────────────────────────────────────────────
+
+if "queue" not in st.session_state:
+ st.session_state.queue: list[dict] = _load_jsonl(_QUEUE_FILE)
+
+if "labeled" not in st.session_state:
+ st.session_state.labeled: list[dict] = _load_jsonl(_SCORE_FILE)
+ st.session_state.labeled_keys: set[str] = {
+ _entry_key(r) for r in st.session_state.labeled
+ }
+
+if "idx" not in st.session_state:
+ # Start past already-labeled entries in the queue
+ labeled_keys = st.session_state.labeled_keys
+ for i, entry in enumerate(st.session_state.queue):
+ if _entry_key(entry) not in labeled_keys:
+ st.session_state.idx = i
+ break
+ else:
+ st.session_state.idx = len(st.session_state.queue)
+
+if "history" not in st.session_state:
+ st.session_state.history: list[tuple[int, str]] = [] # (queue_idx, label)
+
+
+# ── Sidebar stats ────────────────────────────────────────────────────────────
+
+with st.sidebar:
+ labeled = st.session_state.labeled
+ queue = st.session_state.queue
+ unlabeled = [e for e in queue if _entry_key(e) not in st.session_state.labeled_keys]
+
+ st.metric("✅ Labeled", len(labeled))
+ st.metric("📥 Queue", len(unlabeled))
+
+ if labeled:
+ st.caption("**Label distribution**")
+ counts = {lbl: 0 for lbl in LABELS}
+ for r in labeled:
+ counts[r.get("label", "")] = counts.get(r.get("label", ""), 0) + 1
+ for lbl in LABELS:
+ m = _LABEL_META[lbl]
+ st.caption(f"{m['emoji']} {lbl}: **{counts[lbl]}**")
+
+
+# ── Tabs ─────────────────────────────────────────────────────────────────────
+
+tab_label, tab_fetch, tab_stats = st.tabs(["🃏 Label", "📥 Fetch", "📊 Stats"])
+
+
+# ══════════════════════════════════════════════════════════════════════════════
+# FETCH TAB
+# ══════════════════════════════════════════════════════════════════════════════
+
+with tab_fetch:
+ accounts = _load_config()
+
+ if not accounts:
+ st.warning(
+ f"No accounts configured. Copy `config/label_tool.yaml.example` → "
+ f"`config/label_tool.yaml` and add your IMAP accounts.",
+ icon="⚠️",
+ )
+ else:
+ st.markdown(f"**{len(accounts)} account(s) configured:**")
+ for acc in accounts:
+ st.caption(f"• {acc.get('name', acc.get('username'))} ({acc.get('host')})")
+
+ col_days, col_limit = st.columns(2)
+ days = col_days.number_input("Days back", min_value=7, max_value=730, value=180)
+ limit = col_limit.number_input("Max emails per account", min_value=10, max_value=1000, value=150)
+
+ all_accs = [a.get("name", a.get("username")) for a in accounts]
+ selected = st.multiselect("Accounts to fetch", all_accs, default=all_accs)
+
+ if st.button("📥 Fetch from IMAP", disabled=not accounts or not selected, type="primary"):
+ existing_keys = {_entry_key(e) for e in st.session_state.queue}
+ existing_keys.update(st.session_state.labeled_keys)
+
+ fetched_all: list[dict] = []
+ status = st.status("Fetching…", expanded=True)
+ _live = status.empty()
+
+ for acc in accounts:
+ name = acc.get("name", acc.get("username"))
+ if name not in selected:
+ continue
+ status.write(f"Connecting to **{name}**…")
+ try:
+ emails = _fetch_account(
+ acc, days=int(days), limit=int(limit),
+ known_keys=existing_keys,
+ progress_cb=lambda p, msg: _live.markdown(f"⏳ {msg}"),
+ )
+ _live.empty()
+ fetched_all.extend(emails)
+ status.write(f"✓ {name}: {len(emails)} new emails")
+ except Exception as e:
+ _live.empty()
+ status.write(f"✗ {name}: {e}")
+
+ if fetched_all:
+ _save_jsonl(_QUEUE_FILE, st.session_state.queue + fetched_all)
+ st.session_state.queue = _load_jsonl(_QUEUE_FILE)
+ # Reset idx to first unlabeled
+ labeled_keys = st.session_state.labeled_keys
+ for i, entry in enumerate(st.session_state.queue):
+ if _entry_key(entry) not in labeled_keys:
+ st.session_state.idx = i
+ break
+ status.update(label=f"Done — {len(fetched_all)} new emails added to queue", state="complete")
+ else:
+ status.update(label="No new emails found (all already in queue or score file)", state="complete")
+
+
+# ══════════════════════════════════════════════════════════════════════════════
+# LABEL TAB
+# ══════════════════════════════════════════════════════════════════════════════
+
+with tab_label:
+ queue = st.session_state.queue
+ labeled_keys = st.session_state.labeled_keys
+ idx = st.session_state.idx
+
+ # Advance idx past already-labeled entries
+ while idx < len(queue) and _entry_key(queue[idx]) in labeled_keys:
+ idx += 1
+ st.session_state.idx = idx
+
+ unlabeled = [e for e in queue if _entry_key(e) not in labeled_keys]
+ total_in_queue = len(queue)
+ n_labeled = len(st.session_state.labeled)
+
+ if not queue:
+ st.info("Queue is empty — go to **Fetch** to pull emails from IMAP.", icon="📥")
+ elif not unlabeled:
+ st.success(
+ f"🎉 All {n_labeled} emails labeled! Go to **Stats** to review and export.",
+ icon="✅",
+ )
+ else:
+ # Progress
+ labeled_in_queue = total_in_queue - len(unlabeled)
+ progress_pct = labeled_in_queue / total_in_queue if total_in_queue else 0
+ st.progress(progress_pct, text=f"{labeled_in_queue} / {total_in_queue} labeled in queue")
+
+ # Current email
+ entry = queue[idx]
+
+ # Card HTML
+ subj = entry.get("subject", "(no subject)") or "(no subject)"
+ from_ = entry.get("from_addr", "") or ""
+ date_ = entry.get("date", "") or ""
+ acct = entry.get("account", "") or ""
+ body = (entry.get("body") or "").strip()
+
+ st.markdown(
+ f"""