зеркало из https://github.com/mozilla/TTS.git
Batch group shuffling
This commit is contained in:
Родитель
a165cd7bda
Коммит
30fea0b957
|
@ -3,6 +3,7 @@ import numpy as np
|
|||
import collections
|
||||
import librosa
|
||||
import torch
|
||||
import random
|
||||
from torch.utils.data import Dataset
|
||||
|
||||
from utils.text import text_to_sequence
|
||||
|
@ -17,8 +18,10 @@ class MyDataset(Dataset):
|
|||
outputs_per_step,
|
||||
text_cleaner,
|
||||
ap,
|
||||
batch_group_size=0,
|
||||
min_seq_len=0):
|
||||
self.root_dir = root_dir
|
||||
self.batch_group_size = batch_group_size
|
||||
self.wav_dir = os.path.join(root_dir, 'wavs')
|
||||
self.csv_dir = os.path.join(root_dir, csv_file)
|
||||
with open(self.csv_dir, "r", encoding="utf8") as f:
|
||||
|
@ -30,7 +33,7 @@ class MyDataset(Dataset):
|
|||
self.ap = ap
|
||||
print(" > Reading LJSpeech from - {}".format(root_dir))
|
||||
print(" | > Number of instances : {}".format(len(self.frames)))
|
||||
self._sort_frames()
|
||||
self.sort_frames()
|
||||
|
||||
def load_wav(self, filename):
|
||||
try:
|
||||
|
@ -39,8 +42,8 @@ class MyDataset(Dataset):
|
|||
except RuntimeError as e:
|
||||
print(" !! Cannot read file : {}".format(filename))
|
||||
|
||||
def _sort_frames(self):
|
||||
r"""Sort sequences in ascending order"""
|
||||
def sort_frames(self):
|
||||
r"""Sort text sequences in ascending order"""
|
||||
lengths = np.array([len(ins[1]) for ins in self.frames])
|
||||
|
||||
print(" | > Max length sequence {}".format(np.max(lengths)))
|
||||
|
@ -58,6 +61,15 @@ class MyDataset(Dataset):
|
|||
new_frames.append(self.frames[idx])
|
||||
print(" | > {} instances are ignored by min_seq_len ({})".format(
|
||||
len(ignored), self.min_seq_len))
|
||||
# shuffle batch groups
|
||||
if self.batch_group_size > 0:
|
||||
print(" | > Batch group shuffling is active.")
|
||||
for i in range(len(new_frames) // self.batch_group_size):
|
||||
offset = i * self.batch_group_size
|
||||
end_offset = offset + self.batch_group_size
|
||||
temp_frames = new_frames[offset : end_offset]
|
||||
random.shuffle(temp_frames)
|
||||
new_frames[offset : end_offset] = temp_frames
|
||||
self.frames = new_frames
|
||||
|
||||
def __len__(self):
|
||||
|
|
|
@ -66,6 +66,49 @@ class TestLJSpeechDataset(unittest.TestCase):
|
|||
assert mel_input.shape[0] == c.batch_size
|
||||
assert mel_input.shape[2] == c.num_mels
|
||||
|
||||
def test_batch_group_shuffle(self):
|
||||
if ok_ljspeech:
|
||||
dataset = LJSpeech.MyDataset(
|
||||
os.path.join(c.data_path_LJSpeech),
|
||||
os.path.join(c.data_path_LJSpeech, 'metadata.csv'),
|
||||
c.r,
|
||||
c.text_cleaner,
|
||||
ap=self.ap,
|
||||
batch_group_size=16,
|
||||
min_seq_len=c.min_seq_len)
|
||||
|
||||
dataloader = DataLoader(
|
||||
dataset,
|
||||
batch_size=2,
|
||||
shuffle=True,
|
||||
collate_fn=dataset.collate_fn,
|
||||
drop_last=True,
|
||||
num_workers=c.num_loader_workers)
|
||||
|
||||
frames = dataset.frames
|
||||
for i, data in enumerate(dataloader):
|
||||
if i == self.max_loader_iter:
|
||||
break
|
||||
text_input = data[0]
|
||||
text_lengths = data[1]
|
||||
linear_input = data[2]
|
||||
mel_input = data[3]
|
||||
mel_lengths = data[4]
|
||||
stop_target = data[5]
|
||||
item_idx = data[6]
|
||||
|
||||
neg_values = text_input[text_input < 0]
|
||||
check_count = len(neg_values)
|
||||
assert check_count == 0, \
|
||||
" !! Negative values in text_input: {}".format(check_count)
|
||||
# TODO: more assertion here
|
||||
assert linear_input.shape[0] == c.batch_size
|
||||
assert mel_input.shape[0] == c.batch_size
|
||||
assert mel_input.shape[2] == c.num_mels
|
||||
dataloader.dataset.sort_frames()
|
||||
assert frames[0] != dataloader.dataset.frames[0]
|
||||
|
||||
|
||||
def test_padding(self):
|
||||
if ok_ljspeech:
|
||||
dataset = LJSpeech.MyDataset(
|
||||
|
|
6
train.py
6
train.py
|
@ -191,7 +191,6 @@ def train(model, criterion, criterion_st, data_loader, optimizer, optimizer_st,
|
|||
tb.add_scalar('TrainEpochLoss/StopLoss', avg_stop_loss, current_step)
|
||||
tb.add_scalar('Time/EpochTime', epoch_time, epoch)
|
||||
epoch_time = 0
|
||||
|
||||
return avg_linear_loss, current_step
|
||||
|
||||
|
||||
|
@ -361,6 +360,7 @@ def main(args):
|
|||
c.r,
|
||||
c.text_cleaner,
|
||||
ap=ap,
|
||||
batch_group_size=16*c.batch_size,
|
||||
min_seq_len=c.min_seq_len)
|
||||
|
||||
train_loader = DataLoader(
|
||||
|
@ -374,7 +374,7 @@ def main(args):
|
|||
|
||||
if c.run_eval:
|
||||
val_dataset = Dataset(
|
||||
c.data_path, c.meta_file_val, c.r, c.text_cleaner, ap=ap)
|
||||
c.data_path, c.meta_file_val, c.r, c.text_cleaner, ap=ap, batch_group_size=0)
|
||||
|
||||
val_loader = DataLoader(
|
||||
val_dataset,
|
||||
|
@ -444,6 +444,8 @@ def main(args):
|
|||
flush=True)
|
||||
best_loss = save_best_model(model, optimizer, train_loss, best_loss,
|
||||
OUT_PATH, current_step, epoch)
|
||||
# shuffle batch groups
|
||||
train_loader.dataset.sort_frames()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
|
Загрузка…
Ссылка в новой задаче