fix(classifier): majority-vote key, partial-load guard, sparse label test
This commit is contained in:
parent
88bc6bed67
commit
e823b5e76d
2 changed files with 17 additions and 9 deletions
|
|
@ -399,7 +399,7 @@ class EmbeddingKNNAdapter(ClassifierAdapter):
|
|||
classify(subject, body):
|
||||
Embeds the input email, computes cosine similarity against all stored exemplar
|
||||
vectors, and majority-votes the top-k labels (default k=3). Tie-break: label
|
||||
with the highest mean similarity score among tied vote counts wins.
|
||||
with the highest total similarity score among tied vote counts wins.
|
||||
|
||||
unload():
|
||||
Releases the cf-orch allocation (DELETE .../allocations/{id}) and clears state.
|
||||
|
|
@ -489,10 +489,14 @@ class EmbeddingKNNAdapter(ClassifierAdapter):
|
|||
orch_url_used = ""
|
||||
self._node_url = node_url
|
||||
self._orch_url_used = orch_url_used
|
||||
try:
|
||||
embeddings: dict[str, list[list[float]]] = {}
|
||||
for label, texts in self._exemplar_texts.items():
|
||||
embeddings[label] = self._embed(node_url, texts)
|
||||
self._exemplar_embeddings = embeddings
|
||||
except Exception:
|
||||
self.unload()
|
||||
raise
|
||||
|
||||
def unload(self) -> None:
|
||||
if self._allocation_id and self._orch_url_used:
|
||||
|
|
@ -523,4 +527,7 @@ class EmbeddingKNNAdapter(ClassifierAdapter):
|
|||
votes: dict[str, list[float]] = {}
|
||||
for score, label in top_k:
|
||||
votes.setdefault(label, []).append(score)
|
||||
return max(votes, key=lambda lbl: sum(votes[lbl]))
|
||||
return max(
|
||||
votes,
|
||||
key=lambda lbl: (len(votes[lbl]), sum(votes[lbl])),
|
||||
)
|
||||
|
|
|
|||
|
|
@ -592,13 +592,13 @@ def test_classify_tiebreak_by_mean_score():
|
|||
|
||||
def test_classify_sparse_label_can_win():
|
||||
from unittest.mock import patch
|
||||
# "hired" has only 1 exemplar; query vector is closest to it
|
||||
# "hired" has only 1 exemplar; with k=1, the single closest match wins
|
||||
adapter = _adapter_with_embeddings({
|
||||
"rejected": [[0.0, 0.0, 1.0], [0.0, 0.1, 0.9]],
|
||||
"hired": [[1.0, 0.0, 0.0]],
|
||||
}, k=3)
|
||||
}, k=1)
|
||||
|
||||
# Query [1,0,0] → hired exemplar scores 1.0; rejected exemplars score ~0
|
||||
# Query [1,0,0] → hired exemplar scores 1.0; closest single match wins
|
||||
with patch("httpx.post", return_value=_embed_resp([1.0, 0.0, 0.0])):
|
||||
result = adapter.classify("Welcome aboard", "Your first day details")
|
||||
|
||||
|
|
@ -637,3 +637,4 @@ def test_classify_lazy_loads_when_not_loaded():
|
|||
assert result == "rejected"
|
||||
assert any("/allocate" in u for u in post_urls), "lazy load must call allocate"
|
||||
assert adapter._exemplar_embeddings != {}
|
||||
assert adapter._node_url == "http://navi:11434"
|
||||
|
|
|
|||
Loading…
Reference in a new issue