fix(avocet): FineTunedAdapter GPU device routing + precise body truncation test
This commit is contained in:
parent
7a4ca422ca
commit
da8478082e
2 changed files with 6 additions and 4 deletions
|
|
@ -291,8 +291,9 @@ class FineTunedAdapter(ClassifierAdapter):
|
||||||
import scripts.classifier_adapters as _mod # noqa: PLC0415
|
import scripts.classifier_adapters as _mod # noqa: PLC0415
|
||||||
_pipe_fn = _mod.pipeline
|
_pipe_fn = _mod.pipeline
|
||||||
if _pipe_fn is None:
|
if _pipe_fn is None:
|
||||||
raise ImportError("transformers not installed")
|
raise ImportError("transformers not installed — run: pip install transformers")
|
||||||
self._pipeline = _pipe_fn("text-classification", model=self._model_dir)
|
device = 0 if _cuda_available() else -1
|
||||||
|
self._pipeline = _pipe_fn("text-classification", model=self._model_dir, device=device)
|
||||||
|
|
||||||
def unload(self) -> None:
|
def unload(self) -> None:
|
||||||
self._pipeline = None
|
self._pipeline = None
|
||||||
|
|
|
||||||
|
|
@ -219,8 +219,9 @@ def test_finetuned_adapter_truncates_body_to_400():
|
||||||
adapter.classify("Subject", long_body)
|
adapter.classify("Subject", long_body)
|
||||||
|
|
||||||
call_text = mock_pipe_instance.call_args[0][0]
|
call_text = mock_pipe_instance.call_args[0][0]
|
||||||
# "Subject [SEP] " prefix + 400 body chars = 414 chars max
|
parts = call_text.split(" [SEP] ", 1)
|
||||||
assert len(call_text) <= 420
|
assert len(parts) == 2, "Input must contain ' [SEP] ' separator"
|
||||||
|
assert len(parts[1]) == 400, f"Body must be exactly 400 chars, got {len(parts[1])}"
|
||||||
|
|
||||||
|
|
||||||
def test_finetuned_adapter_returns_label_string():
|
def test_finetuned_adapter_returns_label_string():
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue