This commit is contained in:
Eren Golge 2019-11-15 14:30:28 +01:00
Родитель 8af75cad46
Коммит 79cca4ac80
2 изменённых файлов: 14 добавлений и 2 удалений

Просмотреть файл

@ -18,7 +18,7 @@ class L1LossMasked(nn.Module):
length: A Variable containing a LongTensor of size (batch,)
which contains the length of each data in a batch.
Returns:
loss: An average loss value masked by the length.
loss: An average loss value in range [0, 1] masked by the length.
"""
# mask: (batch, max_len, 1)
target.requires_grad = False
@ -44,7 +44,7 @@ class MSELossMasked(nn.Module):
length: A Variable containing a LongTensor of size (batch,)
which contains the length of each data in a batch.
Returns:
loss: An average loss value masked by the length.
loss: An average loss value in range [0, 1] masked by the length.
"""
# mask: (batch, max_len, 1)
target.requires_grad = False

Просмотреть файл

@ -118,6 +118,7 @@ class EncoderTests(unittest.TestCase):
class L1LossMaskedTests(unittest.TestCase):
def test_in_out(self):
# test input == target
layer = L1LossMasked()
dummy_input = T.ones(4, 8, 128).float()
dummy_target = T.ones(4, 8, 128).float()
@ -125,11 +126,14 @@ class L1LossMaskedTests(unittest.TestCase):
output = layer(dummy_input, dummy_target, dummy_length)
assert output.item() == 0.0
# test input != target
dummy_input = T.ones(4, 8, 128).float()
dummy_target = T.zeros(4, 8, 128).float()
dummy_length = (T.ones(4) * 8).long()
output = layer(dummy_input, dummy_target, dummy_length)
assert output.item() == 1.0, "1.0 vs {}".format(output.data[0])
# test if padded values of input makes any difference
dummy_input = T.ones(4, 8, 128).float()
dummy_target = T.zeros(4, 8, 128).float()
dummy_length = (T.arange(5, 9)).long()
@ -137,3 +141,11 @@ class L1LossMaskedTests(unittest.TestCase):
(sequence_mask(dummy_length).float() - 1.0) * 100.0).unsqueeze(2)
output = layer(dummy_input + mask, dummy_target, dummy_length)
assert output.item() == 1.0, "1.0 vs {}".format(output.data[0])
dummy_input = T.rand(4, 8, 128).float()
dummy_target = dummy_input.detach()
dummy_length = (T.arange(5, 9)).long()
mask = (
(sequence_mask(dummy_length).float() - 1.0) * 100.0).unsqueeze(2)
output = layer(dummy_input + mask, dummy_target, dummy_length)
assert output.item() == 0, "0 vs {}".format(output.data[0])