fix(classifier): majority-vote key, partial-load guard, sparse label test

This commit is contained in:
pyr0ball 2026-05-05 11:39:24 -07:00
parent 88bc6bed67
commit e823b5e76d
2 changed files with 17 additions and 9 deletions

View file

@ -399,7 +399,7 @@ class EmbeddingKNNAdapter(ClassifierAdapter):
classify(subject, body):
Embeds the input email, computes cosine similarity against all stored exemplar
vectors, and majority-votes the top-k labels (default k=3). Tie-break: label
with the highest mean similarity score among tied vote counts wins.
with the highest total similarity score among tied vote counts wins.
unload():
Releases the cf-orch allocation (DELETE .../allocations/{id}) and clears state.
@ -489,10 +489,14 @@ class EmbeddingKNNAdapter(ClassifierAdapter):
orch_url_used = ""
self._node_url = node_url
self._orch_url_used = orch_url_used
embeddings: dict[str, list[list[float]]] = {}
for label, texts in self._exemplar_texts.items():
embeddings[label] = self._embed(node_url, texts)
self._exemplar_embeddings = embeddings
try:
embeddings: dict[str, list[list[float]]] = {}
for label, texts in self._exemplar_texts.items():
embeddings[label] = self._embed(node_url, texts)
self._exemplar_embeddings = embeddings
except Exception:
self.unload()
raise
def unload(self) -> None:
if self._allocation_id and self._orch_url_used:
@ -523,4 +527,7 @@ class EmbeddingKNNAdapter(ClassifierAdapter):
votes: dict[str, list[float]] = {}
for score, label in top_k:
votes.setdefault(label, []).append(score)
return max(votes, key=lambda lbl: sum(votes[lbl]))
return max(
votes,
key=lambda lbl: (len(votes[lbl]), sum(votes[lbl])),
)

View file

@ -592,13 +592,13 @@ def test_classify_tiebreak_by_mean_score():
def test_classify_sparse_label_can_win():
from unittest.mock import patch
# "hired" has only 1 exemplar; query vector is closest to it
# "hired" has only 1 exemplar; with k=1, the single closest match wins
adapter = _adapter_with_embeddings({
"rejected": [[0.0, 0.0, 1.0], [0.0, 0.1, 0.9]],
"hired": [[1.0, 0.0, 0.0]],
}, k=3)
}, k=1)
# Query [1,0,0] → hired exemplar scores 1.0; rejected exemplars score ~0
# Query [1,0,0] → hired exemplar scores 1.0; closest single match wins
with patch("httpx.post", return_value=_embed_resp([1.0, 0.0, 0.0])):
result = adapter.classify("Welcome aboard", "Your first day details")
@ -637,3 +637,4 @@ def test_classify_lazy_loads_when_not_loaded():
assert result == "rejected"
assert any("/allocate" in u for u in post_urls), "lazy load must call allocate"
assert adapter._exemplar_embeddings != {}
assert adapter._node_url == "http://navi:11434"