зеркало из https://github.com/microsoft/mup.git
Merge branch 'torchdistx' of github.com:microsoft/mup into torchdistx
This commit is contained in:
Коммит
18f2ff4fe9
|
@ -96,7 +96,7 @@ model = MyModel(width=100)
|
|||
### `model` behaves exactly the same as `base_model`
|
||||
### (which is in PyTorch's default parametrization).
|
||||
### This provides backward compatibility at this particular model size.
|
||||
### Otherwise, `model`'s init and LR is scaled by μP.
|
||||
### Otherwise, `model`'s init and LR are scaled by μP.
|
||||
### IMPORTANT: this should be called as soon as possible,
|
||||
### before re-initialization and optimizer definition.
|
||||
set_base_shapes(model, base_model, delta=delta_model)
|
||||
|
|
33
mup/optim.py
33
mup/optim.py
|
@ -35,11 +35,21 @@ def process_param_groups(params, **kwargs):
|
|||
param_group['weight_decay'] = kwargs.get('weight_decay', 0.)
|
||||
return param_groups
|
||||
|
||||
def MuAdam(params, impl=Adam, **kwargs):
|
||||
def MuAdam(params, impl=Adam, decoupled_wd=False, **kwargs):
|
||||
'''Adam with μP scaling.
|
||||
|
||||
Note for this to work properly, your model needs to have its base shapes set
|
||||
already using `mup.set_base_shapes`.
|
||||
|
||||
Inputs:
|
||||
impl: the specific Adam-like optimizer implementation from torch.optim or
|
||||
elsewhere
|
||||
decoupled_wd: if True, skips the mup scaling for weight decay, which should
|
||||
be used for optimizer implementations that decouple weight decay from
|
||||
learning rate. See https://github.com/microsoft/mup/issues/1 for a use case.
|
||||
Outputs:
|
||||
An instance of `impl` with refined parameter groups, each of which has the correctly
|
||||
scaled learning rate according to mup.
|
||||
'''
|
||||
new_param_groups = []
|
||||
for param_group in process_param_groups(params, **kwargs):
|
||||
|
@ -65,7 +75,8 @@ def MuAdam(params, impl=Adam, **kwargs):
|
|||
for width_mult, group in matrix_like_p.items():
|
||||
# Scale learning rate and weight decay accordingly
|
||||
group['lr'] /= width_mult
|
||||
group['weight_decay'] *= width_mult
|
||||
if not decoupled_wd:
|
||||
group['weight_decay'] *= width_mult
|
||||
new_param_groups.extend(list(matrix_like_p.values()) + [vector_like_p])
|
||||
return impl(new_param_groups, **kwargs)
|
||||
|
||||
|
@ -77,11 +88,21 @@ def MuAdamW(params, **kwargs):
|
|||
'''
|
||||
return MuAdam(params, impl=AdamW, **kwargs)
|
||||
|
||||
def MuSGD(params, impl=SGD, **kwargs):
|
||||
def MuSGD(params, impl=SGD, decoupled_wd=False, **kwargs):
|
||||
'''SGD with μP scaling.
|
||||
|
||||
Note for this to work properly, your model needs to have its base shapes set
|
||||
already using `mup.set_base_shapes`.
|
||||
|
||||
Inputs:
|
||||
impl: the specific SGD-like optimizer implementation from torch.optim or
|
||||
elsewhere
|
||||
decoupled_wd: if True, skips the mup scaling for weight decay, which should
|
||||
be used for optimizer implementations that decouple weight decay from
|
||||
learning rate. See https://github.com/microsoft/mup/issues/1 for a use case.
|
||||
Outputs:
|
||||
An instance of `impl` with refined parameter groups, each of which has the correctly
|
||||
scaled learning rate according to mup.
|
||||
'''
|
||||
new_param_groups = []
|
||||
for param_group in process_param_groups(params, **kwargs):
|
||||
|
@ -110,10 +131,12 @@ def MuSGD(params, impl=SGD, **kwargs):
|
|||
for width_mult, group in vector_like_p.items():
|
||||
# Scale learning rate and weight decay accordingly
|
||||
group['lr'] *= width_mult
|
||||
group['weight_decay'] /= width_mult
|
||||
if not decoupled_wd:
|
||||
group['weight_decay'] /= width_mult
|
||||
for shape_ratio, group in matrix_like_p.items():
|
||||
group['lr'] /= shape_ratio
|
||||
group['weight_decay'] *= shape_ratio
|
||||
if not decoupled_wd:
|
||||
group['weight_decay'] *= shape_ratio
|
||||
new_param_groups.extend(list(matrix_like_p.values()) + \
|
||||
list(vector_like_p.values()) + [fixed_p])
|
||||
return impl(new_param_groups, **kwargs)
|
||||
|
|
Загрузка…
Ссылка в новой задаче