diff --git a/mup/optim.py b/mup/optim.py index 0b04c48..a327996 100644 --- a/mup/optim.py +++ b/mup/optim.py @@ -40,6 +40,16 @@ def MuAdam(params, impl=Adam, decoupled_wd=False, **kwargs): Note for this to work properly, your model needs to have its base shapes set already using `mup.set_base_shapes`. + + Inputs: + impl: the specific Adam-like optimizer implementation from torch.optim or + elsewhere + decoupled_wd: if True, skips the mup scaling for weight decay, which should + be used for optimizer implementations that decouple weight decay from + learning rate. See https://github.com/microsoft/mup/issues/1 for a use case. + Outputs: + An instance of `impl` with refined parameter groups, each of which has the correctly + scaled learning rate according to mup. ''' new_param_groups = [] for param_group in process_param_groups(params, **kwargs): @@ -83,6 +93,16 @@ def MuSGD(params, impl=SGD, decoupled_wd=False, **kwargs): Note for this to work properly, your model needs to have its base shapes set already using `mup.set_base_shapes`. + + Inputs: + impl: the specific SGD-like optimizer implementation from torch.optim or + elsewhere + decoupled_wd: if True, skips the mup scaling for weight decay, which should + be used for optimizer implementations that decouple weight decay from + learning rate. See https://github.com/microsoft/mup/issues/1 for a use case. + Outputs: + An instance of `impl` with refined parameter groups, each of which has the correctly + scaled learning rate according to mup. ''' new_param_groups = [] for param_group in process_param_groups(params, **kwargs):