fix(classifier): atomic embed assignment, logging on orch failure, guard double load
This commit is contained in:
parent
f2f150b4fb
commit
4a64a6686d
2 changed files with 46 additions and 6 deletions
|
|
@ -7,6 +7,8 @@ from __future__ import annotations
|
||||||
|
|
||||||
import abc
|
import abc
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
|
import httpx
|
||||||
|
import logging
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
|
|
@ -23,6 +25,8 @@ __all__ = [
|
||||||
"EmbeddingKNNAdapter",
|
"EmbeddingKNNAdapter",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
_logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
LABELS: list[str] = [
|
LABELS: list[str] = [
|
||||||
"interview_scheduled",
|
"interview_scheduled",
|
||||||
"offer_received",
|
"offer_received",
|
||||||
|
|
@ -447,7 +451,6 @@ class EmbeddingKNNAdapter(ClassifierAdapter):
|
||||||
return cforch.get("coordinator_url", ""), cforch.get("ollama_url", "")
|
return cforch.get("coordinator_url", ""), cforch.get("ollama_url", "")
|
||||||
|
|
||||||
def _embed(self, node_url: str, texts: list[str]) -> list[list[float]]:
|
def _embed(self, node_url: str, texts: list[str]) -> list[list[float]]:
|
||||||
import httpx # noqa: PLC0415
|
|
||||||
resp = httpx.post(
|
resp = httpx.post(
|
||||||
f"{node_url}/v1/embeddings",
|
f"{node_url}/v1/embeddings",
|
||||||
json={"model": self._model_id, "input": texts},
|
json={"model": self._model_id, "input": texts},
|
||||||
|
|
@ -457,7 +460,10 @@ class EmbeddingKNNAdapter(ClassifierAdapter):
|
||||||
return [item["embedding"] for item in resp.json()["data"]]
|
return [item["embedding"] for item in resp.json()["data"]]
|
||||||
|
|
||||||
def load(self) -> None:
|
def load(self) -> None:
|
||||||
import httpx # noqa: PLC0415
|
if self._allocation_id or self._exemplar_embeddings:
|
||||||
|
raise RuntimeError(
|
||||||
|
"EmbeddingKNNAdapter.load() called while already loaded — call unload() first"
|
||||||
|
)
|
||||||
orch_url, ollama_url = self._resolve_urls()
|
orch_url, ollama_url = self._resolve_urls()
|
||||||
node_url = ""
|
node_url = ""
|
||||||
orch_url_used = ""
|
orch_url_used = ""
|
||||||
|
|
@ -473,21 +479,24 @@ class EmbeddingKNNAdapter(ClassifierAdapter):
|
||||||
node_url = data["url"]
|
node_url = data["url"]
|
||||||
self._allocation_id = data["allocation_id"]
|
self._allocation_id = data["allocation_id"]
|
||||||
orch_url_used = orch_url
|
orch_url_used = orch_url
|
||||||
except Exception:
|
except Exception as exc:
|
||||||
pass
|
_logger.warning(
|
||||||
|
"cf-orch allocation failed, falling back to direct ollama_url: %s", exc
|
||||||
|
)
|
||||||
if not node_url:
|
if not node_url:
|
||||||
node_url = ollama_url
|
node_url = ollama_url
|
||||||
self._allocation_id = ""
|
self._allocation_id = ""
|
||||||
orch_url_used = ""
|
orch_url_used = ""
|
||||||
self._node_url = node_url
|
self._node_url = node_url
|
||||||
self._orch_url_used = orch_url_used
|
self._orch_url_used = orch_url_used
|
||||||
|
embeddings: dict[str, list[list[float]]] = {}
|
||||||
for label, texts in self._exemplar_texts.items():
|
for label, texts in self._exemplar_texts.items():
|
||||||
self._exemplar_embeddings[label] = self._embed(node_url, texts)
|
embeddings[label] = self._embed(node_url, texts)
|
||||||
|
self._exemplar_embeddings = embeddings
|
||||||
|
|
||||||
def unload(self) -> None:
|
def unload(self) -> None:
|
||||||
if self._allocation_id and self._orch_url_used:
|
if self._allocation_id and self._orch_url_used:
|
||||||
try:
|
try:
|
||||||
import httpx # noqa: PLC0415
|
|
||||||
httpx.request(
|
httpx.request(
|
||||||
"DELETE",
|
"DELETE",
|
||||||
f"{self._orch_url_used}/api/services/ollama/allocations/{self._allocation_id}",
|
f"{self._orch_url_used}/api/services/ollama/allocations/{self._allocation_id}",
|
||||||
|
|
|
||||||
|
|
@ -418,6 +418,8 @@ def test_load_calls_allocate_then_embeds_each_label():
|
||||||
assert "hired" in adapter._exemplar_embeddings
|
assert "hired" in adapter._exemplar_embeddings
|
||||||
assert len(adapter._exemplar_embeddings["rejected"]) == 1
|
assert len(adapter._exemplar_embeddings["rejected"]) == 1
|
||||||
assert len(adapter._exemplar_embeddings["hired"]) == 2
|
assert len(adapter._exemplar_embeddings["hired"]) == 2
|
||||||
|
assert adapter._exemplar_embeddings["rejected"][0] == [0.1, 0.2, 0.3]
|
||||||
|
assert adapter._exemplar_embeddings["hired"][0] == [0.1, 0.2, 0.3]
|
||||||
|
|
||||||
|
|
||||||
def test_load_falls_back_to_ollama_when_allocate_fails():
|
def test_load_falls_back_to_ollama_when_allocate_fails():
|
||||||
|
|
@ -450,6 +452,35 @@ def test_load_falls_back_to_ollama_when_allocate_fails():
|
||||||
assert "rejected" in adapter._exemplar_embeddings
|
assert "rejected" in adapter._exemplar_embeddings
|
||||||
|
|
||||||
|
|
||||||
|
def test_load_falls_back_to_ollama_when_allocate_raises():
|
||||||
|
from unittest.mock import patch, MagicMock
|
||||||
|
import httpx as _httpx
|
||||||
|
from scripts.classifier_adapters import EmbeddingKNNAdapter
|
||||||
|
|
||||||
|
exemplars = {"rejected": ["We went with others"]}
|
||||||
|
adapter = EmbeddingKNNAdapter(
|
||||||
|
"test", "nomic-embed-text", k=3,
|
||||||
|
orch_url="http://orch:7700", ollama_url="http://ollama:11434",
|
||||||
|
exemplar_texts=exemplars,
|
||||||
|
)
|
||||||
|
|
||||||
|
def raising_mock(url, *, json=None, timeout=None, **kwargs):
|
||||||
|
if "/allocate" in url:
|
||||||
|
raise _httpx.ConnectError("connection refused")
|
||||||
|
resp = MagicMock()
|
||||||
|
resp.raise_for_status.return_value = None
|
||||||
|
resp.json.return_value = {"data": [{"embedding": [0.1, 0.2, 0.3]}]}
|
||||||
|
return resp
|
||||||
|
|
||||||
|
with patch("httpx.post", side_effect=raising_mock):
|
||||||
|
adapter.load()
|
||||||
|
|
||||||
|
assert adapter._allocation_id == ""
|
||||||
|
assert adapter._orch_url_used == ""
|
||||||
|
assert adapter._node_url == "http://ollama:11434"
|
||||||
|
assert "rejected" in adapter._exemplar_embeddings
|
||||||
|
|
||||||
|
|
||||||
# ---- EmbeddingKNNAdapter.unload() tests ----
|
# ---- EmbeddingKNNAdapter.unload() tests ----
|
||||||
|
|
||||||
def test_unload_releases_orch_allocation_and_clears_state():
|
def test_unload_releases_orch_allocation_and_clears_state():
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue