This commit is contained in:
Edward Hu 2022-05-18 19:09:49 -04:00 коммит произвёл GitHub
Родитель 44303b6e63
Коммит 59b0c8694f
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: 4AEE18F83AFDEB23
1 изменённых файлов: 20 добавлений и 0 удалений

Просмотреть файл

@ -40,6 +40,16 @@ def MuAdam(params, impl=Adam, decoupled_wd=False, **kwargs):
Note for this to work properly, your model needs to have its base shapes set
already using `mup.set_base_shapes`.
Inputs:
impl: the specific Adam-like optimizer implementation from torch.optim or
elsewhere
decoupled_wd: if True, skips the mup scaling for weight decay, which should
be used for optimizer implementations that decouple weight decay from
learning rate. See https://github.com/microsoft/mup/issues/1 for a use case.
Outputs:
An instance of `impl` with refined parameter groups, each of which has the correctly
scaled learning rate according to mup.
'''
new_param_groups = []
for param_group in process_param_groups(params, **kwargs):
@ -83,6 +93,16 @@ def MuSGD(params, impl=SGD, decoupled_wd=False, **kwargs):
Note for this to work properly, your model needs to have its base shapes set
already using `mup.set_base_shapes`.
Inputs:
impl: the specific SGD-like optimizer implementation from torch.optim or
elsewhere
decoupled_wd: if True, skips the mup scaling for weight decay, which should
be used for optimizer implementations that decouple weight decay from
learning rate. See https://github.com/microsoft/mup/issues/1 for a use case.
Outputs:
An instance of `impl` with refined parameter groups, each of which has the correctly
scaled learning rate according to mup.
'''
new_param_groups = []
for param_group in process_param_groups(params, **kwargs):