This commit is contained in:
Wutao Lin 2019-05-24 11:55:40 +08:00
Родитель fe0655cc42
Коммит 69c774a502
1 изменённых файлов: 2 добавлений и 5 удалений

Просмотреть файл

@ -234,12 +234,9 @@ def main(params):
### optimizer
if isinstance(lm.model, nn.DataParallel):
optimizer = eval(conf.optimizer_name)((it.__next__() for it in itertools.cycle([lm.model.parameters(), lm.model.module.layers['embedding'].get_parameters()])), **conf.optimizer_params)
optimizer = eval(conf.optimizer_name)(list(lm.model.parameters()) + list(lm.model.module.layers['embedding'].get_parameters()), **conf.optimizer_params)
else:
optimizer = eval(conf.optimizer_name)((it.__next__() for it in itertools.cycle([lm.model.parameters(), lm.model.layers['embedding'].get_parameters()])), **conf.optimizer_params)
#optimizer = eval(conf.optimizer_name)(list(lm.model.parameters()) + list(lm.model.module.layers['embedding'].get_parameters()), **conf.optimizer_params)
#optimizer = eval(conf.optimizer_name)(lm.model.parameters(), **conf.optimizer_params)
optimizer = eval(conf.optimizer_name)(list(lm.model.parameters()) + list(lm.model.layers['embedding'].get_parameters()), **conf.optimizer_params)
## train
lm.train(optimizer, loss_fn)