fix(classifier): atomic embed assignment, logging on orch failure, guard double load

This commit is contained in:
pyr0ball 2026-05-05 07:53:15 -07:00
parent f2f150b4fb
commit 4a64a6686d
2 changed files with 46 additions and 6 deletions

View file

@ -7,6 +7,8 @@ from __future__ import annotations
import abc import abc
from collections import defaultdict from collections import defaultdict
import httpx
import logging
from pathlib import Path from pathlib import Path
from typing import Any from typing import Any
@ -23,6 +25,8 @@ __all__ = [
"EmbeddingKNNAdapter", "EmbeddingKNNAdapter",
] ]
_logger = logging.getLogger(__name__)
LABELS: list[str] = [ LABELS: list[str] = [
"interview_scheduled", "interview_scheduled",
"offer_received", "offer_received",
@ -447,7 +451,6 @@ class EmbeddingKNNAdapter(ClassifierAdapter):
return cforch.get("coordinator_url", ""), cforch.get("ollama_url", "") return cforch.get("coordinator_url", ""), cforch.get("ollama_url", "")
def _embed(self, node_url: str, texts: list[str]) -> list[list[float]]: def _embed(self, node_url: str, texts: list[str]) -> list[list[float]]:
import httpx # noqa: PLC0415
resp = httpx.post( resp = httpx.post(
f"{node_url}/v1/embeddings", f"{node_url}/v1/embeddings",
json={"model": self._model_id, "input": texts}, json={"model": self._model_id, "input": texts},
@ -457,7 +460,10 @@ class EmbeddingKNNAdapter(ClassifierAdapter):
return [item["embedding"] for item in resp.json()["data"]] return [item["embedding"] for item in resp.json()["data"]]
def load(self) -> None: def load(self) -> None:
import httpx # noqa: PLC0415 if self._allocation_id or self._exemplar_embeddings:
raise RuntimeError(
"EmbeddingKNNAdapter.load() called while already loaded — call unload() first"
)
orch_url, ollama_url = self._resolve_urls() orch_url, ollama_url = self._resolve_urls()
node_url = "" node_url = ""
orch_url_used = "" orch_url_used = ""
@ -473,21 +479,24 @@ class EmbeddingKNNAdapter(ClassifierAdapter):
node_url = data["url"] node_url = data["url"]
self._allocation_id = data["allocation_id"] self._allocation_id = data["allocation_id"]
orch_url_used = orch_url orch_url_used = orch_url
except Exception: except Exception as exc:
pass _logger.warning(
"cf-orch allocation failed, falling back to direct ollama_url: %s", exc
)
if not node_url: if not node_url:
node_url = ollama_url node_url = ollama_url
self._allocation_id = "" self._allocation_id = ""
orch_url_used = "" orch_url_used = ""
self._node_url = node_url self._node_url = node_url
self._orch_url_used = orch_url_used self._orch_url_used = orch_url_used
embeddings: dict[str, list[list[float]]] = {}
for label, texts in self._exemplar_texts.items(): for label, texts in self._exemplar_texts.items():
self._exemplar_embeddings[label] = self._embed(node_url, texts) embeddings[label] = self._embed(node_url, texts)
self._exemplar_embeddings = embeddings
def unload(self) -> None: def unload(self) -> None:
if self._allocation_id and self._orch_url_used: if self._allocation_id and self._orch_url_used:
try: try:
import httpx # noqa: PLC0415
httpx.request( httpx.request(
"DELETE", "DELETE",
f"{self._orch_url_used}/api/services/ollama/allocations/{self._allocation_id}", f"{self._orch_url_used}/api/services/ollama/allocations/{self._allocation_id}",

View file

@ -418,6 +418,8 @@ def test_load_calls_allocate_then_embeds_each_label():
assert "hired" in adapter._exemplar_embeddings assert "hired" in adapter._exemplar_embeddings
assert len(adapter._exemplar_embeddings["rejected"]) == 1 assert len(adapter._exemplar_embeddings["rejected"]) == 1
assert len(adapter._exemplar_embeddings["hired"]) == 2 assert len(adapter._exemplar_embeddings["hired"]) == 2
assert adapter._exemplar_embeddings["rejected"][0] == [0.1, 0.2, 0.3]
assert adapter._exemplar_embeddings["hired"][0] == [0.1, 0.2, 0.3]
def test_load_falls_back_to_ollama_when_allocate_fails(): def test_load_falls_back_to_ollama_when_allocate_fails():
@ -450,6 +452,35 @@ def test_load_falls_back_to_ollama_when_allocate_fails():
assert "rejected" in adapter._exemplar_embeddings assert "rejected" in adapter._exemplar_embeddings
def test_load_falls_back_to_ollama_when_allocate_raises():
from unittest.mock import patch, MagicMock
import httpx as _httpx
from scripts.classifier_adapters import EmbeddingKNNAdapter
exemplars = {"rejected": ["We went with others"]}
adapter = EmbeddingKNNAdapter(
"test", "nomic-embed-text", k=3,
orch_url="http://orch:7700", ollama_url="http://ollama:11434",
exemplar_texts=exemplars,
)
def raising_mock(url, *, json=None, timeout=None, **kwargs):
if "/allocate" in url:
raise _httpx.ConnectError("connection refused")
resp = MagicMock()
resp.raise_for_status.return_value = None
resp.json.return_value = {"data": [{"embedding": [0.1, 0.2, 0.3]}]}
return resp
with patch("httpx.post", side_effect=raising_mock):
adapter.load()
assert adapter._allocation_id == ""
assert adapter._orch_url_used == ""
assert adapter._node_url == "http://ollama:11434"
assert "rejected" in adapter._exemplar_embeddings
# ---- EmbeddingKNNAdapter.unload() tests ---- # ---- EmbeddingKNNAdapter.unload() tests ----
def test_unload_releases_orch_allocation_and_clears_state(): def test_unload_releases_orch_allocation_and_clears_state():