avocet/scripts/finetune_classifier.py
pyr0ball 753f8f5def fix(avocet): use_reentrant=False for gradient checkpointing
Reentrant gradient checkpointing (the default) conflicts with Accelerate's
gradient accumulation context manager -- causes 'backward through graph a
second time' on the first training step. use_reentrant=False uses the
non-reentrant autograd hook path which is compatible with Accelerate >= 0.27.
2026-03-15 17:23:40 -07:00

416 lines
15 KiB
Python

"""Fine-tune email classifiers on the labeled dataset.
CLI entry point. All prints use flush=True so stdout is SSE-streamable.
Usage:
python scripts/finetune_classifier.py --model deberta-small [--epochs 5]
Supported --model values: deberta-small, bge-m3
"""
from __future__ import annotations
import argparse
import hashlib
import json
import sys
from collections import Counter
from datetime import datetime, timezone
from pathlib import Path
from typing import Any
import torch
import torch.nn.functional as F
from torch.utils.data import Dataset as TorchDataset
from sklearn.model_selection import train_test_split
from sklearn.metrics import f1_score, accuracy_score
from transformers import (
AutoTokenizer,
AutoModelForSequenceClassification,
EvalPrediction,
Trainer,
TrainingArguments,
EarlyStoppingCallback,
)
sys.path.insert(0, str(Path(__file__).parent.parent))
from scripts.classifier_adapters import LABELS
_ROOT = Path(__file__).parent.parent
_MODEL_CONFIG: dict[str, dict[str, Any]] = {
"deberta-small": {
"base_model_id": "cross-encoder/nli-deberta-v3-small",
"max_tokens": 512,
# fp16 must stay OFF — DeBERTa-v3 disentangled attention overflows fp16.
"fp16": False,
# batch_size=8 + grad_accum=2 keeps effective batch of 16 while halving
# per-step activation memory. gradient_checkpointing recomputes activations
# on backward instead of storing them — ~60% less activation VRAM.
"batch_size": 8,
"grad_accum": 2,
"gradient_checkpointing": True,
},
"bge-m3": {
"base_model_id": "MoritzLaurer/bge-m3-zeroshot-v2.0",
"max_tokens": 512,
"fp16": True,
"batch_size": 4,
"grad_accum": 4,
"gradient_checkpointing": True,
},
}
def load_and_prepare_data(score_files: Path | list[Path]) -> tuple[list[str], list[str]]:
"""Load labeled JSONL and return (texts, labels) filtered to canonical LABELS.
score_files: a single Path or a list of Paths. When multiple files are given,
rows are merged with last-write-wins deduplication keyed by content hash
(MD5 of subject + body[:100]).
Drops rows with non-canonical labels (with warning), and drops entire classes
that have fewer than 2 total samples (required for stratified split).
Warns (but continues) for classes with fewer than 5 samples.
"""
# Normalise to list — backwards compatible with single-Path callers.
if isinstance(score_files, Path):
score_files = [score_files]
for score_file in score_files:
if not score_file.exists():
raise FileNotFoundError(
f"Labeled data not found: {score_file}\n"
"Run the label tool first to generate email_score.jsonl."
)
label_set = set(LABELS)
# Use a plain dict keyed by content hash; later entries overwrite earlier ones
# (last-write wins), which lets later labeling runs correct earlier labels.
seen: dict[str, dict] = {}
total = 0
for score_file in score_files:
with score_file.open() as fh:
for line in fh:
line = line.strip()
if not line:
continue
try:
r = json.loads(line)
except json.JSONDecodeError:
continue
lbl = r.get("label", "")
if lbl not in label_set:
print(
f"[data] WARNING: Dropping row with non-canonical label {lbl!r}",
flush=True,
)
continue
content_hash = hashlib.md5(
(r.get("subject", "") + (r.get("body", "") or "")[:100]).encode(
"utf-8", errors="replace"
)
).hexdigest()
seen[content_hash] = r
total += 1
kept = len(seen)
dropped = total - kept
if dropped > 0:
print(
f"[data] Deduped: kept {kept} of {total} rows (dropped {dropped} duplicates)",
flush=True,
)
rows = list(seen.values())
# Count samples per class
counts: Counter = Counter(r["label"] for r in rows)
# Drop classes with < 2 total samples (cannot stratify-split)
drop_classes: set[str] = set()
for lbl, cnt in counts.items():
if cnt < 2:
print(
f"[data] WARNING: Dropping class {lbl!r} — only {counts[lbl]} total "
f"sample(s). Need at least 2 for stratified split.",
flush=True,
)
drop_classes.add(lbl)
# Warn for classes with < 5 samples (unreliable eval F1)
for lbl, cnt in counts.items():
if lbl not in drop_classes and cnt < 5:
print(
f"[data] WARNING: Class {lbl!r} has only {cnt} sample(s). "
f"Eval F1 for this class will be unreliable.",
flush=True,
)
# Filter rows
rows = [r for r in rows if r["label"] not in drop_classes]
texts = [f"{r['subject']} [SEP] {r['body'][:400]}" for r in rows]
labels = [r["label"] for r in rows]
return texts, labels
def compute_class_weights(label_ids: list[int], n_classes: int) -> torch.Tensor:
"""Compute inverse-frequency class weights.
Formula: total / (n_classes * class_count) per class.
Unseen classes (count=0) use count=1 to avoid division by zero.
Returns a CPU float32 tensor of shape (n_classes,).
"""
counts = Counter(label_ids)
total = len(label_ids)
weights = []
for cls in range(n_classes):
cnt = counts.get(cls, 1) # use 1 for unseen to avoid div-by-zero
weights.append(total / (n_classes * cnt))
return torch.tensor(weights, dtype=torch.float32)
def compute_metrics_for_trainer(eval_pred: EvalPrediction) -> dict:
"""Compute macro F1 and accuracy from EvalPrediction.
Called by Hugging Face Trainer at each evaluation step.
"""
logits, label_ids = eval_pred.predictions, eval_pred.label_ids
preds = logits.argmax(axis=-1)
macro_f1 = f1_score(label_ids, preds, average="macro", zero_division=0)
acc = accuracy_score(label_ids, preds)
return {"macro_f1": float(macro_f1), "accuracy": float(acc)}
class WeightedTrainer(Trainer):
"""Trainer subclass that applies per-class weights to the cross-entropy loss."""
def compute_loss(self, model, inputs, return_outputs=False, **kwargs):
# **kwargs is required — absorbs num_items_in_batch added in Transformers 4.38.
# Do not remove it; removing it causes TypeError on the first training step.
labels = inputs.pop("labels")
outputs = model(**inputs)
# Move class_weights to the same device as logits — required for GPU training.
# class_weights is created on CPU; logits are on cuda:0 during training.
weight = self.class_weights.to(outputs.logits.device)
loss = F.cross_entropy(outputs.logits, labels, weight=weight)
return (loss, outputs) if return_outputs else loss
# ---------------------------------------------------------------------------
# Training dataset wrapper
# ---------------------------------------------------------------------------
class _EmailDataset(TorchDataset):
def __init__(self, encodings: dict, label_ids: list[int]) -> None:
self.encodings = encodings
self.label_ids = label_ids
def __len__(self) -> int:
return len(self.label_ids)
def __getitem__(self, idx: int) -> dict:
item = {k: torch.tensor(v[idx]) for k, v in self.encodings.items()}
item["labels"] = torch.tensor(self.label_ids[idx], dtype=torch.long)
return item
# ---------------------------------------------------------------------------
# Main training function
# ---------------------------------------------------------------------------
def run_finetune(model_key: str, epochs: int = 5, score_files: list[Path] | None = None) -> None:
"""Fine-tune the specified model on labeled data.
score_files: list of score JSONL paths to merge. Defaults to [_ROOT / "data" / "email_score.jsonl"].
Saves model + tokenizer + training_info.json to models/avocet-{model_key}/.
All prints use flush=True for SSE streaming.
"""
if model_key not in _MODEL_CONFIG:
raise ValueError(f"Unknown model key: {model_key!r}. Choose from: {list(_MODEL_CONFIG)}")
if score_files is None:
score_files = [_ROOT / "data" / "email_score.jsonl"]
config = _MODEL_CONFIG[model_key]
base_model_id = config["base_model_id"]
output_dir = _ROOT / "models" / f"avocet-{model_key}"
print(f"[finetune] Model: {model_key} ({base_model_id})", flush=True)
print(f"[finetune] Score files: {[str(f) for f in score_files]}", flush=True)
print(f"[finetune] Output: {output_dir}", flush=True)
if output_dir.exists():
print(f"[finetune] WARNING: {output_dir} already exists — will overwrite.", flush=True)
# --- Data ---
print(f"[finetune] Loading data ...", flush=True)
texts, str_labels = load_and_prepare_data(score_files)
present_labels = sorted(set(str_labels))
label2id = {l: i for i, l in enumerate(present_labels)}
id2label = {i: l for l, i in label2id.items()}
n_classes = len(present_labels)
label_ids = [label2id[l] for l in str_labels]
print(f"[finetune] {len(texts)} samples, {n_classes} classes", flush=True)
# Stratified 80/20 split — ensure val set has at least n_classes samples.
# For very small datasets (e.g. example data) we may need to give the val set
# more than 20% so every class appears at least once in eval.
desired_test = max(int(len(texts) * 0.2), n_classes)
# test_size must leave at least n_classes samples for train too
desired_test = min(desired_test, len(texts) - n_classes)
(train_texts, val_texts,
train_label_ids, val_label_ids) = train_test_split(
texts, label_ids,
test_size=desired_test,
stratify=label_ids,
random_state=42,
)
print(f"[finetune] Train: {len(train_texts)}, Val: {len(val_texts)}", flush=True)
# Warn for classes with < 5 training samples
train_counts = Counter(train_label_ids)
for cls_id, cnt in train_counts.items():
if cnt < 5:
print(
f"[finetune] WARNING: Class {id2label[cls_id]!r} has {cnt} training sample(s). "
"Eval F1 for this class will be unreliable.",
flush=True,
)
# --- Tokenize ---
print(f"[finetune] Loading tokenizer ...", flush=True)
tokenizer = AutoTokenizer.from_pretrained(base_model_id)
train_enc = tokenizer(train_texts, truncation=True,
max_length=config["max_tokens"], padding=True)
val_enc = tokenizer(val_texts, truncation=True,
max_length=config["max_tokens"], padding=True)
train_dataset = _EmailDataset(train_enc, train_label_ids)
val_dataset = _EmailDataset(val_enc, val_label_ids)
# --- Class weights ---
class_weights = compute_class_weights(train_label_ids, n_classes)
print(f"[finetune] Class weights computed", flush=True)
# --- Model ---
print(f"[finetune] Loading model ...", flush=True)
model = AutoModelForSequenceClassification.from_pretrained(
base_model_id,
num_labels=n_classes,
ignore_mismatched_sizes=True, # NLI head (3-class) → new head (n_classes)
id2label=id2label,
label2id=label2id,
)
if config["gradient_checkpointing"]:
# use_reentrant=False avoids "backward through graph a second time" errors
# when Accelerate's gradient accumulation context is layered on top.
# Reentrant checkpointing (the default) conflicts with Accelerate ≥ 0.27.
model.gradient_checkpointing_enable(
gradient_checkpointing_kwargs={"use_reentrant": False}
)
# --- TrainingArguments ---
training_args = TrainingArguments(
output_dir=str(output_dir),
num_train_epochs=epochs,
per_device_train_batch_size=config["batch_size"],
per_device_eval_batch_size=config["batch_size"],
gradient_accumulation_steps=config["grad_accum"],
learning_rate=2e-5,
lr_scheduler_type="linear",
warmup_ratio=0.1,
fp16=config["fp16"],
eval_strategy="epoch",
save_strategy="epoch",
load_best_model_at_end=True,
metric_for_best_model="macro_f1",
greater_is_better=True,
logging_steps=10,
report_to="none",
save_total_limit=2,
)
trainer = WeightedTrainer(
model=model,
args=training_args,
train_dataset=train_dataset,
eval_dataset=val_dataset,
compute_metrics=compute_metrics_for_trainer,
callbacks=[EarlyStoppingCallback(early_stopping_patience=3)],
)
trainer.class_weights = class_weights
# --- Train ---
print(f"[finetune] Starting training ({epochs} epochs) ...", flush=True)
train_result = trainer.train()
print(f"[finetune] Training complete. Steps: {train_result.global_step}", flush=True)
# --- Evaluate ---
print(f"[finetune] Evaluating best checkpoint ...", flush=True)
metrics = trainer.evaluate()
val_macro_f1 = metrics.get("eval_macro_f1", 0.0)
val_accuracy = metrics.get("eval_accuracy", 0.0)
print(f"[finetune] Val macro-F1: {val_macro_f1:.4f}, Accuracy: {val_accuracy:.4f}", flush=True)
# --- Save model + tokenizer ---
print(f"[finetune] Saving model to {output_dir} ...", flush=True)
trainer.save_model(str(output_dir))
tokenizer.save_pretrained(str(output_dir))
# --- Write training_info.json ---
label_counts = dict(Counter(str_labels))
info = {
"name": f"avocet-{model_key}",
"base_model_id": base_model_id,
"timestamp": datetime.now(timezone.utc).isoformat(),
"epochs_run": epochs,
"val_macro_f1": round(val_macro_f1, 4),
"val_accuracy": round(val_accuracy, 4),
"sample_count": len(texts),
"train_sample_count": len(train_texts),
"label_counts": label_counts,
"score_files": [str(f) for f in score_files],
}
info_path = output_dir / "training_info.json"
info_path.write_text(json.dumps(info, indent=2), encoding="utf-8")
print(f"[finetune] Saved training_info.json: val_macro_f1={val_macro_f1:.4f}", flush=True)
print(f"[finetune] Done.", flush=True)
# ---------------------------------------------------------------------------
# CLI
# ---------------------------------------------------------------------------
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Fine-tune an email classifier")
parser.add_argument(
"--model",
choices=list(_MODEL_CONFIG),
required=True,
help="Model key to fine-tune",
)
parser.add_argument(
"--epochs",
type=int,
default=5,
help="Number of training epochs (default: 5)",
)
parser.add_argument(
"--score",
dest="score_files",
type=Path,
action="append",
metavar="FILE",
help="Score JSONL file to include (repeatable; defaults to data/email_score.jsonl)",
)
args = parser.parse_args()
score_files = args.score_files or None # None → run_finetune uses default
run_finetune(args.model, args.epochs, score_files=score_files)