This commit is contained in:
Thomas Werkmeister 2019-04-29 11:07:04 +02:00
Родитель 2e8446539f
Коммит 7b2804cc0d
9 изменённых файлов: 15 добавлений и 174 удалений

Просмотреть файл

@ -69,9 +69,7 @@ Audio length is approximately 6 secs.
## Datasets and Data-Loading
TTS provides a generic dataloder easy to use for new datasets. You need to write an adaptor to format and that's all you need.Check ```datasets/preprocess.py``` to see example adaptors. After you wrote an adaptor, you need to set ```dataset``` field in ```config.json```. Do not forget other data related fields.
You can also use pre-computed features. In this case, compute features with ```extract_features.py``` and set ```dataset``` field as ```tts_cache```.
TTS provides a generic dataloder easy to use for new datasets. You need to write an adaptor to format and that's all you need.Check ```datasets/preprocess.py``` to see example adaptors. After you wrote an adaptor, you need to set ```dataset``` field in ```config.json```. Do not forget other data related fields.
Example datasets, we successfully applied TTS, are linked below.

Просмотреть файл

@ -62,7 +62,7 @@
"data_path": "/home/erogol/Data/LJSpeech-1.1", // DATASET-RELATED: can overwritten from command argument
"meta_file_train": "metadata_train.csv", // DATASET-RELATED: metafile for training dataloader.
"meta_file_val": "metadata_val.csv", // DATASET-RELATED: metafile for evaluation dataloader.
"dataset": "ljspeech", // DATASET-RELATED: one of TTS.dataset.preprocessors depending on your target dataset. Use "tts_cache" for pre-computed dataset by extract_features.py
"dataset": "ljspeech", // DATASET-RELATED: one of TTS.dataset.preprocessors depending on your target dataset.
"min_seq_len": 0, // DATASET-RELATED: minimum text length to use in training
"max_seq_len": 150, // DATASET-RELATED: maximum text length
"output_path": "/media/erogol/data_ssd/Data/models/ljspeech_models/", // DATASET-RELATED: output path for all training outputs.

Просмотреть файл

@ -61,7 +61,7 @@
"data_path": "/media/erogol/data_ssd/Data/LJSpeech-1.1", // DATASET-RELATED: can overwritten from command argument
"meta_file_train": "prompts_train.data", // DATASET-RELATED: metafile for training dataloader.
"meta_file_val": "prompts_val.data", // DATASET-RELATED: metafile for evaluation dataloader.
"dataset": "mozilla", // DATASET-RELATED: one of TTS.dataset.preprocessors depending on your target dataset. Use "tts_cache" for pre-computed dataset by extract_features.py
"dataset": "mozilla", // DATASET-RELATED: one of TTS.dataset.preprocessors depending on your target dataset.
"min_seq_len": 0, // DATASET-RELATED: minimum text length to use in training
"max_seq_len": 150, // DATASET-RELATED: maximum text length
"output_path": "../keep/", // DATASET-RELATED: output path for all training outputs.

Просмотреть файл

@ -22,7 +22,6 @@ class MyDataset(Dataset):
batch_group_size=0,
min_seq_len=0,
max_seq_len=float("inf"),
cached=False,
use_phonemes=True,
phoneme_cache_path=None,
phoneme_language="en-us",
@ -61,7 +60,6 @@ class MyDataset(Dataset):
self.min_seq_len = min_seq_len
self.max_seq_len = max_seq_len
self.ap = ap
self.cached = cached
self.use_phonemes = use_phonemes
self.phoneme_cache_path = phoneme_cache_path
self.phoneme_language = phoneme_language
@ -113,23 +111,10 @@ class MyDataset(Dataset):
return text
def load_data(self, idx):
if self.cached:
wav_name = self.items[idx][1]
mel_name = self.items[idx][2]
linear_name = self.items[idx][3]
text = self.items[idx][0]
if wav_name.split('.')[-1] == 'npy':
wav = self.load_np(wav_name)
else:
wav = np.asarray(self.load_wav(wav_name), dtype=np.float32)
mel = self.load_np(mel_name)
linear = self.load_np(linear_name)
else:
text, wav_file = self.items[idx]
wav = np.asarray(self.load_wav(wav_file), dtype=np.float32)
mel = None
linear = None
text, wav_file = self.items[idx]
wav = np.asarray(self.load_wav(wav_file), dtype=np.float32)
mel = None
linear = None
if self.use_phonemes:
text = self.load_phoneme_sequence(wav_file, text)

