Implements load_and_prepare_data (JSONL ingestion with class filtering), compute_class_weights (inverse-frequency, div-by-zero safe), compute_metrics_for_trainer (macro F1 + accuracy), and WeightedTrainer.compute_loss (**kwargs-safe for Transformers 4.38+ num_items_in_batch). All 12 tests pass.
166 lines
5.5 KiB
Python
166 lines
5.5 KiB
Python
"""Fine-tune email classifiers on the labeled dataset.
|
|
|
|
CLI entry point. All prints use flush=True so stdout is SSE-streamable.
|
|
|
|
Usage:
|
|
python scripts/finetune_classifier.py --model deberta-small [--epochs 5]
|
|
|
|
Supported --model values: deberta-small, bge-m3
|
|
"""
|
|
from __future__ import annotations
|
|
|
|
import argparse
|
|
import json
|
|
import sys
|
|
from collections import Counter
|
|
from datetime import datetime, timezone
|
|
from pathlib import Path
|
|
from typing import Any
|
|
|
|
import torch
|
|
import torch.nn.functional as F
|
|
from sklearn.model_selection import train_test_split
|
|
from sklearn.metrics import f1_score, accuracy_score
|
|
from transformers import (
|
|
AutoTokenizer,
|
|
AutoModelForSequenceClassification,
|
|
EvalPrediction,
|
|
Trainer,
|
|
TrainingArguments,
|
|
EarlyStoppingCallback,
|
|
)
|
|
|
|
sys.path.insert(0, str(Path(__file__).parent.parent))
|
|
|
|
from scripts.classifier_adapters import LABELS
|
|
|
|
_ROOT = Path(__file__).parent.parent
|
|
|
|
_MODEL_CONFIG: dict[str, dict[str, Any]] = {
|
|
"deberta-small": {
|
|
"base_model_id": "cross-encoder/nli-deberta-v3-small",
|
|
"max_tokens": 512,
|
|
"fp16": False,
|
|
"batch_size": 16,
|
|
"grad_accum": 1,
|
|
"gradient_checkpointing": False,
|
|
},
|
|
"bge-m3": {
|
|
"base_model_id": "MoritzLaurer/bge-m3-zeroshot-v2.0",
|
|
"max_tokens": 512,
|
|
"fp16": True,
|
|
"batch_size": 4,
|
|
"grad_accum": 4,
|
|
"gradient_checkpointing": True,
|
|
},
|
|
}
|
|
|
|
|
|
def load_and_prepare_data(score_file: Path) -> tuple[list[str], list[str]]:
|
|
"""Load labeled JSONL and return (texts, labels) filtered to canonical LABELS.
|
|
|
|
Drops rows with non-canonical labels (with warning), and drops entire classes
|
|
that have fewer than 2 total samples (required for stratified split).
|
|
Warns (but continues) for classes with fewer than 5 samples.
|
|
"""
|
|
if not score_file.exists():
|
|
raise FileNotFoundError(
|
|
f"Labeled data not found: {score_file}\n"
|
|
"Run the label tool first to generate email_score.jsonl."
|
|
)
|
|
|
|
label_set = set(LABELS)
|
|
rows: list[dict] = []
|
|
|
|
with score_file.open() as fh:
|
|
for line in fh:
|
|
line = line.strip()
|
|
if not line:
|
|
continue
|
|
try:
|
|
r = json.loads(line)
|
|
except json.JSONDecodeError:
|
|
continue
|
|
lbl = r.get("label", "")
|
|
if lbl not in label_set:
|
|
print(
|
|
f"[data] WARNING: Dropping row with non-canonical label {lbl!r}",
|
|
flush=True,
|
|
)
|
|
continue
|
|
rows.append(r)
|
|
|
|
# Count samples per class
|
|
counts: Counter = Counter(r["label"] for r in rows)
|
|
|
|
# Drop classes with < 2 total samples (cannot stratify-split)
|
|
drop_classes: set[str] = set()
|
|
for lbl, cnt in counts.items():
|
|
if cnt < 2:
|
|
print(
|
|
f"[data] WARNING: Dropping class {lbl!r} — only {counts[lbl]} total "
|
|
f"sample(s). Need at least 2 for stratified split.",
|
|
flush=True,
|
|
)
|
|
drop_classes.add(lbl)
|
|
|
|
# Warn for classes with < 5 samples (unreliable eval F1)
|
|
for lbl, cnt in counts.items():
|
|
if lbl not in drop_classes and cnt < 5:
|
|
print(
|
|
f"[data] WARNING: Class {lbl!r} has only {cnt} sample(s). "
|
|
f"Eval F1 for this class will be unreliable.",
|
|
flush=True,
|
|
)
|
|
|
|
# Filter rows
|
|
rows = [r for r in rows if r["label"] not in drop_classes]
|
|
|
|
texts = [f"{r['subject']} [SEP] {r['body'][:400]}" for r in rows]
|
|
labels = [r["label"] for r in rows]
|
|
|
|
return texts, labels
|
|
|
|
|
|
def compute_class_weights(label_ids: list[int], n_classes: int) -> torch.Tensor:
|
|
"""Compute inverse-frequency class weights.
|
|
|
|
Formula: total / (n_classes * class_count) per class.
|
|
Unseen classes (count=0) use count=1 to avoid division by zero.
|
|
|
|
Returns a CPU float32 tensor of shape (n_classes,).
|
|
"""
|
|
counts = Counter(label_ids)
|
|
total = len(label_ids)
|
|
weights = []
|
|
for cls in range(n_classes):
|
|
cnt = counts.get(cls, 1) # use 1 for unseen to avoid div-by-zero
|
|
weights.append(total / (n_classes * cnt))
|
|
return torch.tensor(weights, dtype=torch.float32)
|
|
|
|
|
|
def compute_metrics_for_trainer(eval_pred: EvalPrediction) -> dict:
|
|
"""Compute macro F1 and accuracy from EvalPrediction.
|
|
|
|
Called by Hugging Face Trainer at each evaluation step.
|
|
"""
|
|
logits, label_ids = eval_pred.predictions, eval_pred.label_ids
|
|
preds = logits.argmax(axis=-1)
|
|
macro_f1 = f1_score(label_ids, preds, average="macro", zero_division=0)
|
|
acc = accuracy_score(label_ids, preds)
|
|
return {"macro_f1": float(macro_f1), "accuracy": float(acc)}
|
|
|
|
|
|
class WeightedTrainer(Trainer):
|
|
"""Trainer subclass that applies per-class weights to the cross-entropy loss."""
|
|
|
|
def compute_loss(self, model, inputs, return_outputs=False, **kwargs):
|
|
# **kwargs is required — absorbs num_items_in_batch added in Transformers 4.38.
|
|
# Do not remove it; removing it causes TypeError on the first training step.
|
|
labels = inputs.pop("labels")
|
|
outputs = model(**inputs)
|
|
# Move class_weights to the same device as logits — required for GPU training.
|
|
# class_weights is created on CPU; logits are on cuda:0 during training.
|
|
weight = self.class_weights.to(outputs.logits.device)
|
|
loss = F.cross_entropy(outputs.logits, labels, weight=weight)
|
|
return (loss, outputs) if return_outputs else loss
|