зеркало из https://github.com/mozilla/TTS.git
bug fixes
This commit is contained in:
Родитель
23d9f8a8bc
Коммит
5629292bde
|
@ -364,13 +364,13 @@ class Decoder(nn.Module):
|
|||
processed_memory = self.prenet(self.memory_input)
|
||||
# Attention RNN
|
||||
self.attention_rnn_hidden = self.attention_rnn(
|
||||
torch.cat((processed_memory, self.current_context_vec), -1),
|
||||
torch.cat((processed_memory, self.context_vec), -1),
|
||||
self.attention_rnn_hidden)
|
||||
self.context_vec = self.attention_layer(
|
||||
self.context_vec = self.attention(
|
||||
self.attention_rnn_hidden, inputs, self.processed_inputs, mask)
|
||||
# Concat RNN output and attention context vector
|
||||
decoder_input = self.project_to_decoder_in(
|
||||
torch.cat((self.query, self.context_vec), -1))
|
||||
torch.cat((self.attention_rnn_hidden, self.context_vec), -1))
|
||||
|
||||
# Pass through the decoder RNNs
|
||||
for idx in range(len(self.decoder_rnns)):
|
||||
|
@ -390,7 +390,7 @@ class Decoder(nn.Module):
|
|||
else:
|
||||
stop_token = self.stopnet(stopnet_input)
|
||||
output = output[:, : self.r * self.memory_dim]
|
||||
return output, stop_token, self.attention_layer.attention_weights
|
||||
return output, stop_token, self.attention.attention_weights
|
||||
|
||||
def _update_memory_input(self, new_memory):
|
||||
if self.use_memory_queue:
|
||||
|
|
|
@ -1,7 +1,9 @@
|
|||
import unittest
|
||||
|
||||
from utils.text import phonemes
|
||||
from collections import Counter
|
||||
|
||||
class SymbolsTest(unittest.TestCase):
|
||||
def test_uniqueness(self):
|
||||
assert sorted(phonemes) == sorted(list(set(phonemes)))
|
||||
assert sorted(phonemes) == sorted(list(set(phonemes))), " {} vs {} ".format(len(phonemes), len(set(phonemes)))
|
||||
|
|
@ -1,6 +1,7 @@
|
|||
import os
|
||||
import unittest
|
||||
import shutil
|
||||
import torch
|
||||
|
||||
from torch.utils.data import DataLoader
|
||||
from utils.generic_utils import load_config
|
||||
|
@ -130,10 +131,11 @@ class TestTTSDataset(unittest.TestCase):
|
|||
# check mel_spec consistency
|
||||
wav = self.ap.load_wav(item_idx[0])
|
||||
mel = self.ap.melspectrogram(wav)
|
||||
mel_dl = mel_input[0].cpu().numpy()
|
||||
assert (abs(mel.T).astype("float32")
|
||||
mel = torch.FloatTensor(mel)
|
||||
mel_dl = mel_input[0]
|
||||
assert (abs(mel.T)
|
||||
- abs(mel_dl[:-1])
|
||||
).sum() == 0
|
||||
).sum() == 0, (abs(mel.T)- abs(mel_dl[:-1])).sum()
|
||||
|
||||
# check mel-spec correctness
|
||||
mel_spec = mel_input[0].cpu().numpy()
|
||||
|
|
|
@ -18,7 +18,7 @@ _vowels = 'iyɨʉɯuɪʏʊeøɘəɵɤoɛœɜɞʌɔæɐaɶɑɒᵻ'
|
|||
_non_pulmonic_consonants = 'ʘɓǀɗǃʄǂɠǁʛ'
|
||||
_pulmonic_consonants = 'pbtdʈɖcɟkɡqɢʔɴŋɲɳnɱmʙrʀⱱɾɽɸβfvθðszʃʒʂʐçʝxɣχʁħʕhɦɬɮʋɹɻjɰlɭʎʟ'
|
||||
_suprasegmentals = 'ˈˌːˑ'
|
||||
_other_symbols = 'ʍwɥʜʢʡɕʑɺɧ '
|
||||
_other_symbols = 'ʍwɥʜʢʡɕʑɺɧ'
|
||||
_diacrilics = 'ɚ˞ɫ'
|
||||
_phonemes = sorted(list(_vowels + _non_pulmonic_consonants + _pulmonic_consonants + _suprasegmentals + _other_symbols + _diacrilics))
|
||||
|
||||
|
|
Загрузка…
Ссылка в новой задаче