add an option to not scale wd for decoupled optimizers

This commit is contained in:
Edward Hu 2022-04-27 13:46:44 -04:00
Родитель 5dcc1c6847
Коммит d7c94f9e34
1 изменённых файлов: 8 добавлений и 5 удалений

Просмотреть файл

@ -35,7 +35,7 @@ def process_param_groups(params, **kwargs):
param_group['weight_decay'] = kwargs.get('weight_decay', 0.)
return param_groups
def MuAdam(params, impl=Adam, **kwargs):
def MuAdam(params, impl=Adam, decoupled_wd=False, **kwargs):
'''Adam with μP scaling.
Note for this to work properly, your model needs to have its base shapes set
@ -65,7 +65,8 @@ def MuAdam(params, impl=Adam, **kwargs):
for width_mult, group in matrix_like_p.items():
# Scale learning rate and weight decay accordingly
group['lr'] /= width_mult
group['weight_decay'] *= width_mult
if not decoupled_wd:
group['weight_decay'] *= width_mult
new_param_groups.extend(list(matrix_like_p.values()) + [vector_like_p])
return impl(new_param_groups, **kwargs)
@ -77,7 +78,7 @@ def MuAdamW(params, **kwargs):
'''
return MuAdam(params, impl=AdamW, **kwargs)
def MuSGD(params, impl=SGD, **kwargs):
def MuSGD(params, impl=SGD, decoupled_wd=False, **kwargs):
'''SGD with μP scaling.
Note for this to work properly, your model needs to have its base shapes set
@ -110,10 +111,12 @@ def MuSGD(params, impl=SGD, **kwargs):
for width_mult, group in vector_like_p.items():
# Scale learning rate and weight decay accordingly
group['lr'] *= width_mult
group['weight_decay'] /= width_mult
if not decoupled_wd:
group['weight_decay'] /= width_mult
for shape_ratio, group in matrix_like_p.items():
group['lr'] /= shape_ratio
group['weight_decay'] *= shape_ratio
if not decoupled_wd:
group['weight_decay'] *= shape_ratio
new_param_groups.extend(list(matrix_like_p.values()) + \
list(vector_like_p.values()) + [fixed_p])
return impl(new_param_groups, **kwargs)