diff --git a/scripts/classifier_adapters.py b/scripts/classifier_adapters.py index b612652..61bf1d2 100644 --- a/scripts/classifier_adapters.py +++ b/scripts/classifier_adapters.py @@ -446,11 +446,59 @@ class EmbeddingKNNAdapter(ClassifierAdapter): cforch = cfg.get("cforch", {}) or {} return cforch.get("coordinator_url", ""), cforch.get("ollama_url", "") + def _embed(self, node_url: str, texts: list[str]) -> list[list[float]]: + import httpx # noqa: PLC0415 + resp = httpx.post( + f"{node_url}/v1/embeddings", + json={"model": self._model_id, "input": texts}, + timeout=30.0, + ) + resp.raise_for_status() + return [item["embedding"] for item in resp.json()["data"]] + def load(self) -> None: - raise NotImplementedError + import httpx # noqa: PLC0415 + orch_url, ollama_url = self._resolve_urls() + node_url = "" + orch_url_used = "" + if orch_url: + try: + resp = httpx.post( + f"{orch_url}/api/services/ollama/allocate", + json={"model": self._model_id}, + timeout=15.0, + ) + if resp.status_code == 200: + data = resp.json() + node_url = data["url"] + self._allocation_id = data["allocation_id"] + orch_url_used = orch_url + except Exception: + pass + if not node_url: + node_url = ollama_url + self._allocation_id = "" + orch_url_used = "" + self._node_url = node_url + self._orch_url_used = orch_url_used + for label, texts in self._exemplar_texts.items(): + self._exemplar_embeddings[label] = self._embed(node_url, texts) def unload(self) -> None: - raise NotImplementedError + if self._allocation_id and self._orch_url_used: + try: + import httpx # noqa: PLC0415 + httpx.request( + "DELETE", + f"{self._orch_url_used}/api/services/ollama/allocations/{self._allocation_id}", + timeout=10.0, + ) + except Exception: + pass + self._exemplar_embeddings = {} + self._node_url = "" + self._allocation_id = "" + self._orch_url_used = "" def classify(self, subject: str, body: str) -> str: raise NotImplementedError diff --git a/tests/test_classifier_adapters.py b/tests/test_classifier_adapters.py index b19dc51..780e551 100644 --- a/tests/test_classifier_adapters.py +++ b/tests/test_classifier_adapters.py @@ -362,3 +362,146 @@ def test_embedding_knn_accepts_custom_exemplars(): exemplar_texts=custom, ) assert adapter._exemplar_texts is custom + + +# ---- EmbeddingKNNAdapter.load() tests ---- + +def _make_post_mock(alloc_url="http://navi:11434", alloc_id="alloc-abc"): + """Return a side_effect function for patching httpx.post. + + Allocate calls get alloc_url/alloc_id; embed calls return one [0.1,0.2,0.3] + embedding per input text. + """ + def _side_effect(url, *, json=None, timeout=None, **kwargs): + from unittest.mock import MagicMock + resp = MagicMock() + resp.raise_for_status.return_value = None + if "/allocate" in url: + resp.status_code = 200 + resp.json.return_value = {"allocation_id": alloc_id, "url": alloc_url} + else: + n = len((json or {}).get("input", [])) + resp.status_code = 200 + resp.json.return_value = {"data": [{"embedding": [0.1, 0.2, 0.3]}] * n} + return resp + return _side_effect + + +def test_load_calls_allocate_then_embeds_each_label(): + from unittest.mock import patch + from scripts.classifier_adapters import EmbeddingKNNAdapter + + exemplars = { + "rejected": ["We went with others"], + "hired": ["Welcome aboard!", "First day info"], + } + adapter = EmbeddingKNNAdapter( + "test", "nomic-embed-text", k=3, + orch_url="http://orch:7700", ollama_url="http://ollama:11434", + exemplar_texts=exemplars, + ) + + post_urls = [] + def capturing_mock(url, *, json=None, timeout=None, **kwargs): + post_urls.append(url) + return _make_post_mock()(url, json=json, timeout=timeout) + + with patch("httpx.post", side_effect=capturing_mock): + adapter.load() + + assert any("/allocate" in u for u in post_urls), "expected allocate call" + assert any("/v1/embeddings" in u for u in post_urls), "expected embed call" + assert adapter._allocation_id == "alloc-abc" + assert adapter._node_url == "http://navi:11434" + assert adapter._orch_url_used == "http://orch:7700" + assert "rejected" in adapter._exemplar_embeddings + assert "hired" in adapter._exemplar_embeddings + assert len(adapter._exemplar_embeddings["rejected"]) == 1 + assert len(adapter._exemplar_embeddings["hired"]) == 2 + + +def test_load_falls_back_to_ollama_when_allocate_fails(): + from unittest.mock import patch, MagicMock + from scripts.classifier_adapters import EmbeddingKNNAdapter + + exemplars = {"rejected": ["We went with others"]} + adapter = EmbeddingKNNAdapter( + "test", "nomic-embed-text", k=3, + orch_url="http://orch:7700", ollama_url="http://ollama:11434", + exemplar_texts=exemplars, + ) + + def failing_allocate_mock(url, *, json=None, timeout=None, **kwargs): + resp = MagicMock() + if "/allocate" in url: + resp.status_code = 503 + resp.json.return_value = {} + else: + resp.raise_for_status.return_value = None + resp.json.return_value = {"data": [{"embedding": [0.1, 0.2, 0.3]}]} + return resp + + with patch("httpx.post", side_effect=failing_allocate_mock): + adapter.load() + + assert adapter._allocation_id == "" + assert adapter._orch_url_used == "" + assert adapter._node_url == "http://ollama:11434" + assert "rejected" in adapter._exemplar_embeddings + + +# ---- EmbeddingKNNAdapter.unload() tests ---- + +def test_unload_releases_orch_allocation_and_clears_state(): + from unittest.mock import patch, MagicMock + from scripts.classifier_adapters import EmbeddingKNNAdapter + + adapter = EmbeddingKNNAdapter( + "test", "nomic-embed-text", k=3, + orch_url="http://orch:7700", ollama_url="http://ollama:11434", + ) + adapter._exemplar_embeddings = {"rejected": [[1.0, 0.0]]} + adapter._node_url = "http://navi:11434" + adapter._allocation_id = "alloc-abc" + adapter._orch_url_used = "http://orch:7700" + + delete_calls = [] + def mock_request(method, url, **kwargs): + delete_calls.append((method, url)) + resp = MagicMock() + resp.status_code = 200 + return resp + + with patch("httpx.request", side_effect=mock_request): + adapter.unload() + + assert len(delete_calls) == 1 + method, url = delete_calls[0] + assert method == "DELETE" + assert "alloc-abc" in url + assert adapter._exemplar_embeddings == {} + assert adapter._allocation_id == "" + assert adapter._node_url == "" + assert adapter._orch_url_used == "" + + +def test_unload_skips_delete_on_ollama_fallback_path(): + from unittest.mock import patch + from scripts.classifier_adapters import EmbeddingKNNAdapter + + adapter = EmbeddingKNNAdapter( + "test", "nomic-embed-text", k=3, + orch_url="http://orch:7700", ollama_url="http://ollama:11434", + ) + adapter._exemplar_embeddings = {"rejected": [[1.0, 0.0]]} + adapter._node_url = "http://ollama:11434" + adapter._allocation_id = "" # fallback path: no allocation was made + adapter._orch_url_used = "" + + delete_calls = [] + with patch("httpx.request", side_effect=lambda *a, **k: delete_calls.append(a)): + adapter.unload() + + assert len(delete_calls) == 0 + assert adapter._exemplar_embeddings == {} + assert adapter._node_url == ""