diff --git a/scripts/classifier_adapters.py b/scripts/classifier_adapters.py index e2a8c9c..939c739 100644 --- a/scripts/classifier_adapters.py +++ b/scripts/classifier_adapters.py @@ -399,7 +399,7 @@ class EmbeddingKNNAdapter(ClassifierAdapter): classify(subject, body): Embeds the input email, computes cosine similarity against all stored exemplar vectors, and majority-votes the top-k labels (default k=3). Tie-break: label - with the highest mean similarity score among tied vote counts wins. + with the highest total similarity score among tied vote counts wins. unload(): Releases the cf-orch allocation (DELETE .../allocations/{id}) and clears state. @@ -489,10 +489,14 @@ class EmbeddingKNNAdapter(ClassifierAdapter): orch_url_used = "" self._node_url = node_url self._orch_url_used = orch_url_used - embeddings: dict[str, list[list[float]]] = {} - for label, texts in self._exemplar_texts.items(): - embeddings[label] = self._embed(node_url, texts) - self._exemplar_embeddings = embeddings + try: + embeddings: dict[str, list[list[float]]] = {} + for label, texts in self._exemplar_texts.items(): + embeddings[label] = self._embed(node_url, texts) + self._exemplar_embeddings = embeddings + except Exception: + self.unload() + raise def unload(self) -> None: if self._allocation_id and self._orch_url_used: @@ -523,4 +527,7 @@ class EmbeddingKNNAdapter(ClassifierAdapter): votes: dict[str, list[float]] = {} for score, label in top_k: votes.setdefault(label, []).append(score) - return max(votes, key=lambda lbl: sum(votes[lbl])) + return max( + votes, + key=lambda lbl: (len(votes[lbl]), sum(votes[lbl])), + ) diff --git a/tests/test_classifier_adapters.py b/tests/test_classifier_adapters.py index 3163362..86e778b 100644 --- a/tests/test_classifier_adapters.py +++ b/tests/test_classifier_adapters.py @@ -592,13 +592,13 @@ def test_classify_tiebreak_by_mean_score(): def test_classify_sparse_label_can_win(): from unittest.mock import patch - # "hired" has only 1 exemplar; query vector is closest to it + # "hired" has only 1 exemplar; with k=1, the single closest match wins adapter = _adapter_with_embeddings({ "rejected": [[0.0, 0.0, 1.0], [0.0, 0.1, 0.9]], "hired": [[1.0, 0.0, 0.0]], - }, k=3) + }, k=1) - # Query [1,0,0] → hired exemplar scores 1.0; rejected exemplars score ~0 + # Query [1,0,0] → hired exemplar scores 1.0; closest single match wins with patch("httpx.post", return_value=_embed_resp([1.0, 0.0, 0.0])): result = adapter.classify("Welcome aboard", "Your first day details") @@ -637,3 +637,4 @@ def test_classify_lazy_loads_when_not_loaded(): assert result == "rejected" assert any("/allocate" in u for u in post_urls), "lazy load must call allocate" assert adapter._exemplar_embeddings != {} + assert adapter._node_url == "http://navi:11434"