This commit is contained in:
Eren Golge 2018-05-16 19:20:40 -07:00
Родитель 520ba9551f
Коммит e6112f7b2d
1 изменённых файлов: 7 добавлений и 1 удалений

Просмотреть файл

@ -371,12 +371,18 @@ def main(args):
if args.restore_path:
checkpoint = torch.load(args.restore_path)
model.load_state_dict(checkpoint['model'])
optimizer = optim.Adam(model.parameters(), lr=c.lr)
optimizer.load_state_dict(checkpoint['optimizer'])
print("\n > Model restored from step %d\n" % checkpoint['step'])
for state in optimizer.state.values():
for k, v in state.items():
if torch.is_tensor(v):
state[k] = v.cuda()
print(" > Model restored from step %d" % checkpoint['step'])
start_epoch = checkpoint['step'] // len(train_loader)
best_loss = checkpoint['linear_loss']
start_epoch = 0
args.restore_step = checkpoint['step']
optimizer_st = optim.Adam(model.decoder.stopnet.parameters(), lr=c.lr)
else:
args.restore_step = 0
print("\n > Starting a new training")