continuining training from where it is left off

This commit is contained in:
Eren Golge 2019-10-30 15:48:38 +01:00
Родитель 8f53f9fc8f
Коммит 05f038880b
3 изменённых файлов: 35 добавлений и 46 удалений

Просмотреть файл

@ -1,6 +1,6 @@
{ {
"run_name": "ljspeech", "run_name": "ljspeech-w/o-bd",
"run_description": "t bidirectional decoder test train", "run_description": "tacotron2 without bidirectional decoder",
"audio":{ "audio":{
// Audio processing parameters // Audio processing parameters
@ -46,7 +46,7 @@
"forward_attn_mask": false, "forward_attn_mask": false,
"transition_agent": false, // enable/disable transition agent of forward attention. "transition_agent": false, // enable/disable transition agent of forward attention.
"location_attn": true, // enable_disable location sensitive attention. It is enabled for TACOTRON by default. "location_attn": true, // enable_disable location sensitive attention. It is enabled for TACOTRON by default.
"bidirectional_decoder": true, // use https://arxiv.org/abs/1907.09006. Use it, if attention does not work well with your dataset. "bidirectional_decoder": false, // use https://arxiv.org/abs/1907.09006. Use it, if attention does not work well with your dataset.
"loss_masking": true, // enable / disable loss masking against the sequence padding. "loss_masking": true, // enable / disable loss masking against the sequence padding.
"enable_eos_bos_chars": false, // enable/disable beginning of sentence and end of sentence chars. "enable_eos_bos_chars": false, // enable/disable beginning of sentence and end of sentence chars.
"stopnet": true, // Train stopnet predicting the end of synthesis. "stopnet": true, // Train stopnet predicting the end of synthesis.

Просмотреть файл

@ -125,50 +125,45 @@ def main():
Call train.py as a new process and pass command arguments Call train.py as a new process and pass command arguments
""" """
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument(
'--continue_path',
type=str,
help='Training output folder to conitnue training. Use to continue a training.',
default='')
parser.add_argument( parser.add_argument(
'--restore_path', '--restore_path',
type=str, type=str,
help='Folder path to checkpoints', help='Model file to be restored. Use to finetune a model.',
default='') default='')
parser.add_argument( parser.add_argument(
'--config_path', '--config_path',
type=str, type=str,
help='path to config file for training', help='path to config file for training',
) )
parser.add_argument(
'--data_path', type=str, help='dataset path.', default='')
args = parser.parse_args() args = parser.parse_args()
CONFIG = load_config(args.config_path) # OUT_PATH = create_experiment_folder(CONFIG.output_path, CONFIG.run_name,
OUT_PATH = create_experiment_folder(CONFIG.output_path, CONFIG.run_name, # True)
True) # stdout_path = os.path.join(OUT_PATH, "process_stdout/")
stdout_path = os.path.join(OUT_PATH, "process_stdout/")
num_gpus = torch.cuda.device_count() num_gpus = torch.cuda.device_count()
group_id = time.strftime("%Y_%m_%d-%H%M%S") group_id = time.strftime("%Y_%m_%d-%H%M%S")
# set arguments for train.py # set arguments for train.py
command = ['train.py'] command = ['train.py']
command.append('--continue_path={}'.format(args.continue_path))
command.append('--restore_path={}'.format(args.restore_path)) command.append('--restore_path={}'.format(args.restore_path))
command.append('--config_path={}'.format(args.config_path)) command.append('--config_path={}'.format(args.config_path))
command.append('--group_id=group_{}'.format(group_id)) command.append('--group_id=group_{}'.format(group_id))
command.append('--data_path={}'.format(args.data_path))
command.append('--output_path={}'.format(OUT_PATH))
command.append('') command.append('')
if not os.path.isdir(stdout_path):
os.makedirs(stdout_path)
os.chmod(stdout_path, 0o775)
# run processes # run processes
processes = [] processes = []
for i in range(num_gpus): for i in range(num_gpus):
my_env = os.environ.copy() my_env = os.environ.copy()
my_env["PYTHON_EGG_CACHE"] = "/tmp/tmp{}".format(i) my_env["PYTHON_EGG_CACHE"] = "/tmp/tmp{}".format(i)
command[6] = '--rank={}'.format(i) command[-1] = '--rank={}'.format(i)
stdout = None if i == 0 else open( stdout = None if i == 0 else open(os.devnull, 'w')
os.path.join(stdout_path, "process_{}.log".format(i)), "w")
p = subprocess.Popen(['python3'] + command, stdout=stdout, env=my_env) p = subprocess.Popen(['python3'] + command, stdout=stdout, env=my_env)
processes.append(p) processes.append(p)
print(command) print(command)

Просмотреть файл

@ -1,6 +1,7 @@
import argparse import argparse
import os import os
import sys import sys
import glob
import time import time
import traceback import traceback
@ -643,11 +644,16 @@ def main(args): # pylint: disable=redefined-outer-name
if __name__ == '__main__': if __name__ == '__main__':
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument(
'--continue_path',
type=str,
help='Training output folder to conitnue training. Use to continue a training.',
default='')
parser.add_argument( parser.add_argument(
'--restore_path', '--restore_path',
type=str, type=str,
help='Path to model outputs (checkpoint, tensorboard etc.).', help='Model file to be restored. Use to finetune a model.',
default=0) default='')
parser.add_argument( parser.add_argument(
'--config_path', '--config_path',
type=str, type=str,
@ -657,19 +663,6 @@ if __name__ == '__main__':
type=bool, type=bool,
default=True, default=True,
help='Do not verify commit integrity to run training.') help='Do not verify commit integrity to run training.')
parser.add_argument(
'--data_path',
type=str,
default='',
help='Defines the data path. It overwrites config.json.')
parser.add_argument('--output_path',
type=str,
help='path for training outputs.',
default='')
parser.add_argument('--output_folder',
type=str,
default='',
help='folder name for training outputs.')
# DISTRUBUTED # DISTRUBUTED
parser.add_argument( parser.add_argument(
@ -683,21 +676,22 @@ if __name__ == '__main__':
help='DISTRIBUTED: process group id.') help='DISTRIBUTED: process group id.')
args = parser.parse_args() args = parser.parse_args()
# setup output paths and read configs if args.continue_path != '':
args.output_path = args.continue_path
args.config_path = os.path.join(args.continue_path, 'config.json')
list_of_files = glob.glob(args.continue_path + "/*.pth.tar") # * means all if need specific format then *.csv
latest_model_file = max(list_of_files, key=os.path.getctime)
args.restore_path = latest_model_file
print(f" > Training continues for {args.restore_path}")
# setup output paths and read configs
c = load_config(args.config_path) c = load_config(args.config_path)
_ = os.path.dirname(os.path.realpath(__file__)) _ = os.path.dirname(os.path.realpath(__file__))
if args.data_path != '':
c.data_path = args.data_path
if args.output_path == '': if args.continue_path != '':
OUT_PATH = os.path.join(_, c.output_path) OUT_PATH = create_experiment_folder(args.continue_path, c.run_name, args.debug)
else: else:
OUT_PATH = args.output_path OUT_PATH = create_experiment_folder(c.output_path, c.run_name, args.debug)
if args.group_id == '' and args.output_folder == '':
OUT_PATH = create_experiment_folder(OUT_PATH, c.run_name, args.debug)
else:
OUT_PATH = os.path.join(OUT_PATH, args.output_folder)
AUDIO_PATH = os.path.join(OUT_PATH, 'test_audios') AUDIO_PATH = os.path.join(OUT_PATH, 'test_audios')