skip weight decay for BN and biases, some formatting

This commit is contained in:
Eren Golge 2019-09-28 01:09:28 +02:00
Родитель 53d658fb74
Коммит b76aaf8ad4
2 изменённых файлов: 72 добавлений и 45 удалений

Просмотреть файл

@ -20,7 +20,8 @@ from TTS.utils.generic_utils import (NoamLR, check_update, count_parameters,
load_config, remove_experiment_folder,
save_best_model, save_checkpoint, weight_decay,
set_init_dict, copy_config_file, setup_model,
split_dataset, gradual_training_scheduler, KeepAverage)
split_dataset, gradual_training_scheduler, KeepAverage,
set_weight_decay)
from TTS.utils.logger import Logger
from TTS.utils.speakers import load_speaker_mapping, save_speaker_mapping, \
get_speakers
@ -186,7 +187,7 @@ def train(model, criterion, criterion_st, optimizer, optimizer_st, scheduler,
loss += stop_loss
loss.backward()
optimizer, current_lr = weight_decay(optimizer, c.wd)
optimizer, current_lr = weight_decay(optimizer)
grad_norm, _ = check_update(model, c.grad_clip)
optimizer.step()
@ -197,7 +198,7 @@ def train(model, criterion, criterion_st, optimizer, optimizer_st, scheduler,
# backpass and check the grad norm for stop loss
if c.separate_stopnet:
stop_loss.backward()
optimizer_st, _ = weight_decay(optimizer_st, c.wd)
optimizer_st, _ = weight_decay(optimizer_st)
grad_norm_st, _ = check_update(model.decoder.stopnet, 1.0)
optimizer_st.step()
else:
@ -511,7 +512,8 @@ def main(args): # pylint: disable=redefined-outer-name
print(" | > Num output units : {}".format(ap.num_freq), flush=True)
optimizer = RAdam(model.parameters(), lr=c.lr, weight_decay=0)
params = set_weight_decay(model, c.wd)
optimizer = RAdam(params, lr=c.lr, weight_decay=0)
if c.stopnet and c.separate_stopnet:
optimizer_st = RAdam(
model.decoder.stopnet.parameters(), lr=c.lr, weight_decay=0)

Просмотреть файл

@ -31,8 +31,8 @@ def load_config(config_path):
def get_git_branch():
try:
out = subprocess.check_output(["git", "branch"]).decode("utf8")
current = next(line for line in out.split(
"\n") if line.startswith("*"))
current = next(line for line in out.split("\n")
if line.startswith("*"))
current.replace("* ", "")
except subprocess.CalledProcessError:
current = "inside_docker"
@ -48,8 +48,8 @@ def get_commit_hash():
# raise RuntimeError(
# " !! Commit before training to get the commit hash.")
try:
commit = subprocess.check_output(['git', 'rev-parse', '--short',
'HEAD']).decode().strip()
commit = subprocess.check_output(
['git', 'rev-parse', '--short', 'HEAD']).decode().strip()
# Not copying .git folder into docker container
except subprocess.CalledProcessError:
commit = "0000000"
@ -169,17 +169,43 @@ def lr_decay(init_lr, global_step, warmup_steps):
return lr
def weight_decay(optimizer, wd):
def weight_decay(optimizer):
"""
Custom weight decay operation, not effecting grad values.
"""
for group in optimizer.param_groups:
for param in group['params']:
current_lr = group['lr']
param.data = param.data.add(-wd * group['lr'], param.data)
weight_decay = group['weight_decay']
param.data = param.data.add(-weight_decay * group['lr'],
param.data)
return optimizer, current_lr
def set_weight_decay(model, weight_decay, skip_list={"decoder.attention.v"}):
"""
Skip biases, BatchNorm parameters for weight decay
and attention projection layer v
"""
decay = []
no_decay = []
for name, param in model.named_parameters():
if not param.requires_grad:
continue
if len(param.shape) == 1 or name in skip_list:
print(name)
no_decay.append(param)
else:
decay.append(param)
return [{
'params': no_decay,
'weight_decay': 0.
}, {
'params': decay,
'weight_decay': weight_decay
}]
class NoamLR(torch.optim.lr_scheduler._LRScheduler):
def __init__(self, optimizer, warmup_steps=0.1, last_epoch=-1):
self.warmup_steps = float(warmup_steps)
@ -188,8 +214,8 @@ class NoamLR(torch.optim.lr_scheduler._LRScheduler):
def get_lr(self):
step = max(self.last_epoch, 1)
return [
base_lr * self.warmup_steps**0.5 * min(
step * self.warmup_steps**-1.5, step**-0.5)
base_lr * self.warmup_steps**0.5 *
min(step * self.warmup_steps**-1.5, step**-0.5)
for base_lr in self.base_lrs
]
@ -244,8 +270,8 @@ def set_init_dict(model_dict, checkpoint, c):
}
# 4. overwrite entries in the existing state dict
model_dict.update(pretrained_dict)
print(" | > {} / {} layers are restored.".format(
len(pretrained_dict), len(model_dict)))
print(" | > {} / {} layers are restored.".format(len(pretrained_dict),
len(model_dict)))
return model_dict
@ -254,8 +280,7 @@ def setup_model(num_chars, num_speakers, c):
MyModel = importlib.import_module('TTS.models.' + c.model.lower())
MyModel = getattr(MyModel, c.model)
if c.model.lower() in "tacotron":
model = MyModel(
num_chars=num_chars,
model = MyModel(num_chars=num_chars,
num_speakers=num_speakers,
r=c.r,
linear_dim=1025,
@ -272,8 +297,7 @@ def setup_model(num_chars, num_speakers, c):
location_attn=c.location_attn,
separate_stopnet=c.separate_stopnet)
elif c.model.lower() == "tacotron2":
model = MyModel(
num_chars=num_chars,
model = MyModel(num_chars=num_chars,
num_speakers=num_speakers,
r=c.r,
attn_win=c.windowing,
@ -292,7 +316,8 @@ def split_dataset(items):
is_multi_speaker = False
speakers = [item[-1] for item in items]
is_multi_speaker = len(set(speakers)) > 1
eval_split_size = 500 if 500 < len(items) * 0.01 else int(len(items) * 0.01)
eval_split_size = 500 if 500 < len(items) * 0.01 else int(
len(items) * 0.01)
np.random.seed(0)
np.random.shuffle(items)
if is_multi_speaker: