diff --git a/scripts/finetune_classifier.py b/scripts/finetune_classifier.py index c70929e..e936466 100644 --- a/scripts/finetune_classifier.py +++ b/scripts/finetune_classifier.py @@ -310,7 +310,12 @@ def run_finetune(model_key: str, epochs: int = 5, score_files: list[Path] | None label2id=label2id, ) if config["gradient_checkpointing"]: - model.gradient_checkpointing_enable() + # use_reentrant=False avoids "backward through graph a second time" errors + # when Accelerate's gradient accumulation context is layered on top. + # Reentrant checkpointing (the default) conflicts with Accelerate ≥ 0.27. + model.gradient_checkpointing_enable( + gradient_checkpointing_kwargs={"use_reentrant": False} + ) # --- TrainingArguments --- training_args = TrainingArguments(