Loading from last checkpoint functionality in Trainer.train (#10334)

Enhance resume_from_checkpoint argument of Trainer.train to accept
bool type. If True given, last saved checkpoint in self.args.output_dir
will be loaded. (#10280)
This commit is contained in:
Tanmay Garg 2021-02-23 02:03:00 +05:30 коммит произвёл GitHub
Родитель eab0afc19c
Коммит 94d8767ba3
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: 4AEE18F83AFDEB23
1 изменённых файлов: 12 добавлений и 4 удалений

Просмотреть файл

@ -97,6 +97,7 @@ from .trainer_utils import (
TrainOutput,
default_compute_objective,
default_hp_space,
get_last_checkpoint,
set_seed,
speed_metrics,
)
@ -758,7 +759,7 @@ class Trainer:
def train(
self,
resume_from_checkpoint: Optional[str] = None,
resume_from_checkpoint: Optional[Union[str, bool]] = None,
trial: Union["optuna.Trial", Dict[str, Any]] = None,
**kwargs,
):
@ -766,9 +767,11 @@ class Trainer:
Main training entry point.
Args:
resume_from_checkpoint (:obj:`str`, `optional`):
Local path to a saved checkpoint as saved by a previous instance of :class:`~transformers.Trainer`. If
present, training will resume from the model/optimizer/scheduler states loaded here.
resume_from_checkpoint (:obj:`str` or :obj:`bool`, `optional`):
If a :obj:`str`, local path to a saved checkpoint as saved by a previous instance of
:class:`~transformers.Trainer`. If a :obj:`bool` and equals `True`, load the last checkpoint in
`args.output_dir` as saved by a previous instance of :class:`~transformers.Trainer`. If present,
training will resume from the model/optimizer/scheduler states loaded here.
trial (:obj:`optuna.Trial` or :obj:`Dict[str, Any]`, `optional`):
The trial run or the hyperparameter dictionary for hyperparameter search.
kwargs:
@ -803,6 +806,11 @@ class Trainer:
self.optimizer, self.lr_scheduler = None, None
# Load potential model checkpoint
if isinstance(resume_from_checkpoint, bool) and resume_from_checkpoint:
resume_from_checkpoint = get_last_checkpoint(self.args.output_dir)
if resume_from_checkpoint is None:
raise ValueError(f"No valid checkpoint found in output directory ({self.args.output_dir})")
if resume_from_checkpoint is not None and os.path.isfile(os.path.join(resume_from_checkpoint, WEIGHTS_NAME)):
logger.info(f"Loading model from {resume_from_checkpoint}).")
if isinstance(self.model, PreTrainedModel):