Add early stopping callback to pytorch trainer (#8581)

* Add early stopping patience and minimum threshold metric must improve to prevent early stopping to pytorch trainer

* Add early stopping test

* Set patience counter to 0 if best metric not defined yet

* Make early stopping a callback. Add callback event for updating the best metric for early stopping callback to trigger on.

* Run make style

* make funciton name sensible

* Improve new argument docstring wording and hope that flakey CI test passes.

* Use on_evaluation callback instead of custom. Remove some debug printing

* Move early stopping arguments and state into early stopping callback

* Run make style

* Remove old code

* Fix docs formatting. make style went rogue on me.

* Remove copied attributes and fix variable

* Add assertions on training arguments instead of mutating them. Move comment out of public docs.

* Make separate test for early stopping callback. Add test of invalid arguments.

* Run make style... I remembered before CI this time!

* appease flake8

* Add EarlyStoppingCallback to callback docs

* Make docstring EarlyStoppingCallabck match other callbacks.

* Fix typo in docs
This commit is contained in:
Colin Brochtrup 2020-11-23 17:25:35 -05:00 коммит произвёл GitHub
Родитель 367f497dec
Коммит 8ffc01a76a
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: 4AEE18F83AFDEB23
4 изменённых файлов: 95 добавлений и 0 удалений

Просмотреть файл

@ -44,6 +44,8 @@ Here is the list of the available :class:`~transformers.TrainerCallback` in the
.. autoclass:: transformers.ProgressCallback
.. autoclass:: transformers.EarlyStoppingCallback
.. autoclass:: transformers.integrations.TensorBoardCallback
.. autoclass:: transformers.integrations.WandbCallback

Просмотреть файл

@ -253,6 +253,7 @@ else:
# Trainer
from .trainer_callback import (
DefaultFlowCallback,
EarlyStoppingCallback,
PrinterCallback,
ProgressCallback,
TrainerCallback,

Просмотреть файл

@ -21,6 +21,7 @@ import json
from dataclasses import dataclass
from typing import Dict, List, Optional, Union
import numpy as np
from tqdm.auto import tqdm
from .trainer_utils import EvaluationStrategy
@ -475,3 +476,62 @@ class PrinterCallback(TrainerCallback):
_ = logs.pop("total_flos", None)
if state.is_local_process_zero:
print(logs)
class EarlyStoppingCallback(TrainerCallback):
"""
A :class:`~transformers.TrainerCallback` that handles early stopping.
Args:
early_stopping_patience (:obj:`int`):
Use with :obj:`metric_for_best_model` to stop training when the specified metric worsens for
:obj:`early_stopping_patience` evaluation calls.
early_stopping_threshold(:obj:`float`, `optional`):
Use with TrainingArguments :obj:`metric_for_best_model` and :obj:`early_stopping_patience` to denote how
much the specified metric must improve to satisfy early stopping conditions. `
This callback depends on :class:`~transformers.TrainingArguments` argument `load_best_model_at_end` functionality
to set best_metric in :class:`~transformers.TrainerState`.
"""
def __init__(self, early_stopping_patience: int = 1, early_stopping_threshold: Optional[float] = 0.0):
self.early_stopping_patience = early_stopping_patience
self.early_stopping_threshold = early_stopping_threshold
# early_stopping_patience_counter denotes the number of times validation metrics failed to improve.
self.early_stopping_patience_counter = 0
def check_metric_value(self, args, state, control, metric_value):
# best_metric is set by code for load_best_model
operator = np.greater if args.greater_is_better else np.less
if state.best_metric is None or (
operator(metric_value, state.best_metric)
and abs(metric_value - state.best_metric) > self.early_stopping_threshold
):
self.early_stopping_patience_counter = 0
else:
self.early_stopping_patience_counter += 1
def on_train_begin(self, args, state, control, **kwargs):
assert args.load_best_model_at_end, "EarlyStoppingCallback requires load_best_model_at_end = True"
assert (
args.metric_for_best_model is not None
), "EarlyStoppingCallback requires metric_for_best_model is defined"
assert (
args.evaluation_strategy != EvaluationStrategy.NO
), "EarlyStoppingCallback requires EvaluationStrategy of steps or epoch"
def on_evaluate(self, args, state, control, metrics, **kwargs):
metric_to_check = args.metric_for_best_model
if not metric_to_check.startswith("eval_"):
metric_to_check = f"eval_{metric_to_check}"
metric_value = metrics.get(metric_to_check)
if metric_value is None:
logger.warning(
f"early stopping required metric_for_best_model, but did not find {metric_to_check} so early stopping is disabled"
)
return
self.check_metric_value(args, state, control, metric_value)
if self.early_stopping_patience_counter >= self.early_stopping_patience:
control.should_training_stop = True

Просмотреть файл

@ -42,6 +42,7 @@ if is_torch_available():
AutoModelForMaskedLM,
AutoModelForSequenceClassification,
DataCollatorForLanguageModeling,
EarlyStoppingCallback,
GlueDataset,
GlueDataTrainingArguments,
GPT2Config,
@ -765,6 +766,37 @@ class TrainerIntegrationTest(unittest.TestCase):
train_output = trainer.train()
self.assertEqual(train_output.global_step, int(self.n_epochs))
def test_early_stopping_callback(self):
# early stopping stops training before num_training_epochs
trainer = get_regression_trainer(
num_train_epochs=20,
gradient_accumulation_steps=1,
per_device_train_batch_size=16,
load_best_model_at_end=True,
evaluation_strategy=EvaluationStrategy.EPOCH,
compute_metrics=AlmostAccuracy(),
metric_for_best_model="accuracy",
)
trainer.add_callback(EarlyStoppingCallback(1, 0.0001))
train_output = trainer.train()
self.assertLess(train_output.global_step, 20 * 64 / 16)
# Invalid inputs to trainer with early stopping callback result in assertion error
trainer = get_regression_trainer(
num_train_epochs=20,
gradient_accumulation_steps=1,
per_device_train_batch_size=16,
evaluation_strategy=EvaluationStrategy.EPOCH,
compute_metrics=AlmostAccuracy(),
metric_for_best_model="accuracy",
)
trainer.add_callback(EarlyStoppingCallback(1))
self.assertEqual(trainer.state.global_step, 0)
try:
trainer.train()
except AssertionError:
self.assertEqual(trainer.state.global_step, 0)
def test_flos_extraction(self):
trainer = get_regression_trainer(learning_rate=0.1)