зеркало из https://github.com/mozilla/TTS.git
refactor partial reinit script as a function. Allow user to select layers to reinit in finutunning
This commit is contained in:
Родитель
06a7aeb26d
Коммит
d8908692c5
|
@ -29,6 +29,8 @@
|
|||
"url": "tcp:\/\/localhost:54321"
|
||||
},
|
||||
|
||||
"reinit_layers": ["model.decoder.attention_layer"],
|
||||
|
||||
"model": "Tacotron2", // one of the model in models/
|
||||
"grad_clip": 0.02, // upper limit for gradients for clipping.
|
||||
"epochs": 1000, // total number of epochs to train.
|
||||
|
|
|
@ -29,6 +29,8 @@
|
|||
"url": "tcp:\/\/localhost:54321"
|
||||
},
|
||||
|
||||
"reinit_layers": ["model.decoder.attention_layer"],
|
||||
|
||||
"model": "Tacotron2", // one of the model in models/
|
||||
"grad_clip": 1, // upper limit for gradients for clipping.
|
||||
"epochs": 1000, // total number of epochs to train.
|
||||
|
|
22
train.py
22
train.py
|
@ -22,7 +22,8 @@ from utils.generic_utils import (NoamLR, check_update, count_parameters,
|
|||
create_experiment_folder, get_commit_hash,
|
||||
load_config, lr_decay,
|
||||
remove_experiment_folder, save_best_model,
|
||||
save_checkpoint, sequence_mask, weight_decay)
|
||||
save_checkpoint, sequence_mask, weight_decay,
|
||||
set_init_dict)
|
||||
from utils.logger import Logger
|
||||
from utils.synthesis import synthesis
|
||||
from utils.text.symbols import phonemes, symbols
|
||||
|
@ -396,24 +397,9 @@ def main(args):
|
|||
print(" > Partial model initialization.")
|
||||
partial_init_flag = True
|
||||
model_dict = model.state_dict()
|
||||
# Partial initialization: if there is a mismatch with new and old layer, it is skipped.
|
||||
# 1. filter out unnecessary keys
|
||||
pretrained_dict = {
|
||||
k: v
|
||||
for k, v in checkpoint['model'].items() if k in model_dict
|
||||
}
|
||||
# 2. filter out different size layers
|
||||
pretrained_dict = {
|
||||
k: v
|
||||
for k, v in pretrained_dict.items()
|
||||
if v.numel() == model_dict[k].numel()
|
||||
}
|
||||
# 3. overwrite entries in the existing state dict
|
||||
model_dict.update(pretrained_dict)
|
||||
# 4. load the new state dict
|
||||
model_dict = set_init_dict(model_dict, checkpoint, c)
|
||||
model.load_state_dict(model_dict)
|
||||
print(" | > {} / {} layers are initialized".format(
|
||||
len(pretrained_dict), len(model_dict)))
|
||||
del model_dict
|
||||
if use_cuda:
|
||||
model = model.cuda()
|
||||
criterion.cuda()
|
||||
|
|
|
@ -195,3 +195,30 @@ def sequence_mask(sequence_length, max_len=None):
|
|||
.expand_as(seq_range_expand))
|
||||
# B x T_max
|
||||
return seq_range_expand < seq_length_expand
|
||||
|
||||
|
||||
def set_init_dict(model_dict, checkpoint, c):
|
||||
# Partial initialization: if there is a mismatch with new and old layer, it is skipped.
|
||||
# 1. filter out unnecessary keys
|
||||
pretrained_dict = {
|
||||
k: v
|
||||
for k, v in checkpoint['model'].items() if k in model_dict
|
||||
}
|
||||
# 2. filter out different size layers
|
||||
pretrained_dict = {
|
||||
k: v
|
||||
for k, v in pretrained_dict.items()
|
||||
if v.numel() == model_dict[k].numel()
|
||||
}
|
||||
# 3. skip reinit layers
|
||||
if c.reinit_layers is not None:
|
||||
for reinit_layer_name in c.reinit_layers:
|
||||
pretrained_dict = {
|
||||
k: v
|
||||
for k, v in pretrained_dict.items()
|
||||
if reinit_layer_name not in k
|
||||
}
|
||||
# 4. overwrite entries in the existing state dict
|
||||
model_dict.update(pretrained_dict)
|
||||
print(" | > {} / {} layers are initialized".format(len(pretrained_dict), len(model_dict)))
|
||||
return model_dict
|
Загрузка…
Ссылка в новой задаче