From 64fd19a7b6cbc76f651c0f64d391efb955931d4b Mon Sep 17 00:00:00 2001 From: pyr0ball Date: Sun, 15 Mar 2026 16:02:43 -0700 Subject: [PATCH] fix(avocet): move TorchDataset import to top; split sample_count into total+train --- scripts/finetune_classifier.py | 6 +++--- tests/test_finetune.py | 2 +- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/scripts/finetune_classifier.py b/scripts/finetune_classifier.py index b3d4fe9..9bd832e 100644 --- a/scripts/finetune_classifier.py +++ b/scripts/finetune_classifier.py @@ -20,6 +20,7 @@ from typing import Any import torch import torch.nn.functional as F +from torch.utils.data import Dataset as TorchDataset from sklearn.model_selection import train_test_split from sklearn.metrics import f1_score, accuracy_score from transformers import ( @@ -200,8 +201,6 @@ class WeightedTrainer(Trainer): # Training dataset wrapper # --------------------------------------------------------------------------- -from torch.utils.data import Dataset as TorchDataset - class _EmailDataset(TorchDataset): def __init__(self, encodings: dict, label_ids: list[int]) -> None: @@ -366,7 +365,8 @@ def run_finetune(model_key: str, epochs: int = 5, score_files: list[Path] | None "epochs_run": epochs, "val_macro_f1": round(val_macro_f1, 4), "val_accuracy": round(val_accuracy, 4), - "sample_count": len(train_texts), + "sample_count": len(texts), + "train_sample_count": len(train_texts), "label_counts": label_counts, "score_files": [str(f) for f in score_files], } diff --git a/tests/test_finetune.py b/tests/test_finetune.py index 603a582..29c59d6 100644 --- a/tests/test_finetune.py +++ b/tests/test_finetune.py @@ -362,7 +362,7 @@ def test_integration_finetune_on_example_data(tmp_path): info = json.loads(info_path.read_text()) for key in ("name", "base_model_id", "timestamp", "epochs_run", - "val_macro_f1", "val_accuracy", "sample_count", + "val_macro_f1", "val_accuracy", "sample_count", "train_sample_count", "label_counts", "score_files"): assert key in info, f"Missing key: {key}"