fix(classifier): majority-vote key, partial-load guard, sparse label test
This commit is contained in:
parent
88bc6bed67
commit
e823b5e76d
2 changed files with 17 additions and 9 deletions
|
|
@ -399,7 +399,7 @@ class EmbeddingKNNAdapter(ClassifierAdapter):
|
||||||
classify(subject, body):
|
classify(subject, body):
|
||||||
Embeds the input email, computes cosine similarity against all stored exemplar
|
Embeds the input email, computes cosine similarity against all stored exemplar
|
||||||
vectors, and majority-votes the top-k labels (default k=3). Tie-break: label
|
vectors, and majority-votes the top-k labels (default k=3). Tie-break: label
|
||||||
with the highest mean similarity score among tied vote counts wins.
|
with the highest total similarity score among tied vote counts wins.
|
||||||
|
|
||||||
unload():
|
unload():
|
||||||
Releases the cf-orch allocation (DELETE .../allocations/{id}) and clears state.
|
Releases the cf-orch allocation (DELETE .../allocations/{id}) and clears state.
|
||||||
|
|
@ -489,10 +489,14 @@ class EmbeddingKNNAdapter(ClassifierAdapter):
|
||||||
orch_url_used = ""
|
orch_url_used = ""
|
||||||
self._node_url = node_url
|
self._node_url = node_url
|
||||||
self._orch_url_used = orch_url_used
|
self._orch_url_used = orch_url_used
|
||||||
embeddings: dict[str, list[list[float]]] = {}
|
try:
|
||||||
for label, texts in self._exemplar_texts.items():
|
embeddings: dict[str, list[list[float]]] = {}
|
||||||
embeddings[label] = self._embed(node_url, texts)
|
for label, texts in self._exemplar_texts.items():
|
||||||
self._exemplar_embeddings = embeddings
|
embeddings[label] = self._embed(node_url, texts)
|
||||||
|
self._exemplar_embeddings = embeddings
|
||||||
|
except Exception:
|
||||||
|
self.unload()
|
||||||
|
raise
|
||||||
|
|
||||||
def unload(self) -> None:
|
def unload(self) -> None:
|
||||||
if self._allocation_id and self._orch_url_used:
|
if self._allocation_id and self._orch_url_used:
|
||||||
|
|
@ -523,4 +527,7 @@ class EmbeddingKNNAdapter(ClassifierAdapter):
|
||||||
votes: dict[str, list[float]] = {}
|
votes: dict[str, list[float]] = {}
|
||||||
for score, label in top_k:
|
for score, label in top_k:
|
||||||
votes.setdefault(label, []).append(score)
|
votes.setdefault(label, []).append(score)
|
||||||
return max(votes, key=lambda lbl: sum(votes[lbl]))
|
return max(
|
||||||
|
votes,
|
||||||
|
key=lambda lbl: (len(votes[lbl]), sum(votes[lbl])),
|
||||||
|
)
|
||||||
|
|
|
||||||
|
|
@ -592,13 +592,13 @@ def test_classify_tiebreak_by_mean_score():
|
||||||
|
|
||||||
def test_classify_sparse_label_can_win():
|
def test_classify_sparse_label_can_win():
|
||||||
from unittest.mock import patch
|
from unittest.mock import patch
|
||||||
# "hired" has only 1 exemplar; query vector is closest to it
|
# "hired" has only 1 exemplar; with k=1, the single closest match wins
|
||||||
adapter = _adapter_with_embeddings({
|
adapter = _adapter_with_embeddings({
|
||||||
"rejected": [[0.0, 0.0, 1.0], [0.0, 0.1, 0.9]],
|
"rejected": [[0.0, 0.0, 1.0], [0.0, 0.1, 0.9]],
|
||||||
"hired": [[1.0, 0.0, 0.0]],
|
"hired": [[1.0, 0.0, 0.0]],
|
||||||
}, k=3)
|
}, k=1)
|
||||||
|
|
||||||
# Query [1,0,0] → hired exemplar scores 1.0; rejected exemplars score ~0
|
# Query [1,0,0] → hired exemplar scores 1.0; closest single match wins
|
||||||
with patch("httpx.post", return_value=_embed_resp([1.0, 0.0, 0.0])):
|
with patch("httpx.post", return_value=_embed_resp([1.0, 0.0, 0.0])):
|
||||||
result = adapter.classify("Welcome aboard", "Your first day details")
|
result = adapter.classify("Welcome aboard", "Your first day details")
|
||||||
|
|
||||||
|
|
@ -637,3 +637,4 @@ def test_classify_lazy_loads_when_not_loaded():
|
||||||
assert result == "rejected"
|
assert result == "rejected"
|
||||||
assert any("/allocate" in u for u in post_urls), "lazy load must call allocate"
|
assert any("/allocate" in u for u in post_urls), "lazy load must call allocate"
|
||||||
assert adapter._exemplar_embeddings != {}
|
assert adapter._exemplar_embeddings != {}
|
||||||
|
assert adapter._node_url == "http://navi:11434"
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue