Loading from last checkpoint functionality in Trainer.train (#10334)
Enhance resume_from_checkpoint argument of Trainer.train to accept bool type. If True given, last saved checkpoint in self.args.output_dir will be loaded. (#10280)
This commit is contained in:
Родитель
eab0afc19c
Коммит
94d8767ba3
|
@ -97,6 +97,7 @@ from .trainer_utils import (
|
|||
TrainOutput,
|
||||
default_compute_objective,
|
||||
default_hp_space,
|
||||
get_last_checkpoint,
|
||||
set_seed,
|
||||
speed_metrics,
|
||||
)
|
||||
|
@ -758,7 +759,7 @@ class Trainer:
|
|||
|
||||
def train(
|
||||
self,
|
||||
resume_from_checkpoint: Optional[str] = None,
|
||||
resume_from_checkpoint: Optional[Union[str, bool]] = None,
|
||||
trial: Union["optuna.Trial", Dict[str, Any]] = None,
|
||||
**kwargs,
|
||||
):
|
||||
|
@ -766,9 +767,11 @@ class Trainer:
|
|||
Main training entry point.
|
||||
|
||||
Args:
|
||||
resume_from_checkpoint (:obj:`str`, `optional`):
|
||||
Local path to a saved checkpoint as saved by a previous instance of :class:`~transformers.Trainer`. If
|
||||
present, training will resume from the model/optimizer/scheduler states loaded here.
|
||||
resume_from_checkpoint (:obj:`str` or :obj:`bool`, `optional`):
|
||||
If a :obj:`str`, local path to a saved checkpoint as saved by a previous instance of
|
||||
:class:`~transformers.Trainer`. If a :obj:`bool` and equals `True`, load the last checkpoint in
|
||||
`args.output_dir` as saved by a previous instance of :class:`~transformers.Trainer`. If present,
|
||||
training will resume from the model/optimizer/scheduler states loaded here.
|
||||
trial (:obj:`optuna.Trial` or :obj:`Dict[str, Any]`, `optional`):
|
||||
The trial run or the hyperparameter dictionary for hyperparameter search.
|
||||
kwargs:
|
||||
|
@ -803,6 +806,11 @@ class Trainer:
|
|||
self.optimizer, self.lr_scheduler = None, None
|
||||
|
||||
# Load potential model checkpoint
|
||||
if isinstance(resume_from_checkpoint, bool) and resume_from_checkpoint:
|
||||
resume_from_checkpoint = get_last_checkpoint(self.args.output_dir)
|
||||
if resume_from_checkpoint is None:
|
||||
raise ValueError(f"No valid checkpoint found in output directory ({self.args.output_dir})")
|
||||
|
||||
if resume_from_checkpoint is not None and os.path.isfile(os.path.join(resume_from_checkpoint, WEIGHTS_NAME)):
|
||||
logger.info(f"Loading model from {resume_from_checkpoint}).")
|
||||
if isinstance(self.model, PreTrainedModel):
|
||||
|
|
Загрузка…
Ссылка в новой задаче