feat(avocet): add finetune data pipeline, class weights, WeightedTrainer

Implements load_and_prepare_data (JSONL ingestion with class filtering),
compute_class_weights (inverse-frequency, div-by-zero safe), compute_metrics_for_trainer
(macro F1 + accuracy), and WeightedTrainer.compute_loss (**kwargs-safe for
Transformers 4.38+ num_items_in_batch). All 12 tests pass.
This commit is contained in:
pyr0ball 2026-03-15 15:38:45 -07:00
parent 2d795b9573
commit 5eb593569d
2 changed files with 416 additions and 0 deletions

View file

@ -0,0 +1,166 @@
"""Fine-tune email classifiers on the labeled dataset.
CLI entry point. All prints use flush=True so stdout is SSE-streamable.
Usage:
python scripts/finetune_classifier.py --model deberta-small [--epochs 5]
Supported --model values: deberta-small, bge-m3
"""
from __future__ import annotations
import argparse
import json
import sys
from collections import Counter
from datetime import datetime, timezone
from pathlib import Path
from typing import Any
import torch
import torch.nn.functional as F
from sklearn.model_selection import train_test_split
from sklearn.metrics import f1_score, accuracy_score
from transformers import (
AutoTokenizer,
AutoModelForSequenceClassification,
EvalPrediction,
Trainer,
TrainingArguments,
EarlyStoppingCallback,
)
sys.path.insert(0, str(Path(__file__).parent.parent))
from scripts.classifier_adapters import LABELS
_ROOT = Path(__file__).parent.parent
_MODEL_CONFIG: dict[str, dict[str, Any]] = {
"deberta-small": {
"base_model_id": "cross-encoder/nli-deberta-v3-small",
"max_tokens": 512,
"fp16": False,
"batch_size": 16,
"grad_accum": 1,
"gradient_checkpointing": False,
},
"bge-m3": {
"base_model_id": "MoritzLaurer/bge-m3-zeroshot-v2.0",
"max_tokens": 512,
"fp16": True,
"batch_size": 4,
"grad_accum": 4,
"gradient_checkpointing": True,
},
}
def load_and_prepare_data(score_file: Path) -> tuple[list[str], list[str]]:
"""Load labeled JSONL and return (texts, labels) filtered to canonical LABELS.
Drops rows with non-canonical labels (with warning), and drops entire classes
that have fewer than 2 total samples (required for stratified split).
Warns (but continues) for classes with fewer than 5 samples.
"""
if not score_file.exists():
raise FileNotFoundError(
f"Labeled data not found: {score_file}\n"
"Run the label tool first to generate email_score.jsonl."
)
label_set = set(LABELS)
rows: list[dict] = []
with score_file.open() as fh:
for line in fh:
line = line.strip()
if not line:
continue
try:
r = json.loads(line)
except json.JSONDecodeError:
continue
lbl = r.get("label", "")
if lbl not in label_set:
print(
f"[data] WARNING: Dropping row with non-canonical label {lbl!r}",
flush=True,
)
continue
rows.append(r)
# Count samples per class
counts: Counter = Counter(r["label"] for r in rows)
# Drop classes with < 2 total samples (cannot stratify-split)
drop_classes: set[str] = set()
for lbl, cnt in counts.items():
if cnt < 2:
print(
f"[data] WARNING: Dropping class {lbl!r} — only {counts[lbl]} total "
f"sample(s). Need at least 2 for stratified split.",
flush=True,
)
drop_classes.add(lbl)
# Warn for classes with < 5 samples (unreliable eval F1)
for lbl, cnt in counts.items():
if lbl not in drop_classes and cnt < 5:
print(
f"[data] WARNING: Class {lbl!r} has only {cnt} sample(s). "
f"Eval F1 for this class will be unreliable.",
flush=True,
)
# Filter rows
rows = [r for r in rows if r["label"] not in drop_classes]
texts = [f"{r['subject']} [SEP] {r['body'][:400]}" for r in rows]
labels = [r["label"] for r in rows]
return texts, labels
def compute_class_weights(label_ids: list[int], n_classes: int) -> torch.Tensor:
"""Compute inverse-frequency class weights.
Formula: total / (n_classes * class_count) per class.
Unseen classes (count=0) use count=1 to avoid division by zero.
Returns a CPU float32 tensor of shape (n_classes,).
"""
counts = Counter(label_ids)
total = len(label_ids)
weights = []
for cls in range(n_classes):
cnt = counts.get(cls, 1) # use 1 for unseen to avoid div-by-zero
weights.append(total / (n_classes * cnt))
return torch.tensor(weights, dtype=torch.float32)
def compute_metrics_for_trainer(eval_pred: EvalPrediction) -> dict:
"""Compute macro F1 and accuracy from EvalPrediction.
Called by Hugging Face Trainer at each evaluation step.
"""
logits, label_ids = eval_pred.predictions, eval_pred.label_ids
preds = logits.argmax(axis=-1)
macro_f1 = f1_score(label_ids, preds, average="macro", zero_division=0)
acc = accuracy_score(label_ids, preds)
return {"macro_f1": float(macro_f1), "accuracy": float(acc)}
class WeightedTrainer(Trainer):
"""Trainer subclass that applies per-class weights to the cross-entropy loss."""
def compute_loss(self, model, inputs, return_outputs=False, **kwargs):
# **kwargs is required — absorbs num_items_in_batch added in Transformers 4.38.
# Do not remove it; removing it causes TypeError on the first training step.
labels = inputs.pop("labels")
outputs = model(**inputs)
# Move class_weights to the same device as logits — required for GPU training.
# class_weights is created on CPU; logits are on cuda:0 during training.
weight = self.class_weights.to(outputs.logits.device)
loss = F.cross_entropy(outputs.logits, labels, weight=weight)
return (loss, outputs) if return_outputs else loss

