refactor partial reinit script as a function. Allow user to select layers to reinit in finutunning

This commit is contained in:
Eren Golge 2019-03-23 17:19:40 +01:00
Родитель 06a7aeb26d
Коммит d8908692c5
4 изменённых файлов: 35 добавлений и 18 удалений

Просмотреть файл

@ -29,6 +29,8 @@
"url": "tcp:\/\/localhost:54321"
},
"reinit_layers": ["model.decoder.attention_layer"],
"model": "Tacotron2", // one of the model in models/
"grad_clip": 0.02, // upper limit for gradients for clipping.
"epochs": 1000, // total number of epochs to train.

Просмотреть файл

@ -29,6 +29,8 @@
"url": "tcp:\/\/localhost:54321"
},
"reinit_layers": ["model.decoder.attention_layer"],
"model": "Tacotron2", // one of the model in models/
"grad_clip": 1, // upper limit for gradients for clipping.
"epochs": 1000, // total number of epochs to train.

Просмотреть файл

@ -22,7 +22,8 @@ from utils.generic_utils import (NoamLR, check_update, count_parameters,
create_experiment_folder, get_commit_hash,
load_config, lr_decay,
remove_experiment_folder, save_best_model,
save_checkpoint, sequence_mask, weight_decay)
save_checkpoint, sequence_mask, weight_decay,
set_init_dict)
from utils.logger import Logger
from utils.synthesis import synthesis
from utils.text.symbols import phonemes, symbols
@ -396,24 +397,9 @@ def main(args):
print(" > Partial model initialization.")
partial_init_flag = True
model_dict = model.state_dict()
# Partial initialization: if there is a mismatch with new and old layer, it is skipped.
# 1. filter out unnecessary keys
pretrained_dict = {
k: v
for k, v in checkpoint['model'].items() if k in model_dict
}
# 2. filter out different size layers
pretrained_dict = {
k: v
for k, v in pretrained_dict.items()
if v.numel() == model_dict[k].numel()
}
# 3. overwrite entries in the existing state dict
model_dict.update(pretrained_dict)
# 4. load the new state dict
model_dict = set_init_dict(model_dict, checkpoint, c)
model.load_state_dict(model_dict)
print(" | > {} / {} layers are initialized".format(
len(pretrained_dict), len(model_dict)))
del model_dict
if use_cuda:
model = model.cuda()
criterion.cuda()

Просмотреть файл

@ -195,3 +195,30 @@ def sequence_mask(sequence_length, max_len=None):
.expand_as(seq_range_expand))
# B x T_max
return seq_range_expand < seq_length_expand
def set_init_dict(model_dict, checkpoint, c):
# Partial initialization: if there is a mismatch with new and old layer, it is skipped.
# 1. filter out unnecessary keys
pretrained_dict = {
k: v
for k, v in checkpoint['model'].items() if k in model_dict
}
# 2. filter out different size layers
pretrained_dict = {
k: v
for k, v in pretrained_dict.items()
if v.numel() == model_dict[k].numel()
}
# 3. skip reinit layers
if c.reinit_layers is not None:
for reinit_layer_name in c.reinit_layers:
pretrained_dict = {
k: v
for k, v in pretrained_dict.items()
if reinit_layer_name not in k
}
# 4. overwrite entries in the existing state dict
model_dict.update(pretrained_dict)
print(" | > {} / {} layers are initialized".format(len(pretrained_dict), len(model_dict)))
return model_dict