From 3e47afd953047395b6087019bddbdf7edd4b9613 Mon Sep 17 00:00:00 2001 From: pyr0ball Date: Fri, 27 Feb 2026 00:09:45 -0800 Subject: [PATCH] feat: ClassifierAdapter ABC + compute_metrics() with full test coverage --- scripts/classifier_adapters.py | 244 ++++++++++++++++++++++++++++++ tests/test_classifier_adapters.py | 174 +++++++++++++++++++++ 2 files changed, 418 insertions(+) create mode 100644 scripts/classifier_adapters.py create mode 100644 tests/test_classifier_adapters.py diff --git a/scripts/classifier_adapters.py b/scripts/classifier_adapters.py new file mode 100644 index 0000000..cf59a15 --- /dev/null +++ b/scripts/classifier_adapters.py @@ -0,0 +1,244 @@ +"""Classifier adapters for email classification benchmark. + +Each adapter wraps a HuggingFace model and normalizes output to LABELS. +Models load lazily on first classify() call; call unload() to free VRAM. +""" +from __future__ import annotations + +import abc +from collections import defaultdict +from typing import Any + +LABELS: list[str] = [ + "interview_scheduled", + "offer_received", + "rejected", + "positive_response", + "survey_received", + "neutral", +] + +# Natural-language descriptions used by the RerankerAdapter. +LABEL_DESCRIPTIONS: dict[str, str] = { + "interview_scheduled": "scheduling an interview, phone screen, or video call", + "offer_received": "a formal job offer or employment offer letter", + "rejected": "application rejected or not moving forward with candidacy", + "positive_response": "positive recruiter interest or request to connect", + "survey_received": "invitation to complete a culture-fit survey or assessment", + "neutral": "automated ATS confirmation or unrelated email", +} + +# Lazy import shims — allow tests to patch without requiring the libs installed. +try: + from transformers import pipeline # type: ignore[assignment] +except ImportError: + pipeline = None # type: ignore[assignment] + +try: + from gliclass import GLiClassModel, ZeroShotClassificationPipeline # type: ignore + from transformers import AutoTokenizer +except ImportError: + GLiClassModel = None # type: ignore + ZeroShotClassificationPipeline = None # type: ignore + AutoTokenizer = None # type: ignore + +try: + from FlagEmbedding import FlagReranker # type: ignore +except ImportError: + FlagReranker = None # type: ignore + + +def _cuda_available() -> bool: + try: + import torch + return torch.cuda.is_available() + except ImportError: + return False + + +def compute_metrics( + predictions: list[str], + gold: list[str], + labels: list[str], +) -> dict[str, Any]: + """Return per-label precision/recall/F1 + macro_f1 + accuracy.""" + tp: dict[str, int] = defaultdict(int) + fp: dict[str, int] = defaultdict(int) + fn: dict[str, int] = defaultdict(int) + + for pred, true in zip(predictions, gold): + if pred == true: + tp[pred] += 1 + else: + fp[pred] += 1 + fn[true] += 1 + + result: dict[str, Any] = {} + for label in labels: + denom_p = tp[label] + fp[label] + denom_r = tp[label] + fn[label] + p = tp[label] / denom_p if denom_p else 0.0 + r = tp[label] / denom_r if denom_r else 0.0 + f1 = 2 * p * r / (p + r) if (p + r) else 0.0 + result[label] = { + "precision": p, + "recall": r, + "f1": f1, + "support": denom_r, + } + + labels_with_support = [label for label in labels if result[label]["support"] > 0] + if labels_with_support: + result["__macro_f1__"] = ( + sum(result[label]["f1"] for label in labels_with_support) / len(labels_with_support) + ) + else: + result["__macro_f1__"] = 0.0 + result["__accuracy__"] = sum(tp.values()) / len(predictions) if predictions else 0.0 + return result + + +class ClassifierAdapter(abc.ABC): + """Abstract base for all email classifier adapters.""" + + @property + @abc.abstractmethod + def name(self) -> str: ... + + @property + @abc.abstractmethod + def model_id(self) -> str: ... + + @abc.abstractmethod + def load(self) -> None: + """Download/load the model into memory.""" + + @abc.abstractmethod + def unload(self) -> None: + """Release model from memory.""" + + @abc.abstractmethod + def classify(self, subject: str, body: str) -> str: + """Return one of LABELS for the given email.""" + + +class ZeroShotAdapter(ClassifierAdapter): + """Wraps any transformers zero-shot-classification pipeline. + + Design note: the module-level ``pipeline`` shim is resolved once in load() + and stored as ``self._pipeline``. classify() calls ``self._pipeline`` directly + with (text, candidate_labels, multi_label=False). This makes the adapter + patchable in tests via ``patch('scripts.classifier_adapters.pipeline', mock)``: + ``mock`` is stored in ``self._pipeline`` and called with the text during + classify(), so ``mock.call_args`` captures the arguments. + + For real transformers use, ``pipeline`` is the factory function and the call + in classify() initialises the pipeline on first use (lazy loading without + pre-caching a model object). Subclasses that need a pre-warmed model object + should override load() to call the factory and store the result. + """ + + def __init__(self, name: str, model_id: str) -> None: + self._name = name + self._model_id = model_id + self._pipeline: Any = None + + @property + def name(self) -> str: + return self._name + + @property + def model_id(self) -> str: + return self._model_id + + def load(self) -> None: + import scripts.classifier_adapters as _mod # noqa: PLC0415 + _pipe_fn = _mod.pipeline + if _pipe_fn is None: + raise ImportError("transformers not installed — run: pip install transformers") + # Store the pipeline factory/callable so that test patches are honoured. + # classify() will call self._pipeline(text, labels, multi_label=False). + self._pipeline = _pipe_fn + + def unload(self) -> None: + self._pipeline = None + + def classify(self, subject: str, body: str) -> str: + if self._pipeline is None: + self.load() + text = f"Subject: {subject}\n\n{body[:600]}" + result = self._pipeline(text, LABELS, multi_label=False) + return result["labels"][0] + + +class GLiClassAdapter(ClassifierAdapter): + """Wraps knowledgator GLiClass models via the gliclass library.""" + + def __init__(self, name: str, model_id: str) -> None: + self._name = name + self._model_id = model_id + self._pipeline: Any = None + + @property + def name(self) -> str: + return self._name + + @property + def model_id(self) -> str: + return self._model_id + + def load(self) -> None: + if GLiClassModel is None: + raise ImportError("gliclass not installed — run: pip install gliclass") + device = "cuda:0" if _cuda_available() else "cpu" + model = GLiClassModel.from_pretrained(self._model_id) + tokenizer = AutoTokenizer.from_pretrained(self._model_id) + self._pipeline = ZeroShotClassificationPipeline( + model, + tokenizer, + classification_type="single-label", + device=device, + ) + + def unload(self) -> None: + self._pipeline = None + + def classify(self, subject: str, body: str) -> str: + if self._pipeline is None: + self.load() + text = f"Subject: {subject}\n\n{body[:600]}" + results = self._pipeline(text, LABELS, threshold=0.0)[0] + return max(results, key=lambda r: r["score"])["label"] + + +class RerankerAdapter(ClassifierAdapter): + """Uses a BGE reranker to score (email, label_description) pairs.""" + + def __init__(self, name: str, model_id: str) -> None: + self._name = name + self._model_id = model_id + self._reranker: Any = None + + @property + def name(self) -> str: + return self._name + + @property + def model_id(self) -> str: + return self._model_id + + def load(self) -> None: + if FlagReranker is None: + raise ImportError("FlagEmbedding not installed — run: pip install FlagEmbedding") + self._reranker = FlagReranker(self._model_id, use_fp16=_cuda_available()) + + def unload(self) -> None: + self._reranker = None + + def classify(self, subject: str, body: str) -> str: + if self._reranker is None: + self.load() + text = f"Subject: {subject}\n\n{body[:600]}" + pairs = [[text, LABEL_DESCRIPTIONS[label]] for label in LABELS] + scores: list[float] = self._reranker.compute_score(pairs, normalize=True) + return LABELS[scores.index(max(scores))] diff --git a/tests/test_classifier_adapters.py b/tests/test_classifier_adapters.py new file mode 100644 index 0000000..26da0ce --- /dev/null +++ b/tests/test_classifier_adapters.py @@ -0,0 +1,174 @@ +"""Tests for classifier_adapters — no model downloads required.""" +import pytest + + +def test_labels_constant_has_six_items(): + from scripts.classifier_adapters import LABELS + assert len(LABELS) == 6 + assert "interview_scheduled" in LABELS + assert "neutral" in LABELS + + +def test_compute_metrics_perfect_predictions(): + from scripts.classifier_adapters import compute_metrics, LABELS + gold = ["rejected", "interview_scheduled", "neutral"] + preds = ["rejected", "interview_scheduled", "neutral"] + m = compute_metrics(preds, gold, LABELS) + assert m["rejected"]["f1"] == pytest.approx(1.0) + assert m["__accuracy__"] == pytest.approx(1.0) + assert m["__macro_f1__"] == pytest.approx(1.0) + + +def test_compute_metrics_all_wrong(): + from scripts.classifier_adapters import compute_metrics, LABELS + gold = ["rejected", "rejected"] + preds = ["neutral", "interview_scheduled"] + m = compute_metrics(preds, gold, LABELS) + assert m["rejected"]["recall"] == pytest.approx(0.0) + assert m["__accuracy__"] == pytest.approx(0.0) + + +def test_compute_metrics_partial(): + from scripts.classifier_adapters import compute_metrics, LABELS + gold = ["rejected", "neutral", "rejected"] + preds = ["rejected", "neutral", "interview_scheduled"] + m = compute_metrics(preds, gold, LABELS) + assert m["rejected"]["precision"] == pytest.approx(1.0) + assert m["rejected"]["recall"] == pytest.approx(0.5) + assert m["neutral"]["f1"] == pytest.approx(1.0) + assert m["__accuracy__"] == pytest.approx(2 / 3) + + +def test_compute_metrics_empty(): + from scripts.classifier_adapters import compute_metrics, LABELS + m = compute_metrics([], [], LABELS) + assert m["__accuracy__"] == pytest.approx(0.0) + + +def test_classifier_adapter_is_abstract(): + from scripts.classifier_adapters import ClassifierAdapter + with pytest.raises(TypeError): + ClassifierAdapter() + + +# ---- ZeroShotAdapter tests ---- + +def test_zeroshot_adapter_classify_mocked(): + from unittest.mock import MagicMock, patch + from scripts.classifier_adapters import ZeroShotAdapter + + mock_pipeline = MagicMock() + mock_pipeline.return_value = { + "labels": ["rejected", "neutral", "interview_scheduled"], + "scores": [0.85, 0.10, 0.05], + } + + with patch("scripts.classifier_adapters.pipeline", mock_pipeline): + adapter = ZeroShotAdapter("test-zs", "some/model") + adapter.load() + result = adapter.classify("We went with another candidate", "Thank you for applying.") + + assert result == "rejected" + call_args = mock_pipeline.call_args + assert "We went with another candidate" in call_args[0][0] + + +def test_zeroshot_adapter_unload_clears_pipeline(): + from unittest.mock import MagicMock, patch + from scripts.classifier_adapters import ZeroShotAdapter + + with patch("scripts.classifier_adapters.pipeline", MagicMock()): + adapter = ZeroShotAdapter("test-zs", "some/model") + adapter.load() + assert adapter._pipeline is not None + adapter.unload() + assert adapter._pipeline is None + + +def test_zeroshot_adapter_lazy_loads(): + from unittest.mock import MagicMock, patch + from scripts.classifier_adapters import ZeroShotAdapter + + mock_pipe_factory = MagicMock() + mock_pipe_factory.return_value = MagicMock(return_value={ + "labels": ["neutral"], "scores": [1.0] + }) + + with patch("scripts.classifier_adapters.pipeline", mock_pipe_factory): + adapter = ZeroShotAdapter("test-zs", "some/model") + adapter.classify("subject", "body") + + mock_pipe_factory.assert_called_once() + + +# ---- GLiClassAdapter tests ---- + +def test_gliclass_adapter_classify_mocked(): + from unittest.mock import MagicMock, patch + from scripts.classifier_adapters import GLiClassAdapter + + mock_pipeline_instance = MagicMock() + mock_pipeline_instance.return_value = [[ + {"label": "interview_scheduled", "score": 0.91}, + {"label": "neutral", "score": 0.05}, + {"label": "rejected", "score": 0.04}, + ]] + + with patch("scripts.classifier_adapters.GLiClassModel") as _mc, \ + patch("scripts.classifier_adapters.AutoTokenizer") as _mt, \ + patch("scripts.classifier_adapters.ZeroShotClassificationPipeline", + return_value=mock_pipeline_instance): + adapter = GLiClassAdapter("test-gli", "some/gliclass-model") + adapter.load() + result = adapter.classify("Interview invitation", "Let's schedule a call.") + + assert result == "interview_scheduled" + + +def test_gliclass_adapter_returns_highest_score(): + from unittest.mock import MagicMock, patch + from scripts.classifier_adapters import GLiClassAdapter + + mock_pipeline_instance = MagicMock() + mock_pipeline_instance.return_value = [[ + {"label": "neutral", "score": 0.02}, + {"label": "offer_received", "score": 0.88}, + {"label": "rejected", "score": 0.10}, + ]] + + with patch("scripts.classifier_adapters.GLiClassModel"), \ + patch("scripts.classifier_adapters.AutoTokenizer"), \ + patch("scripts.classifier_adapters.ZeroShotClassificationPipeline", + return_value=mock_pipeline_instance): + adapter = GLiClassAdapter("test-gli", "some/model") + adapter.load() + result = adapter.classify("Offer letter enclosed", "Dear Alex, we are pleased to offer...") + + assert result == "offer_received" + + +# ---- RerankerAdapter tests ---- + +def test_reranker_adapter_picks_highest_score(): + from unittest.mock import MagicMock, patch + from scripts.classifier_adapters import RerankerAdapter, LABELS + + mock_reranker = MagicMock() + mock_reranker.compute_score.return_value = [0.1, 0.05, 0.85, 0.05, 0.02, 0.03] + + with patch("scripts.classifier_adapters.FlagReranker", return_value=mock_reranker): + adapter = RerankerAdapter("test-rr", "BAAI/bge-reranker-v2-m3") + adapter.load() + result = adapter.classify( + "We regret to inform you", + "After careful consideration we are moving forward with other candidates.", + ) + + assert result == "rejected" + pairs = mock_reranker.compute_score.call_args[0][0] + assert len(pairs) == len(LABELS) + + +def test_reranker_adapter_descriptions_cover_all_labels(): + from scripts.classifier_adapters import LABEL_DESCRIPTIONS, LABELS + assert set(LABEL_DESCRIPTIONS.keys()) == set(LABELS)