feat(benchmark): add build_exemplars_from_jsonl() for k-NN seed
This commit is contained in:
parent
e823b5e76d
commit
1d4c07e4a0
2 changed files with 90 additions and 0 deletions
|
|
@ -184,6 +184,37 @@ def discover_finetuned_models(models_dir: Path | None = None) -> list[dict]:
|
||||||
return found
|
return found
|
||||||
|
|
||||||
|
|
||||||
|
def build_exemplars_from_jsonl(path: str, k_per_label: int = 10) -> dict[str, list[str]]:
|
||||||
|
"""Sample up to k_per_label formatted email texts per label from a scored JSONL.
|
||||||
|
|
||||||
|
Formats each row as 'Subject: {subject}\n\n{body[:600]}' — the same format
|
||||||
|
EmbeddingKNNAdapter uses at classify() time. Rows missing the 'label' key
|
||||||
|
are skipped silently.
|
||||||
|
|
||||||
|
Returns dict[label, list[str]] ready for EmbeddingKNNAdapter(exemplar_texts=...).
|
||||||
|
"""
|
||||||
|
result: dict[str, list[str]] = {}
|
||||||
|
p = Path(path)
|
||||||
|
with p.open(encoding="utf-8") as f:
|
||||||
|
for line in f:
|
||||||
|
line = line.strip()
|
||||||
|
if not line:
|
||||||
|
continue
|
||||||
|
try:
|
||||||
|
row = json.loads(line)
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
continue
|
||||||
|
label = row.get("label")
|
||||||
|
if not label:
|
||||||
|
continue
|
||||||
|
texts = result.setdefault(label, [])
|
||||||
|
if len(texts) < k_per_label:
|
||||||
|
texts.append(
|
||||||
|
f"Subject: {row.get('subject', '')}\n\n{row.get('body', '')[:600]}"
|
||||||
|
)
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
def _active_models(include_slow: bool = False) -> dict[str, dict[str, Any]]:
|
def _active_models(include_slow: bool = False) -> dict[str, dict[str, Any]]:
|
||||||
"""Return the active model registry, merged with any discovered fine-tuned models."""
|
"""Return the active model registry, merged with any discovered fine-tuned models."""
|
||||||
active: dict[str, dict[str, Any]] = {
|
active: dict[str, dict[str, Any]] = {
|
||||||
|
|
|
||||||
|
|
@ -166,3 +166,62 @@ def test_active_models_includes_discovered_finetuned(tmp_path):
|
||||||
|
|
||||||
assert "avocet-deberta-small" in models
|
assert "avocet-deberta-small" in models
|
||||||
assert isinstance(models["avocet-deberta-small"]["adapter_instance"], FineTunedAdapter)
|
assert isinstance(models["avocet-deberta-small"]["adapter_instance"], FineTunedAdapter)
|
||||||
|
|
||||||
|
|
||||||
|
# ---- build_exemplars_from_jsonl() tests ----
|
||||||
|
|
||||||
|
def test_build_exemplars_samples_up_to_k_per_label(tmp_path):
|
||||||
|
from scripts.benchmark_classifier import build_exemplars_from_jsonl
|
||||||
|
import json
|
||||||
|
|
||||||
|
rows = [{"subject": f"S{i}", "body": f"B{i}", "label": "rejected"} for i in range(15)]
|
||||||
|
rows.append({"subject": "Hire", "body": "Welcome", "label": "hired"})
|
||||||
|
f = tmp_path / "score.jsonl"
|
||||||
|
f.write_text("\n".join(json.dumps(r) for r in rows))
|
||||||
|
|
||||||
|
result = build_exemplars_from_jsonl(str(f), k_per_label=10)
|
||||||
|
|
||||||
|
assert len(result["rejected"]) == 10
|
||||||
|
assert len(result["hired"]) == 1
|
||||||
|
assert result["rejected"][0].startswith("Subject: S")
|
||||||
|
|
||||||
|
|
||||||
|
def test_build_exemplars_formats_text_correctly(tmp_path):
|
||||||
|
from scripts.benchmark_classifier import build_exemplars_from_jsonl
|
||||||
|
import json
|
||||||
|
|
||||||
|
row = {"subject": "My Subject", "body": "My Body", "label": "neutral"}
|
||||||
|
f = tmp_path / "score.jsonl"
|
||||||
|
f.write_text(json.dumps(row))
|
||||||
|
|
||||||
|
result = build_exemplars_from_jsonl(str(f))
|
||||||
|
|
||||||
|
assert result["neutral"][0] == "Subject: My Subject\n\nMy Body"
|
||||||
|
|
||||||
|
|
||||||
|
def test_build_exemplars_skips_rows_missing_label(tmp_path):
|
||||||
|
from scripts.benchmark_classifier import build_exemplars_from_jsonl
|
||||||
|
import json
|
||||||
|
|
||||||
|
rows = [
|
||||||
|
{"subject": "A", "body": "B", "label": "neutral"},
|
||||||
|
{"subject": "No label here", "body": "Body"},
|
||||||
|
]
|
||||||
|
f = tmp_path / "score.jsonl"
|
||||||
|
f.write_text("\n".join(json.dumps(r) for r in rows))
|
||||||
|
|
||||||
|
result = build_exemplars_from_jsonl(str(f))
|
||||||
|
assert list(result.keys()) == ["neutral"]
|
||||||
|
|
||||||
|
|
||||||
|
def test_build_exemplars_truncates_body_at_600(tmp_path):
|
||||||
|
from scripts.benchmark_classifier import build_exemplars_from_jsonl
|
||||||
|
import json
|
||||||
|
|
||||||
|
row = {"subject": "S", "body": "x" * 800, "label": "neutral"}
|
||||||
|
f = tmp_path / "score.jsonl"
|
||||||
|
f.write_text(json.dumps(row))
|
||||||
|
|
||||||
|
result = build_exemplars_from_jsonl(str(f))
|
||||||
|
body_part = result["neutral"][0].split("\n\n", 1)[1]
|
||||||
|
assert len(body_part) == 600
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue