From 5dcc1c6847b31913a7d257ec51a817ddd303fecd Mon Sep 17 00:00:00 2001 From: Greg Yang <53244851+thegregyang@users.noreply.github.com> Date: Mon, 9 May 2022 00:37:07 -0400 Subject: [PATCH 1/5] typo --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index d54dc91..7422e4a 100644 --- a/README.md +++ b/README.md @@ -96,7 +96,7 @@ model = MyModel(width=100) ### `model` behaves exactly the same as `base_model` ### (which is in PyTorch's default parametrization). ### This provides backward compatibility at this particular model size. -### Otherwise, `model`'s init and LR is scaled by μP. +### Otherwise, `model`'s init and LR are scaled by μP. ### IMPORTANT: this should be called as soon as possible, ### before re-initialization and optimizer definition. set_base_shapes(model, base_model, delta=delta_model) From d7c94f9e34b6ec6c84c93928def88eb4195d87da Mon Sep 17 00:00:00 2001 From: Edward Hu Date: Wed, 27 Apr 2022 13:46:44 -0400 Subject: [PATCH 2/5] add an option to not scale wd for decoupled optimizers --- mup/optim.py | 13 ++++++++----- 1 file changed, 8 insertions(+), 5 deletions(-) diff --git a/mup/optim.py b/mup/optim.py index 3c3e22f..0b04c48 100644 --- a/mup/optim.py +++ b/mup/optim.py @@ -35,7 +35,7 @@ def process_param_groups(params, **kwargs): param_group['weight_decay'] = kwargs.get('weight_decay', 0.) return param_groups -def MuAdam(params, impl=Adam, **kwargs): +def MuAdam(params, impl=Adam, decoupled_wd=False, **kwargs): '''Adam with μP scaling. Note for this to work properly, your model needs to have its base shapes set @@ -65,7 +65,8 @@ def MuAdam(params, impl=Adam, **kwargs): for width_mult, group in matrix_like_p.items(): # Scale learning rate and weight decay accordingly group['lr'] /= width_mult - group['weight_decay'] *= width_mult + if not decoupled_wd: + group['weight_decay'] *= width_mult new_param_groups.extend(list(matrix_like_p.values()) + [vector_like_p]) return impl(new_param_groups, **kwargs) @@ -77,7 +78,7 @@ def MuAdamW(params, **kwargs): ''' return MuAdam(params, impl=AdamW, **kwargs) -def MuSGD(params, impl=SGD, **kwargs): +def MuSGD(params, impl=SGD, decoupled_wd=False, **kwargs): '''SGD with μP scaling. Note for this to work properly, your model needs to have its base shapes set @@ -110,10 +111,12 @@ def MuSGD(params, impl=SGD, **kwargs): for width_mult, group in vector_like_p.items(): # Scale learning rate and weight decay accordingly group['lr'] *= width_mult - group['weight_decay'] /= width_mult + if not decoupled_wd: + group['weight_decay'] /= width_mult for shape_ratio, group in matrix_like_p.items(): group['lr'] /= shape_ratio - group['weight_decay'] *= shape_ratio + if not decoupled_wd: + group['weight_decay'] *= shape_ratio new_param_groups.extend(list(matrix_like_p.values()) + \ list(vector_like_p.values()) + [fixed_p]) return impl(new_param_groups, **kwargs) From ba61bd1b4b65cb9a1f13dd24a5dc52ca615ad982 Mon Sep 17 00:00:00 2001 From: Edward Hu <41635632+edwardjhu@users.noreply.github.com> Date: Wed, 18 May 2022 19:09:49 -0400 Subject: [PATCH 3/5] Update optim.py --- mup/optim.py | 20 ++++++++++++++++++++ 1 file changed, 20 insertions(+) diff --git a/mup/optim.py b/mup/optim.py index 0b04c48..a327996 100644 --- a/mup/optim.py +++ b/mup/optim.py @@ -40,6 +40,16 @@ def MuAdam(params, impl=Adam, decoupled_wd=False, **kwargs): Note for this to work properly, your model needs to have its base shapes set already using `mup.set_base_shapes`. + + Inputs: + impl: the specific Adam-like optimizer implementation from torch.optim or + elsewhere + decoupled_wd: if True, skips the mup scaling for weight decay, which should + be used for optimizer implementations that decouple weight decay from + learning rate. See https://github.com/microsoft/mup/issues/1 for a use case. + Outputs: + An instance of `impl` with refined parameter groups, each of which has the correctly + scaled learning rate according to mup. ''' new_param_groups = [] for param_group in process_param_groups(params, **kwargs): @@ -83,6 +93,16 @@ def MuSGD(params, impl=SGD, decoupled_wd=False, **kwargs): Note for this to work properly, your model needs to have its base shapes set already using `mup.set_base_shapes`. + + Inputs: + impl: the specific SGD-like optimizer implementation from torch.optim or + elsewhere + decoupled_wd: if True, skips the mup scaling for weight decay, which should + be used for optimizer implementations that decouple weight decay from + learning rate. See https://github.com/microsoft/mup/issues/1 for a use case. + Outputs: + An instance of `impl` with refined parameter groups, each of which has the correctly + scaled learning rate according to mup. ''' new_param_groups = [] for param_group in process_param_groups(params, **kwargs): From 812fb0261f7238226d301d44da77f7139cbe2ab7 Mon Sep 17 00:00:00 2001 From: Edward Hu Date: Sun, 8 May 2022 08:17:03 -0400 Subject: [PATCH 4/5] add torchdistx to readme --- README.md | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/README.md b/README.md index 7422e4a..d85345e 100644 --- a/README.md +++ b/README.md @@ -76,14 +76,14 @@ class MyModel(nn.Module): ### Instantiate a base model base_model = MyModel(width=1) -### Optionally, use `device='meta'` to avoid instantiating the parameters -### This requires you to pass the device flag down to all sub-modules -# base_model = MyModel(width=1, device='meta') +### Optionally, use `torchdistx.deferred_init.deferred_init` to avoid instantiating the parameters +### Simply install `torchdistx` and use +# base_model = torchdistx.deferred_init.deferred_init(MyModel, width=1) ### Instantiate a "delta" model that differs from the base model ### in all dimensions ("widths") that one wishes to scale. ### Here it's simple, but e.g., in a Transformer, you may want to scale ### both nhead and dhead, so the delta model should differ in both. -delta_model = MyModel(width=2) # Optionally add the `device='meta'` to avoid instantiating +delta_model = MyModel(width=2) # Optionally use `torchdistx` to avoid instantiating ### Instantiate the target model (the model you actually want to train). ### This should be the same as the base model except From 244c36086a7c13060c5956c750201807f795b12f Mon Sep 17 00:00:00 2001 From: Edward Hu Date: Sun, 8 May 2022 08:34:13 -0400 Subject: [PATCH 5/5] add torchdistx to readme --- README.md | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/README.md b/README.md index d85345e..409f746 100644 --- a/README.md +++ b/README.md @@ -126,9 +126,8 @@ optimizer = MuSGD(model.parameters(), lr=0.1) ``` Note the base and delta models *do not need to be trained* --- we are only extracting parameter shape information from them. -Therefore, optionally, we can avoid instantiating these potentially large models by passing `device='meta'` to their constructor. -However, you need to make sure that the `device` flag is appropriately passed down to the constructor of all submodules. -Of course, it'd be even better if PyTorch can do this automatically for any existing `nn.Module`. If you want to see this happen, please upvote [this PyTorch issue](https://github.com/pytorch/pytorch/issues/74143). +Therefore, optionally, we can avoid instantiating these potentially large models by using the `deferred_init` function in `torchdistx`. +After installing [`torchdistx`](https://github.com/pytorch/torchdistx), use `torchdistx.deferred_init.deferred_init(MyModel, **args)` instead of `MyModel(**args)`. See [this page](https://pytorch.org/torchdistx/latest/deferred_init.html) for more detail. ## How `mup` Works Under the Hood