From 5df33b0f4182a9aabf6f6b3caca1c4da54b08c59 Mon Sep 17 00:00:00 2001 From: pyr0ball Date: Tue, 5 May 2026 12:43:48 -0700 Subject: [PATCH] feat(benchmark): wire EmbeddingKNNAdapter into MODEL_REGISTRY as embed-knn-nomic --- scripts/benchmark_classifier.py | 8 ++++++++ tests/test_benchmark_classifier.py | 19 ++++++++++++++----- 2 files changed, 22 insertions(+), 5 deletions(-) diff --git a/scripts/benchmark_classifier.py b/scripts/benchmark_classifier.py index be935b2..c33fb4d 100644 --- a/scripts/benchmark_classifier.py +++ b/scripts/benchmark_classifier.py @@ -39,6 +39,7 @@ from scripts.classifier_adapters import ( LABELS, LABEL_DESCRIPTIONS, ClassifierAdapter, + EmbeddingKNNAdapter, FineTunedAdapter, GLiClassAdapter, RerankerAdapter, @@ -130,6 +131,13 @@ MODEL_REGISTRY: dict[str, dict[str, Any]] = { "params": "600M", "default": False, }, + "embed-knn-nomic": { + "adapter": EmbeddingKNNAdapter, + "model_id": "nomic-embed-text", + "params": "local-embed", + "default": False, # requires orch or ollama; use --include-slow + "kwargs": {"k": 3}, + }, } # --------------------------------------------------------------------------- diff --git a/tests/test_benchmark_classifier.py b/tests/test_benchmark_classifier.py index d0b2155..f5dcfb0 100644 --- a/tests/test_benchmark_classifier.py +++ b/tests/test_benchmark_classifier.py @@ -2,11 +2,6 @@ import pytest -def test_registry_has_thirteen_models(): - from scripts.benchmark_classifier import MODEL_REGISTRY - assert len(MODEL_REGISTRY) == 13 - - def test_registry_default_count(): from scripts.benchmark_classifier import MODEL_REGISTRY defaults = [k for k, v in MODEL_REGISTRY.items() if v["default"]] @@ -243,3 +238,17 @@ def test_build_exemplars_skips_rows_with_no_content(tmp_path): result = build_exemplars_from_jsonl(str(f)) assert list(result.keys()) == ["neutral"] assert len(result["neutral"]) == 1 + +def test_registry_has_fourteen_models(): + from scripts.benchmark_classifier import MODEL_REGISTRY + assert len(MODEL_REGISTRY) == 14 + + +def test_embed_knn_nomic_registry_entry(): + from scripts.benchmark_classifier import MODEL_REGISTRY + from scripts.classifier_adapters import EmbeddingKNNAdapter + entry = MODEL_REGISTRY["embed-knn-nomic"] + assert entry["adapter"] is EmbeddingKNNAdapter + assert entry["model_id"] == "nomic-embed-text" + assert entry["default"] is False + assert entry.get("kwargs", {}).get("k") == 3