зеркало из https://github.com/mozilla/TTS.git
fix optimizer for restored model
This commit is contained in:
Родитель
855bf8e195
Коммит
1b59d8110c
7
train.py
7
train.py
|
@ -362,17 +362,18 @@ def main(args):
|
|||
checkpoint = torch.load(args.restore_path)
|
||||
model.load_state_dict(checkpoint['model'])
|
||||
optimizer.load_state_dict(checkpoint['optimizer'])
|
||||
print("\n > Model restored from step %d\n" % checkpoint['step'])
|
||||
print(" > Model restored from step %d" % checkpoint['step'])
|
||||
start_epoch = checkpoint['step'] // len(train_loader)
|
||||
best_loss = checkpoint['linear_loss']
|
||||
start_epoch = 0
|
||||
args.restore_step = checkpoint['step']
|
||||
else:
|
||||
args.restore_step = 0
|
||||
print("\n > Starting a new training")
|
||||
print(" > Starting a new training")
|
||||
|
||||
if use_cuda:
|
||||
model = nn.DataParallel(model.cuda())
|
||||
print(" > Using CUDA.")
|
||||
model = nn.DataParallel(model).cuda()
|
||||
|
||||
num_params = count_parameters(model)
|
||||
print(" | > Model has {} parameters".format(num_params))
|
||||
|
|
Загрузка…
Ссылка в новой задаче