Merge branch 'torchdistx' of github.com:microsoft/mup into torchdistx

This commit is contained in:
Edward Hu 2022-05-30 17:49:17 -04:00
Родитель 265f2d9f63 244c36086a
Коммит 18f2ff4fe9
2 изменённых файлов: 29 добавлений и 6 удалений

Просмотреть файл

@ -96,7 +96,7 @@ model = MyModel(width=100)
### `model` behaves exactly the same as `base_model`
### (which is in PyTorch's default parametrization).
### This provides backward compatibility at this particular model size.
### Otherwise, `model`'s init and LR is scaled by μP.
### Otherwise, `model`'s init and LR are scaled by μP.
### IMPORTANT: this should be called as soon as possible,
### before re-initialization and optimizer definition.
set_base_shapes(model, base_model, delta=delta_model)

Просмотреть файл

@ -35,11 +35,21 @@ def process_param_groups(params, **kwargs):
param_group['weight_decay'] = kwargs.get('weight_decay', 0.)
return param_groups
def MuAdam(params, impl=Adam, **kwargs):
def MuAdam(params, impl=Adam, decoupled_wd=False, **kwargs):
'''Adam with μP scaling.
Note for this to work properly, your model needs to have its base shapes set
already using `mup.set_base_shapes`.
Inputs:
impl: the specific Adam-like optimizer implementation from torch.optim or
elsewhere
decoupled_wd: if True, skips the mup scaling for weight decay, which should
be used for optimizer implementations that decouple weight decay from
learning rate. See https://github.com/microsoft/mup/issues/1 for a use case.
Outputs:
An instance of `impl` with refined parameter groups, each of which has the correctly
scaled learning rate according to mup.
'''
new_param_groups = []
for param_group in process_param_groups(params, **kwargs):
@ -65,7 +75,8 @@ def MuAdam(params, impl=Adam, **kwargs):
for width_mult, group in matrix_like_p.items():
# Scale learning rate and weight decay accordingly
group['lr'] /= width_mult
group['weight_decay'] *= width_mult
if not decoupled_wd:
group['weight_decay'] *= width_mult
new_param_groups.extend(list(matrix_like_p.values()) + [vector_like_p])
return impl(new_param_groups, **kwargs)
@ -77,11 +88,21 @@ def MuAdamW(params, **kwargs):
'''
return MuAdam(params, impl=AdamW, **kwargs)
def MuSGD(params, impl=SGD, **kwargs):
def MuSGD(params, impl=SGD, decoupled_wd=False, **kwargs):
'''SGD with μP scaling.
Note for this to work properly, your model needs to have its base shapes set
already using `mup.set_base_shapes`.
Inputs:
impl: the specific SGD-like optimizer implementation from torch.optim or
elsewhere
decoupled_wd: if True, skips the mup scaling for weight decay, which should
be used for optimizer implementations that decouple weight decay from
learning rate. See https://github.com/microsoft/mup/issues/1 for a use case.
Outputs:
An instance of `impl` with refined parameter groups, each of which has the correctly
scaled learning rate according to mup.
'''
new_param_groups = []
for param_group in process_param_groups(params, **kwargs):
@ -110,10 +131,12 @@ def MuSGD(params, impl=SGD, **kwargs):
for width_mult, group in vector_like_p.items():
# Scale learning rate and weight decay accordingly
group['lr'] *= width_mult
group['weight_decay'] /= width_mult
if not decoupled_wd:
group['weight_decay'] /= width_mult
for shape_ratio, group in matrix_like_p.items():
group['lr'] /= shape_ratio
group['weight_decay'] *= shape_ratio
if not decoupled_wd:
group['weight_decay'] *= shape_ratio
new_param_groups.extend(list(matrix_like_p.values()) + \
list(vector_like_p.values()) + [fixed_p])
return impl(new_param_groups, **kwargs)