feat(classifier): add EmbeddingKNNAdapter skeleton and constructor tests
This commit is contained in:
parent
c177fb1628
commit
72449561cf
2 changed files with 113 additions and 0 deletions
|
|
@ -20,6 +20,7 @@ __all__ = [
|
||||||
"GLiClassAdapter",
|
"GLiClassAdapter",
|
||||||
"RerankerAdapter",
|
"RerankerAdapter",
|
||||||
"FineTunedAdapter",
|
"FineTunedAdapter",
|
||||||
|
"EmbeddingKNNAdapter",
|
||||||
]
|
]
|
||||||
|
|
||||||
LABELS: list[str] = [
|
LABELS: list[str] = [
|
||||||
|
|
@ -381,3 +382,75 @@ class FineTunedAdapter(ClassifierAdapter):
|
||||||
text = f"{subject} [SEP] {body[:400]}"
|
text = f"{subject} [SEP] {body[:400]}"
|
||||||
result = self._pipeline(text)
|
result = self._pipeline(text)
|
||||||
return result[0]["label"]
|
return result[0]["label"]
|
||||||
|
|
||||||
|
|
||||||
|
class EmbeddingKNNAdapter(ClassifierAdapter):
|
||||||
|
"""k-NN email classifier using Ollama /v1/embeddings via cf-orch allocation.
|
||||||
|
|
||||||
|
load():
|
||||||
|
1. Allocates an Ollama instance from cf-orch (POST /api/services/ollama/allocate).
|
||||||
|
Falls back to ollama_url directly if orch allocation fails or is not configured.
|
||||||
|
2. Pre-embeds all exemplar texts and stores per-label vector lists.
|
||||||
|
|
||||||
|
classify(subject, body):
|
||||||
|
Embeds the input email, computes cosine similarity against all stored exemplar
|
||||||
|
vectors, and majority-votes the top-k labels (default k=3). Tie-break: label
|
||||||
|
with the highest mean similarity score among tied vote counts wins.
|
||||||
|
|
||||||
|
unload():
|
||||||
|
Releases the cf-orch allocation (DELETE .../allocations/{id}) and clears state.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
name: str,
|
||||||
|
model_id: str,
|
||||||
|
*,
|
||||||
|
k: int = 3,
|
||||||
|
orch_url: str = "",
|
||||||
|
ollama_url: str = "",
|
||||||
|
exemplar_texts: dict[str, list[str]] | None = None,
|
||||||
|
) -> None:
|
||||||
|
self._name = name
|
||||||
|
self._model_id = model_id
|
||||||
|
self._k = k
|
||||||
|
self._orch_url = orch_url
|
||||||
|
self._ollama_url = ollama_url
|
||||||
|
self._exemplar_texts: dict[str, list[str]] = (
|
||||||
|
exemplar_texts if exemplar_texts is not None else DEFAULT_EXEMPLARS
|
||||||
|
)
|
||||||
|
self._exemplar_embeddings: dict[str, list[list[float]]] = {}
|
||||||
|
self._node_url: str = ""
|
||||||
|
self._allocation_id: str = ""
|
||||||
|
self._orch_url_used: str = ""
|
||||||
|
|
||||||
|
@property
|
||||||
|
def name(self) -> str:
|
||||||
|
return self._name
|
||||||
|
|
||||||
|
@property
|
||||||
|
def model_id(self) -> str:
|
||||||
|
return self._model_id
|
||||||
|
|
||||||
|
def _resolve_urls(self) -> tuple[str, str]:
|
||||||
|
if self._orch_url or self._ollama_url:
|
||||||
|
return self._orch_url, self._ollama_url
|
||||||
|
import yaml # noqa: PLC0415
|
||||||
|
cfg_path = Path(__file__).parent.parent / "config" / "label_tool.yaml"
|
||||||
|
cfg: dict = {}
|
||||||
|
if cfg_path.exists():
|
||||||
|
try:
|
||||||
|
cfg = yaml.safe_load(cfg_path.read_text(encoding="utf-8")) or {}
|
||||||
|
except yaml.YAMLError:
|
||||||
|
pass
|
||||||
|
cforch = cfg.get("cforch", {}) or {}
|
||||||
|
return cforch.get("coordinator_url", ""), cforch.get("ollama_url", "")
|
||||||
|
|
||||||
|
def load(self) -> None:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def unload(self) -> None:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def classify(self, subject: str, body: str) -> str:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
|
||||||
|
|
@ -322,3 +322,43 @@ def test_default_exemplars_strings_are_formatted_correctly():
|
||||||
assert "\n\n" in text, (
|
assert "\n\n" in text, (
|
||||||
f"{label!r} exemplar missing double-newline separator: {text[:50]!r}"
|
f"{label!r} exemplar missing double-newline separator: {text[:50]!r}"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# ---- EmbeddingKNNAdapter constructor tests ----
|
||||||
|
|
||||||
|
def test_embedding_knn_is_classifier_adapter():
|
||||||
|
from scripts.classifier_adapters import EmbeddingKNNAdapter, ClassifierAdapter
|
||||||
|
adapter = EmbeddingKNNAdapter(
|
||||||
|
"test-knn", "nomic-embed-text",
|
||||||
|
k=3, orch_url="http://orch:7700", ollama_url="http://ollama:11434",
|
||||||
|
)
|
||||||
|
assert isinstance(adapter, ClassifierAdapter)
|
||||||
|
|
||||||
|
|
||||||
|
def test_embedding_knn_name_and_model_id():
|
||||||
|
from scripts.classifier_adapters import EmbeddingKNNAdapter
|
||||||
|
adapter = EmbeddingKNNAdapter(
|
||||||
|
"embed-knn-nomic", "nomic-embed-text",
|
||||||
|
k=3, orch_url="http://orch:7700", ollama_url="http://ollama:11434",
|
||||||
|
)
|
||||||
|
assert adapter.name == "embed-knn-nomic"
|
||||||
|
assert adapter.model_id == "nomic-embed-text"
|
||||||
|
|
||||||
|
|
||||||
|
def test_embedding_knn_uses_default_exemplars_when_none_given():
|
||||||
|
from scripts.classifier_adapters import EmbeddingKNNAdapter, DEFAULT_EXEMPLARS
|
||||||
|
adapter = EmbeddingKNNAdapter(
|
||||||
|
"test", "nomic-embed-text",
|
||||||
|
k=3, orch_url="http://orch:7700", ollama_url="http://ollama:11434",
|
||||||
|
)
|
||||||
|
assert adapter._exemplar_texts is DEFAULT_EXEMPLARS
|
||||||
|
|
||||||
|
|
||||||
|
def test_embedding_knn_accepts_custom_exemplars():
|
||||||
|
from scripts.classifier_adapters import EmbeddingKNNAdapter
|
||||||
|
custom = {"rejected": ["Sorry, we went with others."]}
|
||||||
|
adapter = EmbeddingKNNAdapter(
|
||||||
|
"test", "nomic-embed-text",
|
||||||
|
k=3, orch_url="http://orch:7700", ollama_url="http://ollama:11434",
|
||||||
|
exemplar_texts=custom,
|
||||||
|
)
|
||||||
|
assert adapter._exemplar_texts is custom
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue