Fix examples titles and optimization doc page (#5408)
This commit is contained in:
Родитель
d60d231ea4
Коммит
4ade7491f4
|
@ -1,4 +1,4 @@
|
||||||
Optimizer
|
Optimization
|
||||||
----------------------------------------------------
|
----------------------------------------------------
|
||||||
|
|
||||||
The ``.optimization`` module provides:
|
The ``.optimization`` module provides:
|
||||||
|
@ -7,24 +7,25 @@ The ``.optimization`` module provides:
|
||||||
- several schedules in the form of schedule objects that inherit from ``_LRSchedule``:
|
- several schedules in the form of schedule objects that inherit from ``_LRSchedule``:
|
||||||
- a gradient accumulation class to accumulate the gradients of multiple batches
|
- a gradient accumulation class to accumulate the gradients of multiple batches
|
||||||
|
|
||||||
``AdamW``
|
``AdamW`` (PyTorch)
|
||||||
~~~~~~~~~~~~~~~~
|
~~~~~~~~~~~~~~~~~~~
|
||||||
|
|
||||||
.. autoclass:: transformers.AdamW
|
.. autoclass:: transformers.AdamW
|
||||||
:members:
|
:members:
|
||||||
|
|
||||||
``AdamWeightDecay``
|
``AdamWeightDecay`` (TensorFlow)
|
||||||
~~~~~~~~~~~~~~~~~~~
|
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||||
|
|
||||||
.. autoclass:: transformers.AdamWeightDecay
|
.. autoclass:: transformers.AdamWeightDecay
|
||||||
|
|
||||||
.. autofunction:: transformers.create_optimizer
|
.. autofunction:: transformers.create_optimizer
|
||||||
|
|
||||||
Schedules
|
Schedules
|
||||||
----------------------------------------------------
|
~~~~~~~~~~~~~~~~~~~
|
||||||
|
|
||||||
|
Learning Rate Schedules (Pytorch)
|
||||||
|
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
||||||
|
|
||||||
Learning Rate Schedules
|
|
||||||
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
|
||||||
.. autofunction:: transformers.get_constant_schedule
|
.. autofunction:: transformers.get_constant_schedule
|
||||||
|
|
||||||
|
|
||||||
|
@ -56,16 +57,16 @@ Learning Rate Schedules
|
||||||
:target: /imgs/warmup_linear_schedule.png
|
:target: /imgs/warmup_linear_schedule.png
|
||||||
:alt:
|
:alt:
|
||||||
|
|
||||||
``Warmup``
|
``Warmup`` (TensorFlow)
|
||||||
~~~~~~~~~~~~~~~~
|
^^^^^^^^^^^^^^^^^^^^^^^
|
||||||
|
|
||||||
.. autoclass:: transformers.WarmUp
|
.. autoclass:: transformers.WarmUp
|
||||||
:members:
|
:members:
|
||||||
|
|
||||||
Gradient Strategies
|
Gradient Strategies
|
||||||
----------------------------------------------------
|
~~~~~~~~~~~~~~~~~~~~
|
||||||
|
|
||||||
``GradientAccumulator``
|
``GradientAccumulator`` (TensorFlow)
|
||||||
~~~~~~~~~~~~~~~~~~~~~~~
|
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
||||||
|
|
||||||
.. autoclass:: transformers.GradientAccumulator
|
.. autoclass:: transformers.GradientAccumulator
|
||||||
|
|
|
@ -1,4 +1,4 @@
|
||||||
## Examples
|
# Examples
|
||||||
|
|
||||||
Version 2.9 of 🤗 Transformers introduces a new [`Trainer`](https://github.com/huggingface/transformers/blob/master/src/transformers/trainer.py) class for PyTorch, and its equivalent [`TFTrainer`](https://github.com/huggingface/transformers/blob/master/src/transformers/trainer_tf.py) for TF 2.
|
Version 2.9 of 🤗 Transformers introduces a new [`Trainer`](https://github.com/huggingface/transformers/blob/master/src/transformers/trainer.py) class for PyTorch, and its equivalent [`TFTrainer`](https://github.com/huggingface/transformers/blob/master/src/transformers/trainer_tf.py) for TF 2.
|
||||||
Running the examples requires PyTorch 1.3.1+ or TensorFlow 2.1+.
|
Running the examples requires PyTorch 1.3.1+ or TensorFlow 2.1+.
|
||||||
|
@ -13,7 +13,7 @@ Here is the list of all our examples:
|
||||||
This is still a work-in-progress – in particular documentation is still sparse – so please **contribute improvements/pull requests.**
|
This is still a work-in-progress – in particular documentation is still sparse – so please **contribute improvements/pull requests.**
|
||||||
|
|
||||||
|
|
||||||
# The Big Table of Tasks
|
## The Big Table of Tasks
|
||||||
|
|
||||||
| Task | Example datasets | Trainer support | TFTrainer support | pytorch-lightning | Colab
|
| Task | Example datasets | Trainer support | TFTrainer support | pytorch-lightning | Colab
|
||||||
|---|---|:---:|:---:|:---:|:---:|
|
|---|---|:---:|:---:|:---:|:---:|
|
||||||
|
|
|
@ -16,6 +16,7 @@
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
import math
|
import math
|
||||||
|
from typing import Callable, Iterable, Tuple
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch.optim import Optimizer
|
from torch.optim import Optimizer
|
||||||
|
@ -25,18 +26,40 @@ from torch.optim.lr_scheduler import LambdaLR
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
def get_constant_schedule(optimizer, last_epoch=-1):
|
def get_constant_schedule(optimizer: Optimizer, last_epoch: int = -1):
|
||||||
""" Create a schedule with a constant learning rate.
|
"""
|
||||||
|
Create a schedule with a constant learning rate, using the learning rate set in optimizer.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
optimizer (:class:`~torch.optim.Optimizer`):
|
||||||
|
The optimizer for which to schedule the learning rate.
|
||||||
|
last_epoch (:obj:`int`, `optional`, defaults to -1):
|
||||||
|
The index of the last epoch when resuming training.
|
||||||
|
|
||||||
|
Return:
|
||||||
|
:obj:`torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule.
|
||||||
"""
|
"""
|
||||||
return LambdaLR(optimizer, lambda _: 1, last_epoch=last_epoch)
|
return LambdaLR(optimizer, lambda _: 1, last_epoch=last_epoch)
|
||||||
|
|
||||||
|
|
||||||
def get_constant_schedule_with_warmup(optimizer, num_warmup_steps, last_epoch=-1):
|
def get_constant_schedule_with_warmup(optimizer: Optimizer, num_warmup_steps: int, last_epoch: int = -1):
|
||||||
""" Create a schedule with a constant learning rate preceded by a warmup
|
"""
|
||||||
period during which the learning rate increases linearly between 0 and 1.
|
Create a schedule with a constant learning rate preceded by a warmup period during which the learning rate
|
||||||
|
increases linearly between 0 and the initial lr set in the optimizer.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
optimizer (:class:`~torch.optim.Optimizer`):
|
||||||
|
The optimizer for which to schedule the learning rate.
|
||||||
|
num_warmup_steps (:obj:`int`):
|
||||||
|
The number of steps for the warmup phase.
|
||||||
|
last_epoch (:obj:`int`, `optional`, defaults to -1):
|
||||||
|
The index of the last epoch when resuming training.
|
||||||
|
|
||||||
|
Return:
|
||||||
|
:obj:`torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def lr_lambda(current_step):
|
def lr_lambda(current_step: int):
|
||||||
if current_step < num_warmup_steps:
|
if current_step < num_warmup_steps:
|
||||||
return float(current_step) / float(max(1.0, num_warmup_steps))
|
return float(current_step) / float(max(1.0, num_warmup_steps))
|
||||||
return 1.0
|
return 1.0
|
||||||
|
@ -45,11 +68,25 @@ def get_constant_schedule_with_warmup(optimizer, num_warmup_steps, last_epoch=-1
|
||||||
|
|
||||||
|
|
||||||
def get_linear_schedule_with_warmup(optimizer, num_warmup_steps, num_training_steps, last_epoch=-1):
|
def get_linear_schedule_with_warmup(optimizer, num_warmup_steps, num_training_steps, last_epoch=-1):
|
||||||
""" Create a schedule with a learning rate that decreases linearly after
|
"""
|
||||||
linearly increasing during a warmup period.
|
Create a schedule with a learning rate that decreases linearly from the initial lr set in the optimizer to 0,
|
||||||
|
after a warmup period during which it increases linearly from 0 to the initial lr set in the optimizer.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
optimizer (:class:`~torch.optim.Optimizer`):
|
||||||
|
The optimizer for which to schedule the learning rate.
|
||||||
|
num_warmup_steps (:obj:`int`):
|
||||||
|
The number of steps for the warmup phase.
|
||||||
|
num_training_steps (:obj:`int`):
|
||||||
|
The totale number of training steps.
|
||||||
|
last_epoch (:obj:`int`, `optional`, defaults to -1):
|
||||||
|
The index of the last epoch when resuming training.
|
||||||
|
|
||||||
|
Return:
|
||||||
|
:obj:`torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def lr_lambda(current_step):
|
def lr_lambda(current_step: int):
|
||||||
if current_step < num_warmup_steps:
|
if current_step < num_warmup_steps:
|
||||||
return float(current_step) / float(max(1, num_warmup_steps))
|
return float(current_step) / float(max(1, num_warmup_steps))
|
||||||
return max(
|
return max(
|
||||||
|
@ -59,10 +96,29 @@ def get_linear_schedule_with_warmup(optimizer, num_warmup_steps, num_training_st
|
||||||
return LambdaLR(optimizer, lr_lambda, last_epoch)
|
return LambdaLR(optimizer, lr_lambda, last_epoch)
|
||||||
|
|
||||||
|
|
||||||
def get_cosine_schedule_with_warmup(optimizer, num_warmup_steps, num_training_steps, num_cycles=0.5, last_epoch=-1):
|
def get_cosine_schedule_with_warmup(
|
||||||
""" Create a schedule with a learning rate that decreases following the
|
optimizer: Optimizer, num_warmup_steps: int, num_training_steps: int, num_cycles: float = 0.5, last_epoch: int = -1
|
||||||
values of the cosine function between 0 and `pi * cycles` after a warmup
|
):
|
||||||
period during which it increases linearly between 0 and 1.
|
"""
|
||||||
|
Create a schedule with a learning rate that decreases following the values of the cosine function between the
|
||||||
|
initial lr set in the optimizer to 0, after a warmup period during which it increases linearly between 0 and the
|
||||||
|
initial lr set in the optimizer.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
optimizer (:class:`~torch.optim.Optimizer`):
|
||||||
|
The optimizer for which to schedule the learning rate.
|
||||||
|
num_warmup_steps (:obj:`int`):
|
||||||
|
The number of steps for the warmup phase.
|
||||||
|
num_training_steps (:obj:`int`):
|
||||||
|
The total number of training steps.
|
||||||
|
num_cycles (:obj:`float`, `optional`, defaults to 0.5):
|
||||||
|
The number of waves in the cosine schedule (the defaults is to just decrease from the max value to 0
|
||||||
|
following a half-cosine).
|
||||||
|
last_epoch (:obj:`int`, `optional`, defaults to -1):
|
||||||
|
The index of the last epoch when resuming training.
|
||||||
|
|
||||||
|
Return:
|
||||||
|
:obj:`torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def lr_lambda(current_step):
|
def lr_lambda(current_step):
|
||||||
|
@ -75,11 +131,27 @@ def get_cosine_schedule_with_warmup(optimizer, num_warmup_steps, num_training_st
|
||||||
|
|
||||||
|
|
||||||
def get_cosine_with_hard_restarts_schedule_with_warmup(
|
def get_cosine_with_hard_restarts_schedule_with_warmup(
|
||||||
optimizer, num_warmup_steps, num_training_steps, num_cycles=1.0, last_epoch=-1
|
optimizer: Optimizer, num_warmup_steps: int, num_training_steps: int, num_cycles: int = 1, last_epoch: int = -1
|
||||||
):
|
):
|
||||||
""" Create a schedule with a learning rate that decreases following the
|
"""
|
||||||
values of the cosine function with several hard restarts, after a warmup
|
Create a schedule with a learning rate that decreases following the values of the cosine function between the
|
||||||
period during which it increases linearly between 0 and 1.
|
initial lr set in the optimizer to 0, with several hard restarts, after a warmup period during which it increases
|
||||||
|
linearly between 0 and the initial lr set in the optimizer.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
optimizer (:class:`~torch.optim.Optimizer`):
|
||||||
|
The optimizer for which to schedule the learning rate.
|
||||||
|
num_warmup_steps (:obj:`int`):
|
||||||
|
The number of steps for the warmup phase.
|
||||||
|
num_training_steps (:obj:`int`):
|
||||||
|
The total number of training steps.
|
||||||
|
num_cycles (:obj:`int`, `optional`, defaults to 1):
|
||||||
|
The number of hard restarts to use.
|
||||||
|
last_epoch (:obj:`int`, `optional`, defaults to -1):
|
||||||
|
The index of the last epoch when resuming training.
|
||||||
|
|
||||||
|
Return:
|
||||||
|
:obj:`torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def lr_lambda(current_step):
|
def lr_lambda(current_step):
|
||||||
|
@ -94,17 +166,34 @@ def get_cosine_with_hard_restarts_schedule_with_warmup(
|
||||||
|
|
||||||
|
|
||||||
class AdamW(Optimizer):
|
class AdamW(Optimizer):
|
||||||
""" Implements Adam algorithm with weight decay fix.
|
"""
|
||||||
|
Implements Adam algorithm with weight decay fix as introduced in
|
||||||
|
`Decoupled Weight Decay Regularization <https://arxiv.org/abs/1711.05101>`__.
|
||||||
|
|
||||||
Parameters:
|
Parameters:
|
||||||
lr (float): learning rate. Default 1e-3.
|
params (:obj:`Iterable[torch.nn.parameter.Parameter]`):
|
||||||
betas (tuple of 2 floats): Adams beta parameters (b1, b2). Default: (0.9, 0.999)
|
Iterable of parameters to optimize or dictionaries defining parameter groups.
|
||||||
eps (float): Adams epsilon. Default: 1e-6
|
lr (:obj:`float`, `optional`, defaults to 1e-3):
|
||||||
weight_decay (float): Weight decay. Default: 0.0
|
The learning rate to use.
|
||||||
correct_bias (bool): can be set to False to avoid correcting bias in Adam (e.g. like in Bert TF repository). Default True.
|
betas (:obj:`Tuple[float,float]`, `optional`, defaults to (0.9, 0.999)):
|
||||||
|
Adam's betas parameters (b1, b2).
|
||||||
|
eps (:obj:`float`, `optional`, defaults to 1e-6):
|
||||||
|
Adam's epsilon for numerical stability.
|
||||||
|
weight_decay (:obj:`float`, `optional`, defaults to 0):
|
||||||
|
Decoupled weight decay to apply.
|
||||||
|
correct_bias (:obj:`bool`, `optional`, defaults to `True`):
|
||||||
|
Whether ot not to correct bias in Adam (for instance, in Bert TF repository they use :obj:`False`).
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-6, weight_decay=0.0, correct_bias=True):
|
def __init__(
|
||||||
|
self,
|
||||||
|
params: Iterable[torch.nn.parameter.Parameter],
|
||||||
|
lr: float = 1e-3,
|
||||||
|
betas: Tuple[float, float] = (0.9, 0.999),
|
||||||
|
eps: float = 1e-6,
|
||||||
|
weight_decay: float = 0.0,
|
||||||
|
correct_bias: bool = True,
|
||||||
|
):
|
||||||
if lr < 0.0:
|
if lr < 0.0:
|
||||||
raise ValueError("Invalid learning rate: {} - should be >= 0.0".format(lr))
|
raise ValueError("Invalid learning rate: {} - should be >= 0.0".format(lr))
|
||||||
if not 0.0 <= betas[0] < 1.0:
|
if not 0.0 <= betas[0] < 1.0:
|
||||||
|
@ -116,12 +205,12 @@ class AdamW(Optimizer):
|
||||||
defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay, correct_bias=correct_bias)
|
defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay, correct_bias=correct_bias)
|
||||||
super().__init__(params, defaults)
|
super().__init__(params, defaults)
|
||||||
|
|
||||||
def step(self, closure=None):
|
def step(self, closure: Callable = None):
|
||||||
"""Performs a single optimization step.
|
"""
|
||||||
|
Performs a single optimization step.
|
||||||
|
|
||||||
Arguments:
|
Arguments:
|
||||||
closure (callable, optional): A closure that reevaluates the model
|
closure (:obj:`Callable`, `optional`): A closure that reevaluates the model and returns the loss.
|
||||||
and returns the loss.
|
|
||||||
"""
|
"""
|
||||||
loss = None
|
loss = None
|
||||||
if closure is not None:
|
if closure is not None:
|
||||||
|
|
|
@ -16,15 +16,36 @@
|
||||||
|
|
||||||
|
|
||||||
import re
|
import re
|
||||||
|
from typing import Callable, List, Optional, Union
|
||||||
|
|
||||||
import tensorflow as tf
|
import tensorflow as tf
|
||||||
|
|
||||||
|
|
||||||
class WarmUp(tf.keras.optimizers.schedules.LearningRateSchedule):
|
class WarmUp(tf.keras.optimizers.schedules.LearningRateSchedule):
|
||||||
"""Applies a warmup schedule on a given learning rate decay schedule."""
|
"""
|
||||||
|
Applies a warmup schedule on a given learning rate decay schedule.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
initial_learning_rate (:obj:`float`):
|
||||||
|
The initial learning rate for the schedule after the warmup (so this will be the learning rate at the end
|
||||||
|
of the warmup).
|
||||||
|
decay_schedule_fn (:obj:`Callable`):
|
||||||
|
The schedule function to apply after the warmup for the rest of training.
|
||||||
|
warmup_steps (:obj:`int`):
|
||||||
|
The number of steps for the warmup part of training.
|
||||||
|
power (:obj:`float`, `optional`, defaults to 1):
|
||||||
|
The power to use for the polynomial warmup (defaults is a linear warmup).
|
||||||
|
name (:obj:`str`, `optional`):
|
||||||
|
Optional name prefix for the returned tensors during the schedule.
|
||||||
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self, initial_learning_rate, decay_schedule_fn, warmup_steps, power=1.0, name=None,
|
self,
|
||||||
|
initial_learning_rate: float,
|
||||||
|
decay_schedule_fn: Callable,
|
||||||
|
warmup_steps: int,
|
||||||
|
power: float = 1.0,
|
||||||
|
name: str = None,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.initial_learning_rate = initial_learning_rate
|
self.initial_learning_rate = initial_learning_rate
|
||||||
|
@ -59,15 +80,34 @@ class WarmUp(tf.keras.optimizers.schedules.LearningRateSchedule):
|
||||||
|
|
||||||
|
|
||||||
def create_optimizer(
|
def create_optimizer(
|
||||||
init_lr,
|
init_lr: float,
|
||||||
num_train_steps,
|
num_train_steps: int,
|
||||||
num_warmup_steps,
|
num_warmup_steps: int,
|
||||||
min_lr_ratio=0.0,
|
min_lr_ratio: float = 0.0,
|
||||||
adam_epsilon=1e-8,
|
adam_epsilon: float = 1e-8,
|
||||||
weight_decay_rate=0.0,
|
weight_decay_rate: float = 0.0,
|
||||||
include_in_weight_decay=None,
|
include_in_weight_decay: Optional[List[str]] = None,
|
||||||
):
|
):
|
||||||
"""Creates an optimizer with learning rate schedule."""
|
"""
|
||||||
|
Creates an optimizer with a learning rate schedule using a warmup phase followed by a linear decay.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
init_lr (:obj:`float`):
|
||||||
|
The desired learning rate at the end of the warmup phase.
|
||||||
|
num_train_step (:obj:`int`):
|
||||||
|
The total number of training steps.
|
||||||
|
num_warmup_steps (:obj:`int`):
|
||||||
|
The number of warmup steps.
|
||||||
|
min_lr_ratio (:obj:`float`, `optional`, defaults to 0):
|
||||||
|
The final learning rate at the end of the linear decay will be :obj:`init_lr * min_lr_ratio`.
|
||||||
|
adam_epsilon (:obj:`float`, `optional`, defaults to 1e-8):
|
||||||
|
The epsilon to use in Adam.
|
||||||
|
weight_decay_rate (:obj:`float`, `optional`, defaults to 0):
|
||||||
|
The weight decay to use.
|
||||||
|
include_in_weight_decay (:obj:`List[str]`, `optional`):
|
||||||
|
List of the parameter names (or re patterns) to apply weight decay to. If none is passed, weight decay is
|
||||||
|
applied to all parameters except bias and layer norm parameters.
|
||||||
|
"""
|
||||||
# Implements linear decay of the learning rate.
|
# Implements linear decay of the learning rate.
|
||||||
lr_schedule = tf.keras.optimizers.schedules.PolynomialDecay(
|
lr_schedule = tf.keras.optimizers.schedules.PolynomialDecay(
|
||||||
initial_learning_rate=init_lr,
|
initial_learning_rate=init_lr,
|
||||||
|
@ -96,26 +136,55 @@ def create_optimizer(
|
||||||
|
|
||||||
|
|
||||||
class AdamWeightDecay(tf.keras.optimizers.Adam):
|
class AdamWeightDecay(tf.keras.optimizers.Adam):
|
||||||
"""Adam enables L2 weight decay and clip_by_global_norm on gradients.
|
"""
|
||||||
Just adding the square of the weights to the loss function is *not* the
|
Adam enables L2 weight decay and clip_by_global_norm on gradients. Just adding the square of the weights to the
|
||||||
correct way of using L2 regularization/weight decay with Adam, since that will
|
loss function is *not* the correct way of using L2 regularization/weight decay with Adam, since that will interact
|
||||||
interact with the m and v parameters in strange ways.
|
with the m and v parameters in strange ways as shown in
|
||||||
Instead we want ot decay the weights in a manner that doesn't interact with
|
`Decoupled Weight Decay Regularization <https://arxiv.org/abs/1711.05101>`__.
|
||||||
the m/v parameters. This is equivalent to adding the square of the weights to
|
|
||||||
the loss with plain (non-momentum) SGD.
|
Instead we want ot decay the weights in a manner that doesn't interact with the m/v parameters. This is equivalent
|
||||||
|
to adding the square of the weights to the loss with plain (non-momentum) SGD.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
learning_rate (:obj:`Union[float, tf.keras.optimizers.schedules.LearningRateSchedule]`, `optional`, defaults to 1e-3):
|
||||||
|
The learning rate to use or a schedule.
|
||||||
|
beta_1 (:obj:`float`, `optional`, defaults to 0.9):
|
||||||
|
The beta1 parameter in Adam, which is the exponential decay rate for the 1st momentum estimates.
|
||||||
|
beta_2 (:obj:`float`, `optional`, defaults to 0.999):
|
||||||
|
The beta2 parameter in Adam, which is the exponential decay rate for the 2nd momentum estimates.
|
||||||
|
epsilon (:obj:`float`, `optional`, defaults to 1e-7):
|
||||||
|
The epsilon paramenter in Adam, which is a small constant for numerical stability.
|
||||||
|
amsgrad (:obj:`bool`, `optional`, default to `False`):
|
||||||
|
Wheter to apply AMSGrad varient of this algorithm or not, see
|
||||||
|
`On the Convergence of Adam and Beyond <https://arxiv.org/abs/1904.09237>`__.
|
||||||
|
weight_decay_rate (:obj:`float`, `optional`, defaults to 0):
|
||||||
|
The weight decay to apply.
|
||||||
|
include_in_weight_decay (:obj:`List[str]`, `optional`):
|
||||||
|
List of the parameter names (or re patterns) to apply weight decay to. If none is passed, weight decay is
|
||||||
|
applied to all parameters by default (unless they are in :obj:`exclude_from_weight_decay`).
|
||||||
|
exclude_from_weight_decay (:obj:`List[str]`, `optional`):
|
||||||
|
List of the parameter names (or re patterns) to exclude from applying weight decay to. If a
|
||||||
|
:obj:`include_in_weight_decay` is passed, the names in it will supersede this list.
|
||||||
|
name (:obj:`str`, `optional`, defaults to 'AdamWeightDecay'):
|
||||||
|
Optional name for the operations created when applying gradients.
|
||||||
|
kwargs:
|
||||||
|
Keyward arguments. Allowed to be {``clipnorm``, ``clipvalue``, ``lr``, ``decay``}. ``clipnorm`` is clip
|
||||||
|
gradients by norm; ``clipvalue`` is clip gradients by value, ``decay`` is included for backward
|
||||||
|
compatibility to allow time inverse decay of learning rate. ``lr`` is included for backward compatibility,
|
||||||
|
recommended to use ``learning_rate`` instead.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
learning_rate=0.001,
|
learning_rate: Union[float, tf.keras.optimizers.schedules.LearningRateSchedule] = 0.001,
|
||||||
beta_1=0.9,
|
beta_1: float = 0.9,
|
||||||
beta_2=0.999,
|
beta_2: float = 0.999,
|
||||||
epsilon=1e-7,
|
epsilon: float = 1e-7,
|
||||||
amsgrad=False,
|
amsgrad: bool = False,
|
||||||
weight_decay_rate=0.0,
|
weight_decay_rate: float = 0.0,
|
||||||
include_in_weight_decay=None,
|
include_in_weight_decay: Optional[List[str]] = None,
|
||||||
exclude_from_weight_decay=None,
|
exclude_from_weight_decay: Optional[List[str]] = None,
|
||||||
name="AdamWeightDecay",
|
name: str = "AdamWeightDecay",
|
||||||
**kwargs
|
**kwargs
|
||||||
):
|
):
|
||||||
super().__init__(learning_rate, beta_1, beta_2, epsilon, amsgrad, name, **kwargs)
|
super().__init__(learning_rate, beta_1, beta_2, epsilon, amsgrad, name, **kwargs)
|
||||||
|
|
Загрузка…
Ссылка в новой задаче