This commit is contained in:
Eren Golge 2019-03-12 00:20:15 +01:00
Родитель 5cbe0f83f6
Коммит 527567d7ce
1 изменённых файлов: 21 добавлений и 21 удалений

Просмотреть файл

@ -299,15 +299,15 @@ class Decoder(nn.Module):
memories = memories.transpose(0, 1)
return memories
def _parse_outputs(self, outputs, gate_outputs, alignments):
def _parse_outputs(self, outputs, stop_tokens, alignments):
alignments = torch.stack(alignments).transpose(0, 1)
gate_outputs = torch.stack(gate_outputs).transpose(0, 1)
gate_outputs = gate_outputs.contiguous()
stop_tokens = torch.stack(stop_tokens).transpose(0, 1)
stop_tokens = stop_tokens.contiguous()
outputs = torch.stack(outputs).transpose(0, 1).contiguous()
outputs = outputs.view(
outputs.size(0), -1, self.mel_channels)
outputs = outputs.transpose(1, 2)
return outputs, gate_outputs, alignments
return outputs, stop_tokens, alignments
def decode(self, memory):
cell_input = torch.cat((memory, self.context), -1)
@ -354,36 +354,36 @@ class Decoder(nn.Module):
self._init_states(inputs, mask=mask)
outputs, gate_outputs, alignments = [], [], []
outputs, stop_tokens, alignments = [], [], []
while len(outputs) < memories.size(0) - 1:
memory = memories[len(outputs)]
mel_output, gate_output, attention_weights = self.decode(
mel_output, stop_token, attention_weights = self.decode(
memory)
outputs += [mel_output.squeeze(1)]
gate_outputs += [gate_output.squeeze(1)]
stop_tokens += [stop_token.squeeze(1)]
alignments += [attention_weights]
outputs, gate_outputs, alignments = self._parse_outputs(
outputs, gate_outputs, alignments)
outputs, stop_tokens, alignments = self._parse_outputs(
outputs, stop_tokens, alignments)
return outputs, gate_outputs, alignments
return outputs, stop_tokens, alignments
def inference(self, inputs):
memory = self.get_go_frame(inputs)
self._init_states(inputs, mask=None)
self.attention_layer.init_win_idx()
outputs, gate_outputs, alignments, t = [], [], [], 0
outputs, stop_tokens, alignments, t = [], [], [], 0
stop_flags = [False, False]
while True:
memory = self.prenet(memory)
mel_output, gate_output, alignment = self.decode(memory)
gate_output = torch.sigmoid(gate_output.data)
mel_output, stop_token, alignment = self.decode(memory)
stop_token = torch.sigmoid(stop_token.data)
outputs += [mel_output.squeeze(1)]
gate_outputs += [gate_output]
stop_tokens += [stop_token]
alignments += [alignment]
stop_flags[0] = stop_flags[0] or gate_output > 0.5
stop_flags[0] = stop_flags[0] or stop_token > 0.5
stop_flags[1] = stop_flags[1] or alignment[0, -2:].sum() > 0.5
if all(stop_flags):
break
@ -394,10 +394,10 @@ class Decoder(nn.Module):
memory = mel_output
t += 1
outputs, gate_outputs, alignments = self._parse_outputs(
outputs, gate_outputs, alignments)
outputs, stop_tokens, alignments = self._parse_outputs(
outputs, stop_tokens, alignments)
return outputs, gate_outputs, alignments
return outputs, stop_tokens, alignments
def inference_step(self, inputs, t, memory=None):
"""
@ -408,7 +408,7 @@ class Decoder(nn.Module):
self._init_states(inputs, mask=None)
memory = self.prenet(memory)
mel_output, gate_output, alignment = self.decode(memory)
gate_output = torch.sigmoid(gate_output.data)
mel_output, stop_token, alignment = self.decode(memory)
stop_token = torch.sigmoid(stop_token.data)
memory = mel_output
return mel_output, gate_output, alignment
return mel_output, stop_token, alignment