fix(avocet): move TorchDataset import to top; split sample_count into total+train
This commit is contained in:
parent
8ba34bb2d1
commit
64fd19a7b6
2 changed files with 4 additions and 4 deletions
|
|
@ -20,6 +20,7 @@ from typing import Any
|
|||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch.utils.data import Dataset as TorchDataset
|
||||
from sklearn.model_selection import train_test_split
|
||||
from sklearn.metrics import f1_score, accuracy_score
|
||||
from transformers import (
|
||||
|
|
@ -200,8 +201,6 @@ class WeightedTrainer(Trainer):
|
|||
# Training dataset wrapper
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
from torch.utils.data import Dataset as TorchDataset
|
||||
|
||||
|
||||
class _EmailDataset(TorchDataset):
|
||||
def __init__(self, encodings: dict, label_ids: list[int]) -> None:
|
||||
|
|
@ -366,7 +365,8 @@ def run_finetune(model_key: str, epochs: int = 5, score_files: list[Path] | None
|
|||
"epochs_run": epochs,
|
||||
"val_macro_f1": round(val_macro_f1, 4),
|
||||
"val_accuracy": round(val_accuracy, 4),
|
||||
"sample_count": len(train_texts),
|
||||
"sample_count": len(texts),
|
||||
"train_sample_count": len(train_texts),
|
||||
"label_counts": label_counts,
|
||||
"score_files": [str(f) for f in score_files],
|
||||
}
|
||||
|
|
|
|||
|
|
@ -362,7 +362,7 @@ def test_integration_finetune_on_example_data(tmp_path):
|
|||
|
||||
info = json.loads(info_path.read_text())
|
||||
for key in ("name", "base_model_id", "timestamp", "epochs_run",
|
||||
"val_macro_f1", "val_accuracy", "sample_count",
|
||||
"val_macro_f1", "val_accuracy", "sample_count", "train_sample_count",
|
||||
"label_counts", "score_files"):
|
||||
assert key in info, f"Missing key: {key}"
|
||||
|
||||
|
|
|
|||
Loading…
Reference in a new issue