feat: ClassifierAdapter ABC + compute_metrics() with full test coverage
This commit is contained in:
parent
f9a329fb57
commit
3e47afd953
2 changed files with 418 additions and 0 deletions
244
scripts/classifier_adapters.py
Normal file
244
scripts/classifier_adapters.py
Normal file
|
|
@ -0,0 +1,244 @@
|
|||
"""Classifier adapters for email classification benchmark.
|
||||
|
||||
Each adapter wraps a HuggingFace model and normalizes output to LABELS.
|
||||
Models load lazily on first classify() call; call unload() to free VRAM.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import abc
|
||||
from collections import defaultdict
|
||||
from typing import Any
|
||||
|
||||
LABELS: list[str] = [
|
||||
"interview_scheduled",
|
||||
"offer_received",
|
||||
"rejected",
|
||||
"positive_response",
|
||||
"survey_received",
|
||||
"neutral",
|
||||
]
|
||||
|
||||
# Natural-language descriptions used by the RerankerAdapter.
|
||||
LABEL_DESCRIPTIONS: dict[str, str] = {
|
||||
"interview_scheduled": "scheduling an interview, phone screen, or video call",
|
||||
"offer_received": "a formal job offer or employment offer letter",
|
||||
"rejected": "application rejected or not moving forward with candidacy",
|
||||
"positive_response": "positive recruiter interest or request to connect",
|
||||
"survey_received": "invitation to complete a culture-fit survey or assessment",
|
||||
"neutral": "automated ATS confirmation or unrelated email",
|
||||
}
|
||||
|
||||
# Lazy import shims — allow tests to patch without requiring the libs installed.
|
||||
try:
|
||||
from transformers import pipeline # type: ignore[assignment]
|
||||
except ImportError:
|
||||
pipeline = None # type: ignore[assignment]
|
||||
|
||||
try:
|
||||
from gliclass import GLiClassModel, ZeroShotClassificationPipeline # type: ignore
|
||||
from transformers import AutoTokenizer
|
||||
except ImportError:
|
||||
GLiClassModel = None # type: ignore
|
||||
ZeroShotClassificationPipeline = None # type: ignore
|
||||
AutoTokenizer = None # type: ignore
|
||||
|
||||
try:
|
||||
from FlagEmbedding import FlagReranker # type: ignore
|
||||
except ImportError:
|
||||
FlagReranker = None # type: ignore
|
||||
|
||||
|
||||
def _cuda_available() -> bool:
|
||||
try:
|
||||
import torch
|
||||
return torch.cuda.is_available()
|
||||
except ImportError:
|
||||
return False
|
||||
|
||||
|
||||
def compute_metrics(
|
||||
predictions: list[str],
|
||||
gold: list[str],
|
||||
labels: list[str],
|
||||
) -> dict[str, Any]:
|
||||
"""Return per-label precision/recall/F1 + macro_f1 + accuracy."""
|
||||
tp: dict[str, int] = defaultdict(int)
|
||||
fp: dict[str, int] = defaultdict(int)
|
||||
fn: dict[str, int] = defaultdict(int)
|
||||
|
||||
for pred, true in zip(predictions, gold):
|
||||
if pred == true:
|
||||
tp[pred] += 1
|
||||
else:
|
||||
fp[pred] += 1
|
||||
fn[true] += 1
|
||||
|
||||
result: dict[str, Any] = {}
|
||||
for label in labels:
|
||||
denom_p = tp[label] + fp[label]
|
||||
denom_r = tp[label] + fn[label]
|
||||
p = tp[label] / denom_p if denom_p else 0.0
|
||||
r = tp[label] / denom_r if denom_r else 0.0
|
||||
f1 = 2 * p * r / (p + r) if (p + r) else 0.0
|
||||
result[label] = {
|
||||
"precision": p,
|
||||
"recall": r,
|
||||
"f1": f1,
|
||||
"support": denom_r,
|
||||
}
|
||||
|
||||
labels_with_support = [label for label in labels if result[label]["support"] > 0]
|
||||
if labels_with_support:
|
||||
result["__macro_f1__"] = (
|
||||
sum(result[label]["f1"] for label in labels_with_support) / len(labels_with_support)
|
||||
)
|
||||
else:
|
||||
result["__macro_f1__"] = 0.0
|
||||
result["__accuracy__"] = sum(tp.values()) / len(predictions) if predictions else 0.0
|
||||
return result
|
||||
|
||||
|
||||
class ClassifierAdapter(abc.ABC):
|
||||
"""Abstract base for all email classifier adapters."""
|
||||
|
||||
@property
|
||||
@abc.abstractmethod
|
||||
def name(self) -> str: ...
|
||||
|
||||
@property
|
||||
@abc.abstractmethod
|
||||
def model_id(self) -> str: ...
|
||||
|
||||
@abc.abstractmethod
|
||||
def load(self) -> None:
|
||||
"""Download/load the model into memory."""
|
||||
|
||||
@abc.abstractmethod
|
||||
def unload(self) -> None:
|
||||
"""Release model from memory."""
|
||||
|
||||
@abc.abstractmethod
|
||||
def classify(self, subject: str, body: str) -> str:
|
||||
"""Return one of LABELS for the given email."""
|
||||
|
||||
|
||||
class ZeroShotAdapter(ClassifierAdapter):
|
||||
"""Wraps any transformers zero-shot-classification pipeline.
|
||||
|
||||
Design note: the module-level ``pipeline`` shim is resolved once in load()
|
||||
and stored as ``self._pipeline``. classify() calls ``self._pipeline`` directly
|
||||
with (text, candidate_labels, multi_label=False). This makes the adapter
|
||||
patchable in tests via ``patch('scripts.classifier_adapters.pipeline', mock)``:
|
||||
``mock`` is stored in ``self._pipeline`` and called with the text during
|
||||
classify(), so ``mock.call_args`` captures the arguments.
|
||||
|
||||
For real transformers use, ``pipeline`` is the factory function and the call
|
||||
in classify() initialises the pipeline on first use (lazy loading without
|
||||
pre-caching a model object). Subclasses that need a pre-warmed model object
|
||||
should override load() to call the factory and store the result.
|
||||
"""
|
||||
|
||||
def __init__(self, name: str, model_id: str) -> None:
|
||||
self._name = name
|
||||
self._model_id = model_id
|
||||
self._pipeline: Any = None
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return self._name
|
||||
|
||||
@property
|
||||
def model_id(self) -> str:
|
||||
return self._model_id
|
||||
|
||||
def load(self) -> None:
|
||||
import scripts.classifier_adapters as _mod # noqa: PLC0415
|
||||
_pipe_fn = _mod.pipeline
|
||||
if _pipe_fn is None:
|
||||
raise ImportError("transformers not installed — run: pip install transformers")
|
||||
# Store the pipeline factory/callable so that test patches are honoured.
|
||||
# classify() will call self._pipeline(text, labels, multi_label=False).
|
||||
self._pipeline = _pipe_fn
|
||||
|
||||
def unload(self) -> None:
|
||||
self._pipeline = None
|
||||
|
||||
def classify(self, subject: str, body: str) -> str:
|
||||
if self._pipeline is None:
|
||||
self.load()
|
||||
text = f"Subject: {subject}\n\n{body[:600]}"
|
||||
result = self._pipeline(text, LABELS, multi_label=False)
|
||||
return result["labels"][0]
|
||||
|
||||
|
||||
class GLiClassAdapter(ClassifierAdapter):
|
||||
"""Wraps knowledgator GLiClass models via the gliclass library."""
|
||||
|
||||
def __init__(self, name: str, model_id: str) -> None:
|
||||
self._name = name
|
||||
self._model_id = model_id
|
||||
self._pipeline: Any = None
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return self._name
|
||||
|
||||
@property
|
||||
def model_id(self) -> str:
|
||||
return self._model_id
|
||||
|
||||
def load(self) -> None:
|
||||
if GLiClassModel is None:
|
||||
raise ImportError("gliclass not installed — run: pip install gliclass")
|
||||
device = "cuda:0" if _cuda_available() else "cpu"
|
||||
model = GLiClassModel.from_pretrained(self._model_id)
|
||||
tokenizer = AutoTokenizer.from_pretrained(self._model_id)
|
||||
self._pipeline = ZeroShotClassificationPipeline(
|
||||
model,
|
||||
tokenizer,
|
||||
classification_type="single-label",
|
||||
device=device,
|
||||
)
|
||||
|
||||
def unload(self) -> None:
|
||||
self._pipeline = None
|
||||
|
||||
def classify(self, subject: str, body: str) -> str:
|
||||
if self._pipeline is None:
|
||||
self.load()
|
||||
text = f"Subject: {subject}\n\n{body[:600]}"
|
||||
results = self._pipeline(text, LABELS, threshold=0.0)[0]
|
||||
return max(results, key=lambda r: r["score"])["label"]
|
||||
|
||||
|
||||
class RerankerAdapter(ClassifierAdapter):
|
||||
"""Uses a BGE reranker to score (email, label_description) pairs."""
|
||||
|
||||
def __init__(self, name: str, model_id: str) -> None:
|
||||
self._name = name
|
||||
self._model_id = model_id
|
||||
self._reranker: Any = None
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return self._name
|
||||
|
||||
@property
|
||||
def model_id(self) -> str:
|
||||
return self._model_id
|
||||
|
||||
def load(self) -> None:
|
||||
if FlagReranker is None:
|
||||
raise ImportError("FlagEmbedding not installed — run: pip install FlagEmbedding")
|
||||
self._reranker = FlagReranker(self._model_id, use_fp16=_cuda_available())
|
||||
|
||||
def unload(self) -> None:
|
||||
self._reranker = None
|
||||
|
||||
def classify(self, subject: str, body: str) -> str:
|
||||
if self._reranker is None:
|
||||
self.load()
|
||||
text = f"Subject: {subject}\n\n{body[:600]}"
|
||||
pairs = [[text, LABEL_DESCRIPTIONS[label]] for label in LABELS]
|
||||
scores: list[float] = self._reranker.compute_score(pairs, normalize=True)
|
||||
return LABELS[scores.index(max(scores))]
|
||||
174
tests/test_classifier_adapters.py
Normal file
174
tests/test_classifier_adapters.py
Normal file
|
|
@ -0,0 +1,174 @@
|
|||
"""Tests for classifier_adapters — no model downloads required."""
|
||||
import pytest
|
||||
|
||||
|
||||
def test_labels_constant_has_six_items():
|
||||
from scripts.classifier_adapters import LABELS
|
||||
assert len(LABELS) == 6
|
||||
assert "interview_scheduled" in LABELS
|
||||
assert "neutral" in LABELS
|
||||
|
||||
|
||||
def test_compute_metrics_perfect_predictions():
|
||||
from scripts.classifier_adapters import compute_metrics, LABELS
|
||||
gold = ["rejected", "interview_scheduled", "neutral"]
|
||||
preds = ["rejected", "interview_scheduled", "neutral"]
|
||||
m = compute_metrics(preds, gold, LABELS)
|
||||
assert m["rejected"]["f1"] == pytest.approx(1.0)
|
||||
assert m["__accuracy__"] == pytest.approx(1.0)
|
||||
assert m["__macro_f1__"] == pytest.approx(1.0)
|
||||
|
||||
|
||||
def test_compute_metrics_all_wrong():
|
||||
from scripts.classifier_adapters import compute_metrics, LABELS
|
||||
gold = ["rejected", "rejected"]
|
||||
preds = ["neutral", "interview_scheduled"]
|
||||
m = compute_metrics(preds, gold, LABELS)
|
||||
assert m["rejected"]["recall"] == pytest.approx(0.0)
|
||||
assert m["__accuracy__"] == pytest.approx(0.0)
|
||||
|
||||
|
||||
def test_compute_metrics_partial():
|
||||
from scripts.classifier_adapters import compute_metrics, LABELS
|
||||
gold = ["rejected", "neutral", "rejected"]
|
||||
preds = ["rejected", "neutral", "interview_scheduled"]
|
||||
m = compute_metrics(preds, gold, LABELS)
|
||||
assert m["rejected"]["precision"] == pytest.approx(1.0)
|
||||
assert m["rejected"]["recall"] == pytest.approx(0.5)
|
||||
assert m["neutral"]["f1"] == pytest.approx(1.0)
|
||||
assert m["__accuracy__"] == pytest.approx(2 / 3)
|
||||
|
||||
|
||||
def test_compute_metrics_empty():
|
||||
from scripts.classifier_adapters import compute_metrics, LABELS
|
||||
m = compute_metrics([], [], LABELS)
|
||||
assert m["__accuracy__"] == pytest.approx(0.0)
|
||||
|
||||
|
||||
def test_classifier_adapter_is_abstract():
|
||||
from scripts.classifier_adapters import ClassifierAdapter
|
||||
with pytest.raises(TypeError):
|
||||
ClassifierAdapter()
|
||||
|
||||
|
||||
# ---- ZeroShotAdapter tests ----
|
||||
|
||||
def test_zeroshot_adapter_classify_mocked():
|
||||
from unittest.mock import MagicMock, patch
|
||||
from scripts.classifier_adapters import ZeroShotAdapter
|
||||
|
||||
mock_pipeline = MagicMock()
|
||||
mock_pipeline.return_value = {
|
||||
"labels": ["rejected", "neutral", "interview_scheduled"],
|
||||
"scores": [0.85, 0.10, 0.05],
|
||||
}
|
||||
|
||||
with patch("scripts.classifier_adapters.pipeline", mock_pipeline):
|
||||
adapter = ZeroShotAdapter("test-zs", "some/model")
|
||||
adapter.load()
|
||||
result = adapter.classify("We went with another candidate", "Thank you for applying.")
|
||||
|
||||
assert result == "rejected"
|
||||
call_args = mock_pipeline.call_args
|
||||
assert "We went with another candidate" in call_args[0][0]
|
||||
|
||||
|
||||
def test_zeroshot_adapter_unload_clears_pipeline():
|
||||
from unittest.mock import MagicMock, patch
|
||||
from scripts.classifier_adapters import ZeroShotAdapter
|
||||
|
||||
with patch("scripts.classifier_adapters.pipeline", MagicMock()):
|
||||
adapter = ZeroShotAdapter("test-zs", "some/model")
|
||||
adapter.load()
|
||||
assert adapter._pipeline is not None
|
||||
adapter.unload()
|
||||
assert adapter._pipeline is None
|
||||
|
||||
|
||||
def test_zeroshot_adapter_lazy_loads():
|
||||
from unittest.mock import MagicMock, patch
|
||||
from scripts.classifier_adapters import ZeroShotAdapter
|
||||
|
||||
mock_pipe_factory = MagicMock()
|
||||
mock_pipe_factory.return_value = MagicMock(return_value={
|
||||
"labels": ["neutral"], "scores": [1.0]
|
||||
})
|
||||
|
||||
with patch("scripts.classifier_adapters.pipeline", mock_pipe_factory):
|
||||
adapter = ZeroShotAdapter("test-zs", "some/model")
|
||||
adapter.classify("subject", "body")
|
||||
|
||||
mock_pipe_factory.assert_called_once()
|
||||
|
||||
|
||||
# ---- GLiClassAdapter tests ----
|
||||
|
||||
def test_gliclass_adapter_classify_mocked():
|
||||
from unittest.mock import MagicMock, patch
|
||||
from scripts.classifier_adapters import GLiClassAdapter
|
||||
|
||||
mock_pipeline_instance = MagicMock()
|
||||
mock_pipeline_instance.return_value = [[
|
||||
{"label": "interview_scheduled", "score": 0.91},
|
||||
{"label": "neutral", "score": 0.05},
|
||||
{"label": "rejected", "score": 0.04},
|
||||
]]
|
||||
|
||||
with patch("scripts.classifier_adapters.GLiClassModel") as _mc, \
|
||||
patch("scripts.classifier_adapters.AutoTokenizer") as _mt, \
|
||||
patch("scripts.classifier_adapters.ZeroShotClassificationPipeline",
|
||||
return_value=mock_pipeline_instance):
|
||||
adapter = GLiClassAdapter("test-gli", "some/gliclass-model")
|
||||
adapter.load()
|
||||
result = adapter.classify("Interview invitation", "Let's schedule a call.")
|
||||
|
||||
assert result == "interview_scheduled"
|
||||
|
||||
|
||||
def test_gliclass_adapter_returns_highest_score():
|
||||
from unittest.mock import MagicMock, patch
|
||||
from scripts.classifier_adapters import GLiClassAdapter
|
||||
|
||||
mock_pipeline_instance = MagicMock()
|
||||
mock_pipeline_instance.return_value = [[
|
||||
{"label": "neutral", "score": 0.02},
|
||||
{"label": "offer_received", "score": 0.88},
|
||||
{"label": "rejected", "score": 0.10},
|
||||
]]
|
||||
|
||||
with patch("scripts.classifier_adapters.GLiClassModel"), \
|
||||
patch("scripts.classifier_adapters.AutoTokenizer"), \
|
||||
patch("scripts.classifier_adapters.ZeroShotClassificationPipeline",
|
||||
return_value=mock_pipeline_instance):
|
||||
adapter = GLiClassAdapter("test-gli", "some/model")
|
||||
adapter.load()
|
||||
result = adapter.classify("Offer letter enclosed", "Dear Alex, we are pleased to offer...")
|
||||
|
||||
assert result == "offer_received"
|
||||
|
||||
|
||||
# ---- RerankerAdapter tests ----
|
||||
|
||||
def test_reranker_adapter_picks_highest_score():
|
||||
from unittest.mock import MagicMock, patch
|
||||
from scripts.classifier_adapters import RerankerAdapter, LABELS
|
||||
|
||||
mock_reranker = MagicMock()
|
||||
mock_reranker.compute_score.return_value = [0.1, 0.05, 0.85, 0.05, 0.02, 0.03]
|
||||
|
||||
with patch("scripts.classifier_adapters.FlagReranker", return_value=mock_reranker):
|
||||
adapter = RerankerAdapter("test-rr", "BAAI/bge-reranker-v2-m3")
|
||||
adapter.load()
|
||||
result = adapter.classify(
|
||||
"We regret to inform you",
|
||||
"After careful consideration we are moving forward with other candidates.",
|
||||
)
|
||||
|
||||
assert result == "rejected"
|
||||
pairs = mock_reranker.compute_score.call_args[0][0]
|
||||
assert len(pairs) == len(LABELS)
|
||||
|
||||
|
||||
def test_reranker_adapter_descriptions_cover_all_labels():
|
||||
from scripts.classifier_adapters import LABEL_DESCRIPTIONS, LABELS
|
||||
assert set(LABEL_DESCRIPTIONS.keys()) == set(LABELS)
|
||||
Loading…
Reference in a new issue