peregrine/scripts/classifier_adapters.py

244 lines
7.9 KiB
Python

"""Classifier adapters for email classification benchmark.
Each adapter wraps a HuggingFace model and normalizes output to LABELS.
Models load lazily on first classify() call; call unload() to free VRAM.
"""
from __future__ import annotations
import abc
from collections import defaultdict
from typing import Any
LABELS: list[str] = [
"interview_scheduled",
"offer_received",
"rejected",
"positive_response",
"survey_received",
"neutral",
]
# Natural-language descriptions used by the RerankerAdapter.
LABEL_DESCRIPTIONS: dict[str, str] = {
"interview_scheduled": "scheduling an interview, phone screen, or video call",
"offer_received": "a formal job offer or employment offer letter",
"rejected": "application rejected or not moving forward with candidacy",
"positive_response": "positive recruiter interest or request to connect",
"survey_received": "invitation to complete a culture-fit survey or assessment",
"neutral": "automated ATS confirmation or unrelated email",
}
# Lazy import shims — allow tests to patch without requiring the libs installed.
try:
from transformers import pipeline # type: ignore[assignment]
except ImportError:
pipeline = None # type: ignore[assignment]
try:
from gliclass import GLiClassModel, ZeroShotClassificationPipeline # type: ignore
from transformers import AutoTokenizer
except ImportError:
GLiClassModel = None # type: ignore
ZeroShotClassificationPipeline = None # type: ignore
AutoTokenizer = None # type: ignore
try:
from FlagEmbedding import FlagReranker # type: ignore
except ImportError:
FlagReranker = None # type: ignore
def _cuda_available() -> bool:
try:
import torch
return torch.cuda.is_available()
except ImportError:
return False
def compute_metrics(
predictions: list[str],
gold: list[str],
labels: list[str],
) -> dict[str, Any]:
"""Return per-label precision/recall/F1 + macro_f1 + accuracy."""
tp: dict[str, int] = defaultdict(int)
fp: dict[str, int] = defaultdict(int)
fn: dict[str, int] = defaultdict(int)
for pred, true in zip(predictions, gold):
if pred == true:
tp[pred] += 1
else:
fp[pred] += 1
fn[true] += 1
result: dict[str, Any] = {}
for label in labels:
denom_p = tp[label] + fp[label]
denom_r = tp[label] + fn[label]
p = tp[label] / denom_p if denom_p else 0.0
r = tp[label] / denom_r if denom_r else 0.0
f1 = 2 * p * r / (p + r) if (p + r) else 0.0
result[label] = {
"precision": p,
"recall": r,
"f1": f1,
"support": denom_r,
}
labels_with_support = [label for label in labels if result[label]["support"] > 0]
if labels_with_support:
result["__macro_f1__"] = (
sum(result[label]["f1"] for label in labels_with_support) / len(labels_with_support)
)
else:
result["__macro_f1__"] = 0.0
result["__accuracy__"] = sum(tp.values()) / len(predictions) if predictions else 0.0
return result
class ClassifierAdapter(abc.ABC):
"""Abstract base for all email classifier adapters."""
@property
@abc.abstractmethod
def name(self) -> str: ...
@property
@abc.abstractmethod
def model_id(self) -> str: ...
@abc.abstractmethod
def load(self) -> None:
"""Download/load the model into memory."""
@abc.abstractmethod
def unload(self) -> None:
"""Release model from memory."""
@abc.abstractmethod
def classify(self, subject: str, body: str) -> str:
"""Return one of LABELS for the given email."""
class ZeroShotAdapter(ClassifierAdapter):
"""Wraps any transformers zero-shot-classification pipeline.
Design note: the module-level ``pipeline`` shim is resolved once in load()
and stored as ``self._pipeline``. classify() calls ``self._pipeline`` directly
with (text, candidate_labels, multi_label=False). This makes the adapter
patchable in tests via ``patch('scripts.classifier_adapters.pipeline', mock)``:
``mock`` is stored in ``self._pipeline`` and called with the text during
classify(), so ``mock.call_args`` captures the arguments.
For real transformers use, ``pipeline`` is the factory function and the call
in classify() initialises the pipeline on first use (lazy loading without
pre-caching a model object). Subclasses that need a pre-warmed model object
should override load() to call the factory and store the result.
"""
def __init__(self, name: str, model_id: str) -> None:
self._name = name
self._model_id = model_id
self._pipeline: Any = None
@property
def name(self) -> str:
return self._name
@property
def model_id(self) -> str:
return self._model_id
def load(self) -> None:
import scripts.classifier_adapters as _mod # noqa: PLC0415
_pipe_fn = _mod.pipeline
if _pipe_fn is None:
raise ImportError("transformers not installed — run: pip install transformers")
# Store the pipeline factory/callable so that test patches are honoured.
# classify() will call self._pipeline(text, labels, multi_label=False).
self._pipeline = _pipe_fn
def unload(self) -> None:
self._pipeline = None
def classify(self, subject: str, body: str) -> str:
if self._pipeline is None:
self.load()
text = f"Subject: {subject}\n\n{body[:600]}"
result = self._pipeline(text, LABELS, multi_label=False)
return result["labels"][0]
class GLiClassAdapter(ClassifierAdapter):
"""Wraps knowledgator GLiClass models via the gliclass library."""
def __init__(self, name: str, model_id: str) -> None:
self._name = name
self._model_id = model_id
self._pipeline: Any = None
@property
def name(self) -> str:
return self._name
@property
def model_id(self) -> str:
return self._model_id
def load(self) -> None:
if GLiClassModel is None:
raise ImportError("gliclass not installed — run: pip install gliclass")
device = "cuda:0" if _cuda_available() else "cpu"
model = GLiClassModel.from_pretrained(self._model_id)
tokenizer = AutoTokenizer.from_pretrained(self._model_id)
self._pipeline = ZeroShotClassificationPipeline(
model,
tokenizer,
classification_type="single-label",
device=device,
)
def unload(self) -> None:
self._pipeline = None
def classify(self, subject: str, body: str) -> str:
if self._pipeline is None:
self.load()
text = f"Subject: {subject}\n\n{body[:600]}"
results = self._pipeline(text, LABELS, threshold=0.0)[0]
return max(results, key=lambda r: r["score"])["label"]
class RerankerAdapter(ClassifierAdapter):
"""Uses a BGE reranker to score (email, label_description) pairs."""
def __init__(self, name: str, model_id: str) -> None:
self._name = name
self._model_id = model_id
self._reranker: Any = None
@property
def name(self) -> str:
return self._name
@property
def model_id(self) -> str:
return self._model_id
def load(self) -> None:
if FlagReranker is None:
raise ImportError("FlagEmbedding not installed — run: pip install FlagEmbedding")
self._reranker = FlagReranker(self._model_id, use_fp16=_cuda_available())
def unload(self) -> None:
self._reranker = None
def classify(self, subject: str, body: str) -> str:
if self._reranker is None:
self.load()
text = f"Subject: {subject}\n\n{body[:600]}"
pairs = [[text, LABEL_DESCRIPTIONS[label]] for label in LABELS]
scores: list[float] = self._reranker.compute_score(pairs, normalize=True)
return LABELS[scores.index(max(scores))]