504 lines
20 KiB
Python
504 lines
20 KiB
Python
"""Classifier adapters for email classification benchmark.
|
||
|
||
Each adapter wraps a HuggingFace model and normalizes output to LABELS.
|
||
Models load lazily on first classify() call; call unload() to free VRAM.
|
||
"""
|
||
from __future__ import annotations
|
||
|
||
import abc
|
||
from collections import defaultdict
|
||
from pathlib import Path
|
||
from typing import Any
|
||
|
||
__all__ = [
|
||
"LABELS",
|
||
"LABEL_DESCRIPTIONS",
|
||
"DEFAULT_EXEMPLARS",
|
||
"compute_metrics",
|
||
"ClassifierAdapter",
|
||
"ZeroShotAdapter",
|
||
"GLiClassAdapter",
|
||
"RerankerAdapter",
|
||
"FineTunedAdapter",
|
||
"EmbeddingKNNAdapter",
|
||
]
|
||
|
||
LABELS: list[str] = [
|
||
"interview_scheduled",
|
||
"offer_received",
|
||
"rejected",
|
||
"positive_response",
|
||
"survey_received",
|
||
"neutral",
|
||
"event_rescheduled",
|
||
"digest",
|
||
"new_lead",
|
||
"hired",
|
||
]
|
||
|
||
# Natural-language descriptions used by the RerankerAdapter.
|
||
LABEL_DESCRIPTIONS: dict[str, str] = {
|
||
"interview_scheduled": "scheduling an interview, phone screen, or video call",
|
||
"offer_received": "a formal job offer or employment offer letter",
|
||
"rejected": "application rejected or not moving forward with candidacy",
|
||
"positive_response": "positive recruiter interest or request to connect",
|
||
"survey_received": "invitation to complete a culture-fit survey or assessment",
|
||
"neutral": "automated ATS confirmation such as application received",
|
||
"event_rescheduled": "an interview or scheduled event moved to a new time",
|
||
"digest": "job digest or multi-listing email with multiple job postings",
|
||
"new_lead": "unsolicited recruiter outreach or cold contact about a new opportunity",
|
||
"hired": "job offer accepted, onboarding logistics, welcome email, or start date confirmation",
|
||
}
|
||
|
||
# Lazy import shims — allow tests to patch without requiring the libs installed.
|
||
try:
|
||
from transformers import pipeline # type: ignore[assignment]
|
||
except ImportError:
|
||
pipeline = None # type: ignore[assignment]
|
||
|
||
try:
|
||
from gliclass import GLiClassModel, ZeroShotClassificationPipeline # type: ignore
|
||
from transformers import AutoTokenizer
|
||
except ImportError:
|
||
GLiClassModel = None # type: ignore
|
||
ZeroShotClassificationPipeline = None # type: ignore
|
||
AutoTokenizer = None # type: ignore
|
||
|
||
try:
|
||
from FlagEmbedding import FlagReranker # type: ignore
|
||
except ImportError:
|
||
FlagReranker = None # type: ignore
|
||
|
||
|
||
def _cuda_available() -> bool:
|
||
try:
|
||
import torch
|
||
return torch.cuda.is_available()
|
||
except ImportError:
|
||
return False
|
||
|
||
|
||
def compute_metrics(
|
||
predictions: list[str],
|
||
gold: list[str],
|
||
labels: list[str],
|
||
) -> dict[str, Any]:
|
||
"""Return per-label precision/recall/F1 + macro_f1 + accuracy."""
|
||
tp: dict[str, int] = defaultdict(int)
|
||
fp: dict[str, int] = defaultdict(int)
|
||
fn: dict[str, int] = defaultdict(int)
|
||
|
||
for pred, true in zip(predictions, gold):
|
||
if pred == true:
|
||
tp[pred] += 1
|
||
else:
|
||
fp[pred] += 1
|
||
fn[true] += 1
|
||
|
||
result: dict[str, Any] = {}
|
||
for label in labels:
|
||
denom_p = tp[label] + fp[label]
|
||
denom_r = tp[label] + fn[label]
|
||
p = tp[label] / denom_p if denom_p else 0.0
|
||
r = tp[label] / denom_r if denom_r else 0.0
|
||
f1 = 2 * p * r / (p + r) if (p + r) else 0.0
|
||
result[label] = {
|
||
"precision": p,
|
||
"recall": r,
|
||
"f1": f1,
|
||
"support": denom_r,
|
||
}
|
||
|
||
labels_with_support = [label for label in labels if result[label]["support"] > 0]
|
||
if labels_with_support:
|
||
result["__macro_f1__"] = (
|
||
sum(result[label]["f1"] for label in labels_with_support) / len(labels_with_support)
|
||
)
|
||
else:
|
||
result["__macro_f1__"] = 0.0
|
||
result["__accuracy__"] = sum(tp.values()) / len(predictions) if predictions else 0.0
|
||
return result
|
||
|
||
|
||
|
||
def _cosine(a: list[float], b: list[float]) -> float:
|
||
dot = sum(x * y for x, y in zip(a, b))
|
||
norm_a = sum(x * x for x in a) ** 0.5
|
||
norm_b = sum(x * x for x in b) ** 0.5
|
||
return dot / (norm_a * norm_b) if norm_a and norm_b else 0.0
|
||
|
||
|
||
DEFAULT_EXEMPLARS: dict[str, list[str]] = {
|
||
"interview_scheduled": [
|
||
"Subject: Interview Invitation\n\nWe would like to invite you for a phone screen next week.",
|
||
"Subject: Schedule a call\n\nCould you be available for a video interview on Tuesday?",
|
||
"Subject: Next Steps\n\nWe'd like to move forward with a technical interview. Please select a time.",
|
||
"Subject: Interview Details\n\nHere are the dial-in instructions for your interview tomorrow.",
|
||
],
|
||
"offer_received": [
|
||
"Subject: Offer Letter Enclosed\n\nWe are pleased to extend you an offer of employment.",
|
||
"Subject: Job Offer\n\nDear candidate, we are excited to offer you the position of Software Engineer.",
|
||
"Subject: Employment Offer\n\nPlease find attached your formal offer letter and compensation details.",
|
||
"Subject: Offer of Employment\n\nCongratulations! We would like to offer you a full-time position.",
|
||
],
|
||
"rejected": [
|
||
"Subject: Your Application\n\nAfter careful consideration, we have decided to move forward with other candidates.",
|
||
"Subject: Application Status\n\nWe regret to inform you that your application has not been selected.",
|
||
"Subject: Thank you for applying\n\nWe appreciate your interest but have chosen not to proceed.",
|
||
"Subject: Update on your candidacy\n\nWe will not be moving forward with your application at this time.",
|
||
],
|
||
"positive_response": [
|
||
"Subject: Your profile\n\nI came across your LinkedIn and think you would be a great fit for our team.",
|
||
"Subject: Exciting opportunity\n\nWe were impressed by your background and would love to connect.",
|
||
"Subject: Following up\n\nThank you for your interest — we'd like to learn more about your experience.",
|
||
"Subject: Great fit\n\nYour skills align well with what we are looking for. Let's set up a call.",
|
||
],
|
||
"survey_received": [
|
||
"Subject: Candidate Experience Survey\n\nPlease complete this brief survey about your application experience.",
|
||
"Subject: Culture Fit Assessment\n\nAs part of our process, we ask all candidates to complete a short assessment.",
|
||
"Subject: Skills Assessment\n\nWe'd like you to complete our online coding assessment before proceeding.",
|
||
"Subject: Personality Assessment\n\nPlease complete the following assessment as the next step in our process.",
|
||
"Subject: Pre-interview questionnaire\n\nBefore we schedule your interview, please complete this brief skills survey.",
|
||
],
|
||
"neutral": [
|
||
"Subject: Application Received\n\nWe have received your application and will be in touch.",
|
||
"Subject: Thank you for applying\n\nYour application is under review. We will contact you if needed.",
|
||
"Subject: Confirmation\n\nThis email confirms receipt of your application to our company.",
|
||
"Subject: Application Confirmation\n\nThank you for your interest. We will review your materials and follow up.",
|
||
],
|
||
"event_rescheduled": [
|
||
"Subject: Interview Rescheduled\n\nDue to a conflict, we need to move your interview to a new time.",
|
||
"Subject: Change of interview time\n\nWe apologize — your interview has been rescheduled to Thursday.",
|
||
"Subject: Updated interview details\n\nYour interview has been moved from Monday to Wednesday at 2pm.",
|
||
"Subject: Reschedule request\n\nWould you be available to reschedule to a different time slot?",
|
||
"Subject: New interview time\n\nYour phone screen has been moved from tomorrow to next week.",
|
||
],
|
||
"digest": [
|
||
"Subject: 15 new jobs matching your search\n\nHere are the latest job postings that match your profile.",
|
||
"Subject: Weekly Job Digest\n\nThis week's top opportunities for Software Engineers in your area.",
|
||
"Subject: Jobs you might like\n\nBased on your profile, here are some positions we recommend.",
|
||
"Subject: New jobs for you\n\nSee the latest openings from companies on your watchlist.",
|
||
],
|
||
"new_lead": [
|
||
"Subject: Exciting opportunity at our company\n\nHi, I noticed your background and think you'd be a great fit.",
|
||
"Subject: Are you open to new opportunities?\n\nI'm a recruiter reaching out about a role matching your experience.",
|
||
"Subject: Quick question\n\nWould you be interested in hearing about a senior engineering role?",
|
||
"Subject: Recruiting outreach\n\nI came across your profile and wanted to share an exciting opening.",
|
||
],
|
||
"hired": [
|
||
"Subject: Welcome to the team!\n\nWe are thrilled to have you join us. Here are your onboarding details.",
|
||
"Subject: Onboarding information\n\nCongratulations on accepting our offer. Your start date is confirmed.",
|
||
"Subject: First day information\n\nWe look forward to your first day. Please arrive at 9am and ask for HR.",
|
||
"Subject: Background check initiated\n\nAs part of your onboarding, we have initiated a background check.",
|
||
"Subject: Equipment setup\n\nYour laptop and equipment will be ready for pickup on your first day.",
|
||
],
|
||
}
|
||
|
||
|
||
class ClassifierAdapter(abc.ABC):
|
||
"""Abstract base for all email classifier adapters."""
|
||
|
||
@property
|
||
@abc.abstractmethod
|
||
def name(self) -> str: ...
|
||
|
||
@property
|
||
@abc.abstractmethod
|
||
def model_id(self) -> str: ...
|
||
|
||
@abc.abstractmethod
|
||
def load(self) -> None:
|
||
"""Download/load the model into memory."""
|
||
|
||
@abc.abstractmethod
|
||
def unload(self) -> None:
|
||
"""Release model from memory."""
|
||
|
||
@abc.abstractmethod
|
||
def classify(self, subject: str, body: str) -> str:
|
||
"""Return one of LABELS for the given email."""
|
||
|
||
|
||
class ZeroShotAdapter(ClassifierAdapter):
|
||
"""Wraps any transformers zero-shot-classification pipeline.
|
||
|
||
load() calls pipeline("zero-shot-classification", model=..., device=...) to get
|
||
an inference callable, stored as self._pipeline. classify() then calls
|
||
self._pipeline(text, LABELS, multi_label=False). In tests, patch
|
||
'scripts.classifier_adapters.pipeline' with a MagicMock whose .return_value is
|
||
itself a MagicMock(return_value={...}) to simulate both the factory call and the
|
||
inference call.
|
||
|
||
two_pass: if True, classify() runs a second pass restricted to the top-2 labels
|
||
from the first pass, forcing a binary choice. This typically improves confidence
|
||
without the accuracy cost of a full 6-label second run.
|
||
"""
|
||
|
||
def __init__(self, name: str, model_id: str, two_pass: bool = False) -> None:
|
||
self._name = name
|
||
self._model_id = model_id
|
||
self._pipeline: Any = None
|
||
self._two_pass = two_pass
|
||
|
||
@property
|
||
def name(self) -> str:
|
||
return self._name
|
||
|
||
@property
|
||
def model_id(self) -> str:
|
||
return self._model_id
|
||
|
||
def load(self) -> None:
|
||
import scripts.classifier_adapters as _mod # noqa: PLC0415
|
||
_pipe_fn = _mod.pipeline
|
||
if _pipe_fn is None:
|
||
raise ImportError("transformers not installed — run: pip install transformers")
|
||
device = 0 if _cuda_available() else -1
|
||
# Instantiate the pipeline once; classify() calls the resulting object on each text.
|
||
self._pipeline = _pipe_fn("zero-shot-classification", model=self._model_id, device=device)
|
||
|
||
def unload(self) -> None:
|
||
self._pipeline = None
|
||
|
||
def classify(self, subject: str, body: str) -> str:
|
||
if self._pipeline is None:
|
||
self.load()
|
||
text = f"Subject: {subject}\n\n{body[:600]}"
|
||
result = self._pipeline(text, LABELS, multi_label=False)
|
||
if self._two_pass and len(result["labels"]) >= 2:
|
||
top2 = result["labels"][:2]
|
||
result = self._pipeline(text, top2, multi_label=False)
|
||
return result["labels"][0]
|
||
|
||
|
||
class GLiClassAdapter(ClassifierAdapter):
|
||
"""Wraps knowledgator GLiClass models via the gliclass library."""
|
||
|
||
def __init__(self, name: str, model_id: str) -> None:
|
||
self._name = name
|
||
self._model_id = model_id
|
||
self._pipeline: Any = None
|
||
|
||
@property
|
||
def name(self) -> str:
|
||
return self._name
|
||
|
||
@property
|
||
def model_id(self) -> str:
|
||
return self._model_id
|
||
|
||
def load(self) -> None:
|
||
if GLiClassModel is None:
|
||
raise ImportError("gliclass not installed — run: pip install gliclass")
|
||
device = "cuda:0" if _cuda_available() else "cpu"
|
||
model = GLiClassModel.from_pretrained(self._model_id)
|
||
tokenizer = AutoTokenizer.from_pretrained(self._model_id)
|
||
self._pipeline = ZeroShotClassificationPipeline(
|
||
model,
|
||
tokenizer,
|
||
classification_type="single-label",
|
||
device=device,
|
||
)
|
||
|
||
def unload(self) -> None:
|
||
self._pipeline = None
|
||
|
||
def classify(self, subject: str, body: str) -> str:
|
||
if self._pipeline is None:
|
||
self.load()
|
||
text = f"Subject: {subject}\n\n{body[:600]}"
|
||
results = self._pipeline(text, LABELS, threshold=0.0)[0]
|
||
return max(results, key=lambda r: r["score"])["label"]
|
||
|
||
|
||
class RerankerAdapter(ClassifierAdapter):
|
||
"""Uses a BGE reranker to score (email, label_description) pairs."""
|
||
|
||
def __init__(self, name: str, model_id: str) -> None:
|
||
self._name = name
|
||
self._model_id = model_id
|
||
self._reranker: Any = None
|
||
|
||
@property
|
||
def name(self) -> str:
|
||
return self._name
|
||
|
||
@property
|
||
def model_id(self) -> str:
|
||
return self._model_id
|
||
|
||
def load(self) -> None:
|
||
if FlagReranker is None:
|
||
raise ImportError("FlagEmbedding not installed — run: pip install FlagEmbedding")
|
||
self._reranker = FlagReranker(self._model_id, use_fp16=_cuda_available())
|
||
|
||
def unload(self) -> None:
|
||
self._reranker = None
|
||
|
||
def classify(self, subject: str, body: str) -> str:
|
||
if self._reranker is None:
|
||
self.load()
|
||
text = f"Subject: {subject}\n\n{body[:600]}"
|
||
pairs = [[text, LABEL_DESCRIPTIONS.get(label, label.replace("_", " "))] for label in LABELS]
|
||
scores: list[float] = self._reranker.compute_score(pairs, normalize=True)
|
||
return LABELS[scores.index(max(scores))]
|
||
|
||
|
||
class FineTunedAdapter(ClassifierAdapter):
|
||
"""Loads a fine-tuned checkpoint from a local models/ directory.
|
||
|
||
Uses pipeline("text-classification") for a single forward pass.
|
||
Input format: 'subject [SEP] body[:400]' — must match training format exactly.
|
||
Expected inference speed: ~10–20ms/email vs 111–338ms for zero-shot.
|
||
"""
|
||
|
||
def __init__(self, name: str, model_dir: str) -> None:
|
||
self._name = name
|
||
self._model_dir = model_dir
|
||
self._pipeline: Any = None
|
||
|
||
@property
|
||
def name(self) -> str:
|
||
return self._name
|
||
|
||
@property
|
||
def model_id(self) -> str:
|
||
return self._model_dir
|
||
|
||
def load(self) -> None:
|
||
import scripts.classifier_adapters as _mod # noqa: PLC0415
|
||
_pipe_fn = _mod.pipeline
|
||
if _pipe_fn is None:
|
||
raise ImportError("transformers not installed — run: pip install transformers")
|
||
device = 0 if _cuda_available() else -1
|
||
self._pipeline = _pipe_fn("text-classification", model=self._model_dir, device=device)
|
||
|
||
def unload(self) -> None:
|
||
self._pipeline = None
|
||
|
||
def classify(self, subject: str, body: str) -> str:
|
||
if self._pipeline is None:
|
||
self.load()
|
||
text = f"{subject} [SEP] {body[:400]}"
|
||
result = self._pipeline(text)
|
||
return result[0]["label"]
|
||
|
||
|
||
class EmbeddingKNNAdapter(ClassifierAdapter):
|
||
"""k-NN email classifier using Ollama /v1/embeddings via cf-orch allocation.
|
||
|
||
load():
|
||
1. Allocates an Ollama instance from cf-orch (POST /api/services/ollama/allocate).
|
||
Falls back to ollama_url directly if orch allocation fails or is not configured.
|
||
2. Pre-embeds all exemplar texts and stores per-label vector lists.
|
||
|
||
classify(subject, body):
|
||
Embeds the input email, computes cosine similarity against all stored exemplar
|
||
vectors, and majority-votes the top-k labels (default k=3). Tie-break: label
|
||
with the highest mean similarity score among tied vote counts wins.
|
||
|
||
unload():
|
||
Releases the cf-orch allocation (DELETE .../allocations/{id}) and clears state.
|
||
"""
|
||
|
||
def __init__(
|
||
self,
|
||
name: str,
|
||
model_id: str,
|
||
*,
|
||
k: int = 3,
|
||
orch_url: str = "",
|
||
ollama_url: str = "",
|
||
exemplar_texts: dict[str, list[str]] | None = None,
|
||
) -> None:
|
||
self._name = name
|
||
self._model_id = model_id
|
||
self._k = k
|
||
self._orch_url = orch_url
|
||
self._ollama_url = ollama_url
|
||
self._exemplar_texts: dict[str, list[str]] = (
|
||
exemplar_texts if exemplar_texts is not None else DEFAULT_EXEMPLARS
|
||
)
|
||
self._exemplar_embeddings: dict[str, list[list[float]]] = {}
|
||
self._node_url: str = ""
|
||
self._allocation_id: str = ""
|
||
self._orch_url_used: str = ""
|
||
|
||
@property
|
||
def name(self) -> str:
|
||
return self._name
|
||
|
||
@property
|
||
def model_id(self) -> str:
|
||
return self._model_id
|
||
|
||
def _resolve_urls(self) -> tuple[str, str]:
|
||
if self._orch_url or self._ollama_url:
|
||
return self._orch_url, self._ollama_url
|
||
import yaml # noqa: PLC0415
|
||
cfg_path = Path(__file__).parent.parent / "config" / "label_tool.yaml"
|
||
cfg: dict = {}
|
||
if cfg_path.exists():
|
||
try:
|
||
cfg = yaml.safe_load(cfg_path.read_text(encoding="utf-8")) or {}
|
||
except yaml.YAMLError:
|
||
pass
|
||
cforch = cfg.get("cforch", {}) or {}
|
||
return cforch.get("coordinator_url", ""), cforch.get("ollama_url", "")
|
||
|
||
def _embed(self, node_url: str, texts: list[str]) -> list[list[float]]:
|
||
import httpx # noqa: PLC0415
|
||
resp = httpx.post(
|
||
f"{node_url}/v1/embeddings",
|
||
json={"model": self._model_id, "input": texts},
|
||
timeout=30.0,
|
||
)
|
||
resp.raise_for_status()
|
||
return [item["embedding"] for item in resp.json()["data"]]
|
||
|
||
def load(self) -> None:
|
||
import httpx # noqa: PLC0415
|
||
orch_url, ollama_url = self._resolve_urls()
|
||
node_url = ""
|
||
orch_url_used = ""
|
||
if orch_url:
|
||
try:
|
||
resp = httpx.post(
|
||
f"{orch_url}/api/services/ollama/allocate",
|
||
json={"model": self._model_id},
|
||
timeout=15.0,
|
||
)
|
||
if resp.status_code == 200:
|
||
data = resp.json()
|
||
node_url = data["url"]
|
||
self._allocation_id = data["allocation_id"]
|
||
orch_url_used = orch_url
|
||
except Exception:
|
||
pass
|
||
if not node_url:
|
||
node_url = ollama_url
|
||
self._allocation_id = ""
|
||
orch_url_used = ""
|
||
self._node_url = node_url
|
||
self._orch_url_used = orch_url_used
|
||
for label, texts in self._exemplar_texts.items():
|
||
self._exemplar_embeddings[label] = self._embed(node_url, texts)
|
||
|
||
def unload(self) -> None:
|
||
if self._allocation_id and self._orch_url_used:
|
||
try:
|
||
import httpx # noqa: PLC0415
|
||
httpx.request(
|
||
"DELETE",
|
||
f"{self._orch_url_used}/api/services/ollama/allocations/{self._allocation_id}",
|
||
timeout=10.0,
|
||
)
|
||
except Exception:
|
||
pass
|
||
self._exemplar_embeddings = {}
|
||
self._node_url = ""
|
||
self._allocation_id = ""
|
||
self._orch_url_used = ""
|
||
|
||
def classify(self, subject: str, body: str) -> str:
|
||
raise NotImplementedError
|