Update attention module Possible BUG FIX2

This commit is contained in:
Eren Golge 2018-02-05 08:22:30 -08:00
Родитель b6c5771a6f
Коммит 3cafc6568c
1 изменённых файлов: 4 добавлений и 1 удалений

Просмотреть файл

@ -63,6 +63,8 @@ class AttentionWrapper(nn.Module):
# Alignment # Alignment
# (batch, max_time) # (batch, max_time)
# e_{ij} = a(s_{i-1}, h_j) # e_{ij} = a(s_{i-1}, h_j)
# import ipdb
# ipdb.set_trace()
alignment = self.alignment_model(cell_state, processed_inputs) alignment = self.alignment_model(cell_state, processed_inputs)
if mask is not None: if mask is not None:
@ -80,12 +82,13 @@ class AttentionWrapper(nn.Module):
# Concat input query and previous context_vec context # Concat input query and previous context_vec context
cell_input = torch.cat((query, context_vec), -1) cell_input = torch.cat((query, context_vec), -1)
cell_input = cell_input.unsqueeze(1) #cell_input = cell_input.unsqueeze(1)
# Feed it to RNN # Feed it to RNN
# s_i = f(y_{i-1}, c_{i}, s_{i-1}) # s_i = f(y_{i-1}, c_{i}, s_{i-1})
cell_output = self.rnn_cell(cell_input, cell_state) cell_output = self.rnn_cell(cell_input, cell_state)
context_vec = context_vec.squeeze(1)
return cell_output, context_vec, alignment return cell_output, context_vec, alignment