зеркало из https://github.com/mozilla/TTS.git
Count total number of model parameters
This commit is contained in:
Родитель
e3b2d2a827
Коммит
c72b8fd64c
Различия файлов скрыты, потому что одна или несколько строк слишком длинны
|
@ -2,4 +2,4 @@ librosa
|
|||
inflect
|
||||
unidecode
|
||||
tensorboard
|
||||
tensorboardX
|
||||
tensorboardX
|
||||
|
|
|
@ -10,7 +10,6 @@ class PrenetTests(unittest.TestCase):
|
|||
layer = Prenet(128, out_features=[256, 128])
|
||||
dummy_input = T.autograd.Variable(T.rand(4, 128))
|
||||
|
||||
|
||||
print(layer)
|
||||
output = layer(dummy_input)
|
||||
assert output.shape[0] == 4
|
||||
|
@ -49,7 +48,7 @@ class EncoderTests(unittest.TestCase):
|
|||
|
||||
def test_in_out(self):
|
||||
layer = Encoder(128)
|
||||
dummy_input = T.autograd.Variable(T.rand(4, 8, 128))
|
||||
dummy_input = T.autograd.Variable(T.rand(4, 8, 128))
|
||||
|
||||
print(layer)
|
||||
output = layer(dummy_input)
|
||||
|
|
6
train.py
6
train.py
|
@ -20,7 +20,8 @@ from tensorboardX import SummaryWriter
|
|||
|
||||
from utils.generic_utils import (Progbar, remove_experiment_folder,
|
||||
create_experiment_folder, save_checkpoint,
|
||||
save_best_model, load_config, lr_decay)
|
||||
save_best_model, load_config, lr_decay,
|
||||
count_parameters)
|
||||
from utils.model import get_param_size
|
||||
from utils.visual import plot_alignment, plot_spectrogram
|
||||
from datasets.LJSpeech import LJSpeechDataset
|
||||
|
@ -106,6 +107,9 @@ def main(args):
|
|||
start_epoch = 0
|
||||
print("\n > Starting a new training")
|
||||
|
||||
num_params = count_parameters(model)
|
||||
print(" | > Model has {} parameters".format(num_params))
|
||||
|
||||
model = model.train()
|
||||
|
||||
if not os.path.exists(CHECKPOINT_PATH):
|
||||
|
|
|
@ -101,6 +101,12 @@ def lr_decay(init_lr, global_step):
|
|||
step**-0.5)
|
||||
return lr
|
||||
|
||||
|
||||
def count_parameters(model):
|
||||
r"""Count number of trainable parameters in a network"""
|
||||
return sum(p.numel() for p in model.parameters() if p.requires_grad)
|
||||
|
||||
|
||||
class Progbar(object):
|
||||
"""Displays a progress bar.
|
||||
# Arguments
|
||||
|
|
Загрузка…
Ссылка в новой задаче