feat(avocet): add FineTunedAdapter for local checkpoint inference
This commit is contained in:
parent
71d0bfafe6
commit
7a4ca422ca
2 changed files with 127 additions and 0 deletions
|
|
@ -17,6 +17,7 @@ __all__ = [
|
||||||
"ZeroShotAdapter",
|
"ZeroShotAdapter",
|
||||||
"GLiClassAdapter",
|
"GLiClassAdapter",
|
||||||
"RerankerAdapter",
|
"RerankerAdapter",
|
||||||
|
"FineTunedAdapter",
|
||||||
]
|
]
|
||||||
|
|
||||||
LABELS: list[str] = [
|
LABELS: list[str] = [
|
||||||
|
|
@ -263,3 +264,42 @@ class RerankerAdapter(ClassifierAdapter):
|
||||||
pairs = [[text, LABEL_DESCRIPTIONS.get(label, label.replace("_", " "))] for label in LABELS]
|
pairs = [[text, LABEL_DESCRIPTIONS.get(label, label.replace("_", " "))] for label in LABELS]
|
||||||
scores: list[float] = self._reranker.compute_score(pairs, normalize=True)
|
scores: list[float] = self._reranker.compute_score(pairs, normalize=True)
|
||||||
return LABELS[scores.index(max(scores))]
|
return LABELS[scores.index(max(scores))]
|
||||||
|
|
||||||
|
|
||||||
|
class FineTunedAdapter(ClassifierAdapter):
|
||||||
|
"""Loads a fine-tuned checkpoint from a local models/ directory.
|
||||||
|
|
||||||
|
Uses pipeline("text-classification") for a single forward pass.
|
||||||
|
Input format: 'subject [SEP] body[:400]' — must match training format exactly.
|
||||||
|
Expected inference speed: ~10–20ms/email vs 111–338ms for zero-shot.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, name: str, model_dir: str) -> None:
|
||||||
|
self._name = name
|
||||||
|
self._model_dir = model_dir
|
||||||
|
self._pipeline: Any = None
|
||||||
|
|
||||||
|
@property
|
||||||
|
def name(self) -> str:
|
||||||
|
return self._name
|
||||||
|
|
||||||
|
@property
|
||||||
|
def model_id(self) -> str:
|
||||||
|
return self._model_dir
|
||||||
|
|
||||||
|
def load(self) -> None:
|
||||||
|
import scripts.classifier_adapters as _mod # noqa: PLC0415
|
||||||
|
_pipe_fn = _mod.pipeline
|
||||||
|
if _pipe_fn is None:
|
||||||
|
raise ImportError("transformers not installed")
|
||||||
|
self._pipeline = _pipe_fn("text-classification", model=self._model_dir)
|
||||||
|
|
||||||
|
def unload(self) -> None:
|
||||||
|
self._pipeline = None
|
||||||
|
|
||||||
|
def classify(self, subject: str, body: str) -> str:
|
||||||
|
if self._pipeline is None:
|
||||||
|
self.load()
|
||||||
|
text = f"{subject} [SEP] {body[:400]}"
|
||||||
|
result = self._pipeline(text)
|
||||||
|
return result[0]["label"]
|
||||||
|
|
|
||||||
|
|
@ -180,3 +180,90 @@ def test_reranker_adapter_picks_highest_score():
|
||||||
def test_reranker_adapter_descriptions_cover_all_labels():
|
def test_reranker_adapter_descriptions_cover_all_labels():
|
||||||
from scripts.classifier_adapters import LABEL_DESCRIPTIONS, LABELS
|
from scripts.classifier_adapters import LABEL_DESCRIPTIONS, LABELS
|
||||||
assert set(LABEL_DESCRIPTIONS.keys()) == set(LABELS)
|
assert set(LABEL_DESCRIPTIONS.keys()) == set(LABELS)
|
||||||
|
|
||||||
|
|
||||||
|
# ---- FineTunedAdapter tests ----
|
||||||
|
|
||||||
|
def test_finetuned_adapter_classify_calls_pipeline_with_sep_format(tmp_path):
|
||||||
|
"""classify() must format input as 'subject [SEP] body[:400]' — not the zero-shot format."""
|
||||||
|
from unittest.mock import MagicMock, patch
|
||||||
|
from scripts.classifier_adapters import FineTunedAdapter
|
||||||
|
|
||||||
|
mock_result = [{"label": "digest", "score": 0.95}]
|
||||||
|
mock_pipe_instance = MagicMock(return_value=mock_result)
|
||||||
|
mock_pipe_factory = MagicMock(return_value=mock_pipe_instance)
|
||||||
|
|
||||||
|
adapter = FineTunedAdapter("avocet-deberta-small", str(tmp_path))
|
||||||
|
with patch("scripts.classifier_adapters.pipeline", mock_pipe_factory):
|
||||||
|
result = adapter.classify("Test subject", "Test body")
|
||||||
|
|
||||||
|
assert result == "digest"
|
||||||
|
call_args = mock_pipe_instance.call_args[0][0]
|
||||||
|
assert "[SEP]" in call_args
|
||||||
|
assert "Test subject" in call_args
|
||||||
|
assert "Test body" in call_args
|
||||||
|
|
||||||
|
|
||||||
|
def test_finetuned_adapter_truncates_body_to_400():
|
||||||
|
"""Body must be truncated to 400 chars in the [SEP] format."""
|
||||||
|
from unittest.mock import MagicMock, patch
|
||||||
|
from scripts.classifier_adapters import FineTunedAdapter, LABELS
|
||||||
|
|
||||||
|
long_body = "x" * 800
|
||||||
|
mock_result = [{"label": "neutral", "score": 0.9}]
|
||||||
|
mock_pipe_instance = MagicMock(return_value=mock_result)
|
||||||
|
mock_pipe_factory = MagicMock(return_value=mock_pipe_instance)
|
||||||
|
|
||||||
|
adapter = FineTunedAdapter("avocet-deberta-small", "/fake/path")
|
||||||
|
with patch("scripts.classifier_adapters.pipeline", mock_pipe_factory):
|
||||||
|
adapter.classify("Subject", long_body)
|
||||||
|
|
||||||
|
call_text = mock_pipe_instance.call_args[0][0]
|
||||||
|
# "Subject [SEP] " prefix + 400 body chars = 414 chars max
|
||||||
|
assert len(call_text) <= 420
|
||||||
|
|
||||||
|
|
||||||
|
def test_finetuned_adapter_returns_label_string():
|
||||||
|
"""classify() must return a plain string, not a dict."""
|
||||||
|
from unittest.mock import MagicMock, patch
|
||||||
|
from scripts.classifier_adapters import FineTunedAdapter
|
||||||
|
|
||||||
|
mock_result = [{"label": "interview_scheduled", "score": 0.87}]
|
||||||
|
mock_pipe_instance = MagicMock(return_value=mock_result)
|
||||||
|
mock_pipe_factory = MagicMock(return_value=mock_pipe_instance)
|
||||||
|
|
||||||
|
adapter = FineTunedAdapter("avocet-deberta-small", "/fake/path")
|
||||||
|
with patch("scripts.classifier_adapters.pipeline", mock_pipe_factory):
|
||||||
|
result = adapter.classify("S", "B")
|
||||||
|
|
||||||
|
assert isinstance(result, str)
|
||||||
|
assert result == "interview_scheduled"
|
||||||
|
|
||||||
|
|
||||||
|
def test_finetuned_adapter_lazy_loads_pipeline():
|
||||||
|
"""Pipeline factory must not be called until classify() is first called."""
|
||||||
|
from unittest.mock import MagicMock, patch
|
||||||
|
from scripts.classifier_adapters import FineTunedAdapter
|
||||||
|
|
||||||
|
mock_pipe_factory = MagicMock(return_value=MagicMock(return_value=[{"label": "neutral", "score": 0.9}]))
|
||||||
|
|
||||||
|
with patch("scripts.classifier_adapters.pipeline", mock_pipe_factory):
|
||||||
|
adapter = FineTunedAdapter("avocet-deberta-small", "/fake/path")
|
||||||
|
assert not mock_pipe_factory.called
|
||||||
|
adapter.classify("s", "b")
|
||||||
|
assert mock_pipe_factory.called
|
||||||
|
|
||||||
|
|
||||||
|
def test_finetuned_adapter_unload_clears_pipeline():
|
||||||
|
"""unload() must set _pipeline to None so memory is released."""
|
||||||
|
from unittest.mock import MagicMock, patch
|
||||||
|
from scripts.classifier_adapters import FineTunedAdapter
|
||||||
|
|
||||||
|
mock_pipe_factory = MagicMock(return_value=MagicMock(return_value=[{"label": "neutral", "score": 0.9}]))
|
||||||
|
|
||||||
|
with patch("scripts.classifier_adapters.pipeline", mock_pipe_factory):
|
||||||
|
adapter = FineTunedAdapter("avocet-deberta-small", "/fake/path")
|
||||||
|
adapter.classify("s", "b")
|
||||||
|
assert adapter._pipeline is not None
|
||||||
|
adapter.unload()
|
||||||
|
assert adapter._pipeline is None
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue