зеркало из https://github.com/mozilla/TTS.git
num_speakers larger than 1
This commit is contained in:
Родитель
d23e29ea1f
Коммит
6390c3b2e6
|
@ -29,7 +29,7 @@ class Tacotron(nn.Module):
|
|||
self.linear_dim = linear_dim
|
||||
self.embedding = nn.Embedding(num_chars, 256)
|
||||
self.embedding.weight.data.normal_(0, 0.3)
|
||||
if num_speakers > 0:
|
||||
if num_speakers > 1:
|
||||
self.speaker_embedding = nn.Embedding(num_speakers, 256)
|
||||
self.speaker_embedding.weight.data.normal_(0, 0.3)
|
||||
self.encoder = Encoder(256)
|
||||
|
|
|
@ -29,7 +29,7 @@ class Tacotron2(nn.Module):
|
|||
std = sqrt(2.0 / (num_chars + 512))
|
||||
val = sqrt(3.0) * std # uniform bounds for std
|
||||
self.embedding.weight.data.uniform_(-val, val)
|
||||
if num_speakers > 0:
|
||||
if num_speakers > 1:
|
||||
self.speaker_embedding = nn.Embedding(num_speakers, 512)
|
||||
self.speaker_embedding.weight.data.normal_(0, 0.3)
|
||||
self.encoder = Encoder(512)
|
||||
|
|
|
@ -30,7 +30,7 @@ class TacotronGST(nn.Module):
|
|||
self.linear_dim = linear_dim
|
||||
self.embedding = nn.Embedding(num_chars, 256)
|
||||
self.embedding.weight.data.normal_(0, 0.3)
|
||||
if num_speakers > 0:
|
||||
if num_speakers > 1:
|
||||
self.speaker_embedding = nn.Embedding(num_speakers, 256)
|
||||
self.speaker_embedding.weight.data.normal_(0, 0.3)
|
||||
self.encoder = Encoder(256)
|
||||
|
|
14
train.py
14
train.py
|
@ -78,7 +78,7 @@ def setup_loader(is_val=False, verbose=False):
|
|||
def train(model, criterion, criterion_st, optimizer, optimizer_st, scheduler,
|
||||
ap, epoch):
|
||||
data_loader = setup_loader(is_val=False, verbose=(epoch==0))
|
||||
if c.num_speakers > 0:
|
||||
if c.num_speakers > 1:
|
||||
speaker_mapping = load_speaker_mapping(OUT_PATH)
|
||||
model.train()
|
||||
epoch_time = 0
|
||||
|
@ -102,7 +102,7 @@ def train(model, criterion, criterion_st, optimizer, optimizer_st, scheduler,
|
|||
avg_text_length = torch.mean(text_lengths.float())
|
||||
avg_spec_length = torch.mean(mel_lengths.float())
|
||||
|
||||
if c.num_speakers > 0:
|
||||
if c.num_speakers > 1:
|
||||
speaker_ids = []
|
||||
for speaker_name in speaker_names:
|
||||
if speaker_name not in speaker_mapping:
|
||||
|
@ -272,14 +272,14 @@ def train(model, criterion, criterion_st, optimizer, optimizer_st, scheduler,
|
|||
tb_logger.tb_model_weights(model, current_step)
|
||||
|
||||
# save speaker mapping
|
||||
if c.num_speakers > 0:
|
||||
if c.num_speakers > 1:
|
||||
save_speaker_mapping(OUT_PATH, speaker_mapping)
|
||||
return avg_postnet_loss, current_step
|
||||
|
||||
|
||||
def evaluate(model, criterion, criterion_st, ap, current_step, epoch):
|
||||
data_loader = setup_loader(is_val=True)
|
||||
if c.num_speakers > 0:
|
||||
if c.num_speakers > 1:
|
||||
speaker_mapping = load_speaker_mapping(OUT_PATH)
|
||||
model.eval()
|
||||
epoch_time = 0
|
||||
|
@ -311,7 +311,7 @@ def evaluate(model, criterion, criterion_st, ap, current_step, epoch):
|
|||
mel_lengths = data[5]
|
||||
stop_targets = data[6]
|
||||
|
||||
if c.num_speakers > 0:
|
||||
if c.num_speakers > 1:
|
||||
speaker_ids = [speaker_mapping[speaker_name]
|
||||
for speaker_name in speaker_names]
|
||||
speaker_ids = torch.LongTensor(speaker_ids)
|
||||
|
@ -415,7 +415,7 @@ def evaluate(model, criterion, criterion_st, ap, current_step, epoch):
|
|||
test_audios = {}
|
||||
test_figures = {}
|
||||
print(" | > Synthesizing test sentences")
|
||||
speaker_id = 0 if c.num_speakers > 0 else None
|
||||
speaker_id = 0 if c.num_speakers > 1 else None
|
||||
for idx, test_sentence in enumerate(test_sentences):
|
||||
try:
|
||||
wav, alignment, decoder_output, postnet_output, stop_tokens = synthesis(
|
||||
|
@ -484,7 +484,7 @@ def main(args):
|
|||
args.restore_step = checkpoint['step']
|
||||
# copying speakers.json
|
||||
prev_out_path = os.path.dirname(args.restore_path)
|
||||
if c.num_speakers > 0:
|
||||
if c.num_speakers > 1:
|
||||
copy_speaker_mapping(prev_out_path, OUT_PATH)
|
||||
else:
|
||||
args.restore_step = 0
|
||||
|
|
Загрузка…
Ссылка в новой задаче