This commit is contained in:
Eren Golge 2019-08-16 15:08:04 +02:00
Родитель 23d9f8a8bc
Коммит 5629292bde
4 изменённых файлов: 13 добавлений и 9 удалений

Просмотреть файл

@ -364,13 +364,13 @@ class Decoder(nn.Module):
processed_memory = self.prenet(self.memory_input)
# Attention RNN
self.attention_rnn_hidden = self.attention_rnn(
torch.cat((processed_memory, self.current_context_vec), -1),
torch.cat((processed_memory, self.context_vec), -1),
self.attention_rnn_hidden)
self.context_vec = self.attention_layer(
self.context_vec = self.attention(
self.attention_rnn_hidden, inputs, self.processed_inputs, mask)
# Concat RNN output and attention context vector
decoder_input = self.project_to_decoder_in(
torch.cat((self.query, self.context_vec), -1))
torch.cat((self.attention_rnn_hidden, self.context_vec), -1))
# Pass through the decoder RNNs
for idx in range(len(self.decoder_rnns)):
@ -390,7 +390,7 @@ class Decoder(nn.Module):
else:
stop_token = self.stopnet(stopnet_input)
output = output[:, : self.r * self.memory_dim]
return output, stop_token, self.attention_layer.attention_weights
return output, stop_token, self.attention.attention_weights
def _update_memory_input(self, new_memory):
if self.use_memory_queue:

Просмотреть файл

@ -1,7 +1,9 @@
import unittest
from utils.text import phonemes
from collections import Counter
class SymbolsTest(unittest.TestCase):
def test_uniqueness(self):
assert sorted(phonemes) == sorted(list(set(phonemes)))
assert sorted(phonemes) == sorted(list(set(phonemes))), " {} vs {} ".format(len(phonemes), len(set(phonemes)))

Просмотреть файл

@ -1,6 +1,7 @@
import os
import unittest
import shutil
import torch
from torch.utils.data import DataLoader
from utils.generic_utils import load_config
@ -130,10 +131,11 @@ class TestTTSDataset(unittest.TestCase):
# check mel_spec consistency
wav = self.ap.load_wav(item_idx[0])
mel = self.ap.melspectrogram(wav)
mel_dl = mel_input[0].cpu().numpy()
assert (abs(mel.T).astype("float32")
mel = torch.FloatTensor(mel)
mel_dl = mel_input[0]
assert (abs(mel.T)
- abs(mel_dl[:-1])
).sum() == 0
).sum() == 0, (abs(mel.T)- abs(mel_dl[:-1])).sum()
# check mel-spec correctness
mel_spec = mel_input[0].cpu().numpy()

Просмотреть файл

@ -18,7 +18,7 @@ _vowels = 'iyɨʉɯuɪʏʊeøɘəɵɤoɛœɜɞʌɔæɐaɶɑɒᵻ'
_non_pulmonic_consonants = 'ʘɓǀɗǃʄǂɠǁʛ'
_pulmonic_consonants = 'pbtdʈɖcɟkɡʔɴŋɲɳnɱmʙrʀⱱɾɽɸβfvθðszʃʒʂʐçʝxɣχʁħʕhɦɬɮʋɹɻjɰlɭʎʟ'
_suprasegmentals = 'ˈˌːˑ'
_other_symbols = 'ʍwɥʜʢʡɕʑɺɧ '
_other_symbols = 'ʍwɥʜʢʡɕʑɺɧ'
_diacrilics = 'ɚ˞ɫ'
_phonemes = sorted(list(_vowels + _non_pulmonic_consonants + _pulmonic_consonants + _suprasegmentals + _other_symbols + _diacrilics))