зеркало из https://github.com/mozilla/TTS.git
Weight decay described here: http://www.fast.ai/2018/07/02/adam-weight-decay/
This commit is contained in:
Родитель
9b29b4e281
Коммит
16db5159f1
10
train.py
10
train.py
|
@ -89,6 +89,9 @@ def train(model, criterion, criterion_st, data_loader, optimizer, optimizer_st,
|
|||
|
||||
# backpass and check the grad norm for spec losses
|
||||
loss.backward(retain_graph=True)
|
||||
for group in optimizer.param_groups:
|
||||
for param in group['params']:
|
||||
param.data = param.data.add(-c.wd * group['lr'], param.data)
|
||||
grad_norm, skip_flag = check_update(model, 1)
|
||||
if skip_flag:
|
||||
optimizer.zero_grad()
|
||||
|
@ -98,6 +101,9 @@ def train(model, criterion, criterion_st, data_loader, optimizer, optimizer_st,
|
|||
|
||||
# backpass and check the grad norm for stop loss
|
||||
stop_loss.backward()
|
||||
for group in optimizer_st.param_groups:
|
||||
for param in group['params']:
|
||||
param.data = param.data.add(-c.wd * group['lr'], param.data)
|
||||
grad_norm_st, skip_flag = check_update(model.decoder.stopnet, 0.5)
|
||||
if skip_flag:
|
||||
optimizer_st.zero_grad()
|
||||
|
@ -390,9 +396,9 @@ def main(args):
|
|||
model = Tacotron(c.embedding_size, ap.num_freq, c.num_mels, c.r)
|
||||
print(" | > Num output units : {}".format(ap.num_freq), flush=True)
|
||||
|
||||
optimizer = optim.Adam(model.parameters(), lr=c.lr, weight_decay=c.wd)
|
||||
optimizer = optim.Adam(model.parameters(), lr=c.lr, weight_decay=0)
|
||||
optimizer_st = optim.Adam(
|
||||
model.decoder.stopnet.parameters(), lr=c.lr, weight_decay=c.wd)
|
||||
model.decoder.stopnet.parameters(), lr=c.lr, weight_decay=0)
|
||||
|
||||
criterion = L1LossMasked()
|
||||
criterion_st = nn.BCELoss()
|
||||
|
|
Загрузка…
Ссылка в новой задаче