From 05f038880b1fdad7b2ec73738fc669562fec4af7 Mon Sep 17 00:00:00 2001 From: Eren Golge Date: Wed, 30 Oct 2019 15:48:38 +0100 Subject: [PATCH] continuining training from where it is left off --- config.json | 6 +++--- distribute.py | 29 ++++++++++++----------------- train.py | 46 ++++++++++++++++++++-------------------------- 3 files changed, 35 insertions(+), 46 deletions(-) diff --git a/config.json b/config.json index 4ca54b9..623a599 100644 --- a/config.json +++ b/config.json @@ -1,6 +1,6 @@ { - "run_name": "ljspeech", - "run_description": "t bidirectional decoder test train", + "run_name": "ljspeech-w/o-bd", + "run_description": "tacotron2 without bidirectional decoder", "audio":{ // Audio processing parameters @@ -46,7 +46,7 @@ "forward_attn_mask": false, "transition_agent": false, // enable/disable transition agent of forward attention. "location_attn": true, // enable_disable location sensitive attention. It is enabled for TACOTRON by default. - "bidirectional_decoder": true, // use https://arxiv.org/abs/1907.09006. Use it, if attention does not work well with your dataset. + "bidirectional_decoder": false, // use https://arxiv.org/abs/1907.09006. Use it, if attention does not work well with your dataset. "loss_masking": true, // enable / disable loss masking against the sequence padding. "enable_eos_bos_chars": false, // enable/disable beginning of sentence and end of sentence chars. "stopnet": true, // Train stopnet predicting the end of synthesis. diff --git a/distribute.py b/distribute.py index f65fbe7..99f62f5 100644 --- a/distribute.py +++ b/distribute.py @@ -125,50 +125,45 @@ def main(): Call train.py as a new process and pass command arguments """ parser = argparse.ArgumentParser() + parser.add_argument( + '--continue_path', + type=str, + help='Training output folder to conitnue training. Use to continue a training.', + default='') parser.add_argument( '--restore_path', type=str, - help='Folder path to checkpoints', + help='Model file to be restored. Use to finetune a model.', default='') parser.add_argument( '--config_path', type=str, help='path to config file for training', ) - parser.add_argument( - '--data_path', type=str, help='dataset path.', default='') - args = parser.parse_args() - CONFIG = load_config(args.config_path) - OUT_PATH = create_experiment_folder(CONFIG.output_path, CONFIG.run_name, - True) - stdout_path = os.path.join(OUT_PATH, "process_stdout/") + # OUT_PATH = create_experiment_folder(CONFIG.output_path, CONFIG.run_name, + # True) + # stdout_path = os.path.join(OUT_PATH, "process_stdout/") num_gpus = torch.cuda.device_count() group_id = time.strftime("%Y_%m_%d-%H%M%S") # set arguments for train.py command = ['train.py'] + command.append('--continue_path={}'.format(args.continue_path)) command.append('--restore_path={}'.format(args.restore_path)) command.append('--config_path={}'.format(args.config_path)) command.append('--group_id=group_{}'.format(group_id)) - command.append('--data_path={}'.format(args.data_path)) - command.append('--output_path={}'.format(OUT_PATH)) command.append('') - if not os.path.isdir(stdout_path): - os.makedirs(stdout_path) - os.chmod(stdout_path, 0o775) - # run processes processes = [] for i in range(num_gpus): my_env = os.environ.copy() my_env["PYTHON_EGG_CACHE"] = "/tmp/tmp{}".format(i) - command[6] = '--rank={}'.format(i) - stdout = None if i == 0 else open( - os.path.join(stdout_path, "process_{}.log".format(i)), "w") + command[-1] = '--rank={}'.format(i) + stdout = None if i == 0 else open(os.devnull, 'w') p = subprocess.Popen(['python3'] + command, stdout=stdout, env=my_env) processes.append(p) print(command) diff --git a/train.py b/train.py index 89e2115..48cb058 100644 --- a/train.py +++ b/train.py @@ -1,6 +1,7 @@ import argparse import os import sys +import glob import time import traceback @@ -643,11 +644,16 @@ def main(args): # pylint: disable=redefined-outer-name if __name__ == '__main__': parser = argparse.ArgumentParser() + parser.add_argument( + '--continue_path', + type=str, + help='Training output folder to conitnue training. Use to continue a training.', + default='') parser.add_argument( '--restore_path', type=str, - help='Path to model outputs (checkpoint, tensorboard etc.).', - default=0) + help='Model file to be restored. Use to finetune a model.', + default='') parser.add_argument( '--config_path', type=str, @@ -657,19 +663,6 @@ if __name__ == '__main__': type=bool, default=True, help='Do not verify commit integrity to run training.') - parser.add_argument( - '--data_path', - type=str, - default='', - help='Defines the data path. It overwrites config.json.') - parser.add_argument('--output_path', - type=str, - help='path for training outputs.', - default='') - parser.add_argument('--output_folder', - type=str, - default='', - help='folder name for training outputs.') # DISTRUBUTED parser.add_argument( @@ -683,21 +676,22 @@ if __name__ == '__main__': help='DISTRIBUTED: process group id.') args = parser.parse_args() - # setup output paths and read configs + if args.continue_path != '': + args.output_path = args.continue_path + args.config_path = os.path.join(args.continue_path, 'config.json') + list_of_files = glob.glob(args.continue_path + "/*.pth.tar") # * means all if need specific format then *.csv + latest_model_file = max(list_of_files, key=os.path.getctime) + args.restore_path = latest_model_file + print(f" > Training continues for {args.restore_path}") + + # setup output paths and read configs c = load_config(args.config_path) _ = os.path.dirname(os.path.realpath(__file__)) - if args.data_path != '': - c.data_path = args.data_path - if args.output_path == '': - OUT_PATH = os.path.join(_, c.output_path) + if args.continue_path != '': + OUT_PATH = create_experiment_folder(args.continue_path, c.run_name, args.debug) else: - OUT_PATH = args.output_path - - if args.group_id == '' and args.output_folder == '': - OUT_PATH = create_experiment_folder(OUT_PATH, c.run_name, args.debug) - else: - OUT_PATH = os.path.join(OUT_PATH, args.output_folder) + OUT_PATH = create_experiment_folder(c.output_path, c.run_name, args.debug) AUDIO_PATH = os.path.join(OUT_PATH, 'test_audios')