fix(avocet): move TorchDataset import to top; split sample_count into total+train

This commit is contained in:
pyr0ball 2026-03-15 16:02:43 -07:00
parent 8ba34bb2d1
commit 64fd19a7b6
2 changed files with 4 additions and 4 deletions

View file

@ -20,6 +20,7 @@ from typing import Any
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
from torch.utils.data import Dataset as TorchDataset
from sklearn.model_selection import train_test_split from sklearn.model_selection import train_test_split
from sklearn.metrics import f1_score, accuracy_score from sklearn.metrics import f1_score, accuracy_score
from transformers import ( from transformers import (
@ -200,8 +201,6 @@ class WeightedTrainer(Trainer):
# Training dataset wrapper # Training dataset wrapper
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
from torch.utils.data import Dataset as TorchDataset
class _EmailDataset(TorchDataset): class _EmailDataset(TorchDataset):
def __init__(self, encodings: dict, label_ids: list[int]) -> None: def __init__(self, encodings: dict, label_ids: list[int]) -> None:
@ -366,7 +365,8 @@ def run_finetune(model_key: str, epochs: int = 5, score_files: list[Path] | None
"epochs_run": epochs, "epochs_run": epochs,
"val_macro_f1": round(val_macro_f1, 4), "val_macro_f1": round(val_macro_f1, 4),
"val_accuracy": round(val_accuracy, 4), "val_accuracy": round(val_accuracy, 4),
"sample_count": len(train_texts), "sample_count": len(texts),
"train_sample_count": len(train_texts),
"label_counts": label_counts, "label_counts": label_counts,
"score_files": [str(f) for f in score_files], "score_files": [str(f) for f in score_files],
} }

View file

@ -362,7 +362,7 @@ def test_integration_finetune_on_example_data(tmp_path):
info = json.loads(info_path.read_text()) info = json.loads(info_path.read_text())
for key in ("name", "base_model_id", "timestamp", "epochs_run", for key in ("name", "base_model_id", "timestamp", "epochs_run",
"val_macro_f1", "val_accuracy", "sample_count", "val_macro_f1", "val_accuracy", "sample_count", "train_sample_count",
"label_counts", "score_files"): "label_counts", "score_files"):
assert key in info, f"Missing key: {key}" assert key in info, f"Missing key: {key}"