From 8feb326a60fce455ba439d4e1fb7bf0e66642bd4 Mon Sep 17 00:00:00 2001 From: Edresson Casanova Date: Sun, 1 Mar 2020 15:47:08 -0300 Subject: [PATCH] add text parameters in config.json --- config.json | 10 +++++++ datasets/TTSDataset.py | 9 ++++-- notebooks/Benchmark-PWGAN.ipynb | 6 +++- notebooks/Benchmark.ipynb | 6 +++- notebooks/ExtractTTSpectrogram.ipynb | 8 ++++-- notebooks/TestAttention.ipynb | 6 +++- server/synthesizer.py | 9 +++++- synthesize.py | 8 +++++- tests/test_demo_server.py | 5 +++- tests/test_loader.py | 1 + train.py | 8 ++++-- utils/generic_utils.py | 9 ++++++ utils/synthesis.py | 5 ++-- utils/text/__init__.py | 41 +++++++++++++++++++++++----- utils/text/symbols.py | 21 +++++++++----- utils/visual.py | 5 ++-- 16 files changed, 126 insertions(+), 31 deletions(-) diff --git a/config.json b/config.json index c1a8158..2a7c455 100644 --- a/config.json +++ b/config.json @@ -27,6 +27,16 @@ "trim_db": 60 // threshold for timming silence. Set this according to your dataset. }, + // VOCABULARY PARAMETERS + "text":{ + "pad": "_", + "eos": "~", + "bos": "^", + "characters": "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz!'(),-.:;? ", + "punctuations":"!'(),-.:;? ", + "phonemes":"iyɨʉɯuɪʏʊeøɘəɵɤoɛœɜɞʌɔæɐaɶɑɒᵻʘɓǀɗǃʄǂɠǁʛpbtdʈɖcɟkɡqɢʔɴŋɲɳnɱmʙrʀⱱɾɽɸβfvθðszʃʒʂʐçʝxɣχʁħʕhɦɬɮʋɹɻjɰlɭʎʟˈˌːˑʍwɥʜʢʡɕʑɺɧɚ˞ɫ" + }, + // DISTRIBUTED TRAINING "distributed":{ "backend": "nccl", diff --git a/datasets/TTSDataset.py b/datasets/TTSDataset.py index a45d77f..cccd65a 100644 --- a/datasets/TTSDataset.py +++ b/datasets/TTSDataset.py @@ -15,6 +15,7 @@ class MyDataset(Dataset): text_cleaner, ap, meta_data, + tp=None, batch_group_size=0, min_seq_len=0, max_seq_len=float("inf"), @@ -49,6 +50,7 @@ class MyDataset(Dataset): self.min_seq_len = min_seq_len self.max_seq_len = max_seq_len self.ap = ap + self.tp = tp self.use_phonemes = use_phonemes self.phoneme_cache_path = phoneme_cache_path self.phoneme_language = phoneme_language @@ -81,7 +83,8 @@ class MyDataset(Dataset): config option.""" phonemes = phoneme_to_sequence(text, [self.cleaners], language=self.phoneme_language, - enable_eos_bos=False) + enable_eos_bos=False, + tp=self.tp) phonemes = np.asarray(phonemes, dtype=np.int32) np.save(cache_path, phonemes) return phonemes @@ -101,7 +104,7 @@ class MyDataset(Dataset): phonemes = self._generate_and_cache_phoneme_sequence(text, cache_path) if self.enable_eos_bos: - phonemes = pad_with_eos_bos(phonemes) + phonemes = pad_with_eos_bos(phonemes, tp=self.tp) phonemes = np.asarray(phonemes, dtype=np.int32) return phonemes @@ -113,7 +116,7 @@ class MyDataset(Dataset): text = self._load_or_generate_phoneme_sequence(wav_file, text) else: text = np.asarray( - text_to_sequence(text, [self.cleaners]), dtype=np.int32) + text_to_sequence(text, [self.cleaners], tp=self.tp), dtype=np.int32) assert text.size > 0, self.items[idx][1] assert wav.size > 0, self.items[idx][1] diff --git a/notebooks/Benchmark-PWGAN.ipynb b/notebooks/Benchmark-PWGAN.ipynb index 430d329..4a2a21d 100644 --- a/notebooks/Benchmark-PWGAN.ipynb +++ b/notebooks/Benchmark-PWGAN.ipynb @@ -132,7 +132,7 @@ "outputs": [], "source": [ "# LOAD TTS MODEL\n", - "from TTS.utils.text.symbols import symbols, phonemes\n", + "from TTS.utils.text.symbols import make_symbols, symbols, phonemes\n", "\n", "# multi speaker \n", "if CONFIG.use_speaker_embedding:\n", @@ -142,6 +142,10 @@ " speakers = []\n", " speaker_id = None\n", "\n", + "# if the vocabulary was passed, replace the default\n", + "if 'text' in CONFIG.keys():\n", + " symbols, phonemes = make_symbols(**CONFIG.text)\n", + "\n", "# load the model\n", "num_chars = len(phonemes) if CONFIG.use_phonemes else len(symbols)\n", "model = setup_model(num_chars, len(speakers), CONFIG)\n", diff --git a/notebooks/Benchmark.ipynb b/notebooks/Benchmark.ipynb index 00ac7d1..528d7a3 100644 --- a/notebooks/Benchmark.ipynb +++ b/notebooks/Benchmark.ipynb @@ -65,7 +65,7 @@ "from TTS.utils.text import text_to_sequence\n", "from TTS.utils.synthesis import synthesis\n", "from TTS.utils.visual import visualize\n", - "from TTS.utils.text.symbols import symbols, phonemes\n", + "from TTS.utils.text.symbols import make_symbols, symbols, phonemes\n", "\n", "import IPython\n", "from IPython.display import Audio\n", @@ -149,6 +149,10 @@ " speakers = []\n", " speaker_id = None\n", "\n", + "# if the vocabulary was passed, replace the default\n", + "if 'text' in CONFIG.keys():\n", + " symbols, phonemes = make_symbols(**CONFIG.text)\n", + "\n", "# load the model\n", "num_chars = len(phonemes) if CONFIG.use_phonemes else len(symbols)\n", "model = setup_model(num_chars, len(speakers), CONFIG)\n", diff --git a/notebooks/ExtractTTSpectrogram.ipynb b/notebooks/ExtractTTSpectrogram.ipynb index 20038f7..2313e47 100644 --- a/notebooks/ExtractTTSpectrogram.ipynb +++ b/notebooks/ExtractTTSpectrogram.ipynb @@ -37,7 +37,7 @@ "from TTS.utils.audio import AudioProcessor\n", "from TTS.utils.visual import plot_spectrogram\n", "from TTS.utils.generic_utils import load_config, setup_model, sequence_mask\n", - "from TTS.utils.text.symbols import symbols, phonemes\n", + "from TTS.utils.text.symbols import make_symbols, symbols, phonemes\n", "\n", "%matplotlib inline\n", "\n", @@ -94,6 +94,10 @@ "metadata": {}, "outputs": [], "source": [ + "# if the vocabulary was passed, replace the default\n", + "if 'text' in C.keys():\n", + " symbols, phonemes = make_symbols(**C.text)\n", + "\n", "# load the model\n", "num_chars = len(phonemes) if C.use_phonemes else len(symbols)\n", "# TODO: multiple speaker\n", @@ -116,7 +120,7 @@ "preprocessor = importlib.import_module('TTS.datasets.preprocess')\n", "preprocessor = getattr(preprocessor, DATASET.lower())\n", "meta_data = preprocessor(DATA_PATH,METADATA_FILE)\n", - "dataset = MyDataset(checkpoint['r'], C.text_cleaner, ap, meta_data, use_phonemes=C.use_phonemes, phoneme_cache_path=C.phoneme_cache_path, enable_eos_bos=C.enable_eos_bos_chars)\n", + "dataset = MyDataset(checkpoint['r'], C.text_cleaner, ap, meta_data,tp=C.text if 'text' in C.keys() else None, use_phonemes=C.use_phonemes, phoneme_cache_path=C.phoneme_cache_path, enable_eos_bos=C.enable_eos_bos_chars)\n", "loader = DataLoader(dataset, batch_size=BATCH_SIZE, num_workers=4, collate_fn=dataset.collate_fn, shuffle=False, drop_last=False)" ] }, diff --git a/notebooks/TestAttention.ipynb b/notebooks/TestAttention.ipynb index a1867d1..5310fb9 100644 --- a/notebooks/TestAttention.ipynb +++ b/notebooks/TestAttention.ipynb @@ -100,7 +100,7 @@ "outputs": [], "source": [ "# LOAD TTS MODEL\n", - "from TTS.utils.text.symbols import symbols, phonemes\n", + "from TTS.utils.text.symbols import make_symbols, symbols, phonemes\n", "\n", "# multi speaker \n", "if CONFIG.use_speaker_embedding:\n", @@ -110,6 +110,10 @@ " speakers = []\n", " speaker_id = None\n", "\n", + "# if the vocabulary was passed, replace the default\n", + "if 'text' in CONFIG.keys():\n", + " symbols, phonemes = make_symbols(**CONFIG.text)\n", + "\n", "# load the model\n", "num_chars = len(phonemes) if CONFIG.use_phonemes else len(symbols)\n", "model = setup_model(num_chars, len(speakers), CONFIG)\n", diff --git a/server/synthesizer.py b/server/synthesizer.py index 347bef2..f001afc 100644 --- a/server/synthesizer.py +++ b/server/synthesizer.py @@ -10,7 +10,7 @@ from TTS.utils.audio import AudioProcessor from TTS.utils.generic_utils import load_config, setup_model from TTS.utils.speakers import load_speaker_mapping from TTS.utils.synthesis import * -from TTS.utils.text import phonemes, symbols +from TTS.utils.text import make_symbols, phonemes, symbols alphabets = r"([A-Za-z])" prefixes = r"(Mr|St|Mrs|Ms|Dr)[.]" @@ -38,12 +38,19 @@ class Synthesizer(object): self.config.pwgan_config, self.config.use_cuda) def load_tts(self, tts_checkpoint, tts_config, use_cuda): + global symbols, phonemes + print(" > Loading TTS model ...") print(" | > model config: ", tts_config) print(" | > checkpoint file: ", tts_checkpoint) + self.tts_config = load_config(tts_config) self.use_phonemes = self.tts_config.use_phonemes self.ap = AudioProcessor(**self.tts_config.audio) + + if 'text' in self.tts_config.keys(): + symbols, phonemes = make_symbols(**self.tts_config.text) + if self.use_phonemes: self.input_size = len(phonemes) else: diff --git a/synthesize.py b/synthesize.py index bf85d7c..d294701 100644 --- a/synthesize.py +++ b/synthesize.py @@ -8,7 +8,7 @@ import string from TTS.utils.synthesis import synthesis from TTS.utils.generic_utils import load_config, setup_model -from TTS.utils.text.symbols import symbols, phonemes +from TTS.utils.text.symbols import make_symbols, symbols, phonemes from TTS.utils.audio import AudioProcessor @@ -48,6 +48,8 @@ def tts(model, if __name__ == "__main__": + global symbols, phonemes + parser = argparse.ArgumentParser() parser.add_argument('text', type=str, help='Text to generate speech.') parser.add_argument('config_path', @@ -105,6 +107,10 @@ if __name__ == "__main__": # load the audio processor ap = AudioProcessor(**C.audio) + # if the vocabulary was passed, replace the default + if 'text' in C.keys(): + symbols, phonemes = make_symbols(**C.text) + # load speakers if args.speakers_json != '': speakers = json.load(open(args.speakers_json, 'r')) diff --git a/tests/test_demo_server.py b/tests/test_demo_server.py index c343a6a..3e360e2 100644 --- a/tests/test_demo_server.py +++ b/tests/test_demo_server.py @@ -5,13 +5,16 @@ import torch as T from TTS.server.synthesizer import Synthesizer from TTS.tests import get_tests_input_path, get_tests_output_path -from TTS.utils.text.symbols import phonemes, symbols +from TTS.utils.text.symbols import make_symbols, phonemes, symbols from TTS.utils.generic_utils import load_config, save_checkpoint, setup_model class DemoServerTest(unittest.TestCase): def _create_random_model(self): config = load_config(os.path.join(get_tests_output_path(), 'dummy_model_config.json')) + if 'text' in config.keys(): + symbols, phonemes = make_symbols(**config.text) + num_chars = len(phonemes) if config.use_phonemes else len(symbols) model = setup_model(num_chars, 0, config) output_path = os.path.join(get_tests_output_path()) diff --git a/tests/test_loader.py b/tests/test_loader.py index d872789..5141fa8 100644 --- a/tests/test_loader.py +++ b/tests/test_loader.py @@ -38,6 +38,7 @@ class TestTTSDataset(unittest.TestCase): c.text_cleaner, ap=self.ap, meta_data=items, + tp=c.text if 'text' in c.keys() else None, batch_group_size=bgs, min_seq_len=c.min_seq_len, max_seq_len=float("inf"), diff --git a/train.py b/train.py index 7bfb875..96c268f 100644 --- a/train.py +++ b/train.py @@ -25,7 +25,7 @@ from TTS.utils.logger import Logger from TTS.utils.speakers import load_speaker_mapping, save_speaker_mapping, \ get_speakers from TTS.utils.synthesis import synthesis -from TTS.utils.text.symbols import phonemes, symbols +from TTS.utils.text.symbols import make_symbols, phonemes, symbols from TTS.utils.visual import plot_alignment, plot_spectrogram from TTS.datasets.preprocess import load_meta_data from TTS.utils.radam import RAdam @@ -49,6 +49,7 @@ def setup_loader(ap, r, is_val=False, verbose=False): c.text_cleaner, meta_data=meta_data_eval if is_val else meta_data_train, ap=ap, + tp=c.text if 'text' in c.keys() else None, batch_group_size=0 if is_val else c.batch_group_size * c.batch_size, min_seq_len=c.min_seq_len, @@ -515,9 +516,12 @@ def evaluate(model, criterion, criterion_st, ap, global_step, epoch): # FIXME: move args definition/parsing inside of main? def main(args): # pylint: disable=redefined-outer-name - global meta_data_train, meta_data_eval + global meta_data_train, meta_data_eval, symbols, phonemes # Audio processor ap = AudioProcessor(**c.audio) + + if 'text' in c.keys(): + symbols, phonemes = make_symbols(**c.text) # DISTRUBUTED if num_gpus > 1: diff --git a/utils/generic_utils.py b/utils/generic_utils.py index a8de5bb..6aecdc7 100644 --- a/utils/generic_utils.py +++ b/utils/generic_utils.py @@ -425,6 +425,15 @@ def check_config(c): _check_argument('power', c['audio'], restricted=True, val_type=float, min_val=1, max_val=5) _check_argument('griffin_lim_iters', c['audio'], restricted=True, val_type=int, min_val=10, max_val=1000) + # vocabulary parameters + _check_argument('text', c, restricted=False, val_type=dict) # parameter not mandatory + _check_argument('pad', c['text'] if 'text' in c.keys() else {}, restricted=True if 'text' in c.keys() else False, val_type=str) # mandatory if "text parameters" else no mandatory + _check_argument('eos', c['text'] if 'text' in c.keys() else {}, restricted=True if 'text' in c.keys() else False, val_type=str) + _check_argument('bos', c['text'] if 'text' in c.keys() else {}, restricted=True if 'text' in c.keys() else False, val_type=str) + _check_argument('characters', c['text'] if 'text' in c.keys() else {}, restricted=True if 'text' in c.keys() else False, val_type=str) + _check_argument('phonemes', c['text'] if 'text' in c.keys() else {}, restricted=True if 'text' in c.keys() else False, val_type=str) + _check_argument('punctuations', c['text'] if 'text' in c.keys() else {}, restricted=True if 'text' in c.keys() else False, val_type=str) + # normalization parameters _check_argument('signal_norm', c['audio'], restricted=True, val_type=bool) _check_argument('symmetric_norm', c['audio'], restricted=True, val_type=bool) diff --git a/utils/synthesis.py b/utils/synthesis.py index 79a17c7..c5ff2e7 100644 --- a/utils/synthesis.py +++ b/utils/synthesis.py @@ -9,10 +9,11 @@ def text_to_seqvec(text, CONFIG, use_cuda): if CONFIG.use_phonemes: seq = np.asarray( phoneme_to_sequence(text, text_cleaner, CONFIG.phoneme_language, - CONFIG.enable_eos_bos_chars), + CONFIG.enable_eos_bos_chars, + tp=CONFIG.text if 'text' in CONFIG.keys() else None), dtype=np.int32) else: - seq = np.asarray(text_to_sequence(text, text_cleaner), dtype=np.int32) + seq = np.asarray(text_to_sequence(text, text_cleaner, tp=CONFIG.text if 'text' in CONFIG.keys() else None), dtype=np.int32) # torch tensor chars_var = torch.from_numpy(seq).unsqueeze(0) if use_cuda: diff --git a/utils/text/__init__.py b/utils/text/__init__.py index 0e6684d..fcb239b 100644 --- a/utils/text/__init__.py +++ b/utils/text/__init__.py @@ -4,7 +4,7 @@ import re import phonemizer from phonemizer.phonemize import phonemize from TTS.utils.text import cleaners -from TTS.utils.text.symbols import symbols, phonemes, _phoneme_punctuations, _bos, \ +from TTS.utils.text.symbols import make_symbols, symbols, phonemes, _phoneme_punctuations, _bos, \ _eos # Mappings from symbol to numeric ID and vice versa: @@ -56,11 +56,23 @@ def text2phone(text, language): return ph -def pad_with_eos_bos(phoneme_sequence): +def pad_with_eos_bos(phoneme_sequence, tp=None): + global _PHONEMES_TO_ID, _bos, _eos + if tp: + _bos = tp['bos'] + _eos = tp['eos'] + _, phonemes = make_symbols(**tp) + _PHONEMES_TO_ID = {s: i for i, s in enumerate(phonemes)} + return [_PHONEMES_TO_ID[_bos]] + list(phoneme_sequence) + [_PHONEMES_TO_ID[_eos]] -def phoneme_to_sequence(text, cleaner_names, language, enable_eos_bos=False): +def phoneme_to_sequence(text, cleaner_names, language, enable_eos_bos=False, tp=None): + global _PHONEMES_TO_ID + if tp: + _, phonemes = make_symbols(**tp) + _PHONEMES_TO_ID = {s: i for i, s in enumerate(phonemes)} + sequence = [] text = text.replace(":", "") clean_text = _clean_text(text, cleaner_names) @@ -72,13 +84,18 @@ def phoneme_to_sequence(text, cleaner_names, language, enable_eos_bos=False): sequence += _phoneme_to_sequence(phoneme) # Append EOS char if enable_eos_bos: - sequence = pad_with_eos_bos(sequence) + sequence = pad_with_eos_bos(sequence, tp=tp) return sequence -def sequence_to_phoneme(sequence): +def sequence_to_phoneme(sequence, tp=None): '''Converts a sequence of IDs back to a string''' + global _ID_TO_PHONEMES result = '' + if tp: + _, phonemes = make_symbols(**tp) + _ID_TO_PHONEMES = {i: s for i, s in enumerate(phonemes)} + for symbol_id in sequence: if symbol_id in _ID_TO_PHONEMES: s = _ID_TO_PHONEMES[symbol_id] @@ -86,7 +103,7 @@ def sequence_to_phoneme(sequence): return result.replace('}{', ' ') -def text_to_sequence(text, cleaner_names): +def text_to_sequence(text, cleaner_names, tp=None): '''Converts a string of text to a sequence of IDs corresponding to the symbols in the text. The text can optionally have ARPAbet sequences enclosed in curly braces embedded @@ -99,6 +116,11 @@ def text_to_sequence(text, cleaner_names): Returns: List of integers corresponding to the symbols in the text ''' + global _SYMBOL_TO_ID + if tp: + symbols, _ = make_symbols(**tp) + _SYMBOL_TO_ID = {s: i for i, s in enumerate(symbols)} + sequence = [] # Check for curly braces and treat their contents as ARPAbet: while text: @@ -113,8 +135,13 @@ def text_to_sequence(text, cleaner_names): return sequence -def sequence_to_text(sequence): +def sequence_to_text(sequence, tp=None): '''Converts a sequence of IDs back to a string''' + global _ID_TO_SYMBOL + if tp: + symbols, _ = make_symbols(**tp) + _ID_TO_SYMBOL = {i: s for i, s in enumerate(symbols)} + result = '' for symbol_id in sequence: if symbol_id in _ID_TO_SYMBOL: diff --git a/utils/text/symbols.py b/utils/text/symbols.py index ee6fd2c..e4a4b10 100644 --- a/utils/text/symbols.py +++ b/utils/text/symbols.py @@ -5,6 +5,18 @@ Defines the set of symbols used in text input to the model. The default is a set of ASCII characters that works well for English or text that has been run through Unidecode. For other data, you can modify _characters. See TRAINING_DATA.md for details. ''' +def make_symbols(characters, phonemes, punctuations='!\'(),-.:;? ', pad='_', eos='~', bos='^'): + ''' Function to create symbols and phonemes ''' + _phonemes = sorted(list(phonemes)) + + # Prepend "@" to ARPAbet symbols to ensure uniqueness (some are the same as uppercase letters): + _arpabet = ['@' + s for s in _phonemes] + + # Export all symbols: + symbols = [pad, eos, bos] + list(characters) + _arpabet + phonemes = [pad, eos, bos] + list(_phonemes) + list(punctuations) + + return symbols, phonemes _pad = '_' _eos = '~' @@ -20,14 +32,9 @@ _pulmonic_consonants = 'pbtdʈɖcɟkɡqɢʔɴŋɲɳnɱmʙrʀⱱɾɽɸβfvθðsz _suprasegmentals = 'ˈˌːˑ' _other_symbols = 'ʍwɥʜʢʡɕʑɺɧ' _diacrilics = 'ɚ˞ɫ' -_phonemes = sorted(list(_vowels + _non_pulmonic_consonants + _pulmonic_consonants + _suprasegmentals + _other_symbols + _diacrilics)) +_phonemes = _vowels + _non_pulmonic_consonants + _pulmonic_consonants + _suprasegmentals + _other_symbols + _diacrilics -# Prepend "@" to ARPAbet symbols to ensure uniqueness (some are the same as uppercase letters): -_arpabet = ['@' + s for s in _phonemes] - -# Export all symbols: -symbols = [_pad, _eos, _bos] + list(_characters) + _arpabet -phonemes = [_pad, _eos, _bos] + list(_phonemes) + list(_punctuations) +symbols, phonemes = make_symbols( _characters, _phonemes,_punctuations, _pad, _eos, _bos) # Generate ALIEN language # from random import shuffle diff --git a/utils/visual.py b/utils/visual.py index ab51366..2f93d81 100644 --- a/utils/visual.py +++ b/utils/visual.py @@ -54,9 +54,10 @@ def visualize(alignment, spectrogram_postnet, stop_tokens, text, hop_length, CON plt.xlabel("Decoder timestamp", fontsize=label_fontsize) plt.ylabel("Encoder timestamp", fontsize=label_fontsize) if CONFIG.use_phonemes: - seq = phoneme_to_sequence(text, [CONFIG.text_cleaner], CONFIG.phoneme_language, CONFIG.enable_eos_bos_chars) - text = sequence_to_phoneme(seq) + seq = phoneme_to_sequence(text, [CONFIG.text_cleaner], CONFIG.phoneme_language, CONFIG.enable_eos_bos_chars, tp=CONFIG.text if 'text' in CONFIG.keys() else None) + text = sequence_to_phoneme(seq, tp=CONFIG.text if 'text' in CONFIG.keys() else None) print(text) + plt.yticks(range(len(text)), list(text)) plt.colorbar()