Count total number of model parameters

This commit is contained in:
Eren Golge 2018-02-23 06:20:22 -08:00
Родитель e3b2d2a827
Коммит c72b8fd64c
5 изменённых файлов: 284 добавлений и 25692 удалений

Различия файлов скрыты, потому что одна или несколько строк слишком длинны

Просмотреть файл

@ -2,4 +2,4 @@ librosa
inflect
unidecode
tensorboard
tensorboardX
tensorboardX

Просмотреть файл

@ -10,7 +10,6 @@ class PrenetTests(unittest.TestCase):
layer = Prenet(128, out_features=[256, 128])
dummy_input = T.autograd.Variable(T.rand(4, 128))
print(layer)
output = layer(dummy_input)
assert output.shape[0] == 4
@ -49,7 +48,7 @@ class EncoderTests(unittest.TestCase):
def test_in_out(self):
layer = Encoder(128)
dummy_input = T.autograd.Variable(T.rand(4, 8, 128))
dummy_input = T.autograd.Variable(T.rand(4, 8, 128))
print(layer)
output = layer(dummy_input)

Просмотреть файл

@ -20,7 +20,8 @@ from tensorboardX import SummaryWriter
from utils.generic_utils import (Progbar, remove_experiment_folder,
create_experiment_folder, save_checkpoint,
save_best_model, load_config, lr_decay)
save_best_model, load_config, lr_decay,
count_parameters)
from utils.model import get_param_size
from utils.visual import plot_alignment, plot_spectrogram
from datasets.LJSpeech import LJSpeechDataset
@ -106,6 +107,9 @@ def main(args):
start_epoch = 0
print("\n > Starting a new training")
num_params = count_parameters(model)
print(" | > Model has {} parameters".format(num_params))
model = model.train()
if not os.path.exists(CHECKPOINT_PATH):

Просмотреть файл

@ -101,6 +101,12 @@ def lr_decay(init_lr, global_step):
step**-0.5)
return lr
def count_parameters(model):
r"""Count number of trainable parameters in a network"""
return sum(p.numel() for p in model.parameters() if p.requires_grad)
class Progbar(object):
"""Displays a progress bar.
# Arguments