diff --git a/README.md b/README.md index 082b529..409f746 100644 --- a/README.md +++ b/README.md @@ -96,7 +96,7 @@ model = MyModel(width=100) ### `model` behaves exactly the same as `base_model` ### (which is in PyTorch's default parametrization). ### This provides backward compatibility at this particular model size. -### Otherwise, `model`'s init and LR is scaled by μP. +### Otherwise, `model`'s init and LR are scaled by μP. ### IMPORTANT: this should be called as soon as possible, ### before re-initialization and optimizer definition. set_base_shapes(model, base_model, delta=delta_model) diff --git a/mup/optim.py b/mup/optim.py index 3c3e22f..a327996 100644 --- a/mup/optim.py +++ b/mup/optim.py @@ -35,11 +35,21 @@ def process_param_groups(params, **kwargs): param_group['weight_decay'] = kwargs.get('weight_decay', 0.) return param_groups -def MuAdam(params, impl=Adam, **kwargs): +def MuAdam(params, impl=Adam, decoupled_wd=False, **kwargs): '''Adam with μP scaling. Note for this to work properly, your model needs to have its base shapes set already using `mup.set_base_shapes`. + + Inputs: + impl: the specific Adam-like optimizer implementation from torch.optim or + elsewhere + decoupled_wd: if True, skips the mup scaling for weight decay, which should + be used for optimizer implementations that decouple weight decay from + learning rate. See https://github.com/microsoft/mup/issues/1 for a use case. + Outputs: + An instance of `impl` with refined parameter groups, each of which has the correctly + scaled learning rate according to mup. ''' new_param_groups = [] for param_group in process_param_groups(params, **kwargs): @@ -65,7 +75,8 @@ def MuAdam(params, impl=Adam, **kwargs): for width_mult, group in matrix_like_p.items(): # Scale learning rate and weight decay accordingly group['lr'] /= width_mult - group['weight_decay'] *= width_mult + if not decoupled_wd: + group['weight_decay'] *= width_mult new_param_groups.extend(list(matrix_like_p.values()) + [vector_like_p]) return impl(new_param_groups, **kwargs) @@ -77,11 +88,21 @@ def MuAdamW(params, **kwargs): ''' return MuAdam(params, impl=AdamW, **kwargs) -def MuSGD(params, impl=SGD, **kwargs): +def MuSGD(params, impl=SGD, decoupled_wd=False, **kwargs): '''SGD with μP scaling. Note for this to work properly, your model needs to have its base shapes set already using `mup.set_base_shapes`. + + Inputs: + impl: the specific SGD-like optimizer implementation from torch.optim or + elsewhere + decoupled_wd: if True, skips the mup scaling for weight decay, which should + be used for optimizer implementations that decouple weight decay from + learning rate. See https://github.com/microsoft/mup/issues/1 for a use case. + Outputs: + An instance of `impl` with refined parameter groups, each of which has the correctly + scaled learning rate according to mup. ''' new_param_groups = [] for param_group in process_param_groups(params, **kwargs): @@ -110,10 +131,12 @@ def MuSGD(params, impl=SGD, **kwargs): for width_mult, group in vector_like_p.items(): # Scale learning rate and weight decay accordingly group['lr'] *= width_mult - group['weight_decay'] /= width_mult + if not decoupled_wd: + group['weight_decay'] /= width_mult for shape_ratio, group in matrix_like_p.items(): group['lr'] /= shape_ratio - group['weight_decay'] *= shape_ratio + if not decoupled_wd: + group['weight_decay'] *= shape_ratio new_param_groups.extend(list(matrix_like_p.values()) + \ list(vector_like_p.values()) + [fixed_p]) return impl(new_param_groups, **kwargs)