From 1d4c07e4a005ce054cea22d9027e94c90e1b0a99 Mon Sep 17 00:00:00 2001 From: pyr0ball Date: Tue, 5 May 2026 11:43:12 -0700 Subject: [PATCH] feat(benchmark): add build_exemplars_from_jsonl() for k-NN seed --- scripts/benchmark_classifier.py | 31 ++++++++++++++++ tests/test_benchmark_classifier.py | 59 ++++++++++++++++++++++++++++++ 2 files changed, 90 insertions(+) diff --git a/scripts/benchmark_classifier.py b/scripts/benchmark_classifier.py index b3c5b49..d979bd3 100644 --- a/scripts/benchmark_classifier.py +++ b/scripts/benchmark_classifier.py @@ -184,6 +184,37 @@ def discover_finetuned_models(models_dir: Path | None = None) -> list[dict]: return found +def build_exemplars_from_jsonl(path: str, k_per_label: int = 10) -> dict[str, list[str]]: + """Sample up to k_per_label formatted email texts per label from a scored JSONL. + + Formats each row as 'Subject: {subject}\n\n{body[:600]}' — the same format + EmbeddingKNNAdapter uses at classify() time. Rows missing the 'label' key + are skipped silently. + + Returns dict[label, list[str]] ready for EmbeddingKNNAdapter(exemplar_texts=...). + """ + result: dict[str, list[str]] = {} + p = Path(path) + with p.open(encoding="utf-8") as f: + for line in f: + line = line.strip() + if not line: + continue + try: + row = json.loads(line) + except json.JSONDecodeError: + continue + label = row.get("label") + if not label: + continue + texts = result.setdefault(label, []) + if len(texts) < k_per_label: + texts.append( + f"Subject: {row.get('subject', '')}\n\n{row.get('body', '')[:600]}" + ) + return result + + def _active_models(include_slow: bool = False) -> dict[str, dict[str, Any]]: """Return the active model registry, merged with any discovered fine-tuned models.""" active: dict[str, dict[str, Any]] = { diff --git a/tests/test_benchmark_classifier.py b/tests/test_benchmark_classifier.py index 1b1a71b..f076aff 100644 --- a/tests/test_benchmark_classifier.py +++ b/tests/test_benchmark_classifier.py @@ -166,3 +166,62 @@ def test_active_models_includes_discovered_finetuned(tmp_path): assert "avocet-deberta-small" in models assert isinstance(models["avocet-deberta-small"]["adapter_instance"], FineTunedAdapter) + + +# ---- build_exemplars_from_jsonl() tests ---- + +def test_build_exemplars_samples_up_to_k_per_label(tmp_path): + from scripts.benchmark_classifier import build_exemplars_from_jsonl + import json + + rows = [{"subject": f"S{i}", "body": f"B{i}", "label": "rejected"} for i in range(15)] + rows.append({"subject": "Hire", "body": "Welcome", "label": "hired"}) + f = tmp_path / "score.jsonl" + f.write_text("\n".join(json.dumps(r) for r in rows)) + + result = build_exemplars_from_jsonl(str(f), k_per_label=10) + + assert len(result["rejected"]) == 10 + assert len(result["hired"]) == 1 + assert result["rejected"][0].startswith("Subject: S") + + +def test_build_exemplars_formats_text_correctly(tmp_path): + from scripts.benchmark_classifier import build_exemplars_from_jsonl + import json + + row = {"subject": "My Subject", "body": "My Body", "label": "neutral"} + f = tmp_path / "score.jsonl" + f.write_text(json.dumps(row)) + + result = build_exemplars_from_jsonl(str(f)) + + assert result["neutral"][0] == "Subject: My Subject\n\nMy Body" + + +def test_build_exemplars_skips_rows_missing_label(tmp_path): + from scripts.benchmark_classifier import build_exemplars_from_jsonl + import json + + rows = [ + {"subject": "A", "body": "B", "label": "neutral"}, + {"subject": "No label here", "body": "Body"}, + ] + f = tmp_path / "score.jsonl" + f.write_text("\n".join(json.dumps(r) for r in rows)) + + result = build_exemplars_from_jsonl(str(f)) + assert list(result.keys()) == ["neutral"] + + +def test_build_exemplars_truncates_body_at_600(tmp_path): + from scripts.benchmark_classifier import build_exemplars_from_jsonl + import json + + row = {"subject": "S", "body": "x" * 800, "label": "neutral"} + f = tmp_path / "score.jsonl" + f.write_text(json.dumps(row)) + + result = build_exemplars_from_jsonl(str(f)) + body_part = result["neutral"][0].split("\n\n", 1)[1] + assert len(body_part) == 600