Remove DataParallel from the model state before saving

This commit is contained in:
Eren Golge 2018-02-21 07:03:53 -08:00
Родитель 2c345b622d
Коммит 1bfd8f73e7
4 изменённых файлов: 25707 добавлений и 277 удалений

Просмотреть файл

@ -213,7 +213,7 @@ class Decoder(nn.Module):
r (int): number of outputs per time step.
eps (float): threshold for detecting the end of a sentence.
"""
def __init__(self, in_features, memory_dim, r, eps=0.2):
def __init__(self, in_features, memory_dim, r, eps=0.05):
super(Decoder, self).__init__()
self.max_decoder_steps = 200
self.memory_dim = memory_dim

Различия файлов скрыты, потому что одна или несколько строк слишком длинны

Двоичные данные
utils/.data.py.swp

Двоичный файл не отображается.

Просмотреть файл

@ -48,12 +48,26 @@ def copy_config_file(config_file, path):
shutil.copyfile(config_file, out_path)
def _trim_model_state_dict(state_dict):
r"""Remove 'module.' prefix from state dictionary. It is necessary as it
is loded for the next time by model.load_state(). Otherwise, it complains
about the torch.DataParallel()"""
new_state_dict = OrderedDict()
for k, v in state_dict.items():
name = k[7:] # remove `module.`
new_state_dict[name] = v
return new_state_dict
def save_checkpoint(model, optimizer, model_loss, best_loss, out_path,
current_step, epoch):
checkpoint_path = 'checkpoint_{}.pth.tar'.format(current_step)
checkpoint_path = os.path.join(out_path, checkpoint_path)
print("\n | > Checkpoint saving : {}".format(checkpoint_path))
state = {'model': model.state_dict(),
new_state_dict = _trim_model_state_dict(model.state_dict())
state = {'model': new_state_dict,
'optimizer': optimizer.state_dict(),
'step': current_step,
'epoch': epoch,
@ -65,7 +79,8 @@ def save_checkpoint(model, optimizer, model_loss, best_loss, out_path,
def save_best_model(model, optimizer, model_loss, best_loss, out_path,
current_step, epoch):
if model_loss < best_loss:
state = {'model': model.state_dict(),
new_state_dict = _trim_model_state_dict(model.state_dict())
state = {'model': new_state_dict,
'optimizer': optimizer.state_dict(),
'step': current_step,
'epoch': epoch,