diff --git a/scripts/finetune_classifier.py b/scripts/finetune_classifier.py index c225b71..b3d4fe9 100644 --- a/scripts/finetune_classifier.py +++ b/scripts/finetune_classifier.py @@ -10,6 +10,7 @@ Supported --model values: deberta-small, bge-m3 from __future__ import annotations import argparse +import hashlib import json import sys from collections import Counter @@ -56,39 +57,68 @@ _MODEL_CONFIG: dict[str, dict[str, Any]] = { } -def load_and_prepare_data(score_file: Path) -> tuple[list[str], list[str]]: +def load_and_prepare_data(score_files: Path | list[Path]) -> tuple[list[str], list[str]]: """Load labeled JSONL and return (texts, labels) filtered to canonical LABELS. + score_files: a single Path or a list of Paths. When multiple files are given, + rows are merged with last-write-wins deduplication keyed by content hash + (MD5 of subject + body[:100]). + Drops rows with non-canonical labels (with warning), and drops entire classes that have fewer than 2 total samples (required for stratified split). Warns (but continues) for classes with fewer than 5 samples. """ - if not score_file.exists(): - raise FileNotFoundError( - f"Labeled data not found: {score_file}\n" - "Run the label tool first to generate email_score.jsonl." - ) + # Normalise to list — backwards compatible with single-Path callers. + if isinstance(score_files, Path): + score_files = [score_files] + + for score_file in score_files: + if not score_file.exists(): + raise FileNotFoundError( + f"Labeled data not found: {score_file}\n" + "Run the label tool first to generate email_score.jsonl." + ) label_set = set(LABELS) - rows: list[dict] = [] + # Use a plain dict keyed by content hash; later entries overwrite earlier ones + # (last-write wins), which lets later labeling runs correct earlier labels. + seen: dict[str, dict] = {} + total = 0 - with score_file.open() as fh: - for line in fh: - line = line.strip() - if not line: - continue - try: - r = json.loads(line) - except json.JSONDecodeError: - continue - lbl = r.get("label", "") - if lbl not in label_set: - print( - f"[data] WARNING: Dropping row with non-canonical label {lbl!r}", - flush=True, - ) - continue - rows.append(r) + for score_file in score_files: + with score_file.open() as fh: + for line in fh: + line = line.strip() + if not line: + continue + try: + r = json.loads(line) + except json.JSONDecodeError: + continue + lbl = r.get("label", "") + if lbl not in label_set: + print( + f"[data] WARNING: Dropping row with non-canonical label {lbl!r}", + flush=True, + ) + continue + content_hash = hashlib.md5( + (r.get("subject", "") + (r.get("body", "") or "")[:100]).encode( + "utf-8", errors="replace" + ) + ).hexdigest() + seen[content_hash] = r + total += 1 + + kept = len(seen) + dropped = total - kept + if dropped > 0: + print( + f"[data] Deduped: kept {kept} of {total} rows (dropped {dropped} duplicates)", + flush=True, + ) + + rows = list(seen.values()) # Count samples per class counts: Counter = Counter(r["label"] for r in rows) @@ -164,3 +194,214 @@ class WeightedTrainer(Trainer): weight = self.class_weights.to(outputs.logits.device) loss = F.cross_entropy(outputs.logits, labels, weight=weight) return (loss, outputs) if return_outputs else loss + + +# --------------------------------------------------------------------------- +# Training dataset wrapper +# --------------------------------------------------------------------------- + +from torch.utils.data import Dataset as TorchDataset + + +class _EmailDataset(TorchDataset): + def __init__(self, encodings: dict, label_ids: list[int]) -> None: + self.encodings = encodings + self.label_ids = label_ids + + def __len__(self) -> int: + return len(self.label_ids) + + def __getitem__(self, idx: int) -> dict: + item = {k: torch.tensor(v[idx]) for k, v in self.encodings.items()} + item["labels"] = torch.tensor(self.label_ids[idx], dtype=torch.long) + return item + + +# --------------------------------------------------------------------------- +# Main training function +# --------------------------------------------------------------------------- + +def run_finetune(model_key: str, epochs: int = 5, score_files: list[Path] | None = None) -> None: + """Fine-tune the specified model on labeled data. + + score_files: list of score JSONL paths to merge. Defaults to [_ROOT / "data" / "email_score.jsonl"]. + Saves model + tokenizer + training_info.json to models/avocet-{model_key}/. + All prints use flush=True for SSE streaming. + """ + if model_key not in _MODEL_CONFIG: + raise ValueError(f"Unknown model key: {model_key!r}. Choose from: {list(_MODEL_CONFIG)}") + + if score_files is None: + score_files = [_ROOT / "data" / "email_score.jsonl"] + + config = _MODEL_CONFIG[model_key] + base_model_id = config["base_model_id"] + output_dir = _ROOT / "models" / f"avocet-{model_key}" + + print(f"[finetune] Model: {model_key} ({base_model_id})", flush=True) + print(f"[finetune] Score files: {[str(f) for f in score_files]}", flush=True) + print(f"[finetune] Output: {output_dir}", flush=True) + if output_dir.exists(): + print(f"[finetune] WARNING: {output_dir} already exists — will overwrite.", flush=True) + + # --- Data --- + print(f"[finetune] Loading data ...", flush=True) + texts, str_labels = load_and_prepare_data(score_files) + + present_labels = sorted(set(str_labels)) + label2id = {l: i for i, l in enumerate(present_labels)} + id2label = {i: l for l, i in label2id.items()} + n_classes = len(present_labels) + label_ids = [label2id[l] for l in str_labels] + + print(f"[finetune] {len(texts)} samples, {n_classes} classes", flush=True) + + # Stratified 80/20 split — ensure val set has at least n_classes samples. + # For very small datasets (e.g. example data) we may need to give the val set + # more than 20% so every class appears at least once in eval. + desired_test = max(int(len(texts) * 0.2), n_classes) + # test_size must leave at least n_classes samples for train too + desired_test = min(desired_test, len(texts) - n_classes) + (train_texts, val_texts, + train_label_ids, val_label_ids) = train_test_split( + texts, label_ids, + test_size=desired_test, + stratify=label_ids, + random_state=42, + ) + print(f"[finetune] Train: {len(train_texts)}, Val: {len(val_texts)}", flush=True) + + # Warn for classes with < 5 training samples + train_counts = Counter(train_label_ids) + for cls_id, cnt in train_counts.items(): + if cnt < 5: + print( + f"[finetune] WARNING: Class {id2label[cls_id]!r} has {cnt} training sample(s). " + "Eval F1 for this class will be unreliable.", + flush=True, + ) + + # --- Tokenize --- + print(f"[finetune] Loading tokenizer ...", flush=True) + tokenizer = AutoTokenizer.from_pretrained(base_model_id) + + train_enc = tokenizer(train_texts, truncation=True, + max_length=config["max_tokens"], padding=True) + val_enc = tokenizer(val_texts, truncation=True, + max_length=config["max_tokens"], padding=True) + + train_dataset = _EmailDataset(train_enc, train_label_ids) + val_dataset = _EmailDataset(val_enc, val_label_ids) + + # --- Class weights --- + class_weights = compute_class_weights(train_label_ids, n_classes) + print(f"[finetune] Class weights computed", flush=True) + + # --- Model --- + print(f"[finetune] Loading model ...", flush=True) + model = AutoModelForSequenceClassification.from_pretrained( + base_model_id, + num_labels=n_classes, + ignore_mismatched_sizes=True, # NLI head (3-class) → new head (n_classes) + id2label=id2label, + label2id=label2id, + ) + if config["gradient_checkpointing"]: + model.gradient_checkpointing_enable() + + # --- TrainingArguments --- + training_args = TrainingArguments( + output_dir=str(output_dir), + num_train_epochs=epochs, + per_device_train_batch_size=config["batch_size"], + per_device_eval_batch_size=config["batch_size"], + gradient_accumulation_steps=config["grad_accum"], + learning_rate=2e-5, + lr_scheduler_type="linear", + warmup_ratio=0.1, + fp16=config["fp16"], + eval_strategy="epoch", + save_strategy="epoch", + load_best_model_at_end=True, + metric_for_best_model="macro_f1", + greater_is_better=True, + logging_steps=10, + report_to="none", + save_total_limit=2, + ) + + trainer = WeightedTrainer( + model=model, + args=training_args, + train_dataset=train_dataset, + eval_dataset=val_dataset, + compute_metrics=compute_metrics_for_trainer, + callbacks=[EarlyStoppingCallback(early_stopping_patience=3)], + ) + trainer.class_weights = class_weights + + # --- Train --- + print(f"[finetune] Starting training ({epochs} epochs) ...", flush=True) + train_result = trainer.train() + print(f"[finetune] Training complete. Steps: {train_result.global_step}", flush=True) + + # --- Evaluate --- + print(f"[finetune] Evaluating best checkpoint ...", flush=True) + metrics = trainer.evaluate() + val_macro_f1 = metrics.get("eval_macro_f1", 0.0) + val_accuracy = metrics.get("eval_accuracy", 0.0) + print(f"[finetune] Val macro-F1: {val_macro_f1:.4f}, Accuracy: {val_accuracy:.4f}", flush=True) + + # --- Save model + tokenizer --- + print(f"[finetune] Saving model to {output_dir} ...", flush=True) + trainer.save_model(str(output_dir)) + tokenizer.save_pretrained(str(output_dir)) + + # --- Write training_info.json --- + label_counts = dict(Counter(str_labels)) + info = { + "name": f"avocet-{model_key}", + "base_model_id": base_model_id, + "timestamp": datetime.now(timezone.utc).isoformat(), + "epochs_run": epochs, + "val_macro_f1": round(val_macro_f1, 4), + "val_accuracy": round(val_accuracy, 4), + "sample_count": len(train_texts), + "label_counts": label_counts, + "score_files": [str(f) for f in score_files], + } + info_path = output_dir / "training_info.json" + info_path.write_text(json.dumps(info, indent=2), encoding="utf-8") + print(f"[finetune] Saved training_info.json: val_macro_f1={val_macro_f1:.4f}", flush=True) + print(f"[finetune] Done.", flush=True) + + +# --------------------------------------------------------------------------- +# CLI +# --------------------------------------------------------------------------- + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="Fine-tune an email classifier") + parser.add_argument( + "--model", + choices=list(_MODEL_CONFIG), + required=True, + help="Model key to fine-tune", + ) + parser.add_argument( + "--epochs", + type=int, + default=5, + help="Number of training epochs (default: 5)", + ) + parser.add_argument( + "--score", + dest="score_files", + type=Path, + action="append", + metavar="FILE", + help="Score JSONL file to include (repeatable; defaults to data/email_score.jsonl)", + ) + args = parser.parse_args() + score_files = args.score_files or None # None → run_finetune uses default + run_finetune(args.model, args.epochs, score_files=score_files) diff --git a/tests/test_finetune.py b/tests/test_finetune.py index 1c3e0bd..603a582 100644 --- a/tests/test_finetune.py +++ b/tests/test_finetune.py @@ -249,3 +249,123 @@ def test_weighted_trainer_compute_loss_returns_outputs_when_requested(): assert isinstance(result, tuple) loss, outputs = result assert isinstance(loss, torch.Tensor) + + +# ---- Multi-file merge / dedup tests ---- + +def test_load_and_prepare_data_merges_multiple_files(tmp_path): + """Multiple score files must be merged into a single dataset.""" + from scripts.finetune_classifier import load_and_prepare_data + + file1 = tmp_path / "run1.jsonl" + file2 = tmp_path / "run2.jsonl" + file1.write_text( + json.dumps({"subject": "s1", "body": "b1", "label": "digest"}) + "\n" + + json.dumps({"subject": "s2", "body": "b2", "label": "digest"}) + "\n" + ) + file2.write_text( + json.dumps({"subject": "s3", "body": "b3", "label": "neutral"}) + "\n" + + json.dumps({"subject": "s4", "body": "b4", "label": "neutral"}) + "\n" + ) + + texts, labels = load_and_prepare_data([file1, file2]) + assert len(texts) == 4 + assert labels.count("digest") == 2 + assert labels.count("neutral") == 2 + + +def test_load_and_prepare_data_deduplicates_last_write_wins(tmp_path, capsys): + """Duplicate rows (same content hash) keep the last occurrence.""" + from scripts.finetune_classifier import load_and_prepare_data + + # Same subject+body[:100] = same hash + row_early = {"subject": "Hello", "body": "World", "label": "neutral"} + row_late = {"subject": "Hello", "body": "World", "label": "digest"} # relabeled + + file1 = tmp_path / "run1.jsonl" + file2 = tmp_path / "run2.jsonl" + # Add a second row with different content so class count >= 2 for both classes + file1.write_text( + json.dumps(row_early) + "\n" + + json.dumps({"subject": "Other1", "body": "Other", "label": "neutral"}) + "\n" + ) + file2.write_text( + json.dumps(row_late) + "\n" + + json.dumps({"subject": "Other2", "body": "Stuff", "label": "digest"}) + "\n" + ) + + texts, labels = load_and_prepare_data([file1, file2]) + captured = capsys.readouterr() + + # The duplicate row should be counted as dropped + assert "Deduped" in captured.out + # The relabeled row should have "digest" (last-write wins), not "neutral" + hello_idx = next(i for i, t in enumerate(texts) if t.startswith("Hello [SEP]")) + assert labels[hello_idx] == "digest" + + +def test_load_and_prepare_data_single_path_still_works(tmp_path): + """Passing a single Path (not a list) must still work — backwards compatibility.""" + from scripts.finetune_classifier import load_and_prepare_data + + rows = [ + {"subject": "s1", "body": "b1", "label": "digest"}, + {"subject": "s2", "body": "b2", "label": "digest"}, + ] + score_file = tmp_path / "email_score.jsonl" + score_file.write_text("\n".join(json.dumps(r) for r in rows)) + + texts, labels = load_and_prepare_data(score_file) # single Path, not list + assert len(texts) == 2 + + +# ---- Integration test ---- + +def test_integration_finetune_on_example_data(tmp_path): + """Fine-tune deberta-small on example data for 1 epoch. + + Uses data/email_score.jsonl.example (8 samples, 5 labels represented). + The 5 missing labels must trigger the < 2 samples drop warning. + Verifies training_info.json is written with correct keys. + + Requires job-seeker-classifiers env and downloads deberta-small (~100MB on first run). + """ + import shutil + from scripts import finetune_classifier as ft_mod + from scripts.finetune_classifier import run_finetune + + example_file = ft_mod._ROOT / "data" / "email_score.jsonl.example" + if not example_file.exists(): + pytest.skip("email_score.jsonl.example not found") + + orig_root = ft_mod._ROOT + ft_mod._ROOT = tmp_path + (tmp_path / "data").mkdir() + shutil.copy(example_file, tmp_path / "data" / "email_score.jsonl") + + try: + import io + from contextlib import redirect_stdout + captured = io.StringIO() + with redirect_stdout(captured): + run_finetune("deberta-small", epochs=1) + output = captured.getvalue() + finally: + ft_mod._ROOT = orig_root + + # Missing labels should trigger the < 2 samples drop warning + assert "WARNING: Dropping class" in output + + # training_info.json must exist with correct keys + info_path = tmp_path / "models" / "avocet-deberta-small" / "training_info.json" + assert info_path.exists(), "training_info.json not written" + + info = json.loads(info_path.read_text()) + for key in ("name", "base_model_id", "timestamp", "epochs_run", + "val_macro_f1", "val_accuracy", "sample_count", + "label_counts", "score_files"): + assert key in info, f"Missing key: {key}" + + assert info["name"] == "avocet-deberta-small" + assert info["epochs_run"] == 1 + assert isinstance(info["score_files"], list)