зеркало из https://github.com/mozilla/TTS.git
Set num output units based on audio processor and set imported audio processor based on config.json
This commit is contained in:
Родитель
355df8fe9b
Коммит
adedd7b1a9
26
train.py
26
train.py
|
@ -325,17 +325,20 @@ def evaluate(model, criterion, criterion_st, data_loader, ap, current_step):
|
|||
def main(args):
|
||||
dataset = importlib.import_module('datasets.'+c.dataset)
|
||||
Dataset = getattr(dataset, 'MyDataset')
|
||||
audio = importlib.import_module('utils.'+c.audio_processor)
|
||||
AudioProcessor = getattr(audio, 'AudioProcessor')
|
||||
|
||||
ap = AudioProcessor(sample_rate = c.sample_rate,
|
||||
num_mels = c.num_mels,
|
||||
min_level_db = c.min_level_db,
|
||||
frame_shift_ms = c.frame_shift_ms,
|
||||
frame_length_ms = c.frame_length_ms,
|
||||
ref_level_db = c.ref_level_db,
|
||||
num_freq = c.num_freq,
|
||||
power = c.power,
|
||||
min_mel_freq = c.min_mel_freq,
|
||||
max_mel_freq = c.max_mel_freq)
|
||||
ap = AudioProcessor(sample_rate=c.sample_rate,
|
||||
num_mels=c.num_mels,
|
||||
min_level_db=c.min_level_db,
|
||||
frame_shift_ms=c.frame_shift_ms,
|
||||
frame_length_ms=c.frame_length_ms,
|
||||
ref_level_db=c.ref_level_db,
|
||||
num_freq=c.num_freq,
|
||||
power=c.power,
|
||||
preemphasis=c.preemphasis,
|
||||
min_mel_freq=c.min_mel_freq,
|
||||
max_mel_freq=c.max_mel_freq)
|
||||
|
||||
# Setup the dataset
|
||||
train_dataset = Dataset(c.data_path,
|
||||
|
@ -367,9 +370,10 @@ def main(args):
|
|||
val_loader = None
|
||||
|
||||
model = Tacotron(c.embedding_size,
|
||||
c.num_freq,
|
||||
ap.num_freq,
|
||||
c.num_mels,
|
||||
c.r)
|
||||
print(" | > Num output units : {}".format(ap.num_freq))
|
||||
|
||||
optimizer = optim.Adam(model.parameters(), lr=c.lr)
|
||||
optimizer_st = optim.Adam(model.decoder.stopnet.parameters(), lr=c.lr)
|
||||
|
|
Загрузка…
Ссылка в новой задаче