diff --git a/scripts/finetune_classifier.py b/scripts/finetune_classifier.py index b3d4fe9..9bd832e 100644 --- a/scripts/finetune_classifier.py +++ b/scripts/finetune_classifier.py @@ -20,6 +20,7 @@ from typing import Any import torch import torch.nn.functional as F +from torch.utils.data import Dataset as TorchDataset from sklearn.model_selection import train_test_split from sklearn.metrics import f1_score, accuracy_score from transformers import ( @@ -200,8 +201,6 @@ class WeightedTrainer(Trainer): # Training dataset wrapper # --------------------------------------------------------------------------- -from torch.utils.data import Dataset as TorchDataset - class _EmailDataset(TorchDataset): def __init__(self, encodings: dict, label_ids: list[int]) -> None: @@ -366,7 +365,8 @@ def run_finetune(model_key: str, epochs: int = 5, score_files: list[Path] | None "epochs_run": epochs, "val_macro_f1": round(val_macro_f1, 4), "val_accuracy": round(val_accuracy, 4), - "sample_count": len(train_texts), + "sample_count": len(texts), + "train_sample_count": len(train_texts), "label_counts": label_counts, "score_files": [str(f) for f in score_files], } diff --git a/tests/test_finetune.py b/tests/test_finetune.py index 603a582..29c59d6 100644 --- a/tests/test_finetune.py +++ b/tests/test_finetune.py @@ -362,7 +362,7 @@ def test_integration_finetune_on_example_data(tmp_path): info = json.loads(info_path.read_text()) for key in ("name", "base_model_id", "timestamp", "epochs_run", - "val_macro_f1", "val_accuracy", "sample_count", + "val_macro_f1", "val_accuracy", "sample_count", "train_sample_count", "label_counts", "score_files"): assert key in info, f"Missing key: {key}"