This commit is contained in:
Thomas Werkmeister 2019-07-02 14:46:41 +02:00
Родитель d23e29ea1f
Коммит 6390c3b2e6
4 изменённых файлов: 10 добавлений и 10 удалений

Просмотреть файл

@ -29,7 +29,7 @@ class Tacotron(nn.Module):
self.linear_dim = linear_dim
self.embedding = nn.Embedding(num_chars, 256)
self.embedding.weight.data.normal_(0, 0.3)
if num_speakers > 0:
if num_speakers > 1:
self.speaker_embedding = nn.Embedding(num_speakers, 256)
self.speaker_embedding.weight.data.normal_(0, 0.3)
self.encoder = Encoder(256)

Просмотреть файл

@ -29,7 +29,7 @@ class Tacotron2(nn.Module):
std = sqrt(2.0 / (num_chars + 512))
val = sqrt(3.0) * std # uniform bounds for std
self.embedding.weight.data.uniform_(-val, val)
if num_speakers > 0:
if num_speakers > 1:
self.speaker_embedding = nn.Embedding(num_speakers, 512)
self.speaker_embedding.weight.data.normal_(0, 0.3)
self.encoder = Encoder(512)

Просмотреть файл

@ -30,7 +30,7 @@ class TacotronGST(nn.Module):
self.linear_dim = linear_dim
self.embedding = nn.Embedding(num_chars, 256)
self.embedding.weight.data.normal_(0, 0.3)
if num_speakers > 0:
if num_speakers > 1:
self.speaker_embedding = nn.Embedding(num_speakers, 256)
self.speaker_embedding.weight.data.normal_(0, 0.3)
self.encoder = Encoder(256)

Просмотреть файл

@ -78,7 +78,7 @@ def setup_loader(is_val=False, verbose=False):
def train(model, criterion, criterion_st, optimizer, optimizer_st, scheduler,
ap, epoch):
data_loader = setup_loader(is_val=False, verbose=(epoch==0))
if c.num_speakers > 0:
if c.num_speakers > 1:
speaker_mapping = load_speaker_mapping(OUT_PATH)
model.train()
epoch_time = 0
@ -102,7 +102,7 @@ def train(model, criterion, criterion_st, optimizer, optimizer_st, scheduler,
avg_text_length = torch.mean(text_lengths.float())
avg_spec_length = torch.mean(mel_lengths.float())
if c.num_speakers > 0:
if c.num_speakers > 1:
speaker_ids = []
for speaker_name in speaker_names:
if speaker_name not in speaker_mapping:
@ -272,14 +272,14 @@ def train(model, criterion, criterion_st, optimizer, optimizer_st, scheduler,
tb_logger.tb_model_weights(model, current_step)
# save speaker mapping
if c.num_speakers > 0:
if c.num_speakers > 1:
save_speaker_mapping(OUT_PATH, speaker_mapping)
return avg_postnet_loss, current_step
def evaluate(model, criterion, criterion_st, ap, current_step, epoch):
data_loader = setup_loader(is_val=True)
if c.num_speakers > 0:
if c.num_speakers > 1:
speaker_mapping = load_speaker_mapping(OUT_PATH)
model.eval()
epoch_time = 0
@ -311,7 +311,7 @@ def evaluate(model, criterion, criterion_st, ap, current_step, epoch):
mel_lengths = data[5]
stop_targets = data[6]
if c.num_speakers > 0:
if c.num_speakers > 1:
speaker_ids = [speaker_mapping[speaker_name]
for speaker_name in speaker_names]
speaker_ids = torch.LongTensor(speaker_ids)
@ -415,7 +415,7 @@ def evaluate(model, criterion, criterion_st, ap, current_step, epoch):
test_audios = {}
test_figures = {}
print(" | > Synthesizing test sentences")
speaker_id = 0 if c.num_speakers > 0 else None
speaker_id = 0 if c.num_speakers > 1 else None
for idx, test_sentence in enumerate(test_sentences):
try:
wav, alignment, decoder_output, postnet_output, stop_tokens = synthesis(
@ -484,7 +484,7 @@ def main(args):
args.restore_step = checkpoint['step']
# copying speakers.json
prev_out_path = os.path.dirname(args.restore_path)
if c.num_speakers > 0:
if c.num_speakers > 1:
copy_speaker_mapping(prev_out_path, OUT_PATH)
else:
args.restore_step = 0