This commit is contained in:
Thomas Werkmeister 2019-07-02 14:46:41 +02:00
Родитель d23e29ea1f
Коммит 6390c3b2e6
4 изменённых файлов: 10 добавлений и 10 удалений

Просмотреть файл

@ -29,7 +29,7 @@ class Tacotron(nn.Module):
self.linear_dim = linear_dim self.linear_dim = linear_dim
self.embedding = nn.Embedding(num_chars, 256) self.embedding = nn.Embedding(num_chars, 256)
self.embedding.weight.data.normal_(0, 0.3) self.embedding.weight.data.normal_(0, 0.3)
if num_speakers > 0: if num_speakers > 1:
self.speaker_embedding = nn.Embedding(num_speakers, 256) self.speaker_embedding = nn.Embedding(num_speakers, 256)
self.speaker_embedding.weight.data.normal_(0, 0.3) self.speaker_embedding.weight.data.normal_(0, 0.3)
self.encoder = Encoder(256) self.encoder = Encoder(256)

Просмотреть файл

@ -29,7 +29,7 @@ class Tacotron2(nn.Module):
std = sqrt(2.0 / (num_chars + 512)) std = sqrt(2.0 / (num_chars + 512))
val = sqrt(3.0) * std # uniform bounds for std val = sqrt(3.0) * std # uniform bounds for std
self.embedding.weight.data.uniform_(-val, val) self.embedding.weight.data.uniform_(-val, val)
if num_speakers > 0: if num_speakers > 1:
self.speaker_embedding = nn.Embedding(num_speakers, 512) self.speaker_embedding = nn.Embedding(num_speakers, 512)
self.speaker_embedding.weight.data.normal_(0, 0.3) self.speaker_embedding.weight.data.normal_(0, 0.3)
self.encoder = Encoder(512) self.encoder = Encoder(512)

Просмотреть файл

@ -30,7 +30,7 @@ class TacotronGST(nn.Module):
self.linear_dim = linear_dim self.linear_dim = linear_dim
self.embedding = nn.Embedding(num_chars, 256) self.embedding = nn.Embedding(num_chars, 256)
self.embedding.weight.data.normal_(0, 0.3) self.embedding.weight.data.normal_(0, 0.3)
if num_speakers > 0: if num_speakers > 1:
self.speaker_embedding = nn.Embedding(num_speakers, 256) self.speaker_embedding = nn.Embedding(num_speakers, 256)
self.speaker_embedding.weight.data.normal_(0, 0.3) self.speaker_embedding.weight.data.normal_(0, 0.3)
self.encoder = Encoder(256) self.encoder = Encoder(256)

Просмотреть файл

@ -78,7 +78,7 @@ def setup_loader(is_val=False, verbose=False):
def train(model, criterion, criterion_st, optimizer, optimizer_st, scheduler, def train(model, criterion, criterion_st, optimizer, optimizer_st, scheduler,
ap, epoch): ap, epoch):
data_loader = setup_loader(is_val=False, verbose=(epoch==0)) data_loader = setup_loader(is_val=False, verbose=(epoch==0))
if c.num_speakers > 0: if c.num_speakers > 1:
speaker_mapping = load_speaker_mapping(OUT_PATH) speaker_mapping = load_speaker_mapping(OUT_PATH)
model.train() model.train()
epoch_time = 0 epoch_time = 0
@ -102,7 +102,7 @@ def train(model, criterion, criterion_st, optimizer, optimizer_st, scheduler,
avg_text_length = torch.mean(text_lengths.float()) avg_text_length = torch.mean(text_lengths.float())
avg_spec_length = torch.mean(mel_lengths.float()) avg_spec_length = torch.mean(mel_lengths.float())
if c.num_speakers > 0: if c.num_speakers > 1:
speaker_ids = [] speaker_ids = []
for speaker_name in speaker_names: for speaker_name in speaker_names:
if speaker_name not in speaker_mapping: if speaker_name not in speaker_mapping:
@ -272,14 +272,14 @@ def train(model, criterion, criterion_st, optimizer, optimizer_st, scheduler,
tb_logger.tb_model_weights(model, current_step) tb_logger.tb_model_weights(model, current_step)
# save speaker mapping # save speaker mapping
if c.num_speakers > 0: if c.num_speakers > 1:
save_speaker_mapping(OUT_PATH, speaker_mapping) save_speaker_mapping(OUT_PATH, speaker_mapping)
return avg_postnet_loss, current_step return avg_postnet_loss, current_step
def evaluate(model, criterion, criterion_st, ap, current_step, epoch): def evaluate(model, criterion, criterion_st, ap, current_step, epoch):
data_loader = setup_loader(is_val=True) data_loader = setup_loader(is_val=True)
if c.num_speakers > 0: if c.num_speakers > 1:
speaker_mapping = load_speaker_mapping(OUT_PATH) speaker_mapping = load_speaker_mapping(OUT_PATH)
model.eval() model.eval()
epoch_time = 0 epoch_time = 0
@ -311,7 +311,7 @@ def evaluate(model, criterion, criterion_st, ap, current_step, epoch):
mel_lengths = data[5] mel_lengths = data[5]
stop_targets = data[6] stop_targets = data[6]
if c.num_speakers > 0: if c.num_speakers > 1:
speaker_ids = [speaker_mapping[speaker_name] speaker_ids = [speaker_mapping[speaker_name]
for speaker_name in speaker_names] for speaker_name in speaker_names]
speaker_ids = torch.LongTensor(speaker_ids) speaker_ids = torch.LongTensor(speaker_ids)
@ -415,7 +415,7 @@ def evaluate(model, criterion, criterion_st, ap, current_step, epoch):
test_audios = {} test_audios = {}
test_figures = {} test_figures = {}
print(" | > Synthesizing test sentences") print(" | > Synthesizing test sentences")
speaker_id = 0 if c.num_speakers > 0 else None speaker_id = 0 if c.num_speakers > 1 else None
for idx, test_sentence in enumerate(test_sentences): for idx, test_sentence in enumerate(test_sentences):
try: try:
wav, alignment, decoder_output, postnet_output, stop_tokens = synthesis( wav, alignment, decoder_output, postnet_output, stop_tokens = synthesis(
@ -484,7 +484,7 @@ def main(args):
args.restore_step = checkpoint['step'] args.restore_step = checkpoint['step']
# copying speakers.json # copying speakers.json
prev_out_path = os.path.dirname(args.restore_path) prev_out_path = os.path.dirname(args.restore_path)
if c.num_speakers > 0: if c.num_speakers > 1:
copy_speaker_mapping(prev_out_path, OUT_PATH) copy_speaker_mapping(prev_out_path, OUT_PATH)
else: else:
args.restore_step = 0 args.restore_step = 0