зеркало из https://github.com/microsoft/mup.git
add an option to not scale wd for decoupled optimizers
This commit is contained in:
Родитель
5dcc1c6847
Коммит
d7c94f9e34
13
mup/optim.py
13
mup/optim.py
|
@ -35,7 +35,7 @@ def process_param_groups(params, **kwargs):
|
|||
param_group['weight_decay'] = kwargs.get('weight_decay', 0.)
|
||||
return param_groups
|
||||
|
||||
def MuAdam(params, impl=Adam, **kwargs):
|
||||
def MuAdam(params, impl=Adam, decoupled_wd=False, **kwargs):
|
||||
'''Adam with μP scaling.
|
||||
|
||||
Note for this to work properly, your model needs to have its base shapes set
|
||||
|
@ -65,7 +65,8 @@ def MuAdam(params, impl=Adam, **kwargs):
|
|||
for width_mult, group in matrix_like_p.items():
|
||||
# Scale learning rate and weight decay accordingly
|
||||
group['lr'] /= width_mult
|
||||
group['weight_decay'] *= width_mult
|
||||
if not decoupled_wd:
|
||||
group['weight_decay'] *= width_mult
|
||||
new_param_groups.extend(list(matrix_like_p.values()) + [vector_like_p])
|
||||
return impl(new_param_groups, **kwargs)
|
||||
|
||||
|
@ -77,7 +78,7 @@ def MuAdamW(params, **kwargs):
|
|||
'''
|
||||
return MuAdam(params, impl=AdamW, **kwargs)
|
||||
|
||||
def MuSGD(params, impl=SGD, **kwargs):
|
||||
def MuSGD(params, impl=SGD, decoupled_wd=False, **kwargs):
|
||||
'''SGD with μP scaling.
|
||||
|
||||
Note for this to work properly, your model needs to have its base shapes set
|
||||
|
@ -110,10 +111,12 @@ def MuSGD(params, impl=SGD, **kwargs):
|
|||
for width_mult, group in vector_like_p.items():
|
||||
# Scale learning rate and weight decay accordingly
|
||||
group['lr'] *= width_mult
|
||||
group['weight_decay'] /= width_mult
|
||||
if not decoupled_wd:
|
||||
group['weight_decay'] /= width_mult
|
||||
for shape_ratio, group in matrix_like_p.items():
|
||||
group['lr'] /= shape_ratio
|
||||
group['weight_decay'] *= shape_ratio
|
||||
if not decoupled_wd:
|
||||
group['weight_decay'] *= shape_ratio
|
||||
new_param_groups.extend(list(matrix_like_p.values()) + \
|
||||
list(vector_like_p.values()) + [fixed_p])
|
||||
return impl(new_param_groups, **kwargs)
|
||||
|
|
Загрузка…
Ссылка в новой задаче