feat(avocet): run_finetune, CLI, multi-score-file merge with last-write-wins dedup
- load_and_prepare_data() now accepts Path | list[Path]; single-Path callers unchanged - Dedup by MD5(subject + body[:100]); last file/row wins (lets later runs correct labels) - Prints summary line when duplicates are dropped - Added _EmailDataset (TorchDataset wrapper), run_finetune(), and argparse CLI - run_finetune() saves model + tokenizer + training_info.json with score_files provenance - Stratified split guard: val set size clamped to at least n_classes (handles tiny example data) - 3 new unit tests (merge, last-write-wins dedup, single-Path compat) + 1 integration test - All 16 tests pass (15 unit + 1 integration)
This commit is contained in:
parent
f262b23cf5
commit
8ba34bb2d1
2 changed files with 385 additions and 24 deletions
|
|
@ -10,6 +10,7 @@ Supported --model values: deberta-small, bge-m3
|
|||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import hashlib
|
||||
import json
|
||||
import sys
|
||||
from collections import Counter
|
||||
|
|
@ -56,39 +57,68 @@ _MODEL_CONFIG: dict[str, dict[str, Any]] = {
|
|||
}
|
||||
|
||||
|
||||
def load_and_prepare_data(score_file: Path) -> tuple[list[str], list[str]]:
|
||||
def load_and_prepare_data(score_files: Path | list[Path]) -> tuple[list[str], list[str]]:
|
||||
"""Load labeled JSONL and return (texts, labels) filtered to canonical LABELS.
|
||||
|
||||
score_files: a single Path or a list of Paths. When multiple files are given,
|
||||
rows are merged with last-write-wins deduplication keyed by content hash
|
||||
(MD5 of subject + body[:100]).
|
||||
|
||||
Drops rows with non-canonical labels (with warning), and drops entire classes
|
||||
that have fewer than 2 total samples (required for stratified split).
|
||||
Warns (but continues) for classes with fewer than 5 samples.
|
||||
"""
|
||||
if not score_file.exists():
|
||||
raise FileNotFoundError(
|
||||
f"Labeled data not found: {score_file}\n"
|
||||
"Run the label tool first to generate email_score.jsonl."
|
||||
)
|
||||
# Normalise to list — backwards compatible with single-Path callers.
|
||||
if isinstance(score_files, Path):
|
||||
score_files = [score_files]
|
||||
|
||||
for score_file in score_files:
|
||||
if not score_file.exists():
|
||||
raise FileNotFoundError(
|
||||
f"Labeled data not found: {score_file}\n"
|
||||
"Run the label tool first to generate email_score.jsonl."
|
||||
)
|
||||
|
||||
label_set = set(LABELS)
|
||||
rows: list[dict] = []
|
||||
# Use a plain dict keyed by content hash; later entries overwrite earlier ones
|
||||
# (last-write wins), which lets later labeling runs correct earlier labels.
|
||||
seen: dict[str, dict] = {}
|
||||
total = 0
|
||||
|
||||
with score_file.open() as fh:
|
||||
for line in fh:
|
||||
line = line.strip()
|
||||
if not line:
|
||||
continue
|
||||
try:
|
||||
r = json.loads(line)
|
||||
except json.JSONDecodeError:
|
||||
continue
|
||||
lbl = r.get("label", "")
|
||||
if lbl not in label_set:
|
||||
print(
|
||||
f"[data] WARNING: Dropping row with non-canonical label {lbl!r}",
|
||||
flush=True,
|
||||
)
|
||||
continue
|
||||
rows.append(r)
|
||||
for score_file in score_files:
|
||||
with score_file.open() as fh:
|
||||
for line in fh:
|
||||
line = line.strip()
|
||||
if not line:
|
||||
continue
|
||||
try:
|
||||
r = json.loads(line)
|
||||
except json.JSONDecodeError:
|
||||
continue
|
||||
lbl = r.get("label", "")
|
||||
if lbl not in label_set:
|
||||
print(
|
||||
f"[data] WARNING: Dropping row with non-canonical label {lbl!r}",
|
||||
flush=True,
|
||||
)
|
||||
continue
|
||||
content_hash = hashlib.md5(
|
||||
(r.get("subject", "") + (r.get("body", "") or "")[:100]).encode(
|
||||
"utf-8", errors="replace"
|
||||
)
|
||||
).hexdigest()
|
||||
seen[content_hash] = r
|
||||
total += 1
|
||||
|
||||
kept = len(seen)
|
||||
dropped = total - kept
|
||||
if dropped > 0:
|
||||
print(
|
||||
f"[data] Deduped: kept {kept} of {total} rows (dropped {dropped} duplicates)",
|
||||
flush=True,
|
||||
)
|
||||
|
||||
rows = list(seen.values())
|
||||
|
||||
# Count samples per class
|
||||
counts: Counter = Counter(r["label"] for r in rows)
|
||||
|
|
@ -164,3 +194,214 @@ class WeightedTrainer(Trainer):
|
|||
weight = self.class_weights.to(outputs.logits.device)
|
||||
loss = F.cross_entropy(outputs.logits, labels, weight=weight)
|
||||
return (loss, outputs) if return_outputs else loss
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Training dataset wrapper
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
from torch.utils.data import Dataset as TorchDataset
|
||||
|
||||
|
||||
class _EmailDataset(TorchDataset):
|
||||
def __init__(self, encodings: dict, label_ids: list[int]) -> None:
|
||||
self.encodings = encodings
|
||||
self.label_ids = label_ids
|
||||
|
||||
def __len__(self) -> int:
|
||||
return len(self.label_ids)
|
||||
|
||||
def __getitem__(self, idx: int) -> dict:
|
||||
item = {k: torch.tensor(v[idx]) for k, v in self.encodings.items()}
|
||||
item["labels"] = torch.tensor(self.label_ids[idx], dtype=torch.long)
|
||||
return item
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Main training function
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def run_finetune(model_key: str, epochs: int = 5, score_files: list[Path] | None = None) -> None:
|
||||
"""Fine-tune the specified model on labeled data.
|
||||
|
||||
score_files: list of score JSONL paths to merge. Defaults to [_ROOT / "data" / "email_score.jsonl"].
|
||||
Saves model + tokenizer + training_info.json to models/avocet-{model_key}/.
|
||||
All prints use flush=True for SSE streaming.
|
||||
"""
|
||||
if model_key not in _MODEL_CONFIG:
|
||||
raise ValueError(f"Unknown model key: {model_key!r}. Choose from: {list(_MODEL_CONFIG)}")
|
||||
|
||||
if score_files is None:
|
||||
score_files = [_ROOT / "data" / "email_score.jsonl"]
|
||||
|
||||
config = _MODEL_CONFIG[model_key]
|
||||
base_model_id = config["base_model_id"]
|
||||
output_dir = _ROOT / "models" / f"avocet-{model_key}"
|
||||
|
||||
print(f"[finetune] Model: {model_key} ({base_model_id})", flush=True)
|
||||
print(f"[finetune] Score files: {[str(f) for f in score_files]}", flush=True)
|
||||
print(f"[finetune] Output: {output_dir}", flush=True)
|
||||
if output_dir.exists():
|
||||
print(f"[finetune] WARNING: {output_dir} already exists — will overwrite.", flush=True)
|
||||
|
||||
# --- Data ---
|
||||
print(f"[finetune] Loading data ...", flush=True)
|
||||
texts, str_labels = load_and_prepare_data(score_files)
|
||||
|
||||
present_labels = sorted(set(str_labels))
|
||||
label2id = {l: i for i, l in enumerate(present_labels)}
|
||||
id2label = {i: l for l, i in label2id.items()}
|
||||
n_classes = len(present_labels)
|
||||
label_ids = [label2id[l] for l in str_labels]
|
||||
|
||||
print(f"[finetune] {len(texts)} samples, {n_classes} classes", flush=True)
|
||||
|
||||
# Stratified 80/20 split — ensure val set has at least n_classes samples.
|
||||
# For very small datasets (e.g. example data) we may need to give the val set
|
||||
# more than 20% so every class appears at least once in eval.
|
||||
desired_test = max(int(len(texts) * 0.2), n_classes)
|
||||
# test_size must leave at least n_classes samples for train too
|
||||
desired_test = min(desired_test, len(texts) - n_classes)
|
||||
(train_texts, val_texts,
|
||||
train_label_ids, val_label_ids) = train_test_split(
|
||||
texts, label_ids,
|
||||
test_size=desired_test,
|
||||
stratify=label_ids,
|
||||
random_state=42,
|
||||
)
|
||||
print(f"[finetune] Train: {len(train_texts)}, Val: {len(val_texts)}", flush=True)
|
||||
|
||||
# Warn for classes with < 5 training samples
|
||||
train_counts = Counter(train_label_ids)
|
||||
for cls_id, cnt in train_counts.items():
|
||||
if cnt < 5:
|
||||
print(
|
||||
f"[finetune] WARNING: Class {id2label[cls_id]!r} has {cnt} training sample(s). "
|
||||
"Eval F1 for this class will be unreliable.",
|
||||
flush=True,
|
||||
)
|
||||
|
||||
# --- Tokenize ---
|
||||
print(f"[finetune] Loading tokenizer ...", flush=True)
|
||||
tokenizer = AutoTokenizer.from_pretrained(base_model_id)
|
||||
|
||||
train_enc = tokenizer(train_texts, truncation=True,
|
||||
max_length=config["max_tokens"], padding=True)
|
||||
val_enc = tokenizer(val_texts, truncation=True,
|
||||
max_length=config["max_tokens"], padding=True)
|
||||
|
||||
train_dataset = _EmailDataset(train_enc, train_label_ids)
|
||||
val_dataset = _EmailDataset(val_enc, val_label_ids)
|
||||
|
||||
# --- Class weights ---
|
||||
class_weights = compute_class_weights(train_label_ids, n_classes)
|
||||
print(f"[finetune] Class weights computed", flush=True)
|
||||
|
||||
# --- Model ---
|
||||
print(f"[finetune] Loading model ...", flush=True)
|
||||
model = AutoModelForSequenceClassification.from_pretrained(
|
||||
base_model_id,
|
||||
num_labels=n_classes,
|
||||
ignore_mismatched_sizes=True, # NLI head (3-class) → new head (n_classes)
|
||||
id2label=id2label,
|
||||
label2id=label2id,
|
||||
)
|
||||
if config["gradient_checkpointing"]:
|
||||
model.gradient_checkpointing_enable()
|
||||
|
||||
# --- TrainingArguments ---
|
||||
training_args = TrainingArguments(
|
||||
output_dir=str(output_dir),
|
||||
num_train_epochs=epochs,
|
||||
per_device_train_batch_size=config["batch_size"],
|
||||
per_device_eval_batch_size=config["batch_size"],
|
||||
gradient_accumulation_steps=config["grad_accum"],
|
||||
learning_rate=2e-5,
|
||||
lr_scheduler_type="linear",
|
||||
warmup_ratio=0.1,
|
||||
fp16=config["fp16"],
|
||||
eval_strategy="epoch",
|
||||
save_strategy="epoch",
|
||||
load_best_model_at_end=True,
|
||||
metric_for_best_model="macro_f1",
|
||||
greater_is_better=True,
|
||||
logging_steps=10,
|
||||
report_to="none",
|
||||
save_total_limit=2,
|
||||
)
|
||||
|
||||
trainer = WeightedTrainer(
|
||||
model=model,
|
||||
args=training_args,
|
||||
train_dataset=train_dataset,
|
||||
eval_dataset=val_dataset,
|
||||
compute_metrics=compute_metrics_for_trainer,
|
||||
callbacks=[EarlyStoppingCallback(early_stopping_patience=3)],
|
||||
)
|
||||
trainer.class_weights = class_weights
|
||||
|
||||
# --- Train ---
|
||||
print(f"[finetune] Starting training ({epochs} epochs) ...", flush=True)
|
||||
train_result = trainer.train()
|
||||
print(f"[finetune] Training complete. Steps: {train_result.global_step}", flush=True)
|
||||
|
||||
# --- Evaluate ---
|
||||
print(f"[finetune] Evaluating best checkpoint ...", flush=True)
|
||||
metrics = trainer.evaluate()
|
||||
val_macro_f1 = metrics.get("eval_macro_f1", 0.0)
|
||||
val_accuracy = metrics.get("eval_accuracy", 0.0)
|
||||
print(f"[finetune] Val macro-F1: {val_macro_f1:.4f}, Accuracy: {val_accuracy:.4f}", flush=True)
|
||||
|
||||
# --- Save model + tokenizer ---
|
||||
print(f"[finetune] Saving model to {output_dir} ...", flush=True)
|
||||
trainer.save_model(str(output_dir))
|
||||
tokenizer.save_pretrained(str(output_dir))
|
||||
|
||||
# --- Write training_info.json ---
|
||||
label_counts = dict(Counter(str_labels))
|
||||
info = {
|
||||
"name": f"avocet-{model_key}",
|
||||
"base_model_id": base_model_id,
|
||||
"timestamp": datetime.now(timezone.utc).isoformat(),
|
||||
"epochs_run": epochs,
|
||||
"val_macro_f1": round(val_macro_f1, 4),
|
||||
"val_accuracy": round(val_accuracy, 4),
|
||||
"sample_count": len(train_texts),
|
||||
"label_counts": label_counts,
|
||||
"score_files": [str(f) for f in score_files],
|
||||
}
|
||||
info_path = output_dir / "training_info.json"
|
||||
info_path.write_text(json.dumps(info, indent=2), encoding="utf-8")
|
||||
print(f"[finetune] Saved training_info.json: val_macro_f1={val_macro_f1:.4f}", flush=True)
|
||||
print(f"[finetune] Done.", flush=True)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# CLI
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description="Fine-tune an email classifier")
|
||||
parser.add_argument(
|
||||
"--model",
|
||||
choices=list(_MODEL_CONFIG),
|
||||
required=True,
|
||||
help="Model key to fine-tune",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--epochs",
|
||||
type=int,
|
||||
default=5,
|
||||
help="Number of training epochs (default: 5)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--score",
|
||||
dest="score_files",
|
||||
type=Path,
|
||||
action="append",
|
||||
metavar="FILE",
|
||||
help="Score JSONL file to include (repeatable; defaults to data/email_score.jsonl)",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
score_files = args.score_files or None # None → run_finetune uses default
|
||||
run_finetune(args.model, args.epochs, score_files=score_files)
|
||||
|
|
|
|||
|
|
@ -249,3 +249,123 @@ def test_weighted_trainer_compute_loss_returns_outputs_when_requested():
|
|||
assert isinstance(result, tuple)
|
||||
loss, outputs = result
|
||||
assert isinstance(loss, torch.Tensor)
|
||||
|
||||
|
||||
# ---- Multi-file merge / dedup tests ----
|
||||
|
||||
def test_load_and_prepare_data_merges_multiple_files(tmp_path):
|
||||
"""Multiple score files must be merged into a single dataset."""
|
||||
from scripts.finetune_classifier import load_and_prepare_data
|
||||
|
||||
file1 = tmp_path / "run1.jsonl"
|
||||
file2 = tmp_path / "run2.jsonl"
|
||||
file1.write_text(
|
||||
json.dumps({"subject": "s1", "body": "b1", "label": "digest"}) + "\n" +
|
||||
json.dumps({"subject": "s2", "body": "b2", "label": "digest"}) + "\n"
|
||||
)
|
||||
file2.write_text(
|
||||
json.dumps({"subject": "s3", "body": "b3", "label": "neutral"}) + "\n" +
|
||||
json.dumps({"subject": "s4", "body": "b4", "label": "neutral"}) + "\n"
|
||||
)
|
||||
|
||||
texts, labels = load_and_prepare_data([file1, file2])
|
||||
assert len(texts) == 4
|
||||
assert labels.count("digest") == 2
|
||||
assert labels.count("neutral") == 2
|
||||
|
||||
|
||||
def test_load_and_prepare_data_deduplicates_last_write_wins(tmp_path, capsys):
|
||||
"""Duplicate rows (same content hash) keep the last occurrence."""
|
||||
from scripts.finetune_classifier import load_and_prepare_data
|
||||
|
||||
# Same subject+body[:100] = same hash
|
||||
row_early = {"subject": "Hello", "body": "World", "label": "neutral"}
|
||||
row_late = {"subject": "Hello", "body": "World", "label": "digest"} # relabeled
|
||||
|
||||
file1 = tmp_path / "run1.jsonl"
|
||||
file2 = tmp_path / "run2.jsonl"
|
||||
# Add a second row with different content so class count >= 2 for both classes
|
||||
file1.write_text(
|
||||
json.dumps(row_early) + "\n" +
|
||||
json.dumps({"subject": "Other1", "body": "Other", "label": "neutral"}) + "\n"
|
||||
)
|
||||
file2.write_text(
|
||||
json.dumps(row_late) + "\n" +
|
||||
json.dumps({"subject": "Other2", "body": "Stuff", "label": "digest"}) + "\n"
|
||||
)
|
||||
|
||||
texts, labels = load_and_prepare_data([file1, file2])
|
||||
captured = capsys.readouterr()
|
||||
|
||||
# The duplicate row should be counted as dropped
|
||||
assert "Deduped" in captured.out
|
||||
# The relabeled row should have "digest" (last-write wins), not "neutral"
|
||||
hello_idx = next(i for i, t in enumerate(texts) if t.startswith("Hello [SEP]"))
|
||||
assert labels[hello_idx] == "digest"
|
||||
|
||||
|
||||
def test_load_and_prepare_data_single_path_still_works(tmp_path):
|
||||
"""Passing a single Path (not a list) must still work — backwards compatibility."""
|
||||
from scripts.finetune_classifier import load_and_prepare_data
|
||||
|
||||
rows = [
|
||||
{"subject": "s1", "body": "b1", "label": "digest"},
|
||||
{"subject": "s2", "body": "b2", "label": "digest"},
|
||||
]
|
||||
score_file = tmp_path / "email_score.jsonl"
|
||||
score_file.write_text("\n".join(json.dumps(r) for r in rows))
|
||||
|
||||
texts, labels = load_and_prepare_data(score_file) # single Path, not list
|
||||
assert len(texts) == 2
|
||||
|
||||
|
||||
# ---- Integration test ----
|
||||
|
||||
def test_integration_finetune_on_example_data(tmp_path):
|
||||
"""Fine-tune deberta-small on example data for 1 epoch.
|
||||
|
||||
Uses data/email_score.jsonl.example (8 samples, 5 labels represented).
|
||||
The 5 missing labels must trigger the < 2 samples drop warning.
|
||||
Verifies training_info.json is written with correct keys.
|
||||
|
||||
Requires job-seeker-classifiers env and downloads deberta-small (~100MB on first run).
|
||||
"""
|
||||
import shutil
|
||||
from scripts import finetune_classifier as ft_mod
|
||||
from scripts.finetune_classifier import run_finetune
|
||||
|
||||
example_file = ft_mod._ROOT / "data" / "email_score.jsonl.example"
|
||||
if not example_file.exists():
|
||||
pytest.skip("email_score.jsonl.example not found")
|
||||
|
||||
orig_root = ft_mod._ROOT
|
||||
ft_mod._ROOT = tmp_path
|
||||
(tmp_path / "data").mkdir()
|
||||
shutil.copy(example_file, tmp_path / "data" / "email_score.jsonl")
|
||||
|
||||
try:
|
||||
import io
|
||||
from contextlib import redirect_stdout
|
||||
captured = io.StringIO()
|
||||
with redirect_stdout(captured):
|
||||
run_finetune("deberta-small", epochs=1)
|
||||
output = captured.getvalue()
|
||||
finally:
|
||||
ft_mod._ROOT = orig_root
|
||||
|
||||
# Missing labels should trigger the < 2 samples drop warning
|
||||
assert "WARNING: Dropping class" in output
|
||||
|
||||
# training_info.json must exist with correct keys
|
||||
info_path = tmp_path / "models" / "avocet-deberta-small" / "training_info.json"
|
||||
assert info_path.exists(), "training_info.json not written"
|
||||
|
||||
info = json.loads(info_path.read_text())
|
||||
for key in ("name", "base_model_id", "timestamp", "epochs_run",
|
||||
"val_macro_f1", "val_accuracy", "sample_count",
|
||||
"label_counts", "score_files"):
|
||||
assert key in info, f"Missing key: {key}"
|
||||
|
||||
assert info["name"] == "avocet-deberta-small"
|
||||
assert info["epochs_run"] == 1
|
||||
assert isinstance(info["score_files"], list)
|
||||
|
|
|
|||
Loading…
Reference in a new issue