Просмотреть файл

@ -1,18 +1,6 @@
import os
def tts_cache(root_path, meta_file):
"""This format is set for the meta-file generated by extract_features.py"""
txt_file = os.path.join(root_path, meta_file)
items = []
with open(txt_file, 'r', encoding='utf8') as f:
for line in f:
cols = line.split('| ')
# text, wav_full_path, mel_name, linear_name, wav_len, mel_len
items.append(cols)
return items
def tweb(root_path, meta_file):
"""Normalize TWEB dataset.
https://www.kaggle.com/bryanpark/the-world-english-bible-speech-dataset

Просмотреть файл

@ -1,126 +0,0 @@
'''
Extract spectrograms and save them to file for training
'''
import os
import sys
import time
import glob
import argparse
import librosa
import importlib
import numpy as np
import tqdm
from utils.generic_utils import load_config, copy_config_file
from utils.audio import AudioProcessor
from multiprocessing import Pool
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument('--data_path', type=str, help='Data folder.')
parser.add_argument('--cache_path', type=str, help='Cache folder, place to output all the spectrogram files.')
parser.add_argument(
'--config', type=str, help='conf.json file for run settings.')
parser.add_argument(
"--num_proc", type=int, default=8, help="number of processes.")
parser.add_argument(
"--trim_silence",
type=bool,
default=False,
help="trim silence in the voice clip.")
parser.add_argument("--only_mel", type=bool, default=False, help="If True, only melsceptrogram is extracted.")
parser.add_argument("--dataset", type=str, help="Target dataset to be processed.")
parser.add_argument("--val_split", type=int, default=0, help="Number of instances for validation.")
parser.add_argument("--meta_file", type=str, help="Meta data file to be used for the dataset.")
parser.add_argument("--process_audio", type=bool, default=False, help="Preprocess audio files.")
args = parser.parse_args()
DATA_PATH = args.data_path
CACHE_PATH = args.cache_path
CONFIG = load_config(args.config)
# load the right preprocessor
preprocessor = importlib.import_module('datasets.preprocess')
preprocessor = getattr(preprocessor, args.dataset.lower())
items = preprocessor(args.data_path, args.meta_file)
print(" > Input path: ", DATA_PATH)
print(" > Cache path: ", CACHE_PATH)
ap = AudioProcessor(**CONFIG.audio)
def extract_mel(item):
""" Compute spectrograms, length information """
text = item[0]
file_path = item[1]
x = ap.load_wav(file_path, ap.sample_rate)
file_name = os.path.basename(file_path).replace(".wav", "")
mel_file = file_name + "_mel"
mel_path = os.path.join(CACHE_PATH, 'mel', mel_file)
mel = ap.melspectrogram(x.astype('float32')).astype('float32')
np.save(mel_path, mel, allow_pickle=False)
mel_len = mel.shape[1]
wav_len = x.shape[0]
output = [text, file_path, mel_path+".npy", str(wav_len), str(mel_len)]
if not args.only_mel:
linear_file = file_name + "_linear"
linear_path = os.path.join(CACHE_PATH, 'linear', linear_file)
linear = ap.spectrogram(x.astype('float32')).astype('float32')
linear_len = linear.shape[1]
np.save(linear_path, linear, allow_pickle=False)
output.insert(3, linear_path+".npy")
assert mel_len == linear_len
if args.process_audio:
audio_file = file_name + "_audio"
audio_path = os.path.join(CACHE_PATH, 'audio', audio_file)
np.save(audio_path, x, allow_pickle=False)
del output[0]
output.insert(1, audio_path+".npy")
return output
if __name__ == "__main__":
print(" > Number of files: %i" % (len(items)))
if not os.path.exists(CACHE_PATH):
os.makedirs(os.path.join(CACHE_PATH, 'mel'))
if not args.only_mel:
os.makedirs(os.path.join(CACHE_PATH, 'linear'))
if args.process_audio:
os.makedirs(os.path.join(CACHE_PATH, 'audio'))
print(" > A new folder created at {}".format(CACHE_PATH))
# Extract features
r = []
if args.num_proc > 1:
print(" > Using {} processes.".format(args.num_proc))
with Pool(args.num_proc) as p:
r = list(
tqdm.tqdm(
p.imap(extract_mel, items),
total=len(items)))
# r = list(p.imap(extract_mel, file_names))
else:
print(" > Using single process run.")
for item in items:
print(" > ", item[1])
r.append(extract_mel(item))
# Save meta data
if args.cache_path is not None:
file_path = os.path.join(CACHE_PATH, "tts_metadata_val.csv")
file = open(file_path, "w")
for line in r[:args.val_split]:
line = "| ".join(line)
file.write(line + '\n')
file.close()
file_path = os.path.join(CACHE_PATH, "tts_metadata.csv")
file = open(file_path, "w")
for line in r[args.val_split:]:
line = "| ".join(line)
file.write(line + '\n')
file.close()
# copy the used config file to output path for sanity
copy_config_file(args.config, CACHE_PATH)

Просмотреть файл

@ -7,7 +7,7 @@ from torch.utils.data import DataLoader
from utils.generic_utils import load_config
from utils.audio import AudioProcessor
from datasets import TTSDataset
from datasets.preprocess import ljspeech, tts_cache
from datasets.preprocess import ljspeech
file_path = os.path.dirname(os.path.realpath(__file__))
OUTPATH = os.path.join(file_path, "outputs/loader_tests/")
@ -16,15 +16,11 @@ c = load_config(os.path.join(file_path, 'test_config.json'))
ok_ljspeech = os.path.exists(c.data_path)
DATA_EXIST = True
CACHE_EXIST = True
if not os.path.exists(c.data_path_cache):
CACHE_EXIST = False
if not os.path.exists(c.data_path):
DATA_EXIST = False
print(" > Dynamic data loader test: {}".format(DATA_EXIST))
print(" > Cache data loader test: {}".format(CACHE_EXIST))
class TestTTSDataset(unittest.TestCase):
def __init__(self, *args, **kwargs):
@ -126,8 +122,9 @@ class TestTTSDataset(unittest.TestCase):
wav = self.ap.load_wav(item_idx[0])
mel = self.ap.melspectrogram(wav)
mel_dl = mel_input[0].cpu().numpy()
assert (
abs(mel.T).astype("float32") - abs(mel_dl[:-1])).sum() == 0
assert (abs(mel.T).astype("float32")
- abs(mel_dl[:-1])
).sum() == 0
# check mel-spec correctness
mel_spec = mel_input[0].cpu().numpy()
@ -139,7 +136,8 @@ class TestTTSDataset(unittest.TestCase):
linear_spec = linear_input[0].cpu().numpy()
wav = self.ap.inv_spectrogram(linear_spec.T)
self.ap.save_wav(wav, OUTPATH + '/linear_inv_dataloader.wav')
shutil.copy(item_idx[0], OUTPATH + '/linear_target_dataloader.wav')
shutil.copy(item_idx[0],
OUTPATH + '/linear_target_dataloader.wav')
# check the last time step to be zero padded
assert linear_input[0, -1].sum() == 0
@ -192,4 +190,4 @@ class TestTTSDataset(unittest.TestCase):
# check batch conditions
assert (linear_input * stop_target.unsqueeze(2)).sum() == 0
assert (mel_input * stop_target.unsqueeze(2)).sum() == 0
assert (mel_input * stop_target.unsqueeze(2)).sum() == 0

Просмотреть файл

@ -36,7 +36,6 @@
"save_step": 200,
"data_path": "/home/erogol/Data/LJSpeech-1.1/",
"data_path_cache": "/media/erogol/data_ssd/Data/Nancy/tts_cache/",
"output_path": "result",
"min_seq_len": 0,
"max_seq_len": 300,

Просмотреть файл

@ -53,7 +53,6 @@ def setup_loader(is_val=False, verbose=False):
batch_group_size=0 if is_val else c.batch_group_size * c.batch_size,
min_seq_len=0 if is_val else c.min_seq_len,
max_seq_len=float("inf") if is_val else c.max_seq_len,
cached=False if c.dataset != "tts_cache" else True,
phoneme_cache_path=c.phoneme_cache_path,
use_phonemes=c.use_phonemes,
phoneme_language=c.phoneme_language,