зеркало из https://github.com/mozilla/TTS.git
skip weight decay for BN and biases, some formatting
This commit is contained in:
Родитель
53d658fb74
Коммит
b76aaf8ad4
10
train.py
10
train.py
|
@ -20,7 +20,8 @@ from TTS.utils.generic_utils import (NoamLR, check_update, count_parameters,
|
|||
load_config, remove_experiment_folder,
|
||||
save_best_model, save_checkpoint, weight_decay,
|
||||
set_init_dict, copy_config_file, setup_model,
|
||||
split_dataset, gradual_training_scheduler, KeepAverage)
|
||||
split_dataset, gradual_training_scheduler, KeepAverage,
|
||||
set_weight_decay)
|
||||
from TTS.utils.logger import Logger
|
||||
from TTS.utils.speakers import load_speaker_mapping, save_speaker_mapping, \
|
||||
get_speakers
|
||||
|
@ -186,7 +187,7 @@ def train(model, criterion, criterion_st, optimizer, optimizer_st, scheduler,
|
|||
loss += stop_loss
|
||||
|
||||
loss.backward()
|
||||
optimizer, current_lr = weight_decay(optimizer, c.wd)
|
||||
optimizer, current_lr = weight_decay(optimizer)
|
||||
grad_norm, _ = check_update(model, c.grad_clip)
|
||||
optimizer.step()
|
||||
|
||||
|
@ -197,7 +198,7 @@ def train(model, criterion, criterion_st, optimizer, optimizer_st, scheduler,
|
|||
# backpass and check the grad norm for stop loss
|
||||
if c.separate_stopnet:
|
||||
stop_loss.backward()
|
||||
optimizer_st, _ = weight_decay(optimizer_st, c.wd)
|
||||
optimizer_st, _ = weight_decay(optimizer_st)
|
||||
grad_norm_st, _ = check_update(model.decoder.stopnet, 1.0)
|
||||
optimizer_st.step()
|
||||
else:
|
||||
|
@ -511,7 +512,8 @@ def main(args): # pylint: disable=redefined-outer-name
|
|||
|
||||
print(" | > Num output units : {}".format(ap.num_freq), flush=True)
|
||||
|
||||
optimizer = RAdam(model.parameters(), lr=c.lr, weight_decay=0)
|
||||
params = set_weight_decay(model, c.wd)
|
||||
optimizer = RAdam(params, lr=c.lr, weight_decay=0)
|
||||
if c.stopnet and c.separate_stopnet:
|
||||
optimizer_st = RAdam(
|
||||
model.decoder.stopnet.parameters(), lr=c.lr, weight_decay=0)
|
||||
|
|
|
@ -31,8 +31,8 @@ def load_config(config_path):
|
|||
def get_git_branch():
|
||||
try:
|
||||
out = subprocess.check_output(["git", "branch"]).decode("utf8")
|
||||
current = next(line for line in out.split(
|
||||
"\n") if line.startswith("*"))
|
||||
current = next(line for line in out.split("\n")
|
||||
if line.startswith("*"))
|
||||
current.replace("* ", "")
|
||||
except subprocess.CalledProcessError:
|
||||
current = "inside_docker"
|
||||
|
@ -48,8 +48,8 @@ def get_commit_hash():
|
|||
# raise RuntimeError(
|
||||
# " !! Commit before training to get the commit hash.")
|
||||
try:
|
||||
commit = subprocess.check_output(['git', 'rev-parse', '--short',
|
||||
'HEAD']).decode().strip()
|
||||
commit = subprocess.check_output(
|
||||
['git', 'rev-parse', '--short', 'HEAD']).decode().strip()
|
||||
# Not copying .git folder into docker container
|
||||
except subprocess.CalledProcessError:
|
||||
commit = "0000000"
|
||||
|
@ -169,17 +169,43 @@ def lr_decay(init_lr, global_step, warmup_steps):
|
|||
return lr
|
||||
|
||||
|
||||
def weight_decay(optimizer, wd):
|
||||
def weight_decay(optimizer):
|
||||
"""
|
||||
Custom weight decay operation, not effecting grad values.
|
||||
"""
|
||||
for group in optimizer.param_groups:
|
||||
for param in group['params']:
|
||||
current_lr = group['lr']
|
||||
param.data = param.data.add(-wd * group['lr'], param.data)
|
||||
weight_decay = group['weight_decay']
|
||||
param.data = param.data.add(-weight_decay * group['lr'],
|
||||
param.data)
|
||||
return optimizer, current_lr
|
||||
|
||||
|
||||
def set_weight_decay(model, weight_decay, skip_list={"decoder.attention.v"}):
|
||||
"""
|
||||
Skip biases, BatchNorm parameters for weight decay
|
||||
and attention projection layer v
|
||||
"""
|
||||
decay = []
|
||||
no_decay = []
|
||||
for name, param in model.named_parameters():
|
||||
if not param.requires_grad:
|
||||
continue
|
||||
if len(param.shape) == 1 or name in skip_list:
|
||||
print(name)
|
||||
no_decay.append(param)
|
||||
else:
|
||||
decay.append(param)
|
||||
return [{
|
||||
'params': no_decay,
|
||||
'weight_decay': 0.
|
||||
}, {
|
||||
'params': decay,
|
||||
'weight_decay': weight_decay
|
||||
}]
|
||||
|
||||
|
||||
class NoamLR(torch.optim.lr_scheduler._LRScheduler):
|
||||
def __init__(self, optimizer, warmup_steps=0.1, last_epoch=-1):
|
||||
self.warmup_steps = float(warmup_steps)
|
||||
|
@ -188,8 +214,8 @@ class NoamLR(torch.optim.lr_scheduler._LRScheduler):
|
|||
def get_lr(self):
|
||||
step = max(self.last_epoch, 1)
|
||||
return [
|
||||
base_lr * self.warmup_steps**0.5 * min(
|
||||
step * self.warmup_steps**-1.5, step**-0.5)
|
||||
base_lr * self.warmup_steps**0.5 *
|
||||
min(step * self.warmup_steps**-1.5, step**-0.5)
|
||||
for base_lr in self.base_lrs
|
||||
]
|
||||
|
||||
|
@ -244,8 +270,8 @@ def set_init_dict(model_dict, checkpoint, c):
|
|||
}
|
||||
# 4. overwrite entries in the existing state dict
|
||||
model_dict.update(pretrained_dict)
|
||||
print(" | > {} / {} layers are restored.".format(
|
||||
len(pretrained_dict), len(model_dict)))
|
||||
print(" | > {} / {} layers are restored.".format(len(pretrained_dict),
|
||||
len(model_dict)))
|
||||
return model_dict
|
||||
|
||||
|
||||
|
@ -254,8 +280,7 @@ def setup_model(num_chars, num_speakers, c):
|
|||
MyModel = importlib.import_module('TTS.models.' + c.model.lower())
|
||||
MyModel = getattr(MyModel, c.model)
|
||||
if c.model.lower() in "tacotron":
|
||||
model = MyModel(
|
||||
num_chars=num_chars,
|
||||
model = MyModel(num_chars=num_chars,
|
||||
num_speakers=num_speakers,
|
||||
r=c.r,
|
||||
linear_dim=1025,
|
||||
|
@ -272,8 +297,7 @@ def setup_model(num_chars, num_speakers, c):
|
|||
location_attn=c.location_attn,
|
||||
separate_stopnet=c.separate_stopnet)
|
||||
elif c.model.lower() == "tacotron2":
|
||||
model = MyModel(
|
||||
num_chars=num_chars,
|
||||
model = MyModel(num_chars=num_chars,
|
||||
num_speakers=num_speakers,
|
||||
r=c.r,
|
||||
attn_win=c.windowing,
|
||||
|
@ -292,7 +316,8 @@ def split_dataset(items):
|
|||
is_multi_speaker = False
|
||||
speakers = [item[-1] for item in items]
|
||||
is_multi_speaker = len(set(speakers)) > 1
|
||||
eval_split_size = 500 if 500 < len(items) * 0.01 else int(len(items) * 0.01)
|
||||
eval_split_size = 500 if 500 < len(items) * 0.01 else int(
|
||||
len(items) * 0.01)
|
||||
np.random.seed(0)
|
||||
np.random.shuffle(items)
|
||||
if is_multi_speaker:
|
||||
|
|
Загрузка…
Ссылка в новой задаче