Improve optimizer
This commit is contained in:
Родитель
fe0655cc42
Коммит
69c774a502
7
train.py
7
train.py
|
@ -234,12 +234,9 @@ def main(params):
|
|||
|
||||
### optimizer
|
||||
if isinstance(lm.model, nn.DataParallel):
|
||||
optimizer = eval(conf.optimizer_name)((it.__next__() for it in itertools.cycle([lm.model.parameters(), lm.model.module.layers['embedding'].get_parameters()])), **conf.optimizer_params)
|
||||
optimizer = eval(conf.optimizer_name)(list(lm.model.parameters()) + list(lm.model.module.layers['embedding'].get_parameters()), **conf.optimizer_params)
|
||||
else:
|
||||
optimizer = eval(conf.optimizer_name)((it.__next__() for it in itertools.cycle([lm.model.parameters(), lm.model.layers['embedding'].get_parameters()])), **conf.optimizer_params)
|
||||
|
||||
#optimizer = eval(conf.optimizer_name)(list(lm.model.parameters()) + list(lm.model.module.layers['embedding'].get_parameters()), **conf.optimizer_params)
|
||||
#optimizer = eval(conf.optimizer_name)(lm.model.parameters(), **conf.optimizer_params)
|
||||
optimizer = eval(conf.optimizer_name)(list(lm.model.parameters()) + list(lm.model.layers['embedding'].get_parameters()), **conf.optimizer_params)
|
||||
|
||||
## train
|
||||
lm.train(optimizer, loss_fn)
|
||||
|
|
Загрузка…
Ссылка в новой задаче