feat(classifier): implement EmbeddingKNNAdapter.classify() with k-NN vote
This commit is contained in:
parent
4a64a6686d
commit
88bc6bed67
2 changed files with 115 additions and 1 deletions
|
|
@ -510,4 +510,17 @@ class EmbeddingKNNAdapter(ClassifierAdapter):
|
||||||
self._orch_url_used = ""
|
self._orch_url_used = ""
|
||||||
|
|
||||||
def classify(self, subject: str, body: str) -> str:
|
def classify(self, subject: str, body: str) -> str:
|
||||||
raise NotImplementedError
|
if not self._exemplar_embeddings:
|
||||||
|
self.load()
|
||||||
|
text = f"Subject: {subject}\n\n{body[:600]}"
|
||||||
|
[query_vec] = self._embed(self._node_url, [text])
|
||||||
|
scored: list[tuple[float, str]] = [
|
||||||
|
(_cosine(query_vec, vec), label)
|
||||||
|
for label, vecs in self._exemplar_embeddings.items()
|
||||||
|
for vec in vecs
|
||||||
|
]
|
||||||
|
top_k = sorted(scored, reverse=True)[: self._k]
|
||||||
|
votes: dict[str, list[float]] = {}
|
||||||
|
for score, label in top_k:
|
||||||
|
votes.setdefault(label, []).append(score)
|
||||||
|
return max(votes, key=lambda lbl: sum(votes[lbl]))
|
||||||
|
|
|
||||||
|
|
@ -536,3 +536,104 @@ def test_unload_skips_delete_on_ollama_fallback_path():
|
||||||
assert len(delete_calls) == 0
|
assert len(delete_calls) == 0
|
||||||
assert adapter._exemplar_embeddings == {}
|
assert adapter._exemplar_embeddings == {}
|
||||||
assert adapter._node_url == ""
|
assert adapter._node_url == ""
|
||||||
|
|
||||||
|
|
||||||
|
# ---- EmbeddingKNNAdapter.classify() tests ----
|
||||||
|
|
||||||
|
def _adapter_with_embeddings(exemplar_embeddings, k=3):
|
||||||
|
"""Return a pre-loaded EmbeddingKNNAdapter (bypass load()) with given per-label vectors."""
|
||||||
|
from scripts.classifier_adapters import EmbeddingKNNAdapter
|
||||||
|
adapter = EmbeddingKNNAdapter(
|
||||||
|
"test", "nomic-embed-text", k=k,
|
||||||
|
orch_url="http://orch:7700", ollama_url="http://ollama:11434",
|
||||||
|
)
|
||||||
|
adapter._exemplar_embeddings = exemplar_embeddings
|
||||||
|
adapter._node_url = "http://navi:11434"
|
||||||
|
return adapter
|
||||||
|
|
||||||
|
|
||||||
|
def _embed_resp(vec):
|
||||||
|
"""Return a mock httpx response for /v1/embeddings returning a single vector."""
|
||||||
|
from unittest.mock import MagicMock
|
||||||
|
resp = MagicMock()
|
||||||
|
resp.raise_for_status.return_value = None
|
||||||
|
resp.json.return_value = {"data": [{"embedding": vec}]}
|
||||||
|
return resp
|
||||||
|
|
||||||
|
|
||||||
|
def test_classify_returns_majority_vote_label():
|
||||||
|
from unittest.mock import patch
|
||||||
|
adapter = _adapter_with_embeddings({
|
||||||
|
"rejected": [[1.0, 0.0, 0.0], [0.9, 0.1, 0.0], [0.85, 0.15, 0.0]],
|
||||||
|
"neutral": [[0.0, 1.0, 0.0]],
|
||||||
|
}, k=3)
|
||||||
|
|
||||||
|
# Query [1,0,0] is closest to all three "rejected" exemplars
|
||||||
|
with patch("httpx.post", return_value=_embed_resp([1.0, 0.0, 0.0])):
|
||||||
|
result = adapter.classify("We went with others", "Thank you for applying.")
|
||||||
|
|
||||||
|
assert result == "rejected"
|
||||||
|
|
||||||
|
|
||||||
|
def test_classify_tiebreak_by_mean_score():
|
||||||
|
from unittest.mock import patch
|
||||||
|
# k=2: each label gets exactly 1 vote → tie-break by mean similarity
|
||||||
|
# [1,0] query: cosine to [1,0] = 1.0 ("rejected"), cosine to [0.6,0.8] ≈ 0.6 ("neutral")
|
||||||
|
adapter = _adapter_with_embeddings({
|
||||||
|
"rejected": [[1.0, 0.0]],
|
||||||
|
"neutral": [[0.6, 0.8]],
|
||||||
|
}, k=2)
|
||||||
|
|
||||||
|
with patch("httpx.post", return_value=_embed_resp([1.0, 0.0])):
|
||||||
|
result = adapter.classify("Rejection", "Sorry")
|
||||||
|
|
||||||
|
assert result == "rejected"
|
||||||
|
|
||||||
|
|
||||||
|
def test_classify_sparse_label_can_win():
|
||||||
|
from unittest.mock import patch
|
||||||
|
# "hired" has only 1 exemplar; query vector is closest to it
|
||||||
|
adapter = _adapter_with_embeddings({
|
||||||
|
"rejected": [[0.0, 0.0, 1.0], [0.0, 0.1, 0.9]],
|
||||||
|
"hired": [[1.0, 0.0, 0.0]],
|
||||||
|
}, k=3)
|
||||||
|
|
||||||
|
# Query [1,0,0] → hired exemplar scores 1.0; rejected exemplars score ~0
|
||||||
|
with patch("httpx.post", return_value=_embed_resp([1.0, 0.0, 0.0])):
|
||||||
|
result = adapter.classify("Welcome aboard", "Your first day details")
|
||||||
|
|
||||||
|
assert result == "hired"
|
||||||
|
|
||||||
|
|
||||||
|
def test_classify_lazy_loads_when_not_loaded():
|
||||||
|
from unittest.mock import patch
|
||||||
|
from scripts.classifier_adapters import EmbeddingKNNAdapter
|
||||||
|
|
||||||
|
exemplars = {"rejected": ["We went with others"]}
|
||||||
|
adapter = EmbeddingKNNAdapter(
|
||||||
|
"test", "nomic-embed-text", k=1,
|
||||||
|
orch_url="http://orch:7700", ollama_url="http://ollama:11434",
|
||||||
|
exemplar_texts=exemplars,
|
||||||
|
)
|
||||||
|
assert adapter._exemplar_embeddings == {}
|
||||||
|
|
||||||
|
post_urls = []
|
||||||
|
def mock_post(url, *, json=None, timeout=None, **kwargs):
|
||||||
|
post_urls.append(url)
|
||||||
|
from unittest.mock import MagicMock
|
||||||
|
resp = MagicMock()
|
||||||
|
resp.raise_for_status.return_value = None
|
||||||
|
if "/allocate" in url:
|
||||||
|
resp.status_code = 200
|
||||||
|
resp.json.return_value = {"allocation_id": "a1", "url": "http://navi:11434"}
|
||||||
|
else:
|
||||||
|
n = len((json or {}).get("input", []))
|
||||||
|
resp.json.return_value = {"data": [{"embedding": [1.0, 0.0]}] * n}
|
||||||
|
return resp
|
||||||
|
|
||||||
|
with patch("httpx.post", side_effect=mock_post):
|
||||||
|
result = adapter.classify("Rejection", "Sorry")
|
||||||
|
|
||||||
|
assert result == "rejected"
|
||||||
|
assert any("/allocate" in u for u in post_urls), "lazy load must call allocate"
|
||||||
|
assert adapter._exemplar_embeddings != {}
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue