fix(avocet): move TorchDataset import to top; split sample_count into total+train
This commit is contained in:
parent
8ba34bb2d1
commit
64fd19a7b6
2 changed files with 4 additions and 4 deletions
|
|
@ -20,6 +20,7 @@ from typing import Any
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
|
from torch.utils.data import Dataset as TorchDataset
|
||||||
from sklearn.model_selection import train_test_split
|
from sklearn.model_selection import train_test_split
|
||||||
from sklearn.metrics import f1_score, accuracy_score
|
from sklearn.metrics import f1_score, accuracy_score
|
||||||
from transformers import (
|
from transformers import (
|
||||||
|
|
@ -200,8 +201,6 @@ class WeightedTrainer(Trainer):
|
||||||
# Training dataset wrapper
|
# Training dataset wrapper
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
from torch.utils.data import Dataset as TorchDataset
|
|
||||||
|
|
||||||
|
|
||||||
class _EmailDataset(TorchDataset):
|
class _EmailDataset(TorchDataset):
|
||||||
def __init__(self, encodings: dict, label_ids: list[int]) -> None:
|
def __init__(self, encodings: dict, label_ids: list[int]) -> None:
|
||||||
|
|
@ -366,7 +365,8 @@ def run_finetune(model_key: str, epochs: int = 5, score_files: list[Path] | None
|
||||||
"epochs_run": epochs,
|
"epochs_run": epochs,
|
||||||
"val_macro_f1": round(val_macro_f1, 4),
|
"val_macro_f1": round(val_macro_f1, 4),
|
||||||
"val_accuracy": round(val_accuracy, 4),
|
"val_accuracy": round(val_accuracy, 4),
|
||||||
"sample_count": len(train_texts),
|
"sample_count": len(texts),
|
||||||
|
"train_sample_count": len(train_texts),
|
||||||
"label_counts": label_counts,
|
"label_counts": label_counts,
|
||||||
"score_files": [str(f) for f in score_files],
|
"score_files": [str(f) for f in score_files],
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -362,7 +362,7 @@ def test_integration_finetune_on_example_data(tmp_path):
|
||||||
|
|
||||||
info = json.loads(info_path.read_text())
|
info = json.loads(info_path.read_text())
|
||||||
for key in ("name", "base_model_id", "timestamp", "epochs_run",
|
for key in ("name", "base_model_id", "timestamp", "epochs_run",
|
||||||
"val_macro_f1", "val_accuracy", "sample_count",
|
"val_macro_f1", "val_accuracy", "sample_count", "train_sample_count",
|
||||||
"label_counts", "score_files"):
|
"label_counts", "score_files"):
|
||||||
assert key in info, f"Missing key: {key}"
|
assert key in info, f"Missing key: {key}"
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue