Set num output units based on audio processor and set imported audio processor based on config.json

This commit is contained in:
Eren G 2018-07-27 16:14:10 +02:00
Родитель 355df8fe9b
Коммит adedd7b1a9
1 изменённых файлов: 15 добавлений и 11 удалений

Просмотреть файл

@ -325,17 +325,20 @@ def evaluate(model, criterion, criterion_st, data_loader, ap, current_step):
def main(args):
dataset = importlib.import_module('datasets.'+c.dataset)
Dataset = getattr(dataset, 'MyDataset')
audio = importlib.import_module('utils.'+c.audio_processor)
AudioProcessor = getattr(audio, 'AudioProcessor')
ap = AudioProcessor(sample_rate = c.sample_rate,
num_mels = c.num_mels,
min_level_db = c.min_level_db,
frame_shift_ms = c.frame_shift_ms,
frame_length_ms = c.frame_length_ms,
ref_level_db = c.ref_level_db,
num_freq = c.num_freq,
power = c.power,
min_mel_freq = c.min_mel_freq,
max_mel_freq = c.max_mel_freq)
ap = AudioProcessor(sample_rate=c.sample_rate,
num_mels=c.num_mels,
min_level_db=c.min_level_db,
frame_shift_ms=c.frame_shift_ms,
frame_length_ms=c.frame_length_ms,
ref_level_db=c.ref_level_db,
num_freq=c.num_freq,
power=c.power,
preemphasis=c.preemphasis,
min_mel_freq=c.min_mel_freq,
max_mel_freq=c.max_mel_freq)
# Setup the dataset
train_dataset = Dataset(c.data_path,
@ -367,9 +370,10 @@ def main(args):
val_loader = None
model = Tacotron(c.embedding_size,
c.num_freq,
ap.num_freq,
c.num_mels,
c.r)
print(" | > Num output units : {}".format(ap.num_freq))
optimizer = optim.Adam(model.parameters(), lr=c.lr)
optimizer_st = optim.Adam(model.decoder.stopnet.parameters(), lr=c.lr)