diff --git a/train.py b/train.py index 71dd4dc..aa8e92e 100644 --- a/train.py +++ b/train.py @@ -362,17 +362,18 @@ def main(args): checkpoint = torch.load(args.restore_path) model.load_state_dict(checkpoint['model']) optimizer.load_state_dict(checkpoint['optimizer']) - print("\n > Model restored from step %d\n" % checkpoint['step']) + print(" > Model restored from step %d" % checkpoint['step']) start_epoch = checkpoint['step'] // len(train_loader) best_loss = checkpoint['linear_loss'] start_epoch = 0 args.restore_step = checkpoint['step'] else: args.restore_step = 0 - print("\n > Starting a new training") + print(" > Starting a new training") if use_cuda: - model = nn.DataParallel(model.cuda()) + print(" > Using CUDA.") + model = nn.DataParallel(model).cuda() num_params = count_parameters(model) print(" | > Model has {} parameters".format(num_params))