зеркало из https://github.com/mozilla/TTS.git
update model tests for ddc
This commit is contained in:
Родитель
e272e07ce9
Коммит
43d278be2f
|
@ -51,7 +51,7 @@ class TacotronTrainTest(unittest.TestCase):
|
|||
optimizer = optim.Adam(model.parameters(), lr=c.lr)
|
||||
for i in range(5):
|
||||
mel_out, mel_postnet_out, align, stop_tokens = model.forward(
|
||||
input, input_lengths, mel_spec, speaker_ids)
|
||||
input, input_lengths, mel_spec, mel_lengths, speaker_ids)
|
||||
assert torch.sigmoid(stop_tokens).data.max() <= 1.0
|
||||
assert torch.sigmoid(stop_tokens).data.min() >= 0.0
|
||||
optimizer.zero_grad()
|
||||
|
|
|
@ -66,7 +66,7 @@ class TacotronTrainTest(unittest.TestCase):
|
|||
optimizer = optim.Adam(model.parameters(), lr=c.lr)
|
||||
for _ in range(5):
|
||||
mel_out, linear_out, align, stop_tokens = model.forward(
|
||||
input_dummy, input_lengths, mel_spec, speaker_ids)
|
||||
input_dummy, input_lengths, mel_spec, mel_lengths, speaker_ids)
|
||||
optimizer.zero_grad()
|
||||
loss = criterion(mel_out, mel_spec, mel_lengths)
|
||||
stop_loss = criterion_st(stop_tokens, stop_targets)
|
||||
|
@ -95,6 +95,7 @@ class TacotronGSTTrainTest(unittest.TestCase):
|
|||
mel_spec = torch.rand(8, 120, c.audio['num_mels']).to(device)
|
||||
linear_spec = torch.rand(8, 120, c.audio['num_freq']).to(device)
|
||||
mel_lengths = torch.randint(20, 120, (8, )).long().to(device)
|
||||
mel_lengths[-1] = 120
|
||||
stop_targets = torch.zeros(8, 120, 1).float().to(device)
|
||||
speaker_ids = torch.randint(0, 5, (8, )).long().to(device)
|
||||
|
||||
|
@ -130,7 +131,7 @@ class TacotronGSTTrainTest(unittest.TestCase):
|
|||
optimizer = optim.Adam(model.parameters(), lr=c.lr)
|
||||
for _ in range(10):
|
||||
mel_out, linear_out, align, stop_tokens = model.forward(
|
||||
input_dummy, input_lengths, mel_spec, speaker_ids)
|
||||
input_dummy, input_lengths, mel_spec, mel_lengths, speaker_ids)
|
||||
optimizer.zero_grad()
|
||||
loss = criterion(mel_out, mel_spec, mel_lengths)
|
||||
stop_loss = criterion_st(stop_tokens, stop_targets)
|
||||
|
|
Загрузка…
Ссылка в новой задаче