This commit is contained in:
Eren Golge 2019-10-24 14:11:07 +02:00
Родитель 77f5fd0584
Коммит ea32f2368d
2 изменённых файлов: 13 добавлений и 12 удалений

Просмотреть файл

@ -61,7 +61,7 @@ class AttentionEntropyLoss(nn.Module):
def forward(self, align):
"""
Forces attention to be more decisive by penalizing
soft attention weights
soft attention weights
TODO: arguments
TODO: unit_test

Просмотреть файл

@ -162,15 +162,16 @@ class Decoder(nn.Module):
B = inputs.size(0)
# T = inputs.size(1)
if not keep_states:
self.query = torch.zeros(1, device=inputs.device).repeat(B, self.query_dim)
self.attention_rnn_cell_state = torch.zeros(1, device=inputs.device).repeat(B,
self.query_dim)
self.decoder_hidden = torch.zeros(1, device=inputs.device).repeat(B,
self.decoder_rnn_dim)
self.decoder_cell = torch.zeros(1, device=inputs.device).repeat(B,
self.decoder_rnn_dim)
self.context = torch.zeros(1, device=inputs.device).repeat(B,
self.encoder_embedding_dim)
self.query = torch.zeros(1, device=inputs.device).repeat(
B, self.query_dim)
self.attention_rnn_cell_state = torch.zeros(
1, device=inputs.device).repeat(B, self.query_dim)
self.decoder_hidden = torch.zeros(1, device=inputs.device).repeat(
B, self.decoder_rnn_dim)
self.decoder_cell = torch.zeros(1, device=inputs.device).repeat(
B, self.decoder_rnn_dim)
self.context = torch.zeros(1, device=inputs.device).repeat(
B, self.encoder_embedding_dim)
self.inputs = inputs
self.processed_inputs = self.attention.inputs_layer(inputs)
self.mask = mask
@ -277,7 +278,7 @@ class Decoder(nn.Module):
stop_flags[2] = t > inputs.shape[1] * 2
if all(stop_flags):
break
elif len(outputs) == self.max_decoder_steps:
if len(outputs) == self.max_decoder_steps:
print(" | > Decoder stopped with 'max_decoder_steps")
break
@ -317,7 +318,7 @@ class Decoder(nn.Module):
stop_flags[2] = t > inputs.shape[1] * 2
if all(stop_flags):
break
elif len(outputs) == self.max_decoder_steps:
if len(outputs) == self.max_decoder_steps:
print(" | > Decoder stopped with 'max_decoder_steps")
break