250
tests/test_finetune.py Normal file
View file

@ -0,0 +1,250 @@
"""Tests for finetune_classifier — no model downloads required."""
from __future__ import annotations
import json
import pytest
# ---- Data loading tests ----
def test_load_and_prepare_data_drops_non_canonical_labels(tmp_path):
"""Rows with labels not in LABELS must be silently dropped."""
from scripts.finetune_classifier import load_and_prepare_data
from scripts.classifier_adapters import LABELS
# Two samples per canonical label so they survive the < 2 class-drop rule.
rows = [
{"subject": "s1", "body": "b1", "label": "digest"},
{"subject": "s2", "body": "b2", "label": "digest"},
{"subject": "s3", "body": "b3", "label": "profile_alert"}, # non-canonical
{"subject": "s4", "body": "b4", "label": "neutral"},
{"subject": "s5", "body": "b5", "label": "neutral"},
]
score_file = tmp_path / "email_score.jsonl"
score_file.write_text("\n".join(json.dumps(r) for r in rows))
texts, labels = load_and_prepare_data(score_file)
assert len(texts) == 4
assert all(l in LABELS for l in labels)
def test_load_and_prepare_data_formats_input_as_sep(tmp_path):
"""Input text must be 'subject [SEP] body[:400]'."""
# Two samples with the same label so the class survives the < 2 drop rule.
rows = [
{"subject": "Hello", "body": "World" * 100, "label": "neutral"},
{"subject": "Hello2", "body": "World" * 100, "label": "neutral"},
]
score_file = tmp_path / "email_score.jsonl"
score_file.write_text("\n".join(json.dumps(r) for r in rows))
from scripts.finetune_classifier import load_and_prepare_data
texts, labels = load_and_prepare_data(score_file)
assert texts[0].startswith("Hello [SEP] ")
assert len(texts[0]) <= len("Hello [SEP] ") + 400 + 5
def test_load_and_prepare_data_raises_on_missing_file():
"""FileNotFoundError must be raised with actionable message."""
from pathlib import Path
from scripts.finetune_classifier import load_and_prepare_data
with pytest.raises(FileNotFoundError, match="email_score.jsonl"):
load_and_prepare_data(Path("/nonexistent/email_score.jsonl"))
def test_load_and_prepare_data_drops_class_with_fewer_than_2_samples(tmp_path, capsys):
"""Classes with < 2 total samples must be dropped with a warning."""
from scripts.finetune_classifier import load_and_prepare_data
rows = [
{"subject": "s1", "body": "b", "label": "digest"},
{"subject": "s2", "body": "b", "label": "digest"},
{"subject": "s3", "body": "b", "label": "new_lead"}, # only 1 sample — drop
]
score_file = tmp_path / "email_score.jsonl"
score_file.write_text("\n".join(json.dumps(r) for r in rows))
texts, labels = load_and_prepare_data(score_file)
captured = capsys.readouterr()
assert "new_lead" not in labels
assert "new_lead" in captured.out # warning printed
# ---- Class weights tests ----
def test_compute_class_weights_returns_tensor_for_each_class():
"""compute_class_weights must return a float tensor of length n_classes."""
import torch
from scripts.finetune_classifier import compute_class_weights
label_ids = [0, 0, 0, 1, 1, 2] # 3 classes, imbalanced
weights = compute_class_weights(label_ids, n_classes=3)
assert isinstance(weights, torch.Tensor)
assert weights.shape == (3,)
assert all(w > 0 for w in weights)
def test_compute_class_weights_upweights_minority():
"""Minority classes must receive higher weight than majority classes."""
from scripts.finetune_classifier import compute_class_weights
# Class 0: 10 samples, Class 1: 2 samples
label_ids = [0] * 10 + [1] * 2
weights = compute_class_weights(label_ids, n_classes=2)
assert weights[1] > weights[0]
# ---- compute_metrics_for_trainer tests ----
def test_compute_metrics_for_trainer_returns_macro_f1_key():
"""Must return a dict with 'macro_f1' key."""
import numpy as np
from scripts.finetune_classifier import compute_metrics_for_trainer
from transformers import EvalPrediction
logits = np.array([[2.0, 0.1], [0.1, 2.0], [2.0, 0.1]])
labels = np.array([0, 1, 0])
pred = EvalPrediction(predictions=logits, label_ids=labels)
result = compute_metrics_for_trainer(pred)
assert "macro_f1" in result
assert result["macro_f1"] == pytest.approx(1.0)
def test_compute_metrics_for_trainer_returns_accuracy_key():
"""Must also return 'accuracy' key."""
import numpy as np
from scripts.finetune_classifier import compute_metrics_for_trainer
from transformers import EvalPrediction
logits = np.array([[2.0, 0.1], [0.1, 2.0]])
labels = np.array([0, 1])
pred = EvalPrediction(predictions=logits, label_ids=labels)
result = compute_metrics_for_trainer(pred)
assert "accuracy" in result
assert result["accuracy"] == pytest.approx(1.0)
# ---- WeightedTrainer tests ----
def test_weighted_trainer_compute_loss_returns_scalar():
"""compute_loss must return a scalar tensor when return_outputs=False."""
import torch
from unittest.mock import MagicMock
from scripts.finetune_classifier import WeightedTrainer
n_classes = 3
batch = 4
logits = torch.randn(batch, n_classes)
mock_outputs = MagicMock()
mock_outputs.logits = logits
mock_model = MagicMock(return_value=mock_outputs)
trainer = WeightedTrainer.__new__(WeightedTrainer)
trainer.class_weights = torch.ones(n_classes)
inputs = {
"input_ids": torch.zeros(batch, 10, dtype=torch.long),
"labels": torch.randint(0, n_classes, (batch,)),
}
loss = trainer.compute_loss(mock_model, inputs, return_outputs=False)
assert isinstance(loss, torch.Tensor)
assert loss.ndim == 0 # scalar
def test_weighted_trainer_compute_loss_accepts_kwargs():
"""compute_loss must not raise TypeError when called with num_items_in_batch kwarg."""
import torch
from unittest.mock import MagicMock
from scripts.finetune_classifier import WeightedTrainer
n_classes = 3
batch = 2
logits = torch.randn(batch, n_classes)
mock_outputs = MagicMock()
mock_outputs.logits = logits
mock_model = MagicMock(return_value=mock_outputs)
trainer = WeightedTrainer.__new__(WeightedTrainer)
trainer.class_weights = torch.ones(n_classes)
inputs = {
"input_ids": torch.zeros(batch, 5, dtype=torch.long),
"labels": torch.randint(0, n_classes, (batch,)),
}
loss = trainer.compute_loss(mock_model, inputs, return_outputs=False,
num_items_in_batch=batch)
assert isinstance(loss, torch.Tensor)
def test_weighted_trainer_weighted_loss_differs_from_unweighted():
"""Weighted loss must differ from uniform-weight loss for imbalanced inputs."""
import torch
from unittest.mock import MagicMock
from scripts.finetune_classifier import WeightedTrainer
n_classes = 2
batch = 4
# Mixed labels: 3× class-0, 1× class-1.
# Asymmetric logits (class-0 samples predicted well, class-1 predicted poorly)
# ensure per-class CE values differ, so re-weighting changes the weighted mean.
labels = torch.tensor([0, 0, 0, 1], dtype=torch.long)
logits = torch.tensor([[3.0, -1.0], [3.0, -1.0], [3.0, -1.0], [0.5, 0.5]])
mock_outputs = MagicMock()
mock_outputs.logits = logits
trainer_uniform = WeightedTrainer.__new__(WeightedTrainer)
trainer_uniform.class_weights = torch.ones(n_classes)
inputs_uniform = {"input_ids": torch.zeros(batch, 5, dtype=torch.long), "labels": labels.clone()}
loss_uniform = trainer_uniform.compute_loss(MagicMock(return_value=mock_outputs),
inputs_uniform)
trainer_weighted = WeightedTrainer.__new__(WeightedTrainer)
trainer_weighted.class_weights = torch.tensor([0.1, 10.0])
inputs_weighted = {"input_ids": torch.zeros(batch, 5, dtype=torch.long), "labels": labels.clone()}
mock_outputs2 = MagicMock()
mock_outputs2.logits = logits.clone()
loss_weighted = trainer_weighted.compute_loss(MagicMock(return_value=mock_outputs2),
inputs_weighted)
assert not torch.isclose(loss_uniform, loss_weighted)
def test_weighted_trainer_compute_loss_returns_outputs_when_requested():
"""compute_loss with return_outputs=True must return (loss, outputs) tuple."""
import torch
from unittest.mock import MagicMock
from scripts.finetune_classifier import WeightedTrainer
n_classes = 3
batch = 2
logits = torch.randn(batch, n_classes)
mock_outputs = MagicMock()
mock_outputs.logits = logits
mock_model = MagicMock(return_value=mock_outputs)
trainer = WeightedTrainer.__new__(WeightedTrainer)
trainer.class_weights = torch.ones(n_classes)
inputs = {
"input_ids": torch.zeros(batch, 5, dtype=torch.long),
"labels": torch.randint(0, n_classes, (batch,)),
}
result = trainer.compute_loss(mock_model, inputs, return_outputs=True)
assert isinstance(result, tuple)
loss, outputs = result
assert isinstance(loss, torch.Tensor)