зеркало из https://github.com/mozilla/TTS.git
linter fix
This commit is contained in:
Родитель
77f5fd0584
Коммит
ea32f2368d
|
@ -61,7 +61,7 @@ class AttentionEntropyLoss(nn.Module):
|
|||
def forward(self, align):
|
||||
"""
|
||||
Forces attention to be more decisive by penalizing
|
||||
soft attention weights
|
||||
soft attention weights
|
||||
|
||||
TODO: arguments
|
||||
TODO: unit_test
|
||||
|
|
|
@ -162,15 +162,16 @@ class Decoder(nn.Module):
|
|||
B = inputs.size(0)
|
||||
# T = inputs.size(1)
|
||||
if not keep_states:
|
||||
self.query = torch.zeros(1, device=inputs.device).repeat(B, self.query_dim)
|
||||
self.attention_rnn_cell_state = torch.zeros(1, device=inputs.device).repeat(B,
|
||||
self.query_dim)
|
||||
self.decoder_hidden = torch.zeros(1, device=inputs.device).repeat(B,
|
||||
self.decoder_rnn_dim)
|
||||
self.decoder_cell = torch.zeros(1, device=inputs.device).repeat(B,
|
||||
self.decoder_rnn_dim)
|
||||
self.context = torch.zeros(1, device=inputs.device).repeat(B,
|
||||
self.encoder_embedding_dim)
|
||||
self.query = torch.zeros(1, device=inputs.device).repeat(
|
||||
B, self.query_dim)
|
||||
self.attention_rnn_cell_state = torch.zeros(
|
||||
1, device=inputs.device).repeat(B, self.query_dim)
|
||||
self.decoder_hidden = torch.zeros(1, device=inputs.device).repeat(
|
||||
B, self.decoder_rnn_dim)
|
||||
self.decoder_cell = torch.zeros(1, device=inputs.device).repeat(
|
||||
B, self.decoder_rnn_dim)
|
||||
self.context = torch.zeros(1, device=inputs.device).repeat(
|
||||
B, self.encoder_embedding_dim)
|
||||
self.inputs = inputs
|
||||
self.processed_inputs = self.attention.inputs_layer(inputs)
|
||||
self.mask = mask
|
||||
|
@ -277,7 +278,7 @@ class Decoder(nn.Module):
|
|||
stop_flags[2] = t > inputs.shape[1] * 2
|
||||
if all(stop_flags):
|
||||
break
|
||||
elif len(outputs) == self.max_decoder_steps:
|
||||
if len(outputs) == self.max_decoder_steps:
|
||||
print(" | > Decoder stopped with 'max_decoder_steps")
|
||||
break
|
||||
|
||||
|
@ -317,7 +318,7 @@ class Decoder(nn.Module):
|
|||
stop_flags[2] = t > inputs.shape[1] * 2
|
||||
if all(stop_flags):
|
||||
break
|
||||
elif len(outputs) == self.max_decoder_steps:
|
||||
if len(outputs) == self.max_decoder_steps:
|
||||
print(" | > Decoder stopped with 'max_decoder_steps")
|
||||
break
|
||||
|
||||
|
|
Загрузка…
Ссылка в новой задаче