зеркало из https://github.com/mozilla/TTS.git
Remove DataParallel from the model state before saving
This commit is contained in:
Родитель
2c345b622d
Коммит
1bfd8f73e7
|
@ -213,7 +213,7 @@ class Decoder(nn.Module):
|
|||
r (int): number of outputs per time step.
|
||||
eps (float): threshold for detecting the end of a sentence.
|
||||
"""
|
||||
def __init__(self, in_features, memory_dim, r, eps=0.2):
|
||||
def __init__(self, in_features, memory_dim, r, eps=0.05):
|
||||
super(Decoder, self).__init__()
|
||||
self.max_decoder_steps = 200
|
||||
self.memory_dim = memory_dim
|
||||
|
|
Различия файлов скрыты, потому что одна или несколько строк слишком длинны
Двоичные данные
utils/.data.py.swp
Двоичные данные
utils/.data.py.swp
Двоичный файл не отображается.
|
@ -48,12 +48,26 @@ def copy_config_file(config_file, path):
|
|||
shutil.copyfile(config_file, out_path)
|
||||
|
||||
|
||||
def _trim_model_state_dict(state_dict):
|
||||
r"""Remove 'module.' prefix from state dictionary. It is necessary as it
|
||||
is loded for the next time by model.load_state(). Otherwise, it complains
|
||||
about the torch.DataParallel()"""
|
||||
|
||||
new_state_dict = OrderedDict()
|
||||
for k, v in state_dict.items():
|
||||
name = k[7:] # remove `module.`
|
||||
new_state_dict[name] = v
|
||||
return new_state_dict
|
||||
|
||||
|
||||
def save_checkpoint(model, optimizer, model_loss, best_loss, out_path,
|
||||
current_step, epoch):
|
||||
checkpoint_path = 'checkpoint_{}.pth.tar'.format(current_step)
|
||||
checkpoint_path = os.path.join(out_path, checkpoint_path)
|
||||
print("\n | > Checkpoint saving : {}".format(checkpoint_path))
|
||||
state = {'model': model.state_dict(),
|
||||
|
||||
new_state_dict = _trim_model_state_dict(model.state_dict())
|
||||
state = {'model': new_state_dict,
|
||||
'optimizer': optimizer.state_dict(),
|
||||
'step': current_step,
|
||||
'epoch': epoch,
|
||||
|
@ -65,7 +79,8 @@ def save_checkpoint(model, optimizer, model_loss, best_loss, out_path,
|
|||
def save_best_model(model, optimizer, model_loss, best_loss, out_path,
|
||||
current_step, epoch):
|
||||
if model_loss < best_loss:
|
||||
state = {'model': model.state_dict(),
|
||||
new_state_dict = _trim_model_state_dict(model.state_dict())
|
||||
state = {'model': new_state_dict,
|
||||
'optimizer': optimizer.state_dict(),
|
||||
'step': current_step,
|
||||
'epoch': epoch,
|
||||
|
|
Загрузка…
Ссылка в новой задаче