This commit is contained in:
Eren 2018-09-19 14:25:30 +02:00
Родитель 9b29b4e281
Коммит 16db5159f1
1 изменённых файлов: 8 добавлений и 2 удалений

Просмотреть файл

@ -89,6 +89,9 @@ def train(model, criterion, criterion_st, data_loader, optimizer, optimizer_st,
# backpass and check the grad norm for spec losses
loss.backward(retain_graph=True)
for group in optimizer.param_groups:
for param in group['params']:
param.data = param.data.add(-c.wd * group['lr'], param.data)
grad_norm, skip_flag = check_update(model, 1)
if skip_flag:
optimizer.zero_grad()
@ -98,6 +101,9 @@ def train(model, criterion, criterion_st, data_loader, optimizer, optimizer_st,
# backpass and check the grad norm for stop loss
stop_loss.backward()
for group in optimizer_st.param_groups:
for param in group['params']:
param.data = param.data.add(-c.wd * group['lr'], param.data)
grad_norm_st, skip_flag = check_update(model.decoder.stopnet, 0.5)
if skip_flag:
optimizer_st.zero_grad()
@ -390,9 +396,9 @@ def main(args):
model = Tacotron(c.embedding_size, ap.num_freq, c.num_mels, c.r)
print(" | > Num output units : {}".format(ap.num_freq), flush=True)
optimizer = optim.Adam(model.parameters(), lr=c.lr, weight_decay=c.wd)
optimizer = optim.Adam(model.parameters(), lr=c.lr, weight_decay=0)
optimizer_st = optim.Adam(
model.decoder.stopnet.parameters(), lr=c.lr, weight_decay=c.wd)
model.decoder.stopnet.parameters(), lr=c.lr, weight_decay=0)
criterion = L1LossMasked()
criterion_st = nn.BCELoss()