feat(classifier): implement EmbeddingKNNAdapter.load() and unload()
This commit is contained in:
parent
72449561cf
commit
f2f150b4fb
2 changed files with 193 additions and 2 deletions
|
|
@ -446,11 +446,59 @@ class EmbeddingKNNAdapter(ClassifierAdapter):
|
||||||
cforch = cfg.get("cforch", {}) or {}
|
cforch = cfg.get("cforch", {}) or {}
|
||||||
return cforch.get("coordinator_url", ""), cforch.get("ollama_url", "")
|
return cforch.get("coordinator_url", ""), cforch.get("ollama_url", "")
|
||||||
|
|
||||||
|
def _embed(self, node_url: str, texts: list[str]) -> list[list[float]]:
|
||||||
|
import httpx # noqa: PLC0415
|
||||||
|
resp = httpx.post(
|
||||||
|
f"{node_url}/v1/embeddings",
|
||||||
|
json={"model": self._model_id, "input": texts},
|
||||||
|
timeout=30.0,
|
||||||
|
)
|
||||||
|
resp.raise_for_status()
|
||||||
|
return [item["embedding"] for item in resp.json()["data"]]
|
||||||
|
|
||||||
def load(self) -> None:
|
def load(self) -> None:
|
||||||
raise NotImplementedError
|
import httpx # noqa: PLC0415
|
||||||
|
orch_url, ollama_url = self._resolve_urls()
|
||||||
|
node_url = ""
|
||||||
|
orch_url_used = ""
|
||||||
|
if orch_url:
|
||||||
|
try:
|
||||||
|
resp = httpx.post(
|
||||||
|
f"{orch_url}/api/services/ollama/allocate",
|
||||||
|
json={"model": self._model_id},
|
||||||
|
timeout=15.0,
|
||||||
|
)
|
||||||
|
if resp.status_code == 200:
|
||||||
|
data = resp.json()
|
||||||
|
node_url = data["url"]
|
||||||
|
self._allocation_id = data["allocation_id"]
|
||||||
|
orch_url_used = orch_url
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
if not node_url:
|
||||||
|
node_url = ollama_url
|
||||||
|
self._allocation_id = ""
|
||||||
|
orch_url_used = ""
|
||||||
|
self._node_url = node_url
|
||||||
|
self._orch_url_used = orch_url_used
|
||||||
|
for label, texts in self._exemplar_texts.items():
|
||||||
|
self._exemplar_embeddings[label] = self._embed(node_url, texts)
|
||||||
|
|
||||||
def unload(self) -> None:
|
def unload(self) -> None:
|
||||||
raise NotImplementedError
|
if self._allocation_id and self._orch_url_used:
|
||||||
|
try:
|
||||||
|
import httpx # noqa: PLC0415
|
||||||
|
httpx.request(
|
||||||
|
"DELETE",
|
||||||
|
f"{self._orch_url_used}/api/services/ollama/allocations/{self._allocation_id}",
|
||||||
|
timeout=10.0,
|
||||||
|
)
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
self._exemplar_embeddings = {}
|
||||||
|
self._node_url = ""
|
||||||
|
self._allocation_id = ""
|
||||||
|
self._orch_url_used = ""
|
||||||
|
|
||||||
def classify(self, subject: str, body: str) -> str:
|
def classify(self, subject: str, body: str) -> str:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
|
||||||
|
|
@ -362,3 +362,146 @@ def test_embedding_knn_accepts_custom_exemplars():
|
||||||
exemplar_texts=custom,
|
exemplar_texts=custom,
|
||||||
)
|
)
|
||||||
assert adapter._exemplar_texts is custom
|
assert adapter._exemplar_texts is custom
|
||||||
|
|
||||||
|
|
||||||
|
# ---- EmbeddingKNNAdapter.load() tests ----
|
||||||
|
|
||||||
|
def _make_post_mock(alloc_url="http://navi:11434", alloc_id="alloc-abc"):
|
||||||
|
"""Return a side_effect function for patching httpx.post.
|
||||||
|
|
||||||
|
Allocate calls get alloc_url/alloc_id; embed calls return one [0.1,0.2,0.3]
|
||||||
|
embedding per input text.
|
||||||
|
"""
|
||||||
|
def _side_effect(url, *, json=None, timeout=None, **kwargs):
|
||||||
|
from unittest.mock import MagicMock
|
||||||
|
resp = MagicMock()
|
||||||
|
resp.raise_for_status.return_value = None
|
||||||
|
if "/allocate" in url:
|
||||||
|
resp.status_code = 200
|
||||||
|
resp.json.return_value = {"allocation_id": alloc_id, "url": alloc_url}
|
||||||
|
else:
|
||||||
|
n = len((json or {}).get("input", []))
|
||||||
|
resp.status_code = 200
|
||||||
|
resp.json.return_value = {"data": [{"embedding": [0.1, 0.2, 0.3]}] * n}
|
||||||
|
return resp
|
||||||
|
return _side_effect
|
||||||
|
|
||||||
|
|
||||||
|
def test_load_calls_allocate_then_embeds_each_label():
|
||||||
|
from unittest.mock import patch
|
||||||
|
from scripts.classifier_adapters import EmbeddingKNNAdapter
|
||||||
|
|
||||||
|
exemplars = {
|
||||||
|
"rejected": ["We went with others"],
|
||||||
|
"hired": ["Welcome aboard!", "First day info"],
|
||||||
|
}
|
||||||
|
adapter = EmbeddingKNNAdapter(
|
||||||
|
"test", "nomic-embed-text", k=3,
|
||||||
|
orch_url="http://orch:7700", ollama_url="http://ollama:11434",
|
||||||
|
exemplar_texts=exemplars,
|
||||||
|
)
|
||||||
|
|
||||||
|
post_urls = []
|
||||||
|
def capturing_mock(url, *, json=None, timeout=None, **kwargs):
|
||||||
|
post_urls.append(url)
|
||||||
|
return _make_post_mock()(url, json=json, timeout=timeout)
|
||||||
|
|
||||||
|
with patch("httpx.post", side_effect=capturing_mock):
|
||||||
|
adapter.load()
|
||||||
|
|
||||||
|
assert any("/allocate" in u for u in post_urls), "expected allocate call"
|
||||||
|
assert any("/v1/embeddings" in u for u in post_urls), "expected embed call"
|
||||||
|
assert adapter._allocation_id == "alloc-abc"
|
||||||
|
assert adapter._node_url == "http://navi:11434"
|
||||||
|
assert adapter._orch_url_used == "http://orch:7700"
|
||||||
|
assert "rejected" in adapter._exemplar_embeddings
|
||||||
|
assert "hired" in adapter._exemplar_embeddings
|
||||||
|
assert len(adapter._exemplar_embeddings["rejected"]) == 1
|
||||||
|
assert len(adapter._exemplar_embeddings["hired"]) == 2
|
||||||
|
|
||||||
|
|
||||||
|
def test_load_falls_back_to_ollama_when_allocate_fails():
|
||||||
|
from unittest.mock import patch, MagicMock
|
||||||
|
from scripts.classifier_adapters import EmbeddingKNNAdapter
|
||||||
|
|
||||||
|
exemplars = {"rejected": ["We went with others"]}
|
||||||
|
adapter = EmbeddingKNNAdapter(
|
||||||
|
"test", "nomic-embed-text", k=3,
|
||||||
|
orch_url="http://orch:7700", ollama_url="http://ollama:11434",
|
||||||
|
exemplar_texts=exemplars,
|
||||||
|
)
|
||||||
|
|
||||||
|
def failing_allocate_mock(url, *, json=None, timeout=None, **kwargs):
|
||||||
|
resp = MagicMock()
|
||||||
|
if "/allocate" in url:
|
||||||
|
resp.status_code = 503
|
||||||
|
resp.json.return_value = {}
|
||||||
|
else:
|
||||||
|
resp.raise_for_status.return_value = None
|
||||||
|
resp.json.return_value = {"data": [{"embedding": [0.1, 0.2, 0.3]}]}
|
||||||
|
return resp
|
||||||
|
|
||||||
|
with patch("httpx.post", side_effect=failing_allocate_mock):
|
||||||
|
adapter.load()
|
||||||
|
|
||||||
|
assert adapter._allocation_id == ""
|
||||||
|
assert adapter._orch_url_used == ""
|
||||||
|
assert adapter._node_url == "http://ollama:11434"
|
||||||
|
assert "rejected" in adapter._exemplar_embeddings
|
||||||
|
|
||||||
|
|
||||||
|
# ---- EmbeddingKNNAdapter.unload() tests ----
|
||||||
|
|
||||||
|
def test_unload_releases_orch_allocation_and_clears_state():
|
||||||
|
from unittest.mock import patch, MagicMock
|
||||||
|
from scripts.classifier_adapters import EmbeddingKNNAdapter
|
||||||
|
|
||||||
|
adapter = EmbeddingKNNAdapter(
|
||||||
|
"test", "nomic-embed-text", k=3,
|
||||||
|
orch_url="http://orch:7700", ollama_url="http://ollama:11434",
|
||||||
|
)
|
||||||
|
adapter._exemplar_embeddings = {"rejected": [[1.0, 0.0]]}
|
||||||
|
adapter._node_url = "http://navi:11434"
|
||||||
|
adapter._allocation_id = "alloc-abc"
|
||||||
|
adapter._orch_url_used = "http://orch:7700"
|
||||||
|
|
||||||
|
delete_calls = []
|
||||||
|
def mock_request(method, url, **kwargs):
|
||||||
|
delete_calls.append((method, url))
|
||||||
|
resp = MagicMock()
|
||||||
|
resp.status_code = 200
|
||||||
|
return resp
|
||||||
|
|
||||||
|
with patch("httpx.request", side_effect=mock_request):
|
||||||
|
adapter.unload()
|
||||||
|
|
||||||
|
assert len(delete_calls) == 1
|
||||||
|
method, url = delete_calls[0]
|
||||||
|
assert method == "DELETE"
|
||||||
|
assert "alloc-abc" in url
|
||||||
|
assert adapter._exemplar_embeddings == {}
|
||||||
|
assert adapter._allocation_id == ""
|
||||||
|
assert adapter._node_url == ""
|
||||||
|
assert adapter._orch_url_used == ""
|
||||||
|
|
||||||
|
|
||||||
|
def test_unload_skips_delete_on_ollama_fallback_path():
|
||||||
|
from unittest.mock import patch
|
||||||
|
from scripts.classifier_adapters import EmbeddingKNNAdapter
|
||||||
|
|
||||||
|
adapter = EmbeddingKNNAdapter(
|
||||||
|
"test", "nomic-embed-text", k=3,
|
||||||
|
orch_url="http://orch:7700", ollama_url="http://ollama:11434",
|
||||||
|
)
|
||||||
|
adapter._exemplar_embeddings = {"rejected": [[1.0, 0.0]]}
|
||||||
|
adapter._node_url = "http://ollama:11434"
|
||||||
|
adapter._allocation_id = "" # fallback path: no allocation was made
|
||||||
|
adapter._orch_url_used = ""
|
||||||
|
|
||||||
|
delete_calls = []
|
||||||
|
with patch("httpx.request", side_effect=lambda *a, **k: delete_calls.append(a)):
|
||||||
|
adapter.unload()
|
||||||
|
|
||||||
|
assert len(delete_calls) == 0
|
||||||
|
assert adapter._exemplar_embeddings == {}
|
||||||
|
assert adapter._node_url == ""
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue