fix(avocet): FineTunedAdapter GPU device routing + precise body truncation test

This commit is contained in:
pyr0ball 2026-03-15 10:56:47 -07:00
parent 7a4ca422ca
commit da8478082e
2 changed files with 6 additions and 4 deletions

View file

@ -291,8 +291,9 @@ class FineTunedAdapter(ClassifierAdapter):
import scripts.classifier_adapters as _mod # noqa: PLC0415 import scripts.classifier_adapters as _mod # noqa: PLC0415
_pipe_fn = _mod.pipeline _pipe_fn = _mod.pipeline
if _pipe_fn is None: if _pipe_fn is None:
raise ImportError("transformers not installed") raise ImportError("transformers not installed — run: pip install transformers")
self._pipeline = _pipe_fn("text-classification", model=self._model_dir) device = 0 if _cuda_available() else -1
self._pipeline = _pipe_fn("text-classification", model=self._model_dir, device=device)
def unload(self) -> None: def unload(self) -> None:
self._pipeline = None self._pipeline = None

View file

@ -219,8 +219,9 @@ def test_finetuned_adapter_truncates_body_to_400():
adapter.classify("Subject", long_body) adapter.classify("Subject", long_body)
call_text = mock_pipe_instance.call_args[0][0] call_text = mock_pipe_instance.call_args[0][0]
# "Subject [SEP] " prefix + 400 body chars = 414 chars max parts = call_text.split(" [SEP] ", 1)
assert len(call_text) <= 420 assert len(parts) == 2, "Input must contain ' [SEP] ' separator"
assert len(parts[1]) == 400, f"Body must be exactly 400 chars, got {len(parts[1])}"
def test_finetuned_adapter_returns_label_string(): def test_finetuned_adapter_returns_label_string():