diff --git a/scripts/classifier_adapters.py b/scripts/classifier_adapters.py index 61bf1d2..3b3933c 100644 --- a/scripts/classifier_adapters.py +++ b/scripts/classifier_adapters.py @@ -7,6 +7,8 @@ from __future__ import annotations import abc from collections import defaultdict +import httpx +import logging from pathlib import Path from typing import Any @@ -23,6 +25,8 @@ __all__ = [ "EmbeddingKNNAdapter", ] +_logger = logging.getLogger(__name__) + LABELS: list[str] = [ "interview_scheduled", "offer_received", @@ -447,7 +451,6 @@ class EmbeddingKNNAdapter(ClassifierAdapter): return cforch.get("coordinator_url", ""), cforch.get("ollama_url", "") def _embed(self, node_url: str, texts: list[str]) -> list[list[float]]: - import httpx # noqa: PLC0415 resp = httpx.post( f"{node_url}/v1/embeddings", json={"model": self._model_id, "input": texts}, @@ -457,7 +460,10 @@ class EmbeddingKNNAdapter(ClassifierAdapter): return [item["embedding"] for item in resp.json()["data"]] def load(self) -> None: - import httpx # noqa: PLC0415 + if self._allocation_id or self._exemplar_embeddings: + raise RuntimeError( + "EmbeddingKNNAdapter.load() called while already loaded — call unload() first" + ) orch_url, ollama_url = self._resolve_urls() node_url = "" orch_url_used = "" @@ -473,21 +479,24 @@ class EmbeddingKNNAdapter(ClassifierAdapter): node_url = data["url"] self._allocation_id = data["allocation_id"] orch_url_used = orch_url - except Exception: - pass + except Exception as exc: + _logger.warning( + "cf-orch allocation failed, falling back to direct ollama_url: %s", exc + ) if not node_url: node_url = ollama_url self._allocation_id = "" orch_url_used = "" self._node_url = node_url self._orch_url_used = orch_url_used + embeddings: dict[str, list[list[float]]] = {} for label, texts in self._exemplar_texts.items(): - self._exemplar_embeddings[label] = self._embed(node_url, texts) + embeddings[label] = self._embed(node_url, texts) + self._exemplar_embeddings = embeddings def unload(self) -> None: if self._allocation_id and self._orch_url_used: try: - import httpx # noqa: PLC0415 httpx.request( "DELETE", f"{self._orch_url_used}/api/services/ollama/allocations/{self._allocation_id}", diff --git a/tests/test_classifier_adapters.py b/tests/test_classifier_adapters.py index 780e551..feec32a 100644 --- a/tests/test_classifier_adapters.py +++ b/tests/test_classifier_adapters.py @@ -418,6 +418,8 @@ def test_load_calls_allocate_then_embeds_each_label(): assert "hired" in adapter._exemplar_embeddings assert len(adapter._exemplar_embeddings["rejected"]) == 1 assert len(adapter._exemplar_embeddings["hired"]) == 2 + assert adapter._exemplar_embeddings["rejected"][0] == [0.1, 0.2, 0.3] + assert adapter._exemplar_embeddings["hired"][0] == [0.1, 0.2, 0.3] def test_load_falls_back_to_ollama_when_allocate_fails(): @@ -450,6 +452,35 @@ def test_load_falls_back_to_ollama_when_allocate_fails(): assert "rejected" in adapter._exemplar_embeddings +def test_load_falls_back_to_ollama_when_allocate_raises(): + from unittest.mock import patch, MagicMock + import httpx as _httpx + from scripts.classifier_adapters import EmbeddingKNNAdapter + + exemplars = {"rejected": ["We went with others"]} + adapter = EmbeddingKNNAdapter( + "test", "nomic-embed-text", k=3, + orch_url="http://orch:7700", ollama_url="http://ollama:11434", + exemplar_texts=exemplars, + ) + + def raising_mock(url, *, json=None, timeout=None, **kwargs): + if "/allocate" in url: + raise _httpx.ConnectError("connection refused") + resp = MagicMock() + resp.raise_for_status.return_value = None + resp.json.return_value = {"data": [{"embedding": [0.1, 0.2, 0.3]}]} + return resp + + with patch("httpx.post", side_effect=raising_mock): + adapter.load() + + assert adapter._allocation_id == "" + assert adapter._orch_url_used == "" + assert adapter._node_url == "http://ollama:11434" + assert "rejected" in adapter._exemplar_embeddings + + # ---- EmbeddingKNNAdapter.unload() tests ---- def test_unload_releases_orch_allocation_and_clears_state():