зеркало из https://github.com/mozilla/TTS.git
Count total number of model parameters
This commit is contained in:
Родитель
e3b2d2a827
Коммит
c72b8fd64c
Различия файлов скрыты, потому что одна или несколько строк слишком длинны
|
@ -2,4 +2,4 @@ librosa
|
||||||
inflect
|
inflect
|
||||||
unidecode
|
unidecode
|
||||||
tensorboard
|
tensorboard
|
||||||
tensorboardX
|
tensorboardX
|
||||||
|
|
|
@ -10,7 +10,6 @@ class PrenetTests(unittest.TestCase):
|
||||||
layer = Prenet(128, out_features=[256, 128])
|
layer = Prenet(128, out_features=[256, 128])
|
||||||
dummy_input = T.autograd.Variable(T.rand(4, 128))
|
dummy_input = T.autograd.Variable(T.rand(4, 128))
|
||||||
|
|
||||||
|
|
||||||
print(layer)
|
print(layer)
|
||||||
output = layer(dummy_input)
|
output = layer(dummy_input)
|
||||||
assert output.shape[0] == 4
|
assert output.shape[0] == 4
|
||||||
|
@ -49,7 +48,7 @@ class EncoderTests(unittest.TestCase):
|
||||||
|
|
||||||
def test_in_out(self):
|
def test_in_out(self):
|
||||||
layer = Encoder(128)
|
layer = Encoder(128)
|
||||||
dummy_input = T.autograd.Variable(T.rand(4, 8, 128))
|
dummy_input = T.autograd.Variable(T.rand(4, 8, 128))
|
||||||
|
|
||||||
print(layer)
|
print(layer)
|
||||||
output = layer(dummy_input)
|
output = layer(dummy_input)
|
||||||
|
|
6
train.py
6
train.py
|
@ -20,7 +20,8 @@ from tensorboardX import SummaryWriter
|
||||||
|
|
||||||
from utils.generic_utils import (Progbar, remove_experiment_folder,
|
from utils.generic_utils import (Progbar, remove_experiment_folder,
|
||||||
create_experiment_folder, save_checkpoint,
|
create_experiment_folder, save_checkpoint,
|
||||||
save_best_model, load_config, lr_decay)
|
save_best_model, load_config, lr_decay,
|
||||||
|
count_parameters)
|
||||||
from utils.model import get_param_size
|
from utils.model import get_param_size
|
||||||
from utils.visual import plot_alignment, plot_spectrogram
|
from utils.visual import plot_alignment, plot_spectrogram
|
||||||
from datasets.LJSpeech import LJSpeechDataset
|
from datasets.LJSpeech import LJSpeechDataset
|
||||||
|
@ -106,6 +107,9 @@ def main(args):
|
||||||
start_epoch = 0
|
start_epoch = 0
|
||||||
print("\n > Starting a new training")
|
print("\n > Starting a new training")
|
||||||
|
|
||||||
|
num_params = count_parameters(model)
|
||||||
|
print(" | > Model has {} parameters".format(num_params))
|
||||||
|
|
||||||
model = model.train()
|
model = model.train()
|
||||||
|
|
||||||
if not os.path.exists(CHECKPOINT_PATH):
|
if not os.path.exists(CHECKPOINT_PATH):
|
||||||
|
|
|
@ -101,6 +101,12 @@ def lr_decay(init_lr, global_step):
|
||||||
step**-0.5)
|
step**-0.5)
|
||||||
return lr
|
return lr
|
||||||
|
|
||||||
|
|
||||||
|
def count_parameters(model):
|
||||||
|
r"""Count number of trainable parameters in a network"""
|
||||||
|
return sum(p.numel() for p in model.parameters() if p.requires_grad)
|
||||||
|
|
||||||
|
|
||||||
class Progbar(object):
|
class Progbar(object):
|
||||||
"""Displays a progress bar.
|
"""Displays a progress bar.
|
||||||
# Arguments
|
# Arguments
|
||||||
|
|
Загрузка…
Ссылка в новой задаче