зеркало из https://github.com/mozilla/TTS.git
bug fixes
This commit is contained in:
Родитель
23d9f8a8bc
Коммит
5629292bde
|
@ -364,13 +364,13 @@ class Decoder(nn.Module):
|
||||||
processed_memory = self.prenet(self.memory_input)
|
processed_memory = self.prenet(self.memory_input)
|
||||||
# Attention RNN
|
# Attention RNN
|
||||||
self.attention_rnn_hidden = self.attention_rnn(
|
self.attention_rnn_hidden = self.attention_rnn(
|
||||||
torch.cat((processed_memory, self.current_context_vec), -1),
|
torch.cat((processed_memory, self.context_vec), -1),
|
||||||
self.attention_rnn_hidden)
|
self.attention_rnn_hidden)
|
||||||
self.context_vec = self.attention_layer(
|
self.context_vec = self.attention(
|
||||||
self.attention_rnn_hidden, inputs, self.processed_inputs, mask)
|
self.attention_rnn_hidden, inputs, self.processed_inputs, mask)
|
||||||
# Concat RNN output and attention context vector
|
# Concat RNN output and attention context vector
|
||||||
decoder_input = self.project_to_decoder_in(
|
decoder_input = self.project_to_decoder_in(
|
||||||
torch.cat((self.query, self.context_vec), -1))
|
torch.cat((self.attention_rnn_hidden, self.context_vec), -1))
|
||||||
|
|
||||||
# Pass through the decoder RNNs
|
# Pass through the decoder RNNs
|
||||||
for idx in range(len(self.decoder_rnns)):
|
for idx in range(len(self.decoder_rnns)):
|
||||||
|
@ -390,7 +390,7 @@ class Decoder(nn.Module):
|
||||||
else:
|
else:
|
||||||
stop_token = self.stopnet(stopnet_input)
|
stop_token = self.stopnet(stopnet_input)
|
||||||
output = output[:, : self.r * self.memory_dim]
|
output = output[:, : self.r * self.memory_dim]
|
||||||
return output, stop_token, self.attention_layer.attention_weights
|
return output, stop_token, self.attention.attention_weights
|
||||||
|
|
||||||
def _update_memory_input(self, new_memory):
|
def _update_memory_input(self, new_memory):
|
||||||
if self.use_memory_queue:
|
if self.use_memory_queue:
|
||||||
|
|
|
@ -1,7 +1,9 @@
|
||||||
import unittest
|
import unittest
|
||||||
|
|
||||||
from utils.text import phonemes
|
from utils.text import phonemes
|
||||||
|
from collections import Counter
|
||||||
|
|
||||||
class SymbolsTest(unittest.TestCase):
|
class SymbolsTest(unittest.TestCase):
|
||||||
def test_uniqueness(self):
|
def test_uniqueness(self):
|
||||||
assert sorted(phonemes) == sorted(list(set(phonemes)))
|
assert sorted(phonemes) == sorted(list(set(phonemes))), " {} vs {} ".format(len(phonemes), len(set(phonemes)))
|
||||||
|
|
|
@ -1,6 +1,7 @@
|
||||||
import os
|
import os
|
||||||
import unittest
|
import unittest
|
||||||
import shutil
|
import shutil
|
||||||
|
import torch
|
||||||
|
|
||||||
from torch.utils.data import DataLoader
|
from torch.utils.data import DataLoader
|
||||||
from utils.generic_utils import load_config
|
from utils.generic_utils import load_config
|
||||||
|
@ -130,10 +131,11 @@ class TestTTSDataset(unittest.TestCase):
|
||||||
# check mel_spec consistency
|
# check mel_spec consistency
|
||||||
wav = self.ap.load_wav(item_idx[0])
|
wav = self.ap.load_wav(item_idx[0])
|
||||||
mel = self.ap.melspectrogram(wav)
|
mel = self.ap.melspectrogram(wav)
|
||||||
mel_dl = mel_input[0].cpu().numpy()
|
mel = torch.FloatTensor(mel)
|
||||||
assert (abs(mel.T).astype("float32")
|
mel_dl = mel_input[0]
|
||||||
|
assert (abs(mel.T)
|
||||||
- abs(mel_dl[:-1])
|
- abs(mel_dl[:-1])
|
||||||
).sum() == 0
|
).sum() == 0, (abs(mel.T)- abs(mel_dl[:-1])).sum()
|
||||||
|
|
||||||
# check mel-spec correctness
|
# check mel-spec correctness
|
||||||
mel_spec = mel_input[0].cpu().numpy()
|
mel_spec = mel_input[0].cpu().numpy()
|
||||||
|
|
|
@ -18,7 +18,7 @@ _vowels = 'iyɨʉɯuɪʏʊeøɘəɵɤoɛœɜɞʌɔæɐaɶɑɒᵻ'
|
||||||
_non_pulmonic_consonants = 'ʘɓǀɗǃʄǂɠǁʛ'
|
_non_pulmonic_consonants = 'ʘɓǀɗǃʄǂɠǁʛ'
|
||||||
_pulmonic_consonants = 'pbtdʈɖcɟkɡqɢʔɴŋɲɳnɱmʙrʀⱱɾɽɸβfvθðszʃʒʂʐçʝxɣχʁħʕhɦɬɮʋɹɻjɰlɭʎʟ'
|
_pulmonic_consonants = 'pbtdʈɖcɟkɡqɢʔɴŋɲɳnɱmʙrʀⱱɾɽɸβfvθðszʃʒʂʐçʝxɣχʁħʕhɦɬɮʋɹɻjɰlɭʎʟ'
|
||||||
_suprasegmentals = 'ˈˌːˑ'
|
_suprasegmentals = 'ˈˌːˑ'
|
||||||
_other_symbols = 'ʍwɥʜʢʡɕʑɺɧ '
|
_other_symbols = 'ʍwɥʜʢʡɕʑɺɧ'
|
||||||
_diacrilics = 'ɚ˞ɫ'
|
_diacrilics = 'ɚ˞ɫ'
|
||||||
_phonemes = sorted(list(_vowels + _non_pulmonic_consonants + _pulmonic_consonants + _suprasegmentals + _other_symbols + _diacrilics))
|
_phonemes = sorted(list(_vowels + _non_pulmonic_consonants + _pulmonic_consonants + _suprasegmentals + _other_symbols + _diacrilics))
|
||||||
|
|
||||||
|
|
Загрузка…
Ссылка в новой задаче