зеркало из https://github.com/mozilla/TTS.git
124 строки
4.4 KiB
Python
124 строки
4.4 KiB
Python
import numpy as np
|
|
import torch
|
|
import random
|
|
from torch.utils.data import Dataset
|
|
|
|
|
|
class MyDataset(Dataset):
|
|
def __init__(self, ap, meta_data, voice_len=1.6, num_speakers_in_batch=64,
|
|
num_utter_per_speaker=10, skip_speakers=False, verbose=False):
|
|
"""
|
|
Args:
|
|
ap (TTS.utils.AudioProcessor): audio processor object.
|
|
meta_data (list): list of dataset instances.
|
|
seq_len (int): voice segment length in seconds.
|
|
verbose (bool): print diagnostic information.
|
|
"""
|
|
self.items = meta_data
|
|
self.sample_rate = ap.sample_rate
|
|
self.voice_len = voice_len
|
|
self.seq_len = int(voice_len * self.sample_rate)
|
|
self.num_speakers_in_batch = num_speakers_in_batch
|
|
self.num_utter_per_speaker = num_utter_per_speaker
|
|
self.skip_speakers = skip_speakers
|
|
self.ap = ap
|
|
self.verbose = verbose
|
|
self.__parse_items()
|
|
if self.verbose:
|
|
print("\n > DataLoader initialization")
|
|
print(f" | > Number of instances : {len(self.items)}")
|
|
print(f" | > Sequence length: {self.seq_len}")
|
|
print(f" | > Num speakers: {len(self.speakers)}")
|
|
|
|
def load_wav(self, filename):
|
|
audio = self.ap.load_wav(filename)
|
|
return audio
|
|
|
|
def load_data(self, idx):
|
|
text, wav_file, speaker_name = self.items[idx]
|
|
wav = np.asarray(self.load_wav(wav_file), dtype=np.float32)
|
|
mel = self.ap.melspectrogram(wav).astype("float32")
|
|
# sample seq_len
|
|
|
|
assert text.size > 0, self.items[idx][1]
|
|
assert wav.size > 0, self.items[idx][1]
|
|
|
|
sample = {
|
|
"mel": mel,
|
|
"item_idx": self.items[idx][1],
|
|
"speaker_name": speaker_name,
|
|
}
|
|
return sample
|
|
|
|
def __parse_items(self):
|
|
"""
|
|
Find unique speaker ids and create a dict mapping utterances from speaker id
|
|
"""
|
|
speakers = list({item[-1] for item in self.items})
|
|
self.speaker_to_utters = {}
|
|
self.speakers = []
|
|
for speaker in speakers:
|
|
speaker_utters = [item[1] for item in self.items if item[2] == speaker]
|
|
if len(speaker_utters) < self.num_utter_per_speaker and self.skip_speakers:
|
|
print(
|
|
f" [!] Skipped speaker {speaker}. Not enough utterances {self.num_utter_per_speaker} vs {len(speaker_utters)}."
|
|
)
|
|
else:
|
|
self.speakers.append(speaker)
|
|
self.speaker_to_utters[speaker] = speaker_utters
|
|
|
|
def __len__(self):
|
|
return int(1e10)
|
|
|
|
def __sample_speaker(self):
|
|
speaker = random.sample(self.speakers, 1)[0]
|
|
if self.num_utter_per_speaker > len(self.speaker_to_utters[speaker]):
|
|
utters = random.choices(
|
|
self.speaker_to_utters[speaker], k=self.num_utter_per_speaker
|
|
)
|
|
else:
|
|
utters = random.sample(
|
|
self.speaker_to_utters[speaker], self.num_utter_per_speaker
|
|
)
|
|
return speaker, utters
|
|
|
|
def __sample_speaker_utterances(self, speaker):
|
|
"""
|
|
Sample all M utterances for the given speaker.
|
|
"""
|
|
feats = []
|
|
labels = []
|
|
for _ in range(self.num_utter_per_speaker):
|
|
# TODO:dummy but works
|
|
while True:
|
|
if len(self.speaker_to_utters[speaker]) > 0:
|
|
utter = random.sample(self.speaker_to_utters[speaker], 1)[0]
|
|
else:
|
|
self.speakers.remove(speaker)
|
|
speaker, _ = self.__sample_speaker()
|
|
continue
|
|
wav = self.load_wav(utter)
|
|
if wav.shape[0] - self.seq_len > 0:
|
|
break
|
|
self.speaker_to_utters[speaker].remove(utter)
|
|
|
|
offset = random.randint(0, wav.shape[0] - self.seq_len)
|
|
mel = self.ap.melspectrogram(wav[offset : offset + self.seq_len])
|
|
feats.append(torch.FloatTensor(mel))
|
|
labels.append(speaker)
|
|
return feats, labels
|
|
|
|
def __getitem__(self, idx):
|
|
speaker, _ = self.__sample_speaker()
|
|
return speaker
|
|
|
|
def collate_fn(self, batch):
|
|
labels = []
|
|
feats = []
|
|
for speaker in batch:
|
|
feats_, labels_ = self.__sample_speaker_utterances(speaker)
|
|
labels.append(labels_)
|
|
feats.extend(feats_)
|
|
feats = torch.stack(feats)
|
|
return feats.transpose(1, 2), labels
|