зеркало из https://github.com/microsoft/mup.git
Update optim.py
This commit is contained in:
Родитель
44303b6e63
Коммит
59b0c8694f
20
mup/optim.py
20
mup/optim.py
|
@ -40,6 +40,16 @@ def MuAdam(params, impl=Adam, decoupled_wd=False, **kwargs):
|
||||||
|
|
||||||
Note for this to work properly, your model needs to have its base shapes set
|
Note for this to work properly, your model needs to have its base shapes set
|
||||||
already using `mup.set_base_shapes`.
|
already using `mup.set_base_shapes`.
|
||||||
|
|
||||||
|
Inputs:
|
||||||
|
impl: the specific Adam-like optimizer implementation from torch.optim or
|
||||||
|
elsewhere
|
||||||
|
decoupled_wd: if True, skips the mup scaling for weight decay, which should
|
||||||
|
be used for optimizer implementations that decouple weight decay from
|
||||||
|
learning rate. See https://github.com/microsoft/mup/issues/1 for a use case.
|
||||||
|
Outputs:
|
||||||
|
An instance of `impl` with refined parameter groups, each of which has the correctly
|
||||||
|
scaled learning rate according to mup.
|
||||||
'''
|
'''
|
||||||
new_param_groups = []
|
new_param_groups = []
|
||||||
for param_group in process_param_groups(params, **kwargs):
|
for param_group in process_param_groups(params, **kwargs):
|
||||||
|
@ -83,6 +93,16 @@ def MuSGD(params, impl=SGD, decoupled_wd=False, **kwargs):
|
||||||
|
|
||||||
Note for this to work properly, your model needs to have its base shapes set
|
Note for this to work properly, your model needs to have its base shapes set
|
||||||
already using `mup.set_base_shapes`.
|
already using `mup.set_base_shapes`.
|
||||||
|
|
||||||
|
Inputs:
|
||||||
|
impl: the specific SGD-like optimizer implementation from torch.optim or
|
||||||
|
elsewhere
|
||||||
|
decoupled_wd: if True, skips the mup scaling for weight decay, which should
|
||||||
|
be used for optimizer implementations that decouple weight decay from
|
||||||
|
learning rate. See https://github.com/microsoft/mup/issues/1 for a use case.
|
||||||
|
Outputs:
|
||||||
|
An instance of `impl` with refined parameter groups, each of which has the correctly
|
||||||
|
scaled learning rate according to mup.
|
||||||
'''
|
'''
|
||||||
new_param_groups = []
|
new_param_groups = []
|
||||||
for param_group in process_param_groups(params, **kwargs):
|
for param_group in process_param_groups(params, **kwargs):
|
||||||
|
|
Загрузка…
Ссылка в новой задаче