This commit is contained in:
Edward Hu 2022-05-18 19:09:49 -04:00 коммит произвёл GitHub
Родитель 44303b6e63
Коммит 59b0c8694f
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: 4AEE18F83AFDEB23
1 изменённых файлов: 20 добавлений и 0 удалений

Просмотреть файл

@ -40,6 +40,16 @@ def MuAdam(params, impl=Adam, decoupled_wd=False, **kwargs):
Note for this to work properly, your model needs to have its base shapes set Note for this to work properly, your model needs to have its base shapes set
already using `mup.set_base_shapes`. already using `mup.set_base_shapes`.
Inputs:
impl: the specific Adam-like optimizer implementation from torch.optim or
elsewhere
decoupled_wd: if True, skips the mup scaling for weight decay, which should
be used for optimizer implementations that decouple weight decay from
learning rate. See https://github.com/microsoft/mup/issues/1 for a use case.
Outputs:
An instance of `impl` with refined parameter groups, each of which has the correctly
scaled learning rate according to mup.
''' '''
new_param_groups = [] new_param_groups = []
for param_group in process_param_groups(params, **kwargs): for param_group in process_param_groups(params, **kwargs):
@ -83,6 +93,16 @@ def MuSGD(params, impl=SGD, decoupled_wd=False, **kwargs):
Note for this to work properly, your model needs to have its base shapes set Note for this to work properly, your model needs to have its base shapes set
already using `mup.set_base_shapes`. already using `mup.set_base_shapes`.
Inputs:
impl: the specific SGD-like optimizer implementation from torch.optim or
elsewhere
decoupled_wd: if True, skips the mup scaling for weight decay, which should
be used for optimizer implementations that decouple weight decay from
learning rate. See https://github.com/microsoft/mup/issues/1 for a use case.
Outputs:
An instance of `impl` with refined parameter groups, each of which has the correctly
scaled learning rate according to mup.
''' '''
new_param_groups = [] new_param_groups = []
for param_group in process_param_groups(params, **kwargs): for param_group in process_param_groups(params, **kwargs):