diff --git a/scripts/classifier_adapters.py b/scripts/classifier_adapters.py index 3b3933c..e2a8c9c 100644 --- a/scripts/classifier_adapters.py +++ b/scripts/classifier_adapters.py @@ -510,4 +510,17 @@ class EmbeddingKNNAdapter(ClassifierAdapter): self._orch_url_used = "" def classify(self, subject: str, body: str) -> str: - raise NotImplementedError + if not self._exemplar_embeddings: + self.load() + text = f"Subject: {subject}\n\n{body[:600]}" + [query_vec] = self._embed(self._node_url, [text]) + scored: list[tuple[float, str]] = [ + (_cosine(query_vec, vec), label) + for label, vecs in self._exemplar_embeddings.items() + for vec in vecs + ] + top_k = sorted(scored, reverse=True)[: self._k] + votes: dict[str, list[float]] = {} + for score, label in top_k: + votes.setdefault(label, []).append(score) + return max(votes, key=lambda lbl: sum(votes[lbl])) diff --git a/tests/test_classifier_adapters.py b/tests/test_classifier_adapters.py index feec32a..3163362 100644 --- a/tests/test_classifier_adapters.py +++ b/tests/test_classifier_adapters.py @@ -536,3 +536,104 @@ def test_unload_skips_delete_on_ollama_fallback_path(): assert len(delete_calls) == 0 assert adapter._exemplar_embeddings == {} assert adapter._node_url == "" + + +# ---- EmbeddingKNNAdapter.classify() tests ---- + +def _adapter_with_embeddings(exemplar_embeddings, k=3): + """Return a pre-loaded EmbeddingKNNAdapter (bypass load()) with given per-label vectors.""" + from scripts.classifier_adapters import EmbeddingKNNAdapter + adapter = EmbeddingKNNAdapter( + "test", "nomic-embed-text", k=k, + orch_url="http://orch:7700", ollama_url="http://ollama:11434", + ) + adapter._exemplar_embeddings = exemplar_embeddings + adapter._node_url = "http://navi:11434" + return adapter + + +def _embed_resp(vec): + """Return a mock httpx response for /v1/embeddings returning a single vector.""" + from unittest.mock import MagicMock + resp = MagicMock() + resp.raise_for_status.return_value = None + resp.json.return_value = {"data": [{"embedding": vec}]} + return resp + + +def test_classify_returns_majority_vote_label(): + from unittest.mock import patch + adapter = _adapter_with_embeddings({ + "rejected": [[1.0, 0.0, 0.0], [0.9, 0.1, 0.0], [0.85, 0.15, 0.0]], + "neutral": [[0.0, 1.0, 0.0]], + }, k=3) + + # Query [1,0,0] is closest to all three "rejected" exemplars + with patch("httpx.post", return_value=_embed_resp([1.0, 0.0, 0.0])): + result = adapter.classify("We went with others", "Thank you for applying.") + + assert result == "rejected" + + +def test_classify_tiebreak_by_mean_score(): + from unittest.mock import patch + # k=2: each label gets exactly 1 vote → tie-break by mean similarity + # [1,0] query: cosine to [1,0] = 1.0 ("rejected"), cosine to [0.6,0.8] ≈ 0.6 ("neutral") + adapter = _adapter_with_embeddings({ + "rejected": [[1.0, 0.0]], + "neutral": [[0.6, 0.8]], + }, k=2) + + with patch("httpx.post", return_value=_embed_resp([1.0, 0.0])): + result = adapter.classify("Rejection", "Sorry") + + assert result == "rejected" + + +def test_classify_sparse_label_can_win(): + from unittest.mock import patch + # "hired" has only 1 exemplar; query vector is closest to it + adapter = _adapter_with_embeddings({ + "rejected": [[0.0, 0.0, 1.0], [0.0, 0.1, 0.9]], + "hired": [[1.0, 0.0, 0.0]], + }, k=3) + + # Query [1,0,0] → hired exemplar scores 1.0; rejected exemplars score ~0 + with patch("httpx.post", return_value=_embed_resp([1.0, 0.0, 0.0])): + result = adapter.classify("Welcome aboard", "Your first day details") + + assert result == "hired" + + +def test_classify_lazy_loads_when_not_loaded(): + from unittest.mock import patch + from scripts.classifier_adapters import EmbeddingKNNAdapter + + exemplars = {"rejected": ["We went with others"]} + adapter = EmbeddingKNNAdapter( + "test", "nomic-embed-text", k=1, + orch_url="http://orch:7700", ollama_url="http://ollama:11434", + exemplar_texts=exemplars, + ) + assert adapter._exemplar_embeddings == {} + + post_urls = [] + def mock_post(url, *, json=None, timeout=None, **kwargs): + post_urls.append(url) + from unittest.mock import MagicMock + resp = MagicMock() + resp.raise_for_status.return_value = None + if "/allocate" in url: + resp.status_code = 200 + resp.json.return_value = {"allocation_id": "a1", "url": "http://navi:11434"} + else: + n = len((json or {}).get("input", [])) + resp.json.return_value = {"data": [{"embedding": [1.0, 0.0]}] * n} + return resp + + with patch("httpx.post", side_effect=mock_post): + result = adapter.classify("Rejection", "Sorry") + + assert result == "rejected" + assert any("/allocate" in u for u in post_urls), "lazy load must call allocate" + assert adapter._exemplar_embeddings != {}