This commit is contained in:
Eren Golge 2018-04-30 06:12:12 -07:00
Родитель 40b479b3b9
Коммит fda7e7f6c9
1 изменённых файлов: 3 добавлений и 0 удалений

Просмотреть файл

@ -100,6 +100,8 @@ def train(model, criterion, criterion_st, data_loader, optimizer, epoch):
mel_spec = mel_spec.cuda()
mel_lengths = mel_lengths.cuda()
linear_spec = linear_spec.cuda()
stop_target = stop_target.cuda()
# create attention mask
if c.mk > 0.0:
@ -241,6 +243,7 @@ def evaluate(model, criterion, criterion_st, data_loader, current_step):
mel_spec = mel_spec.cuda()
mel_lengths = mel_lengths.cuda()
linear_spec = linear_spec.cuda()
stop_target = stop_target.cuda()
# forward pass
mel_output, linear_output, alignments, stop_tokens =\