split dataset outside preprocessor

This commit is contained in:
Eren Golge 2019-07-16 21:15:04 +02:00
Родитель b7036e458d
Коммит fd081c49b7
3 изменённых файлов: 32 добавлений и 10 удалений

Просмотреть файл

@ -146,7 +146,7 @@ def common_voice(root_path, meta_file):
return items
def libri_tts(root_path, meta_files=None, is_eval=False):
def libri_tts(root_path, meta_files=None):
"""https://ai.google/tools/datasets/libri-tts/"""
items = []
if meta_files is None:
@ -164,6 +164,4 @@ def libri_tts(root_path, meta_files=None, is_eval=False):
items.append([text, wav_file, speaker_name])
for item in items:
assert os.path.exists(item[1]), f" [!] wav file is not exist - {item[1]}"
if meta_files is None:
return items[:500] if is_eval else items[500:]
return items

Просмотреть файл

@ -21,7 +21,8 @@ from utils.generic_utils import (NoamLR, check_update, count_parameters,
load_config, lr_decay,
remove_experiment_folder, save_best_model,
save_checkpoint, sequence_mask, weight_decay,
set_init_dict, copy_config_file, setup_model)
set_init_dict, copy_config_file, setup_model,
split_dataset)
from utils.logger import Logger
from utils.speakers import load_speaker_mapping, save_speaker_mapping, \
get_speakers
@ -44,15 +45,15 @@ def setup_loader(is_val=False, verbose=False):
global meta_data_train
global meta_data_eval
if "meta_data_train" not in globals():
if c.meta_file_train:
if c.meta_file_train is not None:
meta_data_train = get_preprocessor_by_name(c.dataset)(c.data_path, c.meta_file_train)
else:
meta_data_train = get_preprocessor_by_name(c.dataset)(c.data_path, is_eval=False)
if "meta_data_eval" not in globals():
if c.meta_file_val:
meta_data_train = get_preprocessor_by_name(c.dataset)(c.data_path)
if "meta_data_eval" not in globals() and c.run_eval:
if c.meta_file_val is not None:
meta_data_eval = get_preprocessor_by_name(c.dataset)(c.data_path, c.meta_file_val)
else:
meta_data_eval = get_preprocessor_by_name(c.dataset)(c.data_path, is_eval=True)
meta_data_eval, meta_data_train = split_dataset(meta_data_train)
if is_val and not c.run_eval:
loader = None
else:

Просмотреть файл

@ -10,7 +10,7 @@ import torch
import subprocess
import importlib
import numpy as np
from collections import OrderedDict
from collections import OrderedDict, Counter
from torch.autograd import Variable
from utils.text import text_to_sequence
@ -287,3 +287,26 @@ def setup_model(num_chars, num_speakers, c):
location_attn=c.location_attn,
separate_stopnet=c.separate_stopnet)
return model
def split_dataset(items):
is_multi_speaker = False
speakers = [item[-1] for item in items]
is_multi_speaker = len(set(speakers)) > 1
eval_split_size = 500 if 500 < len(items) * 0.01 else int(len(items) * 0.01)
np.random.seed(0)
np.random.shuffle(items)
if is_multi_speaker:
items_eval = []
# most stupid code ever -- Fix it !
while len(items_eval) < eval_split_size:
speakers = [item[-1] for item in items]
speaker_counter = Counter(speakers)
item_idx = np.random.randint(0, len(items))
if speaker_counter[items[item_idx][-1]] > 1:
items_eval.append(items[item_idx])
del items[item_idx]
return items_eval, items
else:
return items[:eval_split_size], items[eval_split_size:]