Add early stopping callback to pytorch trainer (#8581)
* Add early stopping patience and minimum threshold metric must improve to prevent early stopping to pytorch trainer * Add early stopping test * Set patience counter to 0 if best metric not defined yet * Make early stopping a callback. Add callback event for updating the best metric for early stopping callback to trigger on. * Run make style * make funciton name sensible * Improve new argument docstring wording and hope that flakey CI test passes. * Use on_evaluation callback instead of custom. Remove some debug printing * Move early stopping arguments and state into early stopping callback * Run make style * Remove old code * Fix docs formatting. make style went rogue on me. * Remove copied attributes and fix variable * Add assertions on training arguments instead of mutating them. Move comment out of public docs. * Make separate test for early stopping callback. Add test of invalid arguments. * Run make style... I remembered before CI this time! * appease flake8 * Add EarlyStoppingCallback to callback docs * Make docstring EarlyStoppingCallabck match other callbacks. * Fix typo in docs
This commit is contained in:
Родитель
367f497dec
Коммит
8ffc01a76a
|
@ -44,6 +44,8 @@ Here is the list of the available :class:`~transformers.TrainerCallback` in the
|
|||
|
||||
.. autoclass:: transformers.ProgressCallback
|
||||
|
||||
.. autoclass:: transformers.EarlyStoppingCallback
|
||||
|
||||
.. autoclass:: transformers.integrations.TensorBoardCallback
|
||||
|
||||
.. autoclass:: transformers.integrations.WandbCallback
|
||||
|
|
|
@ -253,6 +253,7 @@ else:
|
|||
# Trainer
|
||||
from .trainer_callback import (
|
||||
DefaultFlowCallback,
|
||||
EarlyStoppingCallback,
|
||||
PrinterCallback,
|
||||
ProgressCallback,
|
||||
TrainerCallback,
|
||||
|
|
|
@ -21,6 +21,7 @@ import json
|
|||
from dataclasses import dataclass
|
||||
from typing import Dict, List, Optional, Union
|
||||
|
||||
import numpy as np
|
||||
from tqdm.auto import tqdm
|
||||
|
||||
from .trainer_utils import EvaluationStrategy
|
||||
|
@ -475,3 +476,62 @@ class PrinterCallback(TrainerCallback):
|
|||
_ = logs.pop("total_flos", None)
|
||||
if state.is_local_process_zero:
|
||||
print(logs)
|
||||
|
||||
|
||||
class EarlyStoppingCallback(TrainerCallback):
|
||||
"""
|
||||
A :class:`~transformers.TrainerCallback` that handles early stopping.
|
||||
|
||||
Args:
|
||||
early_stopping_patience (:obj:`int`):
|
||||
Use with :obj:`metric_for_best_model` to stop training when the specified metric worsens for
|
||||
:obj:`early_stopping_patience` evaluation calls.
|
||||
early_stopping_threshold(:obj:`float`, `optional`):
|
||||
Use with TrainingArguments :obj:`metric_for_best_model` and :obj:`early_stopping_patience` to denote how
|
||||
much the specified metric must improve to satisfy early stopping conditions. `
|
||||
|
||||
This callback depends on :class:`~transformers.TrainingArguments` argument `load_best_model_at_end` functionality
|
||||
to set best_metric in :class:`~transformers.TrainerState`.
|
||||
"""
|
||||
|
||||
def __init__(self, early_stopping_patience: int = 1, early_stopping_threshold: Optional[float] = 0.0):
|
||||
self.early_stopping_patience = early_stopping_patience
|
||||
self.early_stopping_threshold = early_stopping_threshold
|
||||
# early_stopping_patience_counter denotes the number of times validation metrics failed to improve.
|
||||
self.early_stopping_patience_counter = 0
|
||||
|
||||
def check_metric_value(self, args, state, control, metric_value):
|
||||
# best_metric is set by code for load_best_model
|
||||
operator = np.greater if args.greater_is_better else np.less
|
||||
if state.best_metric is None or (
|
||||
operator(metric_value, state.best_metric)
|
||||
and abs(metric_value - state.best_metric) > self.early_stopping_threshold
|
||||
):
|
||||
self.early_stopping_patience_counter = 0
|
||||
else:
|
||||
self.early_stopping_patience_counter += 1
|
||||
|
||||
def on_train_begin(self, args, state, control, **kwargs):
|
||||
assert args.load_best_model_at_end, "EarlyStoppingCallback requires load_best_model_at_end = True"
|
||||
assert (
|
||||
args.metric_for_best_model is not None
|
||||
), "EarlyStoppingCallback requires metric_for_best_model is defined"
|
||||
assert (
|
||||
args.evaluation_strategy != EvaluationStrategy.NO
|
||||
), "EarlyStoppingCallback requires EvaluationStrategy of steps or epoch"
|
||||
|
||||
def on_evaluate(self, args, state, control, metrics, **kwargs):
|
||||
metric_to_check = args.metric_for_best_model
|
||||
if not metric_to_check.startswith("eval_"):
|
||||
metric_to_check = f"eval_{metric_to_check}"
|
||||
metric_value = metrics.get(metric_to_check)
|
||||
|
||||
if metric_value is None:
|
||||
logger.warning(
|
||||
f"early stopping required metric_for_best_model, but did not find {metric_to_check} so early stopping is disabled"
|
||||
)
|
||||
return
|
||||
|
||||
self.check_metric_value(args, state, control, metric_value)
|
||||
if self.early_stopping_patience_counter >= self.early_stopping_patience:
|
||||
control.should_training_stop = True
|
||||
|
|
|
@ -42,6 +42,7 @@ if is_torch_available():
|
|||
AutoModelForMaskedLM,
|
||||
AutoModelForSequenceClassification,
|
||||
DataCollatorForLanguageModeling,
|
||||
EarlyStoppingCallback,
|
||||
GlueDataset,
|
||||
GlueDataTrainingArguments,
|
||||
GPT2Config,
|
||||
|
@ -765,6 +766,37 @@ class TrainerIntegrationTest(unittest.TestCase):
|
|||
train_output = trainer.train()
|
||||
self.assertEqual(train_output.global_step, int(self.n_epochs))
|
||||
|
||||
def test_early_stopping_callback(self):
|
||||
# early stopping stops training before num_training_epochs
|
||||
trainer = get_regression_trainer(
|
||||
num_train_epochs=20,
|
||||
gradient_accumulation_steps=1,
|
||||
per_device_train_batch_size=16,
|
||||
load_best_model_at_end=True,
|
||||
evaluation_strategy=EvaluationStrategy.EPOCH,
|
||||
compute_metrics=AlmostAccuracy(),
|
||||
metric_for_best_model="accuracy",
|
||||
)
|
||||
trainer.add_callback(EarlyStoppingCallback(1, 0.0001))
|
||||
train_output = trainer.train()
|
||||
self.assertLess(train_output.global_step, 20 * 64 / 16)
|
||||
|
||||
# Invalid inputs to trainer with early stopping callback result in assertion error
|
||||
trainer = get_regression_trainer(
|
||||
num_train_epochs=20,
|
||||
gradient_accumulation_steps=1,
|
||||
per_device_train_batch_size=16,
|
||||
evaluation_strategy=EvaluationStrategy.EPOCH,
|
||||
compute_metrics=AlmostAccuracy(),
|
||||
metric_for_best_model="accuracy",
|
||||
)
|
||||
trainer.add_callback(EarlyStoppingCallback(1))
|
||||
self.assertEqual(trainer.state.global_step, 0)
|
||||
try:
|
||||
trainer.train()
|
||||
except AssertionError:
|
||||
self.assertEqual(trainer.state.global_step, 0)
|
||||
|
||||
def test_flos_extraction(self):
|
||||
trainer = get_regression_trainer(learning_rate=0.1)
|
||||
|
||||
|
|
Загрузка…
Ссылка в новой задаче