This commit is contained in:
Eren 2018-09-20 11:08:12 +02:00
Родитель a165cd7bda
Коммит 30fea0b957
3 изменённых файлов: 62 добавлений и 5 удалений

Просмотреть файл

@ -3,6 +3,7 @@ import numpy as np
import collections
import librosa
import torch
import random
from torch.utils.data import Dataset
from utils.text import text_to_sequence
@ -17,8 +18,10 @@ class MyDataset(Dataset):
outputs_per_step,
text_cleaner,
ap,
batch_group_size=0,
min_seq_len=0):
self.root_dir = root_dir
self.batch_group_size = batch_group_size
self.wav_dir = os.path.join(root_dir, 'wavs')
self.csv_dir = os.path.join(root_dir, csv_file)
with open(self.csv_dir, "r", encoding="utf8") as f:
@ -30,7 +33,7 @@ class MyDataset(Dataset):
self.ap = ap
print(" > Reading LJSpeech from - {}".format(root_dir))
print(" | > Number of instances : {}".format(len(self.frames)))
self._sort_frames()
self.sort_frames()
def load_wav(self, filename):
try:
@ -39,8 +42,8 @@ class MyDataset(Dataset):
except RuntimeError as e:
print(" !! Cannot read file : {}".format(filename))
def _sort_frames(self):
r"""Sort sequences in ascending order"""
def sort_frames(self):
r"""Sort text sequences in ascending order"""
lengths = np.array([len(ins[1]) for ins in self.frames])
print(" | > Max length sequence {}".format(np.max(lengths)))
@ -58,6 +61,15 @@ class MyDataset(Dataset):
new_frames.append(self.frames[idx])
print(" | > {} instances are ignored by min_seq_len ({})".format(
len(ignored), self.min_seq_len))
# shuffle batch groups
if self.batch_group_size > 0:
print(" | > Batch group shuffling is active.")
for i in range(len(new_frames) // self.batch_group_size):
offset = i * self.batch_group_size
end_offset = offset + self.batch_group_size
temp_frames = new_frames[offset : end_offset]
random.shuffle(temp_frames)
new_frames[offset : end_offset] = temp_frames
self.frames = new_frames
def __len__(self):

Просмотреть файл

@ -66,6 +66,49 @@ class TestLJSpeechDataset(unittest.TestCase):
assert mel_input.shape[0] == c.batch_size
assert mel_input.shape[2] == c.num_mels
def test_batch_group_shuffle(self):
if ok_ljspeech:
dataset = LJSpeech.MyDataset(
os.path.join(c.data_path_LJSpeech),
os.path.join(c.data_path_LJSpeech, 'metadata.csv'),
c.r,
c.text_cleaner,
ap=self.ap,
batch_group_size=16,
min_seq_len=c.min_seq_len)
dataloader = DataLoader(
dataset,
batch_size=2,
shuffle=True,
collate_fn=dataset.collate_fn,
drop_last=True,
num_workers=c.num_loader_workers)
frames = dataset.frames
for i, data in enumerate(dataloader):
if i == self.max_loader_iter:
break
text_input = data[0]
text_lengths = data[1]
linear_input = data[2]
mel_input = data[3]
mel_lengths = data[4]
stop_target = data[5]
item_idx = data[6]
neg_values = text_input[text_input < 0]
check_count = len(neg_values)
assert check_count == 0, \
" !! Negative values in text_input: {}".format(check_count)
# TODO: more assertion here
assert linear_input.shape[0] == c.batch_size
assert mel_input.shape[0] == c.batch_size
assert mel_input.shape[2] == c.num_mels
dataloader.dataset.sort_frames()
assert frames[0] != dataloader.dataset.frames[0]
def test_padding(self):
if ok_ljspeech:
dataset = LJSpeech.MyDataset(

Просмотреть файл

@ -191,7 +191,6 @@ def train(model, criterion, criterion_st, data_loader, optimizer, optimizer_st,
tb.add_scalar('TrainEpochLoss/StopLoss', avg_stop_loss, current_step)
tb.add_scalar('Time/EpochTime', epoch_time, epoch)
epoch_time = 0
return avg_linear_loss, current_step
@ -361,6 +360,7 @@ def main(args):
c.r,
c.text_cleaner,
ap=ap,
batch_group_size=16*c.batch_size,
min_seq_len=c.min_seq_len)
train_loader = DataLoader(
@ -374,7 +374,7 @@ def main(args):
if c.run_eval:
val_dataset = Dataset(
c.data_path, c.meta_file_val, c.r, c.text_cleaner, ap=ap)
c.data_path, c.meta_file_val, c.r, c.text_cleaner, ap=ap, batch_group_size=0)
val_loader = DataLoader(
val_dataset,
@ -444,6 +444,8 @@ def main(args):
flush=True)
best_loss = save_best_model(model, optimizer, train_loss, best_loss,
OUT_PATH, current_step, epoch)
# shuffle batch groups
train_loader.dataset.sort_frames()
if __name__ == '__main__':