feat(classifier): add EmbeddingKNNAdapter skeleton and constructor tests
This commit is contained in:
parent
c177fb1628
commit
72449561cf
2 changed files with 113 additions and 0 deletions
|
|
@ -20,6 +20,7 @@ __all__ = [
|
|||
"GLiClassAdapter",
|
||||
"RerankerAdapter",
|
||||
"FineTunedAdapter",
|
||||
"EmbeddingKNNAdapter",
|
||||
]
|
||||
|
||||
LABELS: list[str] = [
|
||||
|
|
@ -381,3 +382,75 @@ class FineTunedAdapter(ClassifierAdapter):
|
|||
text = f"{subject} [SEP] {body[:400]}"
|
||||
result = self._pipeline(text)
|
||||
return result[0]["label"]
|
||||
|
||||
|
||||
class EmbeddingKNNAdapter(ClassifierAdapter):
|
||||
"""k-NN email classifier using Ollama /v1/embeddings via cf-orch allocation.
|
||||
|
||||
load():
|
||||
1. Allocates an Ollama instance from cf-orch (POST /api/services/ollama/allocate).
|
||||
Falls back to ollama_url directly if orch allocation fails or is not configured.
|
||||
2. Pre-embeds all exemplar texts and stores per-label vector lists.
|
||||
|
||||
classify(subject, body):
|
||||
Embeds the input email, computes cosine similarity against all stored exemplar
|
||||
vectors, and majority-votes the top-k labels (default k=3). Tie-break: label
|
||||
with the highest mean similarity score among tied vote counts wins.
|
||||
|
||||
unload():
|
||||
Releases the cf-orch allocation (DELETE .../allocations/{id}) and clears state.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
name: str,
|
||||
model_id: str,
|
||||
*,
|
||||
k: int = 3,
|
||||
orch_url: str = "",
|
||||
ollama_url: str = "",
|
||||
exemplar_texts: dict[str, list[str]] | None = None,
|
||||
) -> None:
|
||||
self._name = name
|
||||
self._model_id = model_id
|
||||
self._k = k
|
||||
self._orch_url = orch_url
|
||||
self._ollama_url = ollama_url
|
||||
self._exemplar_texts: dict[str, list[str]] = (
|
||||
exemplar_texts if exemplar_texts is not None else DEFAULT_EXEMPLARS
|
||||
)
|
||||
self._exemplar_embeddings: dict[str, list[list[float]]] = {}
|
||||
self._node_url: str = ""
|
||||
self._allocation_id: str = ""
|
||||
self._orch_url_used: str = ""
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return self._name
|
||||
|
||||
@property
|
||||
def model_id(self) -> str:
|
||||
return self._model_id
|
||||
|
||||
def _resolve_urls(self) -> tuple[str, str]:
|
||||
if self._orch_url or self._ollama_url:
|
||||
return self._orch_url, self._ollama_url
|
||||
import yaml # noqa: PLC0415
|
||||
cfg_path = Path(__file__).parent.parent / "config" / "label_tool.yaml"
|
||||
cfg: dict = {}
|
||||
if cfg_path.exists():
|
||||
try:
|
||||
cfg = yaml.safe_load(cfg_path.read_text(encoding="utf-8")) or {}
|
||||
except yaml.YAMLError:
|
||||
pass
|
||||
cforch = cfg.get("cforch", {}) or {}
|
||||
return cforch.get("coordinator_url", ""), cforch.get("ollama_url", "")
|
||||
|
||||
def load(self) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
def unload(self) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
def classify(self, subject: str, body: str) -> str:
|
||||
raise NotImplementedError
|
||||
|
|
|
|||
|
|
@ -322,3 +322,43 @@ def test_default_exemplars_strings_are_formatted_correctly():
|
|||
assert "\n\n" in text, (
|
||||
f"{label!r} exemplar missing double-newline separator: {text[:50]!r}"
|
||||
)
|
||||
|
||||
# ---- EmbeddingKNNAdapter constructor tests ----
|
||||
|
||||
def test_embedding_knn_is_classifier_adapter():
|
||||
from scripts.classifier_adapters import EmbeddingKNNAdapter, ClassifierAdapter
|
||||
adapter = EmbeddingKNNAdapter(
|
||||
"test-knn", "nomic-embed-text",
|
||||
k=3, orch_url="http://orch:7700", ollama_url="http://ollama:11434",
|
||||
)
|
||||
assert isinstance(adapter, ClassifierAdapter)
|
||||
|
||||
|
||||
def test_embedding_knn_name_and_model_id():
|
||||
from scripts.classifier_adapters import EmbeddingKNNAdapter
|
||||
adapter = EmbeddingKNNAdapter(
|
||||
"embed-knn-nomic", "nomic-embed-text",
|
||||
k=3, orch_url="http://orch:7700", ollama_url="http://ollama:11434",
|
||||
)
|
||||
assert adapter.name == "embed-knn-nomic"
|
||||
assert adapter.model_id == "nomic-embed-text"
|
||||
|
||||
|
||||
def test_embedding_knn_uses_default_exemplars_when_none_given():
|
||||
from scripts.classifier_adapters import EmbeddingKNNAdapter, DEFAULT_EXEMPLARS
|
||||
adapter = EmbeddingKNNAdapter(
|
||||
"test", "nomic-embed-text",
|
||||
k=3, orch_url="http://orch:7700", ollama_url="http://ollama:11434",
|
||||
)
|
||||
assert adapter._exemplar_texts is DEFAULT_EXEMPLARS
|
||||
|
||||
|
||||
def test_embedding_knn_accepts_custom_exemplars():
|
||||
from scripts.classifier_adapters import EmbeddingKNNAdapter
|
||||
custom = {"rejected": ["Sorry, we went with others."]}
|
||||
adapter = EmbeddingKNNAdapter(
|
||||
"test", "nomic-embed-text",
|
||||
k=3, orch_url="http://orch:7700", ollama_url="http://ollama:11434",
|
||||
exemplar_texts=custom,
|
||||
)
|
||||
assert adapter._exemplar_texts is custom
|
||||
|
|
|
|||
Loading…
Reference in a new issue