feat(avocet): add FineTunedAdapter for local checkpoint inference
This commit is contained in:
parent
71d0bfafe6
commit
7a4ca422ca
2 changed files with 127 additions and 0 deletions
|
|
@ -17,6 +17,7 @@ __all__ = [
|
|||
"ZeroShotAdapter",
|
||||
"GLiClassAdapter",
|
||||
"RerankerAdapter",
|
||||
"FineTunedAdapter",
|
||||
]
|
||||
|
||||
LABELS: list[str] = [
|
||||
|
|
@ -263,3 +264,42 @@ class RerankerAdapter(ClassifierAdapter):
|
|||
pairs = [[text, LABEL_DESCRIPTIONS.get(label, label.replace("_", " "))] for label in LABELS]
|
||||
scores: list[float] = self._reranker.compute_score(pairs, normalize=True)
|
||||
return LABELS[scores.index(max(scores))]
|
||||
|
||||
|
||||
class FineTunedAdapter(ClassifierAdapter):
|
||||
"""Loads a fine-tuned checkpoint from a local models/ directory.
|
||||
|
||||
Uses pipeline("text-classification") for a single forward pass.
|
||||
Input format: 'subject [SEP] body[:400]' — must match training format exactly.
|
||||
Expected inference speed: ~10–20ms/email vs 111–338ms for zero-shot.
|
||||
"""
|
||||
|
||||
def __init__(self, name: str, model_dir: str) -> None:
|
||||
self._name = name
|
||||
self._model_dir = model_dir
|
||||
self._pipeline: Any = None
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return self._name
|
||||
|
||||
@property
|
||||
def model_id(self) -> str:
|
||||
return self._model_dir
|
||||
|
||||
def load(self) -> None:
|
||||
import scripts.classifier_adapters as _mod # noqa: PLC0415
|
||||
_pipe_fn = _mod.pipeline
|
||||
if _pipe_fn is None:
|
||||
raise ImportError("transformers not installed")
|
||||
self._pipeline = _pipe_fn("text-classification", model=self._model_dir)
|
||||
|
||||
def unload(self) -> None:
|
||||
self._pipeline = None
|
||||
|
||||
def classify(self, subject: str, body: str) -> str:
|
||||
if self._pipeline is None:
|
||||
self.load()
|
||||
text = f"{subject} [SEP] {body[:400]}"
|
||||
result = self._pipeline(text)
|
||||
return result[0]["label"]
|
||||
|
|
|
|||
|
|
@ -180,3 +180,90 @@ def test_reranker_adapter_picks_highest_score():
|
|||
def test_reranker_adapter_descriptions_cover_all_labels():
|
||||
from scripts.classifier_adapters import LABEL_DESCRIPTIONS, LABELS
|
||||
assert set(LABEL_DESCRIPTIONS.keys()) == set(LABELS)
|
||||
|
||||
|
||||
# ---- FineTunedAdapter tests ----
|
||||
|
||||
def test_finetuned_adapter_classify_calls_pipeline_with_sep_format(tmp_path):
|
||||
"""classify() must format input as 'subject [SEP] body[:400]' — not the zero-shot format."""
|
||||
from unittest.mock import MagicMock, patch
|
||||
from scripts.classifier_adapters import FineTunedAdapter
|
||||
|
||||
mock_result = [{"label": "digest", "score": 0.95}]
|
||||
mock_pipe_instance = MagicMock(return_value=mock_result)
|
||||
mock_pipe_factory = MagicMock(return_value=mock_pipe_instance)
|
||||
|
||||
adapter = FineTunedAdapter("avocet-deberta-small", str(tmp_path))
|
||||
with patch("scripts.classifier_adapters.pipeline", mock_pipe_factory):
|
||||
result = adapter.classify("Test subject", "Test body")
|
||||
|
||||
assert result == "digest"
|
||||
call_args = mock_pipe_instance.call_args[0][0]
|
||||
assert "[SEP]" in call_args
|
||||
assert "Test subject" in call_args
|
||||
assert "Test body" in call_args
|
||||
|
||||
|
||||
def test_finetuned_adapter_truncates_body_to_400():
|
||||
"""Body must be truncated to 400 chars in the [SEP] format."""
|
||||
from unittest.mock import MagicMock, patch
|
||||
from scripts.classifier_adapters import FineTunedAdapter, LABELS
|
||||
|
||||
long_body = "x" * 800
|
||||
mock_result = [{"label": "neutral", "score": 0.9}]
|
||||
mock_pipe_instance = MagicMock(return_value=mock_result)
|
||||
mock_pipe_factory = MagicMock(return_value=mock_pipe_instance)
|
||||
|
||||
adapter = FineTunedAdapter("avocet-deberta-small", "/fake/path")
|
||||
with patch("scripts.classifier_adapters.pipeline", mock_pipe_factory):
|
||||
adapter.classify("Subject", long_body)
|
||||
|
||||
call_text = mock_pipe_instance.call_args[0][0]
|
||||
# "Subject [SEP] " prefix + 400 body chars = 414 chars max
|
||||
assert len(call_text) <= 420
|
||||
|
||||
|
||||
def test_finetuned_adapter_returns_label_string():
|
||||
"""classify() must return a plain string, not a dict."""
|
||||
from unittest.mock import MagicMock, patch
|
||||
from scripts.classifier_adapters import FineTunedAdapter
|
||||
|
||||
mock_result = [{"label": "interview_scheduled", "score": 0.87}]
|
||||
mock_pipe_instance = MagicMock(return_value=mock_result)
|
||||
mock_pipe_factory = MagicMock(return_value=mock_pipe_instance)
|
||||
|
||||
adapter = FineTunedAdapter("avocet-deberta-small", "/fake/path")
|
||||
with patch("scripts.classifier_adapters.pipeline", mock_pipe_factory):
|
||||
result = adapter.classify("S", "B")
|
||||
|
||||
assert isinstance(result, str)
|
||||
assert result == "interview_scheduled"
|
||||
|
||||
|
||||
def test_finetuned_adapter_lazy_loads_pipeline():
|
||||
"""Pipeline factory must not be called until classify() is first called."""
|
||||
from unittest.mock import MagicMock, patch
|
||||
from scripts.classifier_adapters import FineTunedAdapter
|
||||
|
||||
mock_pipe_factory = MagicMock(return_value=MagicMock(return_value=[{"label": "neutral", "score": 0.9}]))
|
||||
|
||||
with patch("scripts.classifier_adapters.pipeline", mock_pipe_factory):
|
||||
adapter = FineTunedAdapter("avocet-deberta-small", "/fake/path")
|
||||
assert not mock_pipe_factory.called
|
||||
adapter.classify("s", "b")
|
||||
assert mock_pipe_factory.called
|
||||
|
||||
|
||||
def test_finetuned_adapter_unload_clears_pipeline():
|
||||
"""unload() must set _pipeline to None so memory is released."""
|
||||
from unittest.mock import MagicMock, patch
|
||||
from scripts.classifier_adapters import FineTunedAdapter
|
||||
|
||||
mock_pipe_factory = MagicMock(return_value=MagicMock(return_value=[{"label": "neutral", "score": 0.9}]))
|
||||
|
||||
with patch("scripts.classifier_adapters.pipeline", mock_pipe_factory):
|
||||
adapter = FineTunedAdapter("avocet-deberta-small", "/fake/path")
|
||||
adapter.classify("s", "b")
|
||||
assert adapter._pipeline is not None
|
||||
adapter.unload()
|
||||
assert adapter._pipeline is None
|
||||
|
|
|
|||
Loading…
Reference in a new issue