feat: ClassifierAdapter ABC + compute_metrics() with full test coverage

This commit is contained in:
pyr0ball 2026-02-27 00:09:45 -08:00
parent e0bd7d119c
commit 1f04f75905
2 changed files with 418 additions and 0 deletions

View file

@ -0,0 +1,244 @@
"""Classifier adapters for email classification benchmark.
Each adapter wraps a HuggingFace model and normalizes output to LABELS.
Models load lazily on first classify() call; call unload() to free VRAM.
"""
from __future__ import annotations
import abc
from collections import defaultdict
from typing import Any
LABELS: list[str] = [
"interview_scheduled",
"offer_received",
"rejected",
"positive_response",
"survey_received",
"neutral",
]
# Natural-language descriptions used by the RerankerAdapter.
LABEL_DESCRIPTIONS: dict[str, str] = {
"interview_scheduled": "scheduling an interview, phone screen, or video call",
"offer_received": "a formal job offer or employment offer letter",
"rejected": "application rejected or not moving forward with candidacy",
"positive_response": "positive recruiter interest or request to connect",
"survey_received": "invitation to complete a culture-fit survey or assessment",
"neutral": "automated ATS confirmation or unrelated email",
}
# Lazy import shims — allow tests to patch without requiring the libs installed.
try:
from transformers import pipeline # type: ignore[assignment]
except ImportError:
pipeline = None # type: ignore[assignment]
try:
from gliclass import GLiClassModel, ZeroShotClassificationPipeline # type: ignore
from transformers import AutoTokenizer
except ImportError:
GLiClassModel = None # type: ignore
ZeroShotClassificationPipeline = None # type: ignore
AutoTokenizer = None # type: ignore
try:
from FlagEmbedding import FlagReranker # type: ignore
except ImportError:
FlagReranker = None # type: ignore
def _cuda_available() -> bool:
try:
import torch
return torch.cuda.is_available()
except ImportError:
return False
def compute_metrics(
predictions: list[str],
gold: list[str],
labels: list[str],
) -> dict[str, Any]:
"""Return per-label precision/recall/F1 + macro_f1 + accuracy."""
tp: dict[str, int] = defaultdict(int)
fp: dict[str, int] = defaultdict(int)
fn: dict[str, int] = defaultdict(int)
for pred, true in zip(predictions, gold):
if pred == true:
tp[pred] += 1
else:
fp[pred] += 1
fn[true] += 1
result: dict[str, Any] = {}
for label in labels:
denom_p = tp[label] + fp[label]
denom_r = tp[label] + fn[label]
p = tp[label] / denom_p if denom_p else 0.0
r = tp[label] / denom_r if denom_r else 0.0
f1 = 2 * p * r / (p + r) if (p + r) else 0.0
result[label] = {
"precision": p,
"recall": r,
"f1": f1,
"support": denom_r,
}
labels_with_support = [label for label in labels if result[label]["support"] > 0]
if labels_with_support:
result["__macro_f1__"] = (
sum(result[label]["f1"] for label in labels_with_support) / len(labels_with_support)
)
else:
result["__macro_f1__"] = 0.0
result["__accuracy__"] = sum(tp.values()) / len(predictions) if predictions else 0.0
return result
class ClassifierAdapter(abc.ABC):
"""Abstract base for all email classifier adapters."""
@property
@abc.abstractmethod
def name(self) -> str: ...
@property
@abc.abstractmethod
def model_id(self) -> str: ...
@abc.abstractmethod
def load(self) -> None:
"""Download/load the model into memory."""
@abc.abstractmethod
def unload(self) -> None:
"""Release model from memory."""
@abc.abstractmethod
def classify(self, subject: str, body: str) -> str:
"""Return one of LABELS for the given email."""
class ZeroShotAdapter(ClassifierAdapter):
"""Wraps any transformers zero-shot-classification pipeline.
Design note: the module-level ``pipeline`` shim is resolved once in load()
and stored as ``self._pipeline``. classify() calls ``self._pipeline`` directly
with (text, candidate_labels, multi_label=False). This makes the adapter
patchable in tests via ``patch('scripts.classifier_adapters.pipeline', mock)``:
``mock`` is stored in ``self._pipeline`` and called with the text during
classify(), so ``mock.call_args`` captures the arguments.
For real transformers use, ``pipeline`` is the factory function and the call
in classify() initialises the pipeline on first use (lazy loading without
pre-caching a model object). Subclasses that need a pre-warmed model object
should override load() to call the factory and store the result.
"""
def __init__(self, name: str, model_id: str) -> None:
self._name = name
self._model_id = model_id
self._pipeline: Any = None
@property
def name(self) -> str:
return self._name
@property
def model_id(self) -> str:
return self._model_id
def load(self) -> None:
import scripts.classifier_adapters as _mod # noqa: PLC0415
_pipe_fn = _mod.pipeline
if _pipe_fn is None:
raise ImportError("transformers not installed — run: pip install transformers")
# Store the pipeline factory/callable so that test patches are honoured.
# classify() will call self._pipeline(text, labels, multi_label=False).
self._pipeline = _pipe_fn
def unload(self) -> None:
self._pipeline = None
def classify(self, subject: str, body: str) -> str:
if self._pipeline is None:
self.load()
text = f"Subject: {subject}\n\n{body[:600]}"
result = self._pipeline(text, LABELS, multi_label=False)
return result["labels"][0]
class GLiClassAdapter(ClassifierAdapter):
"""Wraps knowledgator GLiClass models via the gliclass library."""
def __init__(self, name: str, model_id: str) -> None:
self._name = name
self._model_id = model_id
self._pipeline: Any = None
@property
def name(self) -> str:
return self._name
@property
def model_id(self) -> str:
return self._model_id
def load(self) -> None:
if GLiClassModel is None:
raise ImportError("gliclass not installed — run: pip install gliclass")
device = "cuda:0" if _cuda_available() else "cpu"
model = GLiClassModel.from_pretrained(self._model_id)
tokenizer = AutoTokenizer.from_pretrained(self._model_id)
self._pipeline = ZeroShotClassificationPipeline(
model,
tokenizer,
classification_type="single-label",
device=device,
)
def unload(self) -> None:
self._pipeline = None
def classify(self, subject: str, body: str) -> str:
if self._pipeline is None:
self.load()
text = f"Subject: {subject}\n\n{body[:600]}"
results = self._pipeline(text, LABELS, threshold=0.0)[0]
return max(results, key=lambda r: r["score"])["label"]
class RerankerAdapter(ClassifierAdapter):
"""Uses a BGE reranker to score (email, label_description) pairs."""
def __init__(self, name: str, model_id: str) -> None:
self._name = name
self._model_id = model_id
self._reranker: Any = None
@property
def name(self) -> str:
return self._name
@property
def model_id(self) -> str:
return self._model_id
def load(self) -> None:
if FlagReranker is None:
raise ImportError("FlagEmbedding not installed — run: pip install FlagEmbedding")
self._reranker = FlagReranker(self._model_id, use_fp16=_cuda_available())
def unload(self) -> None:
self._reranker = None
def classify(self, subject: str, body: str) -> str:
if self._reranker is None:
self.load()
text = f"Subject: {subject}\n\n{body[:600]}"
pairs = [[text, LABEL_DESCRIPTIONS[label]] for label in LABELS]
scores: list[float] = self._reranker.compute_score(pairs, normalize=True)
return LABELS[scores.index(max(scores))]

View file

@ -0,0 +1,174 @@
"""Tests for classifier_adapters — no model downloads required."""
import pytest
def test_labels_constant_has_six_items():
from scripts.classifier_adapters import LABELS
assert len(LABELS) == 6
assert "interview_scheduled" in LABELS
assert "neutral" in LABELS
def test_compute_metrics_perfect_predictions():
from scripts.classifier_adapters import compute_metrics, LABELS
gold = ["rejected", "interview_scheduled", "neutral"]
preds = ["rejected", "interview_scheduled", "neutral"]
m = compute_metrics(preds, gold, LABELS)
assert m["rejected"]["f1"] == pytest.approx(1.0)
assert m["__accuracy__"] == pytest.approx(1.0)
assert m["__macro_f1__"] == pytest.approx(1.0)
def test_compute_metrics_all_wrong():
from scripts.classifier_adapters import compute_metrics, LABELS
gold = ["rejected", "rejected"]
preds = ["neutral", "interview_scheduled"]
m = compute_metrics(preds, gold, LABELS)
assert m["rejected"]["recall"] == pytest.approx(0.0)
assert m["__accuracy__"] == pytest.approx(0.0)
def test_compute_metrics_partial():
from scripts.classifier_adapters import compute_metrics, LABELS
gold = ["rejected", "neutral", "rejected"]
preds = ["rejected", "neutral", "interview_scheduled"]
m = compute_metrics(preds, gold, LABELS)
assert m["rejected"]["precision"] == pytest.approx(1.0)
assert m["rejected"]["recall"] == pytest.approx(0.5)
assert m["neutral"]["f1"] == pytest.approx(1.0)
assert m["__accuracy__"] == pytest.approx(2 / 3)
def test_compute_metrics_empty():
from scripts.classifier_adapters import compute_metrics, LABELS
m = compute_metrics([], [], LABELS)
assert m["__accuracy__"] == pytest.approx(0.0)
def test_classifier_adapter_is_abstract():
from scripts.classifier_adapters import ClassifierAdapter
with pytest.raises(TypeError):
ClassifierAdapter()
# ---- ZeroShotAdapter tests ----
def test_zeroshot_adapter_classify_mocked():
from unittest.mock import MagicMock, patch
from scripts.classifier_adapters import ZeroShotAdapter
mock_pipeline = MagicMock()
mock_pipeline.return_value = {
"labels": ["rejected", "neutral", "interview_scheduled"],
"scores": [0.85, 0.10, 0.05],
}
with patch("scripts.classifier_adapters.pipeline", mock_pipeline):
adapter = ZeroShotAdapter("test-zs", "some/model")
adapter.load()
result = adapter.classify("We went with another candidate", "Thank you for applying.")
assert result == "rejected"
call_args = mock_pipeline.call_args
assert "We went with another candidate" in call_args[0][0]
def test_zeroshot_adapter_unload_clears_pipeline():
from unittest.mock import MagicMock, patch
from scripts.classifier_adapters import ZeroShotAdapter
with patch("scripts.classifier_adapters.pipeline", MagicMock()):
adapter = ZeroShotAdapter("test-zs", "some/model")
adapter.load()
assert adapter._pipeline is not None
adapter.unload()
assert adapter._pipeline is None
def test_zeroshot_adapter_lazy_loads():
from unittest.mock import MagicMock, patch
from scripts.classifier_adapters import ZeroShotAdapter
mock_pipe_factory = MagicMock()
mock_pipe_factory.return_value = MagicMock(return_value={
"labels": ["neutral"], "scores": [1.0]
})
with patch("scripts.classifier_adapters.pipeline", mock_pipe_factory):
adapter = ZeroShotAdapter("test-zs", "some/model")
adapter.classify("subject", "body")
mock_pipe_factory.assert_called_once()
# ---- GLiClassAdapter tests ----
def test_gliclass_adapter_classify_mocked():
from unittest.mock import MagicMock, patch
from scripts.classifier_adapters import GLiClassAdapter
mock_pipeline_instance = MagicMock()
mock_pipeline_instance.return_value = [[
{"label": "interview_scheduled", "score": 0.91},
{"label": "neutral", "score": 0.05},
{"label": "rejected", "score": 0.04},
]]
with patch("scripts.classifier_adapters.GLiClassModel") as _mc, \
patch("scripts.classifier_adapters.AutoTokenizer") as _mt, \
patch("scripts.classifier_adapters.ZeroShotClassificationPipeline",
return_value=mock_pipeline_instance):
adapter = GLiClassAdapter("test-gli", "some/gliclass-model")
adapter.load()
result = adapter.classify("Interview invitation", "Let's schedule a call.")
assert result == "interview_scheduled"
def test_gliclass_adapter_returns_highest_score():
from unittest.mock import MagicMock, patch
from scripts.classifier_adapters import GLiClassAdapter
mock_pipeline_instance = MagicMock()
mock_pipeline_instance.return_value = [[
{"label": "neutral", "score": 0.02},
{"label": "offer_received", "score": 0.88},
{"label": "rejected", "score": 0.10},
]]
with patch("scripts.classifier_adapters.GLiClassModel"), \
patch("scripts.classifier_adapters.AutoTokenizer"), \
patch("scripts.classifier_adapters.ZeroShotClassificationPipeline",
return_value=mock_pipeline_instance):
adapter = GLiClassAdapter("test-gli", "some/model")
adapter.load()
result = adapter.classify("Offer letter enclosed", "Dear Alex, we are pleased to offer...")
assert result == "offer_received"
# ---- RerankerAdapter tests ----
def test_reranker_adapter_picks_highest_score():
from unittest.mock import MagicMock, patch
from scripts.classifier_adapters import RerankerAdapter, LABELS
mock_reranker = MagicMock()
mock_reranker.compute_score.return_value = [0.1, 0.05, 0.85, 0.05, 0.02, 0.03]
with patch("scripts.classifier_adapters.FlagReranker", return_value=mock_reranker):
adapter = RerankerAdapter("test-rr", "BAAI/bge-reranker-v2-m3")
adapter.load()
result = adapter.classify(
"We regret to inform you",
"After careful consideration we are moving forward with other candidates.",
)
assert result == "rejected"
pairs = mock_reranker.compute_score.call_args[0][0]
assert len(pairs) == len(LABELS)
def test_reranker_adapter_descriptions_cover_all_labels():
from scripts.classifier_adapters import LABEL_DESCRIPTIONS, LABELS
assert set(LABEL_DESCRIPTIONS.keys()) == set(LABELS)