fix(avocet): FineTunedAdapter GPU device routing + precise body truncation test
This commit is contained in:
parent
7a4ca422ca
commit
da8478082e
2 changed files with 6 additions and 4 deletions
|
|
@ -291,8 +291,9 @@ class FineTunedAdapter(ClassifierAdapter):
|
|||
import scripts.classifier_adapters as _mod # noqa: PLC0415
|
||||
_pipe_fn = _mod.pipeline
|
||||
if _pipe_fn is None:
|
||||
raise ImportError("transformers not installed")
|
||||
self._pipeline = _pipe_fn("text-classification", model=self._model_dir)
|
||||
raise ImportError("transformers not installed — run: pip install transformers")
|
||||
device = 0 if _cuda_available() else -1
|
||||
self._pipeline = _pipe_fn("text-classification", model=self._model_dir, device=device)
|
||||
|
||||
def unload(self) -> None:
|
||||
self._pipeline = None
|
||||
|
|
|
|||
|
|
@ -219,8 +219,9 @@ def test_finetuned_adapter_truncates_body_to_400():
|
|||
adapter.classify("Subject", long_body)
|
||||
|
||||
call_text = mock_pipe_instance.call_args[0][0]
|
||||
# "Subject [SEP] " prefix + 400 body chars = 414 chars max
|
||||
assert len(call_text) <= 420
|
||||
parts = call_text.split(" [SEP] ", 1)
|
||||
assert len(parts) == 2, "Input must contain ' [SEP] ' separator"
|
||||
assert len(parts[1]) == 400, f"Body must be exactly 400 chars, got {len(parts[1])}"
|
||||
|
||||
|
||||
def test_finetuned_adapter_returns_label_string():
|
||||
|
|
|
|||
Loading…
Reference in a new issue