From 59b0c8694f6deb1b4ddc54f3d74e9b7025dac9c7 Mon Sep 17 00:00:00 2001 From: Edward Hu <41635632+edwardjhu@users.noreply.github.com> Date: Wed, 18 May 2022 19:09:49 -0400 Subject: [PATCH] Update optim.py --- mup/optim.py | 20 ++++++++++++++++++++ 1 file changed, 20 insertions(+) diff --git a/mup/optim.py b/mup/optim.py index 0b04c48..a327996 100644 --- a/mup/optim.py +++ b/mup/optim.py @@ -40,6 +40,16 @@ def MuAdam(params, impl=Adam, decoupled_wd=False, **kwargs): Note for this to work properly, your model needs to have its base shapes set already using `mup.set_base_shapes`. + + Inputs: + impl: the specific Adam-like optimizer implementation from torch.optim or + elsewhere + decoupled_wd: if True, skips the mup scaling for weight decay, which should + be used for optimizer implementations that decouple weight decay from + learning rate. See https://github.com/microsoft/mup/issues/1 for a use case. + Outputs: + An instance of `impl` with refined parameter groups, each of which has the correctly + scaled learning rate according to mup. ''' new_param_groups = [] for param_group in process_param_groups(params, **kwargs): @@ -83,6 +93,16 @@ def MuSGD(params, impl=SGD, decoupled_wd=False, **kwargs): Note for this to work properly, your model needs to have its base shapes set already using `mup.set_base_shapes`. + + Inputs: + impl: the specific SGD-like optimizer implementation from torch.optim or + elsewhere + decoupled_wd: if True, skips the mup scaling for weight decay, which should + be used for optimizer implementations that decouple weight decay from + learning rate. See https://github.com/microsoft/mup/issues/1 for a use case. + Outputs: + An instance of `impl` with refined parameter groups, each of which has the correctly + scaled learning rate according to mup. ''' new_param_groups = [] for param_group in process_param_groups(params, **kwargs):