зеркало из https://github.com/mozilla/TTS.git
Use MSE loss instead of L1 Loss
This commit is contained in:
Родитель
0afb14ed5e
Коммит
f791f4e5e7
|
@ -17,10 +17,10 @@ def _sequence_mask(sequence_length, max_len=None):
|
||||||
return seq_range_expand < seq_length_expand
|
return seq_range_expand < seq_length_expand
|
||||||
|
|
||||||
|
|
||||||
class L1LossMasked(nn.Module):
|
class L2LossMasked(nn.Module):
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super(L1LossMasked, self).__init__()
|
super(L2LossMasked, self).__init__()
|
||||||
|
|
||||||
def forward(self, input, target, length):
|
def forward(self, input, target, length):
|
||||||
"""
|
"""
|
||||||
|
@ -44,7 +44,7 @@ class L1LossMasked(nn.Module):
|
||||||
# target_flat: (batch * max_len, dim)
|
# target_flat: (batch * max_len, dim)
|
||||||
target_flat = target.view(-1, target.shape[-1])
|
target_flat = target.view(-1, target.shape[-1])
|
||||||
# losses_flat: (batch * max_len, dim)
|
# losses_flat: (batch * max_len, dim)
|
||||||
losses_flat = functional.l1_loss(input, target_flat, size_average=False,
|
losses_flat = functional.mse_loss(input, target_flat, size_average=False,
|
||||||
reduce=False)
|
reduce=False)
|
||||||
# losses: (batch, max_len, dim)
|
# losses: (batch, max_len, dim)
|
||||||
losses = losses_flat.view(*target.size())
|
losses = losses_flat.view(*target.size())
|
||||||
|
|
4
train.py
4
train.py
|
@ -26,7 +26,7 @@ from utils.model import get_param_size
|
||||||
from utils.visual import plot_alignment, plot_spectrogram
|
from utils.visual import plot_alignment, plot_spectrogram
|
||||||
from datasets.LJSpeech import LJSpeechDataset
|
from datasets.LJSpeech import LJSpeechDataset
|
||||||
from models.tacotron import Tacotron
|
from models.tacotron import Tacotron
|
||||||
from layers.losses import L1LossMasked
|
from layers.losses import L2LossMasked
|
||||||
|
|
||||||
torch.manual_seed(1)
|
torch.manual_seed(1)
|
||||||
use_cuda = torch.cuda.is_available()
|
use_cuda = torch.cuda.is_available()
|
||||||
|
@ -365,7 +365,7 @@ def main(args):
|
||||||
optimizer = optim.Adam(model.parameters(), lr=c.lr)
|
optimizer = optim.Adam(model.parameters(), lr=c.lr)
|
||||||
optimizer_st = optim.Adam(model.decoder.stopnet.parameters(), lr=c.lr)
|
optimizer_st = optim.Adam(model.decoder.stopnet.parameters(), lr=c.lr)
|
||||||
|
|
||||||
criterion = L1LossMasked()
|
criterion = L2LossMasked()
|
||||||
criterion_st = nn.BCELoss()
|
criterion_st = nn.BCELoss()
|
||||||
|
|
||||||
if args.restore_path:
|
if args.restore_path:
|
||||||
|
|
Загрузка…
Ссылка в новой задаче