From 72449561cfc1a5ed8a553472e5243b265448ce59 Mon Sep 17 00:00:00 2001 From: pyr0ball Date: Tue, 5 May 2026 06:08:21 -0700 Subject: [PATCH] feat(classifier): add EmbeddingKNNAdapter skeleton and constructor tests --- scripts/classifier_adapters.py | 73 +++++++++++++++++++++++++++++++ tests/test_classifier_adapters.py | 40 +++++++++++++++++ 2 files changed, 113 insertions(+) diff --git a/scripts/classifier_adapters.py b/scripts/classifier_adapters.py index 3021ec4..b612652 100644 --- a/scripts/classifier_adapters.py +++ b/scripts/classifier_adapters.py @@ -20,6 +20,7 @@ __all__ = [ "GLiClassAdapter", "RerankerAdapter", "FineTunedAdapter", + "EmbeddingKNNAdapter", ] LABELS: list[str] = [ @@ -381,3 +382,75 @@ class FineTunedAdapter(ClassifierAdapter): text = f"{subject} [SEP] {body[:400]}" result = self._pipeline(text) return result[0]["label"] + + +class EmbeddingKNNAdapter(ClassifierAdapter): + """k-NN email classifier using Ollama /v1/embeddings via cf-orch allocation. + + load(): + 1. Allocates an Ollama instance from cf-orch (POST /api/services/ollama/allocate). + Falls back to ollama_url directly if orch allocation fails or is not configured. + 2. Pre-embeds all exemplar texts and stores per-label vector lists. + + classify(subject, body): + Embeds the input email, computes cosine similarity against all stored exemplar + vectors, and majority-votes the top-k labels (default k=3). Tie-break: label + with the highest mean similarity score among tied vote counts wins. + + unload(): + Releases the cf-orch allocation (DELETE .../allocations/{id}) and clears state. + """ + + def __init__( + self, + name: str, + model_id: str, + *, + k: int = 3, + orch_url: str = "", + ollama_url: str = "", + exemplar_texts: dict[str, list[str]] | None = None, + ) -> None: + self._name = name + self._model_id = model_id + self._k = k + self._orch_url = orch_url + self._ollama_url = ollama_url + self._exemplar_texts: dict[str, list[str]] = ( + exemplar_texts if exemplar_texts is not None else DEFAULT_EXEMPLARS + ) + self._exemplar_embeddings: dict[str, list[list[float]]] = {} + self._node_url: str = "" + self._allocation_id: str = "" + self._orch_url_used: str = "" + + @property + def name(self) -> str: + return self._name + + @property + def model_id(self) -> str: + return self._model_id + + def _resolve_urls(self) -> tuple[str, str]: + if self._orch_url or self._ollama_url: + return self._orch_url, self._ollama_url + import yaml # noqa: PLC0415 + cfg_path = Path(__file__).parent.parent / "config" / "label_tool.yaml" + cfg: dict = {} + if cfg_path.exists(): + try: + cfg = yaml.safe_load(cfg_path.read_text(encoding="utf-8")) or {} + except yaml.YAMLError: + pass + cforch = cfg.get("cforch", {}) or {} + return cforch.get("coordinator_url", ""), cforch.get("ollama_url", "") + + def load(self) -> None: + raise NotImplementedError + + def unload(self) -> None: + raise NotImplementedError + + def classify(self, subject: str, body: str) -> str: + raise NotImplementedError diff --git a/tests/test_classifier_adapters.py b/tests/test_classifier_adapters.py index aaea9ba..b19dc51 100644 --- a/tests/test_classifier_adapters.py +++ b/tests/test_classifier_adapters.py @@ -322,3 +322,43 @@ def test_default_exemplars_strings_are_formatted_correctly(): assert "\n\n" in text, ( f"{label!r} exemplar missing double-newline separator: {text[:50]!r}" ) + +# ---- EmbeddingKNNAdapter constructor tests ---- + +def test_embedding_knn_is_classifier_adapter(): + from scripts.classifier_adapters import EmbeddingKNNAdapter, ClassifierAdapter + adapter = EmbeddingKNNAdapter( + "test-knn", "nomic-embed-text", + k=3, orch_url="http://orch:7700", ollama_url="http://ollama:11434", + ) + assert isinstance(adapter, ClassifierAdapter) + + +def test_embedding_knn_name_and_model_id(): + from scripts.classifier_adapters import EmbeddingKNNAdapter + adapter = EmbeddingKNNAdapter( + "embed-knn-nomic", "nomic-embed-text", + k=3, orch_url="http://orch:7700", ollama_url="http://ollama:11434", + ) + assert adapter.name == "embed-knn-nomic" + assert adapter.model_id == "nomic-embed-text" + + +def test_embedding_knn_uses_default_exemplars_when_none_given(): + from scripts.classifier_adapters import EmbeddingKNNAdapter, DEFAULT_EXEMPLARS + adapter = EmbeddingKNNAdapter( + "test", "nomic-embed-text", + k=3, orch_url="http://orch:7700", ollama_url="http://ollama:11434", + ) + assert adapter._exemplar_texts is DEFAULT_EXEMPLARS + + +def test_embedding_knn_accepts_custom_exemplars(): + from scripts.classifier_adapters import EmbeddingKNNAdapter + custom = {"rejected": ["Sorry, we went with others."]} + adapter = EmbeddingKNNAdapter( + "test", "nomic-embed-text", + k=3, orch_url="http://orch:7700", ollama_url="http://ollama:11434", + exemplar_texts=custom, + ) + assert adapter._exemplar_texts is custom