diff --git a/scripts/finetune_classifier.py b/scripts/finetune_classifier.py new file mode 100644 index 0000000..c225b71 --- /dev/null +++ b/scripts/finetune_classifier.py @@ -0,0 +1,166 @@ +"""Fine-tune email classifiers on the labeled dataset. + +CLI entry point. All prints use flush=True so stdout is SSE-streamable. + +Usage: + python scripts/finetune_classifier.py --model deberta-small [--epochs 5] + +Supported --model values: deberta-small, bge-m3 +""" +from __future__ import annotations + +import argparse +import json +import sys +from collections import Counter +from datetime import datetime, timezone +from pathlib import Path +from typing import Any + +import torch +import torch.nn.functional as F +from sklearn.model_selection import train_test_split +from sklearn.metrics import f1_score, accuracy_score +from transformers import ( + AutoTokenizer, + AutoModelForSequenceClassification, + EvalPrediction, + Trainer, + TrainingArguments, + EarlyStoppingCallback, +) + +sys.path.insert(0, str(Path(__file__).parent.parent)) + +from scripts.classifier_adapters import LABELS + +_ROOT = Path(__file__).parent.parent + +_MODEL_CONFIG: dict[str, dict[str, Any]] = { + "deberta-small": { + "base_model_id": "cross-encoder/nli-deberta-v3-small", + "max_tokens": 512, + "fp16": False, + "batch_size": 16, + "grad_accum": 1, + "gradient_checkpointing": False, + }, + "bge-m3": { + "base_model_id": "MoritzLaurer/bge-m3-zeroshot-v2.0", + "max_tokens": 512, + "fp16": True, + "batch_size": 4, + "grad_accum": 4, + "gradient_checkpointing": True, + }, +} + + +def load_and_prepare_data(score_file: Path) -> tuple[list[str], list[str]]: + """Load labeled JSONL and return (texts, labels) filtered to canonical LABELS. + + Drops rows with non-canonical labels (with warning), and drops entire classes + that have fewer than 2 total samples (required for stratified split). + Warns (but continues) for classes with fewer than 5 samples. + """ + if not score_file.exists(): + raise FileNotFoundError( + f"Labeled data not found: {score_file}\n" + "Run the label tool first to generate email_score.jsonl." + ) + + label_set = set(LABELS) + rows: list[dict] = [] + + with score_file.open() as fh: + for line in fh: + line = line.strip() + if not line: + continue + try: + r = json.loads(line) + except json.JSONDecodeError: + continue + lbl = r.get("label", "") + if lbl not in label_set: + print( + f"[data] WARNING: Dropping row with non-canonical label {lbl!r}", + flush=True, + ) + continue + rows.append(r) + + # Count samples per class + counts: Counter = Counter(r["label"] for r in rows) + + # Drop classes with < 2 total samples (cannot stratify-split) + drop_classes: set[str] = set() + for lbl, cnt in counts.items(): + if cnt < 2: + print( + f"[data] WARNING: Dropping class {lbl!r} — only {counts[lbl]} total " + f"sample(s). Need at least 2 for stratified split.", + flush=True, + ) + drop_classes.add(lbl) + + # Warn for classes with < 5 samples (unreliable eval F1) + for lbl, cnt in counts.items(): + if lbl not in drop_classes and cnt < 5: + print( + f"[data] WARNING: Class {lbl!r} has only {cnt} sample(s). " + f"Eval F1 for this class will be unreliable.", + flush=True, + ) + + # Filter rows + rows = [r for r in rows if r["label"] not in drop_classes] + + texts = [f"{r['subject']} [SEP] {r['body'][:400]}" for r in rows] + labels = [r["label"] for r in rows] + + return texts, labels + + +def compute_class_weights(label_ids: list[int], n_classes: int) -> torch.Tensor: + """Compute inverse-frequency class weights. + + Formula: total / (n_classes * class_count) per class. + Unseen classes (count=0) use count=1 to avoid division by zero. + + Returns a CPU float32 tensor of shape (n_classes,). + """ + counts = Counter(label_ids) + total = len(label_ids) + weights = [] + for cls in range(n_classes): + cnt = counts.get(cls, 1) # use 1 for unseen to avoid div-by-zero + weights.append(total / (n_classes * cnt)) + return torch.tensor(weights, dtype=torch.float32) + + +def compute_metrics_for_trainer(eval_pred: EvalPrediction) -> dict: + """Compute macro F1 and accuracy from EvalPrediction. + + Called by Hugging Face Trainer at each evaluation step. + """ + logits, label_ids = eval_pred.predictions, eval_pred.label_ids + preds = logits.argmax(axis=-1) + macro_f1 = f1_score(label_ids, preds, average="macro", zero_division=0) + acc = accuracy_score(label_ids, preds) + return {"macro_f1": float(macro_f1), "accuracy": float(acc)} + + +class WeightedTrainer(Trainer): + """Trainer subclass that applies per-class weights to the cross-entropy loss.""" + + def compute_loss(self, model, inputs, return_outputs=False, **kwargs): + # **kwargs is required — absorbs num_items_in_batch added in Transformers 4.38. + # Do not remove it; removing it causes TypeError on the first training step. + labels = inputs.pop("labels") + outputs = model(**inputs) + # Move class_weights to the same device as logits — required for GPU training. + # class_weights is created on CPU; logits are on cuda:0 during training. + weight = self.class_weights.to(outputs.logits.device) + loss = F.cross_entropy(outputs.logits, labels, weight=weight) + return (loss, outputs) if return_outputs else loss diff --git a/tests/test_finetune.py b/tests/test_finetune.py new file mode 100644 index 0000000..d0fc569 --- /dev/null +++ b/tests/test_finetune.py @@ -0,0 +1,250 @@ +"""Tests for finetune_classifier — no model downloads required.""" +from __future__ import annotations + +import json +import pytest + + +# ---- Data loading tests ---- + +def test_load_and_prepare_data_drops_non_canonical_labels(tmp_path): + """Rows with labels not in LABELS must be silently dropped.""" + from scripts.finetune_classifier import load_and_prepare_data + from scripts.classifier_adapters import LABELS + + # Two samples per canonical label so they survive the < 2 class-drop rule. + rows = [ + {"subject": "s1", "body": "b1", "label": "digest"}, + {"subject": "s2", "body": "b2", "label": "digest"}, + {"subject": "s3", "body": "b3", "label": "profile_alert"}, # non-canonical + {"subject": "s4", "body": "b4", "label": "neutral"}, + {"subject": "s5", "body": "b5", "label": "neutral"}, + ] + score_file = tmp_path / "email_score.jsonl" + score_file.write_text("\n".join(json.dumps(r) for r in rows)) + + texts, labels = load_and_prepare_data(score_file) + assert len(texts) == 4 + assert all(l in LABELS for l in labels) + + +def test_load_and_prepare_data_formats_input_as_sep(tmp_path): + """Input text must be 'subject [SEP] body[:400]'.""" + # Two samples with the same label so the class survives the < 2 drop rule. + rows = [ + {"subject": "Hello", "body": "World" * 100, "label": "neutral"}, + {"subject": "Hello2", "body": "World" * 100, "label": "neutral"}, + ] + score_file = tmp_path / "email_score.jsonl" + score_file.write_text("\n".join(json.dumps(r) for r in rows)) + + from scripts.finetune_classifier import load_and_prepare_data + texts, labels = load_and_prepare_data(score_file) + + assert texts[0].startswith("Hello [SEP] ") + assert len(texts[0]) <= len("Hello [SEP] ") + 400 + 5 + + +def test_load_and_prepare_data_raises_on_missing_file(): + """FileNotFoundError must be raised with actionable message.""" + from pathlib import Path + from scripts.finetune_classifier import load_and_prepare_data + + with pytest.raises(FileNotFoundError, match="email_score.jsonl"): + load_and_prepare_data(Path("/nonexistent/email_score.jsonl")) + + +def test_load_and_prepare_data_drops_class_with_fewer_than_2_samples(tmp_path, capsys): + """Classes with < 2 total samples must be dropped with a warning.""" + from scripts.finetune_classifier import load_and_prepare_data + + rows = [ + {"subject": "s1", "body": "b", "label": "digest"}, + {"subject": "s2", "body": "b", "label": "digest"}, + {"subject": "s3", "body": "b", "label": "new_lead"}, # only 1 sample — drop + ] + score_file = tmp_path / "email_score.jsonl" + score_file.write_text("\n".join(json.dumps(r) for r in rows)) + + texts, labels = load_and_prepare_data(score_file) + captured = capsys.readouterr() + + assert "new_lead" not in labels + assert "new_lead" in captured.out # warning printed + + +# ---- Class weights tests ---- + +def test_compute_class_weights_returns_tensor_for_each_class(): + """compute_class_weights must return a float tensor of length n_classes.""" + import torch + from scripts.finetune_classifier import compute_class_weights + + label_ids = [0, 0, 0, 1, 1, 2] # 3 classes, imbalanced + weights = compute_class_weights(label_ids, n_classes=3) + + assert isinstance(weights, torch.Tensor) + assert weights.shape == (3,) + assert all(w > 0 for w in weights) + + +def test_compute_class_weights_upweights_minority(): + """Minority classes must receive higher weight than majority classes.""" + from scripts.finetune_classifier import compute_class_weights + + # Class 0: 10 samples, Class 1: 2 samples + label_ids = [0] * 10 + [1] * 2 + weights = compute_class_weights(label_ids, n_classes=2) + + assert weights[1] > weights[0] + + +# ---- compute_metrics_for_trainer tests ---- + +def test_compute_metrics_for_trainer_returns_macro_f1_key(): + """Must return a dict with 'macro_f1' key.""" + import numpy as np + from scripts.finetune_classifier import compute_metrics_for_trainer + from transformers import EvalPrediction + + logits = np.array([[2.0, 0.1], [0.1, 2.0], [2.0, 0.1]]) + labels = np.array([0, 1, 0]) + pred = EvalPrediction(predictions=logits, label_ids=labels) + + result = compute_metrics_for_trainer(pred) + assert "macro_f1" in result + assert result["macro_f1"] == pytest.approx(1.0) + + +def test_compute_metrics_for_trainer_returns_accuracy_key(): + """Must also return 'accuracy' key.""" + import numpy as np + from scripts.finetune_classifier import compute_metrics_for_trainer + from transformers import EvalPrediction + + logits = np.array([[2.0, 0.1], [0.1, 2.0]]) + labels = np.array([0, 1]) + pred = EvalPrediction(predictions=logits, label_ids=labels) + + result = compute_metrics_for_trainer(pred) + assert "accuracy" in result + assert result["accuracy"] == pytest.approx(1.0) + + +# ---- WeightedTrainer tests ---- + +def test_weighted_trainer_compute_loss_returns_scalar(): + """compute_loss must return a scalar tensor when return_outputs=False.""" + import torch + from unittest.mock import MagicMock + from scripts.finetune_classifier import WeightedTrainer + + n_classes = 3 + batch = 4 + logits = torch.randn(batch, n_classes) + + mock_outputs = MagicMock() + mock_outputs.logits = logits + mock_model = MagicMock(return_value=mock_outputs) + + trainer = WeightedTrainer.__new__(WeightedTrainer) + trainer.class_weights = torch.ones(n_classes) + + inputs = { + "input_ids": torch.zeros(batch, 10, dtype=torch.long), + "labels": torch.randint(0, n_classes, (batch,)), + } + + loss = trainer.compute_loss(mock_model, inputs, return_outputs=False) + assert isinstance(loss, torch.Tensor) + assert loss.ndim == 0 # scalar + + +def test_weighted_trainer_compute_loss_accepts_kwargs(): + """compute_loss must not raise TypeError when called with num_items_in_batch kwarg.""" + import torch + from unittest.mock import MagicMock + from scripts.finetune_classifier import WeightedTrainer + + n_classes = 3 + batch = 2 + logits = torch.randn(batch, n_classes) + + mock_outputs = MagicMock() + mock_outputs.logits = logits + mock_model = MagicMock(return_value=mock_outputs) + + trainer = WeightedTrainer.__new__(WeightedTrainer) + trainer.class_weights = torch.ones(n_classes) + + inputs = { + "input_ids": torch.zeros(batch, 5, dtype=torch.long), + "labels": torch.randint(0, n_classes, (batch,)), + } + + loss = trainer.compute_loss(mock_model, inputs, return_outputs=False, + num_items_in_batch=batch) + assert isinstance(loss, torch.Tensor) + + +def test_weighted_trainer_weighted_loss_differs_from_unweighted(): + """Weighted loss must differ from uniform-weight loss for imbalanced inputs.""" + import torch + from unittest.mock import MagicMock + from scripts.finetune_classifier import WeightedTrainer + + n_classes = 2 + batch = 4 + # Mixed labels: 3× class-0, 1× class-1. + # Asymmetric logits (class-0 samples predicted well, class-1 predicted poorly) + # ensure per-class CE values differ, so re-weighting changes the weighted mean. + labels = torch.tensor([0, 0, 0, 1], dtype=torch.long) + logits = torch.tensor([[3.0, -1.0], [3.0, -1.0], [3.0, -1.0], [0.5, 0.5]]) + + mock_outputs = MagicMock() + mock_outputs.logits = logits + + trainer_uniform = WeightedTrainer.__new__(WeightedTrainer) + trainer_uniform.class_weights = torch.ones(n_classes) + inputs_uniform = {"input_ids": torch.zeros(batch, 5, dtype=torch.long), "labels": labels.clone()} + loss_uniform = trainer_uniform.compute_loss(MagicMock(return_value=mock_outputs), + inputs_uniform) + + trainer_weighted = WeightedTrainer.__new__(WeightedTrainer) + trainer_weighted.class_weights = torch.tensor([0.1, 10.0]) + inputs_weighted = {"input_ids": torch.zeros(batch, 5, dtype=torch.long), "labels": labels.clone()} + + mock_outputs2 = MagicMock() + mock_outputs2.logits = logits.clone() + loss_weighted = trainer_weighted.compute_loss(MagicMock(return_value=mock_outputs2), + inputs_weighted) + + assert not torch.isclose(loss_uniform, loss_weighted) + + +def test_weighted_trainer_compute_loss_returns_outputs_when_requested(): + """compute_loss with return_outputs=True must return (loss, outputs) tuple.""" + import torch + from unittest.mock import MagicMock + from scripts.finetune_classifier import WeightedTrainer + + n_classes = 3 + batch = 2 + logits = torch.randn(batch, n_classes) + + mock_outputs = MagicMock() + mock_outputs.logits = logits + mock_model = MagicMock(return_value=mock_outputs) + + trainer = WeightedTrainer.__new__(WeightedTrainer) + trainer.class_weights = torch.ones(n_classes) + + inputs = { + "input_ids": torch.zeros(batch, 5, dtype=torch.long), + "labels": torch.randint(0, n_classes, (batch,)), + } + + result = trainer.compute_loss(mock_model, inputs, return_outputs=True) + assert isinstance(result, tuple) + loss, outputs = result + assert isinstance(loss, torch.Tensor)