feat(avocet): add finetune data pipeline, class weights, WeightedTrainer
Implements load_and_prepare_data (JSONL ingestion with class filtering), compute_class_weights (inverse-frequency, div-by-zero safe), compute_metrics_for_trainer (macro F1 + accuracy), and WeightedTrainer.compute_loss (**kwargs-safe for Transformers 4.38+ num_items_in_batch). All 12 tests pass.
This commit is contained in:
parent
2d795b9573
commit
5eb593569d
2 changed files with 416 additions and 0 deletions
166
scripts/finetune_classifier.py
Normal file
166
scripts/finetune_classifier.py
Normal file
|
|
@ -0,0 +1,166 @@
|
|||
"""Fine-tune email classifiers on the labeled dataset.
|
||||
|
||||
CLI entry point. All prints use flush=True so stdout is SSE-streamable.
|
||||
|
||||
Usage:
|
||||
python scripts/finetune_classifier.py --model deberta-small [--epochs 5]
|
||||
|
||||
Supported --model values: deberta-small, bge-m3
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import sys
|
||||
from collections import Counter
|
||||
from datetime import datetime, timezone
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from sklearn.model_selection import train_test_split
|
||||
from sklearn.metrics import f1_score, accuracy_score
|
||||
from transformers import (
|
||||
AutoTokenizer,
|
||||
AutoModelForSequenceClassification,
|
||||
EvalPrediction,
|
||||
Trainer,
|
||||
TrainingArguments,
|
||||
EarlyStoppingCallback,
|
||||
)
|
||||
|
||||
sys.path.insert(0, str(Path(__file__).parent.parent))
|
||||
|
||||
from scripts.classifier_adapters import LABELS
|
||||
|
||||
_ROOT = Path(__file__).parent.parent
|
||||
|
||||
_MODEL_CONFIG: dict[str, dict[str, Any]] = {
|
||||
"deberta-small": {
|
||||
"base_model_id": "cross-encoder/nli-deberta-v3-small",
|
||||
"max_tokens": 512,
|
||||
"fp16": False,
|
||||
"batch_size": 16,
|
||||
"grad_accum": 1,
|
||||
"gradient_checkpointing": False,
|
||||
},
|
||||
"bge-m3": {
|
||||
"base_model_id": "MoritzLaurer/bge-m3-zeroshot-v2.0",
|
||||
"max_tokens": 512,
|
||||
"fp16": True,
|
||||
"batch_size": 4,
|
||||
"grad_accum": 4,
|
||||
"gradient_checkpointing": True,
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
def load_and_prepare_data(score_file: Path) -> tuple[list[str], list[str]]:
|
||||
"""Load labeled JSONL and return (texts, labels) filtered to canonical LABELS.
|
||||
|
||||
Drops rows with non-canonical labels (with warning), and drops entire classes
|
||||
that have fewer than 2 total samples (required for stratified split).
|
||||
Warns (but continues) for classes with fewer than 5 samples.
|
||||
"""
|
||||
if not score_file.exists():
|
||||
raise FileNotFoundError(
|
||||
f"Labeled data not found: {score_file}\n"
|
||||
"Run the label tool first to generate email_score.jsonl."
|
||||
)
|
||||
|
||||
label_set = set(LABELS)
|
||||
rows: list[dict] = []
|
||||
|
||||
with score_file.open() as fh:
|
||||
for line in fh:
|
||||
line = line.strip()
|
||||
if not line:
|
||||
continue
|
||||
try:
|
||||
r = json.loads(line)
|
||||
except json.JSONDecodeError:
|
||||
continue
|
||||
lbl = r.get("label", "")
|
||||
if lbl not in label_set:
|
||||
print(
|
||||
f"[data] WARNING: Dropping row with non-canonical label {lbl!r}",
|
||||
flush=True,
|
||||
)
|
||||
continue
|
||||
rows.append(r)
|
||||
|
||||
# Count samples per class
|
||||
counts: Counter = Counter(r["label"] for r in rows)
|
||||
|
||||
# Drop classes with < 2 total samples (cannot stratify-split)
|
||||
drop_classes: set[str] = set()
|
||||
for lbl, cnt in counts.items():
|
||||
if cnt < 2:
|
||||
print(
|
||||
f"[data] WARNING: Dropping class {lbl!r} — only {counts[lbl]} total "
|
||||
f"sample(s). Need at least 2 for stratified split.",
|
||||
flush=True,
|
||||
)
|
||||
drop_classes.add(lbl)
|
||||
|
||||
# Warn for classes with < 5 samples (unreliable eval F1)
|
||||
for lbl, cnt in counts.items():
|
||||
if lbl not in drop_classes and cnt < 5:
|
||||
print(
|
||||
f"[data] WARNING: Class {lbl!r} has only {cnt} sample(s). "
|
||||
f"Eval F1 for this class will be unreliable.",
|
||||
flush=True,
|
||||
)
|
||||
|
||||
# Filter rows
|
||||
rows = [r for r in rows if r["label"] not in drop_classes]
|
||||
|
||||
texts = [f"{r['subject']} [SEP] {r['body'][:400]}" for r in rows]
|
||||
labels = [r["label"] for r in rows]
|
||||
|
||||
return texts, labels
|
||||
|
||||
|
||||
def compute_class_weights(label_ids: list[int], n_classes: int) -> torch.Tensor:
|
||||
"""Compute inverse-frequency class weights.
|
||||
|
||||
Formula: total / (n_classes * class_count) per class.
|
||||
Unseen classes (count=0) use count=1 to avoid division by zero.
|
||||
|
||||
Returns a CPU float32 tensor of shape (n_classes,).
|
||||
"""
|
||||
counts = Counter(label_ids)
|
||||
total = len(label_ids)
|
||||
weights = []
|
||||
for cls in range(n_classes):
|
||||
cnt = counts.get(cls, 1) # use 1 for unseen to avoid div-by-zero
|
||||
weights.append(total / (n_classes * cnt))
|
||||
return torch.tensor(weights, dtype=torch.float32)
|
||||
|
||||
|
||||
def compute_metrics_for_trainer(eval_pred: EvalPrediction) -> dict:
|
||||
"""Compute macro F1 and accuracy from EvalPrediction.
|
||||
|
||||
Called by Hugging Face Trainer at each evaluation step.
|
||||
"""
|
||||
logits, label_ids = eval_pred.predictions, eval_pred.label_ids
|
||||
preds = logits.argmax(axis=-1)
|
||||
macro_f1 = f1_score(label_ids, preds, average="macro", zero_division=0)
|
||||
acc = accuracy_score(label_ids, preds)
|
||||
return {"macro_f1": float(macro_f1), "accuracy": float(acc)}
|
||||
|
||||
|
||||
class WeightedTrainer(Trainer):
|
||||
"""Trainer subclass that applies per-class weights to the cross-entropy loss."""
|
||||
|
||||
def compute_loss(self, model, inputs, return_outputs=False, **kwargs):
|
||||
# **kwargs is required — absorbs num_items_in_batch added in Transformers 4.38.
|
||||
# Do not remove it; removing it causes TypeError on the first training step.
|
||||
labels = inputs.pop("labels")
|
||||
outputs = model(**inputs)
|
||||
# Move class_weights to the same device as logits — required for GPU training.
|
||||
# class_weights is created on CPU; logits are on cuda:0 during training.
|
||||
weight = self.class_weights.to(outputs.logits.device)
|
||||
loss = F.cross_entropy(outputs.logits, labels, weight=weight)
|
||||
return (loss, outputs) if return_outputs else loss
|
||||
250
tests/test_finetune.py
Normal file
250
tests/test_finetune.py
Normal file
|
|
@ -0,0 +1,250 @@
|
|||
"""Tests for finetune_classifier — no model downloads required."""
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import pytest
|
||||
|
||||
|
||||
# ---- Data loading tests ----
|
||||
|
||||
def test_load_and_prepare_data_drops_non_canonical_labels(tmp_path):
|
||||
"""Rows with labels not in LABELS must be silently dropped."""
|
||||
from scripts.finetune_classifier import load_and_prepare_data
|
||||
from scripts.classifier_adapters import LABELS
|
||||
|
||||
# Two samples per canonical label so they survive the < 2 class-drop rule.
|
||||
rows = [
|
||||
{"subject": "s1", "body": "b1", "label": "digest"},
|
||||
{"subject": "s2", "body": "b2", "label": "digest"},
|
||||
{"subject": "s3", "body": "b3", "label": "profile_alert"}, # non-canonical
|
||||
{"subject": "s4", "body": "b4", "label": "neutral"},
|
||||
{"subject": "s5", "body": "b5", "label": "neutral"},
|
||||
]
|
||||
score_file = tmp_path / "email_score.jsonl"
|
||||
score_file.write_text("\n".join(json.dumps(r) for r in rows))
|
||||
|
||||
texts, labels = load_and_prepare_data(score_file)
|
||||
assert len(texts) == 4
|
||||
assert all(l in LABELS for l in labels)
|
||||
|
||||
|
||||
def test_load_and_prepare_data_formats_input_as_sep(tmp_path):
|
||||
"""Input text must be 'subject [SEP] body[:400]'."""
|
||||
# Two samples with the same label so the class survives the < 2 drop rule.
|
||||
rows = [
|
||||
{"subject": "Hello", "body": "World" * 100, "label": "neutral"},
|
||||
{"subject": "Hello2", "body": "World" * 100, "label": "neutral"},
|
||||
]
|
||||
score_file = tmp_path / "email_score.jsonl"
|
||||
score_file.write_text("\n".join(json.dumps(r) for r in rows))
|
||||
|
||||
from scripts.finetune_classifier import load_and_prepare_data
|
||||
texts, labels = load_and_prepare_data(score_file)
|
||||
|
||||
assert texts[0].startswith("Hello [SEP] ")
|
||||
assert len(texts[0]) <= len("Hello [SEP] ") + 400 + 5
|
||||
|
||||
|
||||
def test_load_and_prepare_data_raises_on_missing_file():
|
||||
"""FileNotFoundError must be raised with actionable message."""
|
||||
from pathlib import Path
|
||||
from scripts.finetune_classifier import load_and_prepare_data
|
||||
|
||||
with pytest.raises(FileNotFoundError, match="email_score.jsonl"):
|
||||
load_and_prepare_data(Path("/nonexistent/email_score.jsonl"))
|
||||
|
||||
|
||||
def test_load_and_prepare_data_drops_class_with_fewer_than_2_samples(tmp_path, capsys):
|
||||
"""Classes with < 2 total samples must be dropped with a warning."""
|
||||
from scripts.finetune_classifier import load_and_prepare_data
|
||||
|
||||
rows = [
|
||||
{"subject": "s1", "body": "b", "label": "digest"},
|
||||
{"subject": "s2", "body": "b", "label": "digest"},
|
||||
{"subject": "s3", "body": "b", "label": "new_lead"}, # only 1 sample — drop
|
||||
]
|
||||
score_file = tmp_path / "email_score.jsonl"
|
||||
score_file.write_text("\n".join(json.dumps(r) for r in rows))
|
||||
|
||||
texts, labels = load_and_prepare_data(score_file)
|
||||
captured = capsys.readouterr()
|
||||
|
||||
assert "new_lead" not in labels
|
||||
assert "new_lead" in captured.out # warning printed
|
||||
|
||||
|
||||
# ---- Class weights tests ----
|
||||
|
||||
def test_compute_class_weights_returns_tensor_for_each_class():
|
||||
"""compute_class_weights must return a float tensor of length n_classes."""
|
||||
import torch
|
||||
from scripts.finetune_classifier import compute_class_weights
|
||||
|
||||
label_ids = [0, 0, 0, 1, 1, 2] # 3 classes, imbalanced
|
||||
weights = compute_class_weights(label_ids, n_classes=3)
|
||||
|
||||
assert isinstance(weights, torch.Tensor)
|
||||
assert weights.shape == (3,)
|
||||
assert all(w > 0 for w in weights)
|
||||
|
||||
|
||||
def test_compute_class_weights_upweights_minority():
|
||||
"""Minority classes must receive higher weight than majority classes."""
|
||||
from scripts.finetune_classifier import compute_class_weights
|
||||
|
||||
# Class 0: 10 samples, Class 1: 2 samples
|
||||
label_ids = [0] * 10 + [1] * 2
|
||||
weights = compute_class_weights(label_ids, n_classes=2)
|
||||
|
||||
assert weights[1] > weights[0]
|
||||
|
||||
|
||||
# ---- compute_metrics_for_trainer tests ----
|
||||
|
||||
def test_compute_metrics_for_trainer_returns_macro_f1_key():
|
||||
"""Must return a dict with 'macro_f1' key."""
|
||||
import numpy as np
|
||||
from scripts.finetune_classifier import compute_metrics_for_trainer
|
||||
from transformers import EvalPrediction
|
||||
|
||||
logits = np.array([[2.0, 0.1], [0.1, 2.0], [2.0, 0.1]])
|
||||
labels = np.array([0, 1, 0])
|
||||
pred = EvalPrediction(predictions=logits, label_ids=labels)
|
||||
|
||||
result = compute_metrics_for_trainer(pred)
|
||||
assert "macro_f1" in result
|
||||
assert result["macro_f1"] == pytest.approx(1.0)
|
||||
|
||||
|
||||
def test_compute_metrics_for_trainer_returns_accuracy_key():
|
||||
"""Must also return 'accuracy' key."""
|
||||
import numpy as np
|
||||
from scripts.finetune_classifier import compute_metrics_for_trainer
|
||||
from transformers import EvalPrediction
|
||||
|
||||
logits = np.array([[2.0, 0.1], [0.1, 2.0]])
|
||||
labels = np.array([0, 1])
|
||||
pred = EvalPrediction(predictions=logits, label_ids=labels)
|
||||
|
||||
result = compute_metrics_for_trainer(pred)
|
||||
assert "accuracy" in result
|
||||
assert result["accuracy"] == pytest.approx(1.0)
|
||||
|
||||
|
||||
# ---- WeightedTrainer tests ----
|
||||
|
||||
def test_weighted_trainer_compute_loss_returns_scalar():
|
||||
"""compute_loss must return a scalar tensor when return_outputs=False."""
|
||||
import torch
|
||||
from unittest.mock import MagicMock
|
||||
from scripts.finetune_classifier import WeightedTrainer
|
||||
|
||||
n_classes = 3
|
||||
batch = 4
|
||||
logits = torch.randn(batch, n_classes)
|
||||
|
||||
mock_outputs = MagicMock()
|
||||
mock_outputs.logits = logits
|
||||
mock_model = MagicMock(return_value=mock_outputs)
|
||||
|
||||
trainer = WeightedTrainer.__new__(WeightedTrainer)
|
||||
trainer.class_weights = torch.ones(n_classes)
|
||||
|
||||
inputs = {
|
||||
"input_ids": torch.zeros(batch, 10, dtype=torch.long),
|
||||
"labels": torch.randint(0, n_classes, (batch,)),
|
||||
}
|
||||
|
||||
loss = trainer.compute_loss(mock_model, inputs, return_outputs=False)
|
||||
assert isinstance(loss, torch.Tensor)
|
||||
assert loss.ndim == 0 # scalar
|
||||
|
||||
|
||||
def test_weighted_trainer_compute_loss_accepts_kwargs():
|
||||
"""compute_loss must not raise TypeError when called with num_items_in_batch kwarg."""
|
||||
import torch
|
||||
from unittest.mock import MagicMock
|
||||
from scripts.finetune_classifier import WeightedTrainer
|
||||
|
||||
n_classes = 3
|
||||
batch = 2
|
||||
logits = torch.randn(batch, n_classes)
|
||||
|
||||
mock_outputs = MagicMock()
|
||||
mock_outputs.logits = logits
|
||||
mock_model = MagicMock(return_value=mock_outputs)
|
||||
|
||||
trainer = WeightedTrainer.__new__(WeightedTrainer)
|
||||
trainer.class_weights = torch.ones(n_classes)
|
||||
|
||||
inputs = {
|
||||
"input_ids": torch.zeros(batch, 5, dtype=torch.long),
|
||||
"labels": torch.randint(0, n_classes, (batch,)),
|
||||
}
|
||||
|
||||
loss = trainer.compute_loss(mock_model, inputs, return_outputs=False,
|
||||
num_items_in_batch=batch)
|
||||
assert isinstance(loss, torch.Tensor)
|
||||
|
||||
|
||||
def test_weighted_trainer_weighted_loss_differs_from_unweighted():
|
||||
"""Weighted loss must differ from uniform-weight loss for imbalanced inputs."""
|
||||
import torch
|
||||
from unittest.mock import MagicMock
|
||||
from scripts.finetune_classifier import WeightedTrainer
|
||||
|
||||
n_classes = 2
|
||||
batch = 4
|
||||
# Mixed labels: 3× class-0, 1× class-1.
|
||||
# Asymmetric logits (class-0 samples predicted well, class-1 predicted poorly)
|
||||
# ensure per-class CE values differ, so re-weighting changes the weighted mean.
|
||||
labels = torch.tensor([0, 0, 0, 1], dtype=torch.long)
|
||||
logits = torch.tensor([[3.0, -1.0], [3.0, -1.0], [3.0, -1.0], [0.5, 0.5]])
|
||||
|
||||
mock_outputs = MagicMock()
|
||||
mock_outputs.logits = logits
|
||||
|
||||
trainer_uniform = WeightedTrainer.__new__(WeightedTrainer)
|
||||
trainer_uniform.class_weights = torch.ones(n_classes)
|
||||
inputs_uniform = {"input_ids": torch.zeros(batch, 5, dtype=torch.long), "labels": labels.clone()}
|
||||
loss_uniform = trainer_uniform.compute_loss(MagicMock(return_value=mock_outputs),
|
||||
inputs_uniform)
|
||||
|
||||
trainer_weighted = WeightedTrainer.__new__(WeightedTrainer)
|
||||
trainer_weighted.class_weights = torch.tensor([0.1, 10.0])
|
||||
inputs_weighted = {"input_ids": torch.zeros(batch, 5, dtype=torch.long), "labels": labels.clone()}
|
||||
|
||||
mock_outputs2 = MagicMock()
|
||||
mock_outputs2.logits = logits.clone()
|
||||
loss_weighted = trainer_weighted.compute_loss(MagicMock(return_value=mock_outputs2),
|
||||
inputs_weighted)
|
||||
|
||||
assert not torch.isclose(loss_uniform, loss_weighted)
|
||||
|
||||
|
||||
def test_weighted_trainer_compute_loss_returns_outputs_when_requested():
|
||||
"""compute_loss with return_outputs=True must return (loss, outputs) tuple."""
|
||||
import torch
|
||||
from unittest.mock import MagicMock
|
||||
from scripts.finetune_classifier import WeightedTrainer
|
||||
|
||||
n_classes = 3
|
||||
batch = 2
|
||||
logits = torch.randn(batch, n_classes)
|
||||
|
||||
mock_outputs = MagicMock()
|
||||
mock_outputs.logits = logits
|
||||
mock_model = MagicMock(return_value=mock_outputs)
|
||||
|
||||
trainer = WeightedTrainer.__new__(WeightedTrainer)
|
||||
trainer.class_weights = torch.ones(n_classes)
|
||||
|
||||
inputs = {
|
||||
"input_ids": torch.zeros(batch, 5, dtype=torch.long),
|
||||
"labels": torch.randint(0, n_classes, (batch,)),
|
||||
}
|
||||
|
||||
result = trainer.compute_loss(mock_model, inputs, return_outputs=True)
|
||||
assert isinstance(result, tuple)
|
||||
loss, outputs = result
|
||||
assert isinstance(loss, torch.Tensor)
|
||||
Loading…
Reference in a new issue