This commit is contained in:
Eren Golge 2018-04-25 08:00:30 -07:00
Родитель 52b4bc6bed
Коммит e257bd7278
1 изменённых файлов: 2 добавлений и 1 удалений

Просмотреть файл

@ -44,10 +44,11 @@ class L1LossMasked(nn.Module):
# target_flat: (batch * max_len, dim)
target_flat = target.view(-1, target.shape[-1])
# losses_flat: (batch * max_len, dim)
losses_flat = functional.l1_loss(input, target, size_average=False,
losses_flat = functional.l1_loss(input, target_flat, size_average=False,
reduce=False)
# losses: (batch, max_len, dim)
losses = losses_flat.view(*target.size())
# mask: (batch, max_len, 1)
mask = _sequence_mask(sequence_length=length,
max_len=target.size(1)).unsqueeze(2)