diff --git a/scripts/classifier_adapters.py b/scripts/classifier_adapters.py index 5704de1..c0b3177 100644 --- a/scripts/classifier_adapters.py +++ b/scripts/classifier_adapters.py @@ -17,6 +17,7 @@ __all__ = [ "ZeroShotAdapter", "GLiClassAdapter", "RerankerAdapter", + "FineTunedAdapter", ] LABELS: list[str] = [ @@ -263,3 +264,42 @@ class RerankerAdapter(ClassifierAdapter): pairs = [[text, LABEL_DESCRIPTIONS.get(label, label.replace("_", " "))] for label in LABELS] scores: list[float] = self._reranker.compute_score(pairs, normalize=True) return LABELS[scores.index(max(scores))] + + +class FineTunedAdapter(ClassifierAdapter): + """Loads a fine-tuned checkpoint from a local models/ directory. + + Uses pipeline("text-classification") for a single forward pass. + Input format: 'subject [SEP] body[:400]' — must match training format exactly. + Expected inference speed: ~10–20ms/email vs 111–338ms for zero-shot. + """ + + def __init__(self, name: str, model_dir: str) -> None: + self._name = name + self._model_dir = model_dir + self._pipeline: Any = None + + @property + def name(self) -> str: + return self._name + + @property + def model_id(self) -> str: + return self._model_dir + + def load(self) -> None: + import scripts.classifier_adapters as _mod # noqa: PLC0415 + _pipe_fn = _mod.pipeline + if _pipe_fn is None: + raise ImportError("transformers not installed") + self._pipeline = _pipe_fn("text-classification", model=self._model_dir) + + def unload(self) -> None: + self._pipeline = None + + def classify(self, subject: str, body: str) -> str: + if self._pipeline is None: + self.load() + text = f"{subject} [SEP] {body[:400]}" + result = self._pipeline(text) + return result[0]["label"] diff --git a/tests/test_classifier_adapters.py b/tests/test_classifier_adapters.py index 85f9dc1..9741949 100644 --- a/tests/test_classifier_adapters.py +++ b/tests/test_classifier_adapters.py @@ -180,3 +180,90 @@ def test_reranker_adapter_picks_highest_score(): def test_reranker_adapter_descriptions_cover_all_labels(): from scripts.classifier_adapters import LABEL_DESCRIPTIONS, LABELS assert set(LABEL_DESCRIPTIONS.keys()) == set(LABELS) + + +# ---- FineTunedAdapter tests ---- + +def test_finetuned_adapter_classify_calls_pipeline_with_sep_format(tmp_path): + """classify() must format input as 'subject [SEP] body[:400]' — not the zero-shot format.""" + from unittest.mock import MagicMock, patch + from scripts.classifier_adapters import FineTunedAdapter + + mock_result = [{"label": "digest", "score": 0.95}] + mock_pipe_instance = MagicMock(return_value=mock_result) + mock_pipe_factory = MagicMock(return_value=mock_pipe_instance) + + adapter = FineTunedAdapter("avocet-deberta-small", str(tmp_path)) + with patch("scripts.classifier_adapters.pipeline", mock_pipe_factory): + result = adapter.classify("Test subject", "Test body") + + assert result == "digest" + call_args = mock_pipe_instance.call_args[0][0] + assert "[SEP]" in call_args + assert "Test subject" in call_args + assert "Test body" in call_args + + +def test_finetuned_adapter_truncates_body_to_400(): + """Body must be truncated to 400 chars in the [SEP] format.""" + from unittest.mock import MagicMock, patch + from scripts.classifier_adapters import FineTunedAdapter, LABELS + + long_body = "x" * 800 + mock_result = [{"label": "neutral", "score": 0.9}] + mock_pipe_instance = MagicMock(return_value=mock_result) + mock_pipe_factory = MagicMock(return_value=mock_pipe_instance) + + adapter = FineTunedAdapter("avocet-deberta-small", "/fake/path") + with patch("scripts.classifier_adapters.pipeline", mock_pipe_factory): + adapter.classify("Subject", long_body) + + call_text = mock_pipe_instance.call_args[0][0] + # "Subject [SEP] " prefix + 400 body chars = 414 chars max + assert len(call_text) <= 420 + + +def test_finetuned_adapter_returns_label_string(): + """classify() must return a plain string, not a dict.""" + from unittest.mock import MagicMock, patch + from scripts.classifier_adapters import FineTunedAdapter + + mock_result = [{"label": "interview_scheduled", "score": 0.87}] + mock_pipe_instance = MagicMock(return_value=mock_result) + mock_pipe_factory = MagicMock(return_value=mock_pipe_instance) + + adapter = FineTunedAdapter("avocet-deberta-small", "/fake/path") + with patch("scripts.classifier_adapters.pipeline", mock_pipe_factory): + result = adapter.classify("S", "B") + + assert isinstance(result, str) + assert result == "interview_scheduled" + + +def test_finetuned_adapter_lazy_loads_pipeline(): + """Pipeline factory must not be called until classify() is first called.""" + from unittest.mock import MagicMock, patch + from scripts.classifier_adapters import FineTunedAdapter + + mock_pipe_factory = MagicMock(return_value=MagicMock(return_value=[{"label": "neutral", "score": 0.9}])) + + with patch("scripts.classifier_adapters.pipeline", mock_pipe_factory): + adapter = FineTunedAdapter("avocet-deberta-small", "/fake/path") + assert not mock_pipe_factory.called + adapter.classify("s", "b") + assert mock_pipe_factory.called + + +def test_finetuned_adapter_unload_clears_pipeline(): + """unload() must set _pipeline to None so memory is released.""" + from unittest.mock import MagicMock, patch + from scripts.classifier_adapters import FineTunedAdapter + + mock_pipe_factory = MagicMock(return_value=MagicMock(return_value=[{"label": "neutral", "score": 0.9}])) + + with patch("scripts.classifier_adapters.pipeline", mock_pipe_factory): + adapter = FineTunedAdapter("avocet-deberta-small", "/fake/path") + adapter.classify("s", "b") + assert adapter._pipeline is not None + adapter.unload() + assert adapter._pipeline is None