feat(classifier): add EmbeddingKNNAdapter skeleton and constructor tests

This commit is contained in:
pyr0ball 2026-05-05 06:08:21 -07:00
parent c177fb1628
commit 72449561cf
2 changed files with 113 additions and 0 deletions

View file

@ -20,6 +20,7 @@ __all__ = [
"GLiClassAdapter",
"RerankerAdapter",
"FineTunedAdapter",
"EmbeddingKNNAdapter",
]
LABELS: list[str] = [
@ -381,3 +382,75 @@ class FineTunedAdapter(ClassifierAdapter):
text = f"{subject} [SEP] {body[:400]}"
result = self._pipeline(text)
return result[0]["label"]
class EmbeddingKNNAdapter(ClassifierAdapter):
"""k-NN email classifier using Ollama /v1/embeddings via cf-orch allocation.
load():
1. Allocates an Ollama instance from cf-orch (POST /api/services/ollama/allocate).
Falls back to ollama_url directly if orch allocation fails or is not configured.
2. Pre-embeds all exemplar texts and stores per-label vector lists.
classify(subject, body):
Embeds the input email, computes cosine similarity against all stored exemplar
vectors, and majority-votes the top-k labels (default k=3). Tie-break: label
with the highest mean similarity score among tied vote counts wins.
unload():
Releases the cf-orch allocation (DELETE .../allocations/{id}) and clears state.
"""
def __init__(
self,
name: str,
model_id: str,
*,
k: int = 3,
orch_url: str = "",
ollama_url: str = "",
exemplar_texts: dict[str, list[str]] | None = None,
) -> None:
self._name = name
self._model_id = model_id
self._k = k
self._orch_url = orch_url
self._ollama_url = ollama_url
self._exemplar_texts: dict[str, list[str]] = (
exemplar_texts if exemplar_texts is not None else DEFAULT_EXEMPLARS
)
self._exemplar_embeddings: dict[str, list[list[float]]] = {}
self._node_url: str = ""
self._allocation_id: str = ""
self._orch_url_used: str = ""
@property
def name(self) -> str:
return self._name
@property
def model_id(self) -> str:
return self._model_id
def _resolve_urls(self) -> tuple[str, str]:
if self._orch_url or self._ollama_url:
return self._orch_url, self._ollama_url
import yaml # noqa: PLC0415
cfg_path = Path(__file__).parent.parent / "config" / "label_tool.yaml"
cfg: dict = {}
if cfg_path.exists():
try:
cfg = yaml.safe_load(cfg_path.read_text(encoding="utf-8")) or {}
except yaml.YAMLError:
pass
cforch = cfg.get("cforch", {}) or {}
return cforch.get("coordinator_url", ""), cforch.get("ollama_url", "")
def load(self) -> None:
raise NotImplementedError
def unload(self) -> None:
raise NotImplementedError
def classify(self, subject: str, body: str) -> str:
raise NotImplementedError

View file

@ -322,3 +322,43 @@ def test_default_exemplars_strings_are_formatted_correctly():
assert "\n\n" in text, (
f"{label!r} exemplar missing double-newline separator: {text[:50]!r}"
)
# ---- EmbeddingKNNAdapter constructor tests ----
def test_embedding_knn_is_classifier_adapter():
from scripts.classifier_adapters import EmbeddingKNNAdapter, ClassifierAdapter
adapter = EmbeddingKNNAdapter(
"test-knn", "nomic-embed-text",
k=3, orch_url="http://orch:7700", ollama_url="http://ollama:11434",
)
assert isinstance(adapter, ClassifierAdapter)
def test_embedding_knn_name_and_model_id():
from scripts.classifier_adapters import EmbeddingKNNAdapter
adapter = EmbeddingKNNAdapter(
"embed-knn-nomic", "nomic-embed-text",
k=3, orch_url="http://orch:7700", ollama_url="http://ollama:11434",
)
assert adapter.name == "embed-knn-nomic"
assert adapter.model_id == "nomic-embed-text"
def test_embedding_knn_uses_default_exemplars_when_none_given():
from scripts.classifier_adapters import EmbeddingKNNAdapter, DEFAULT_EXEMPLARS
adapter = EmbeddingKNNAdapter(
"test", "nomic-embed-text",
k=3, orch_url="http://orch:7700", ollama_url="http://ollama:11434",
)
assert adapter._exemplar_texts is DEFAULT_EXEMPLARS
def test_embedding_knn_accepts_custom_exemplars():
from scripts.classifier_adapters import EmbeddingKNNAdapter
custom = {"rejected": ["Sorry, we went with others."]}
adapter = EmbeddingKNNAdapter(
"test", "nomic-embed-text",
k=3, orch_url="http://orch:7700", ollama_url="http://ollama:11434",
exemplar_texts=custom,
)
assert adapter._exemplar_texts is custom