зеркало из https://github.com/mozilla/TTS.git
bug fix loss
This commit is contained in:
Родитель
52b4bc6bed
Коммит
e257bd7278
|
@ -44,10 +44,11 @@ class L1LossMasked(nn.Module):
|
|||
# target_flat: (batch * max_len, dim)
|
||||
target_flat = target.view(-1, target.shape[-1])
|
||||
# losses_flat: (batch * max_len, dim)
|
||||
losses_flat = functional.l1_loss(input, target, size_average=False,
|
||||
losses_flat = functional.l1_loss(input, target_flat, size_average=False,
|
||||
reduce=False)
|
||||
# losses: (batch, max_len, dim)
|
||||
losses = losses_flat.view(*target.size())
|
||||
|
||||
# mask: (batch, max_len, 1)
|
||||
mask = _sequence_mask(sequence_length=length,
|
||||
max_len=target.size(1)).unsqueeze(2)
|
||||
|
|
Загрузка…
Ссылка в новой задаче