зеркало из https://github.com/mozilla/TTS.git
dropped dataset caching
This commit is contained in:
Родитель
2e8446539f
Коммит
7b2804cc0d
|
@ -69,9 +69,7 @@ Audio length is approximately 6 secs.
|
|||
|
||||
|
||||
## Datasets and Data-Loading
|
||||
TTS provides a generic dataloder easy to use for new datasets. You need to write an adaptor to format and that's all you need.Check ```datasets/preprocess.py``` to see example adaptors. After you wrote an adaptor, you need to set ```dataset``` field in ```config.json```. Do not forget other data related fields.
|
||||
|
||||
You can also use pre-computed features. In this case, compute features with ```extract_features.py``` and set ```dataset``` field as ```tts_cache```.
|
||||
TTS provides a generic dataloder easy to use for new datasets. You need to write an adaptor to format and that's all you need.Check ```datasets/preprocess.py``` to see example adaptors. After you wrote an adaptor, you need to set ```dataset``` field in ```config.json```. Do not forget other data related fields.
|
||||
|
||||
Example datasets, we successfully applied TTS, are linked below.
|
||||
|
||||
|
|
|
@ -62,7 +62,7 @@
|
|||
"data_path": "/home/erogol/Data/LJSpeech-1.1", // DATASET-RELATED: can overwritten from command argument
|
||||
"meta_file_train": "metadata_train.csv", // DATASET-RELATED: metafile for training dataloader.
|
||||
"meta_file_val": "metadata_val.csv", // DATASET-RELATED: metafile for evaluation dataloader.
|
||||
"dataset": "ljspeech", // DATASET-RELATED: one of TTS.dataset.preprocessors depending on your target dataset. Use "tts_cache" for pre-computed dataset by extract_features.py
|
||||
"dataset": "ljspeech", // DATASET-RELATED: one of TTS.dataset.preprocessors depending on your target dataset.
|
||||
"min_seq_len": 0, // DATASET-RELATED: minimum text length to use in training
|
||||
"max_seq_len": 150, // DATASET-RELATED: maximum text length
|
||||
"output_path": "/media/erogol/data_ssd/Data/models/ljspeech_models/", // DATASET-RELATED: output path for all training outputs.
|
||||
|
|
|
@ -61,7 +61,7 @@
|
|||
"data_path": "/media/erogol/data_ssd/Data/LJSpeech-1.1", // DATASET-RELATED: can overwritten from command argument
|
||||
"meta_file_train": "prompts_train.data", // DATASET-RELATED: metafile for training dataloader.
|
||||
"meta_file_val": "prompts_val.data", // DATASET-RELATED: metafile for evaluation dataloader.
|
||||
"dataset": "mozilla", // DATASET-RELATED: one of TTS.dataset.preprocessors depending on your target dataset. Use "tts_cache" for pre-computed dataset by extract_features.py
|
||||
"dataset": "mozilla", // DATASET-RELATED: one of TTS.dataset.preprocessors depending on your target dataset.
|
||||
"min_seq_len": 0, // DATASET-RELATED: minimum text length to use in training
|
||||
"max_seq_len": 150, // DATASET-RELATED: maximum text length
|
||||
"output_path": "../keep/", // DATASET-RELATED: output path for all training outputs.
|
||||
|
|
|
@ -22,7 +22,6 @@ class MyDataset(Dataset):
|
|||
batch_group_size=0,
|
||||
min_seq_len=0,
|
||||
max_seq_len=float("inf"),
|
||||
cached=False,
|
||||
use_phonemes=True,
|
||||
phoneme_cache_path=None,
|
||||
phoneme_language="en-us",
|
||||
|
@ -61,7 +60,6 @@ class MyDataset(Dataset):
|
|||
self.min_seq_len = min_seq_len
|
||||
self.max_seq_len = max_seq_len
|
||||
self.ap = ap
|
||||
self.cached = cached
|
||||
self.use_phonemes = use_phonemes
|
||||
self.phoneme_cache_path = phoneme_cache_path
|
||||
self.phoneme_language = phoneme_language
|
||||
|
@ -113,23 +111,10 @@ class MyDataset(Dataset):
|
|||
return text
|
||||
|
||||
def load_data(self, idx):
|
||||
if self.cached:
|
||||
wav_name = self.items[idx][1]
|
||||
mel_name = self.items[idx][2]
|
||||
linear_name = self.items[idx][3]
|
||||
text = self.items[idx][0]
|
||||
|
||||
if wav_name.split('.')[-1] == 'npy':
|
||||
wav = self.load_np(wav_name)
|
||||
else:
|
||||
wav = np.asarray(self.load_wav(wav_name), dtype=np.float32)
|
||||
mel = self.load_np(mel_name)
|
||||
linear = self.load_np(linear_name)
|
||||
else:
|
||||
text, wav_file = self.items[idx]
|
||||
wav = np.asarray(self.load_wav(wav_file), dtype=np.float32)
|
||||
mel = None
|
||||
linear = None
|
||||
text, wav_file = self.items[idx]
|
||||
wav = np.asarray(self.load_wav(wav_file), dtype=np.float32)
|
||||
mel = None
|
||||
linear = None
|
||||
|
||||
if self.use_phonemes:
|
||||
text = self.load_phoneme_sequence(wav_file, text)
|
||||
|
|
|
@ -1,18 +1,6 @@
|
|||
import os
|
||||
|
||||
|
||||
def tts_cache(root_path, meta_file):
|
||||
"""This format is set for the meta-file generated by extract_features.py"""
|
||||
txt_file = os.path.join(root_path, meta_file)
|
||||
items = []
|
||||
with open(txt_file, 'r', encoding='utf8') as f:
|
||||
for line in f:
|
||||
cols = line.split('| ')
|
||||
# text, wav_full_path, mel_name, linear_name, wav_len, mel_len
|
||||
items.append(cols)
|
||||
return items
|
||||
|
||||
|
||||
def tweb(root_path, meta_file):
|
||||
"""Normalize TWEB dataset.
|
||||
https://www.kaggle.com/bryanpark/the-world-english-bible-speech-dataset
|
||||
|
|
|
@ -1,126 +0,0 @@
|
|||
'''
|
||||
Extract spectrograms and save them to file for training
|
||||
'''
|
||||
import os
|
||||
import sys
|
||||
import time
|
||||
import glob
|
||||
import argparse
|
||||
import librosa
|
||||
import importlib
|
||||
import numpy as np
|
||||
import tqdm
|
||||
from utils.generic_utils import load_config, copy_config_file
|
||||
from utils.audio import AudioProcessor
|
||||
|
||||
from multiprocessing import Pool
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('--data_path', type=str, help='Data folder.')
|
||||
parser.add_argument('--cache_path', type=str, help='Cache folder, place to output all the spectrogram files.')
|
||||
parser.add_argument(
|
||||
'--config', type=str, help='conf.json file for run settings.')
|
||||
parser.add_argument(
|
||||
"--num_proc", type=int, default=8, help="number of processes.")
|
||||
parser.add_argument(
|
||||
"--trim_silence",
|
||||
type=bool,
|
||||
default=False,
|
||||
help="trim silence in the voice clip.")
|
||||
parser.add_argument("--only_mel", type=bool, default=False, help="If True, only melsceptrogram is extracted.")
|
||||
parser.add_argument("--dataset", type=str, help="Target dataset to be processed.")
|
||||
parser.add_argument("--val_split", type=int, default=0, help="Number of instances for validation.")
|
||||
parser.add_argument("--meta_file", type=str, help="Meta data file to be used for the dataset.")
|
||||
parser.add_argument("--process_audio", type=bool, default=False, help="Preprocess audio files.")
|
||||
args = parser.parse_args()
|
||||
|
||||
DATA_PATH = args.data_path
|
||||
CACHE_PATH = args.cache_path
|
||||
CONFIG = load_config(args.config)
|
||||
|
||||
# load the right preprocessor
|
||||
preprocessor = importlib.import_module('datasets.preprocess')
|
||||
preprocessor = getattr(preprocessor, args.dataset.lower())
|
||||
items = preprocessor(args.data_path, args.meta_file)
|
||||
|
||||
print(" > Input path: ", DATA_PATH)
|
||||
print(" > Cache path: ", CACHE_PATH)
|
||||
|
||||
ap = AudioProcessor(**CONFIG.audio)
|
||||
|
||||
|
||||
def extract_mel(item):
|
||||
""" Compute spectrograms, length information """
|
||||
text = item[0]
|
||||
file_path = item[1]
|
||||
x = ap.load_wav(file_path, ap.sample_rate)
|
||||
file_name = os.path.basename(file_path).replace(".wav", "")
|
||||
mel_file = file_name + "_mel"
|
||||
mel_path = os.path.join(CACHE_PATH, 'mel', mel_file)
|
||||
mel = ap.melspectrogram(x.astype('float32')).astype('float32')
|
||||
np.save(mel_path, mel, allow_pickle=False)
|
||||
mel_len = mel.shape[1]
|
||||
wav_len = x.shape[0]
|
||||
output = [text, file_path, mel_path+".npy", str(wav_len), str(mel_len)]
|
||||
if not args.only_mel:
|
||||
linear_file = file_name + "_linear"
|
||||
linear_path = os.path.join(CACHE_PATH, 'linear', linear_file)
|
||||
linear = ap.spectrogram(x.astype('float32')).astype('float32')
|
||||
linear_len = linear.shape[1]
|
||||
np.save(linear_path, linear, allow_pickle=False)
|
||||
output.insert(3, linear_path+".npy")
|
||||
assert mel_len == linear_len
|
||||
if args.process_audio:
|
||||
audio_file = file_name + "_audio"
|
||||
audio_path = os.path.join(CACHE_PATH, 'audio', audio_file)
|
||||
np.save(audio_path, x, allow_pickle=False)
|
||||
del output[0]
|
||||
output.insert(1, audio_path+".npy")
|
||||
return output
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
print(" > Number of files: %i" % (len(items)))
|
||||
if not os.path.exists(CACHE_PATH):
|
||||
os.makedirs(os.path.join(CACHE_PATH, 'mel'))
|
||||
if not args.only_mel:
|
||||
os.makedirs(os.path.join(CACHE_PATH, 'linear'))
|
||||
if args.process_audio:
|
||||
os.makedirs(os.path.join(CACHE_PATH, 'audio'))
|
||||
print(" > A new folder created at {}".format(CACHE_PATH))
|
||||
|
||||
# Extract features
|
||||
r = []
|
||||
if args.num_proc > 1:
|
||||
print(" > Using {} processes.".format(args.num_proc))
|
||||
with Pool(args.num_proc) as p:
|
||||
r = list(
|
||||
tqdm.tqdm(
|
||||
p.imap(extract_mel, items),
|
||||
total=len(items)))
|
||||
# r = list(p.imap(extract_mel, file_names))
|
||||
else:
|
||||
print(" > Using single process run.")
|
||||
for item in items:
|
||||
print(" > ", item[1])
|
||||
r.append(extract_mel(item))
|
||||
|
||||
# Save meta data
|
||||
if args.cache_path is not None:
|
||||
file_path = os.path.join(CACHE_PATH, "tts_metadata_val.csv")
|
||||
file = open(file_path, "w")
|
||||
for line in r[:args.val_split]:
|
||||
line = "| ".join(line)
|
||||
file.write(line + '\n')
|
||||
file.close()
|
||||
|
||||
file_path = os.path.join(CACHE_PATH, "tts_metadata.csv")
|
||||
file = open(file_path, "w")
|
||||
for line in r[args.val_split:]:
|
||||
line = "| ".join(line)
|
||||
file.write(line + '\n')
|
||||
file.close()
|
||||
|
||||
# copy the used config file to output path for sanity
|
||||
copy_config_file(args.config, CACHE_PATH)
|
|
@ -7,7 +7,7 @@ from torch.utils.data import DataLoader
|
|||
from utils.generic_utils import load_config
|
||||
from utils.audio import AudioProcessor
|
||||
from datasets import TTSDataset
|
||||
from datasets.preprocess import ljspeech, tts_cache
|
||||
from datasets.preprocess import ljspeech
|
||||
|
||||
file_path = os.path.dirname(os.path.realpath(__file__))
|
||||
OUTPATH = os.path.join(file_path, "outputs/loader_tests/")
|
||||
|
@ -16,15 +16,11 @@ c = load_config(os.path.join(file_path, 'test_config.json'))
|
|||
ok_ljspeech = os.path.exists(c.data_path)
|
||||
|
||||
DATA_EXIST = True
|
||||
CACHE_EXIST = True
|
||||
if not os.path.exists(c.data_path_cache):
|
||||
CACHE_EXIST = False
|
||||
|
||||
if not os.path.exists(c.data_path):
|
||||
DATA_EXIST = False
|
||||
|
||||
print(" > Dynamic data loader test: {}".format(DATA_EXIST))
|
||||
print(" > Cache data loader test: {}".format(CACHE_EXIST))
|
||||
|
||||
|
||||
class TestTTSDataset(unittest.TestCase):
|
||||
def __init__(self, *args, **kwargs):
|
||||
|
@ -126,8 +122,9 @@ class TestTTSDataset(unittest.TestCase):
|
|||
wav = self.ap.load_wav(item_idx[0])
|
||||
mel = self.ap.melspectrogram(wav)
|
||||
mel_dl = mel_input[0].cpu().numpy()
|
||||
assert (
|
||||
abs(mel.T).astype("float32") - abs(mel_dl[:-1])).sum() == 0
|
||||
assert (abs(mel.T).astype("float32")
|
||||
- abs(mel_dl[:-1])
|
||||
).sum() == 0
|
||||
|
||||
# check mel-spec correctness
|
||||
mel_spec = mel_input[0].cpu().numpy()
|
||||
|
@ -139,7 +136,8 @@ class TestTTSDataset(unittest.TestCase):
|
|||
linear_spec = linear_input[0].cpu().numpy()
|
||||
wav = self.ap.inv_spectrogram(linear_spec.T)
|
||||
self.ap.save_wav(wav, OUTPATH + '/linear_inv_dataloader.wav')
|
||||
shutil.copy(item_idx[0], OUTPATH + '/linear_target_dataloader.wav')
|
||||
shutil.copy(item_idx[0],
|
||||
OUTPATH + '/linear_target_dataloader.wav')
|
||||
|
||||
# check the last time step to be zero padded
|
||||
assert linear_input[0, -1].sum() == 0
|
||||
|
@ -192,4 +190,4 @@ class TestTTSDataset(unittest.TestCase):
|
|||
|
||||
# check batch conditions
|
||||
assert (linear_input * stop_target.unsqueeze(2)).sum() == 0
|
||||
assert (mel_input * stop_target.unsqueeze(2)).sum() == 0
|
||||
assert (mel_input * stop_target.unsqueeze(2)).sum() == 0
|
||||
|
|
|
@ -36,7 +36,6 @@
|
|||
|
||||
"save_step": 200,
|
||||
"data_path": "/home/erogol/Data/LJSpeech-1.1/",
|
||||
"data_path_cache": "/media/erogol/data_ssd/Data/Nancy/tts_cache/",
|
||||
"output_path": "result",
|
||||
"min_seq_len": 0,
|
||||
"max_seq_len": 300,
|
||||
|
|
1
train.py
1
train.py
|
@ -53,7 +53,6 @@ def setup_loader(is_val=False, verbose=False):
|
|||
batch_group_size=0 if is_val else c.batch_group_size * c.batch_size,
|
||||
min_seq_len=0 if is_val else c.min_seq_len,
|
||||
max_seq_len=float("inf") if is_val else c.max_seq_len,
|
||||
cached=False if c.dataset != "tts_cache" else True,
|
||||
phoneme_cache_path=c.phoneme_cache_path,
|
||||
use_phonemes=c.use_phonemes,
|
||||
phoneme_language=c.phoneme_language,
|
||||
|
|
Загрузка…
Ссылка в новой задаче