From fd081c49b74a616122d08a9e4d3f4ee043efa68c Mon Sep 17 00:00:00 2001 From: Eren Golge Date: Tue, 16 Jul 2019 21:15:04 +0200 Subject: [PATCH] split dataset outside preprocessor --- datasets/preprocess.py | 4 +--- train.py | 13 +++++++------ utils/generic_utils.py | 25 ++++++++++++++++++++++++- 3 files changed, 32 insertions(+), 10 deletions(-) diff --git a/datasets/preprocess.py b/datasets/preprocess.py index e359fd0..2359983 100644 --- a/datasets/preprocess.py +++ b/datasets/preprocess.py @@ -146,7 +146,7 @@ def common_voice(root_path, meta_file): return items -def libri_tts(root_path, meta_files=None, is_eval=False): +def libri_tts(root_path, meta_files=None): """https://ai.google/tools/datasets/libri-tts/""" items = [] if meta_files is None: @@ -164,6 +164,4 @@ def libri_tts(root_path, meta_files=None, is_eval=False): items.append([text, wav_file, speaker_name]) for item in items: assert os.path.exists(item[1]), f" [!] wav file is not exist - {item[1]}" - if meta_files is None: - return items[:500] if is_eval else items[500:] return items \ No newline at end of file diff --git a/train.py b/train.py index ca40986..458845b 100644 --- a/train.py +++ b/train.py @@ -21,7 +21,8 @@ from utils.generic_utils import (NoamLR, check_update, count_parameters, load_config, lr_decay, remove_experiment_folder, save_best_model, save_checkpoint, sequence_mask, weight_decay, - set_init_dict, copy_config_file, setup_model) + set_init_dict, copy_config_file, setup_model, + split_dataset) from utils.logger import Logger from utils.speakers import load_speaker_mapping, save_speaker_mapping, \ get_speakers @@ -44,15 +45,15 @@ def setup_loader(is_val=False, verbose=False): global meta_data_train global meta_data_eval if "meta_data_train" not in globals(): - if c.meta_file_train: + if c.meta_file_train is not None: meta_data_train = get_preprocessor_by_name(c.dataset)(c.data_path, c.meta_file_train) else: - meta_data_train = get_preprocessor_by_name(c.dataset)(c.data_path, is_eval=False) - if "meta_data_eval" not in globals(): - if c.meta_file_val: + meta_data_train = get_preprocessor_by_name(c.dataset)(c.data_path) + if "meta_data_eval" not in globals() and c.run_eval: + if c.meta_file_val is not None: meta_data_eval = get_preprocessor_by_name(c.dataset)(c.data_path, c.meta_file_val) else: - meta_data_eval = get_preprocessor_by_name(c.dataset)(c.data_path, is_eval=True) + meta_data_eval, meta_data_train = split_dataset(meta_data_train) if is_val and not c.run_eval: loader = None else: diff --git a/utils/generic_utils.py b/utils/generic_utils.py index ff97185..fe7e062 100644 --- a/utils/generic_utils.py +++ b/utils/generic_utils.py @@ -10,7 +10,7 @@ import torch import subprocess import importlib import numpy as np -from collections import OrderedDict +from collections import OrderedDict, Counter from torch.autograd import Variable from utils.text import text_to_sequence @@ -287,3 +287,26 @@ def setup_model(num_chars, num_speakers, c): location_attn=c.location_attn, separate_stopnet=c.separate_stopnet) return model + + +def split_dataset(items): + is_multi_speaker = False + speakers = [item[-1] for item in items] + is_multi_speaker = len(set(speakers)) > 1 + eval_split_size = 500 if 500 < len(items) * 0.01 else int(len(items) * 0.01) + np.random.seed(0) + np.random.shuffle(items) + if is_multi_speaker: + items_eval = [] + # most stupid code ever -- Fix it ! + while len(items_eval) < eval_split_size: + speakers = [item[-1] for item in items] + speaker_counter = Counter(speakers) + item_idx = np.random.randint(0, len(items)) + if speaker_counter[items[item_idx][-1]] > 1: + items_eval.append(items[item_idx]) + del items[item_idx] + return items_eval, items + else: + return items[:eval_split_size], items[eval_split_size:] +