continuining training from where it is left off

This commit is contained in:
Eren Golge 2019-10-30 15:48:38 +01:00
Родитель 8f53f9fc8f
Коммит 05f038880b
3 изменённых файлов: 35 добавлений и 46 удалений

Просмотреть файл

@ -1,6 +1,6 @@
{
"run_name": "ljspeech",
"run_description": "t bidirectional decoder test train",
"run_name": "ljspeech-w/o-bd",
"run_description": "tacotron2 without bidirectional decoder",
"audio":{
// Audio processing parameters
@ -46,7 +46,7 @@
"forward_attn_mask": false,
"transition_agent": false, // enable/disable transition agent of forward attention.
"location_attn": true, // enable_disable location sensitive attention. It is enabled for TACOTRON by default.
"bidirectional_decoder": true, // use https://arxiv.org/abs/1907.09006. Use it, if attention does not work well with your dataset.
"bidirectional_decoder": false, // use https://arxiv.org/abs/1907.09006. Use it, if attention does not work well with your dataset.
"loss_masking": true, // enable / disable loss masking against the sequence padding.
"enable_eos_bos_chars": false, // enable/disable beginning of sentence and end of sentence chars.
"stopnet": true, // Train stopnet predicting the end of synthesis.

Просмотреть файл

@ -125,50 +125,45 @@ def main():
Call train.py as a new process and pass command arguments
"""
parser = argparse.ArgumentParser()
parser.add_argument(
'--continue_path',
type=str,
help='Training output folder to conitnue training. Use to continue a training.',
default='')
parser.add_argument(
'--restore_path',
type=str,
help='Folder path to checkpoints',
help='Model file to be restored. Use to finetune a model.',
default='')
parser.add_argument(
'--config_path',
type=str,
help='path to config file for training',
)
parser.add_argument(
'--data_path', type=str, help='dataset path.', default='')
args = parser.parse_args()
CONFIG = load_config(args.config_path)
OUT_PATH = create_experiment_folder(CONFIG.output_path, CONFIG.run_name,
True)
stdout_path = os.path.join(OUT_PATH, "process_stdout/")
# OUT_PATH = create_experiment_folder(CONFIG.output_path, CONFIG.run_name,
# True)
# stdout_path = os.path.join(OUT_PATH, "process_stdout/")
num_gpus = torch.cuda.device_count()
group_id = time.strftime("%Y_%m_%d-%H%M%S")
# set arguments for train.py
command = ['train.py']
command.append('--continue_path={}'.format(args.continue_path))
command.append('--restore_path={}'.format(args.restore_path))
command.append('--config_path={}'.format(args.config_path))
command.append('--group_id=group_{}'.format(group_id))
command.append('--data_path={}'.format(args.data_path))
command.append('--output_path={}'.format(OUT_PATH))
command.append('')
if not os.path.isdir(stdout_path):
os.makedirs(stdout_path)
os.chmod(stdout_path, 0o775)
# run processes
processes = []
for i in range(num_gpus):
my_env = os.environ.copy()
my_env["PYTHON_EGG_CACHE"] = "/tmp/tmp{}".format(i)
command[6] = '--rank={}'.format(i)
stdout = None if i == 0 else open(
os.path.join(stdout_path, "process_{}.log".format(i)), "w")
command[-1] = '--rank={}'.format(i)
stdout = None if i == 0 else open(os.devnull, 'w')
p = subprocess.Popen(['python3'] + command, stdout=stdout, env=my_env)
processes.append(p)
print(command)

Просмотреть файл

@ -1,6 +1,7 @@
import argparse
import os
import sys
import glob
import time
import traceback
@ -643,11 +644,16 @@ def main(args): # pylint: disable=redefined-outer-name
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument(
'--continue_path',
type=str,
help='Training output folder to conitnue training. Use to continue a training.',
default='')
parser.add_argument(
'--restore_path',
type=str,
help='Path to model outputs (checkpoint, tensorboard etc.).',
default=0)
help='Model file to be restored. Use to finetune a model.',
default='')
parser.add_argument(
'--config_path',
type=str,
@ -657,19 +663,6 @@ if __name__ == '__main__':
type=bool,
default=True,
help='Do not verify commit integrity to run training.')
parser.add_argument(
'--data_path',
type=str,
default='',
help='Defines the data path. It overwrites config.json.')
parser.add_argument('--output_path',
type=str,
help='path for training outputs.',
default='')
parser.add_argument('--output_folder',
type=str,
default='',
help='folder name for training outputs.')
# DISTRUBUTED
parser.add_argument(
@ -683,21 +676,22 @@ if __name__ == '__main__':
help='DISTRIBUTED: process group id.')
args = parser.parse_args()
# setup output paths and read configs
if args.continue_path != '':
args.output_path = args.continue_path
args.config_path = os.path.join(args.continue_path, 'config.json')
list_of_files = glob.glob(args.continue_path + "/*.pth.tar") # * means all if need specific format then *.csv
latest_model_file = max(list_of_files, key=os.path.getctime)
args.restore_path = latest_model_file
print(f" > Training continues for {args.restore_path}")
# setup output paths and read configs
c = load_config(args.config_path)
_ = os.path.dirname(os.path.realpath(__file__))
if args.data_path != '':
c.data_path = args.data_path
if args.output_path == '':
OUT_PATH = os.path.join(_, c.output_path)
if args.continue_path != '':
OUT_PATH = create_experiment_folder(args.continue_path, c.run_name, args.debug)
else:
OUT_PATH = args.output_path
if args.group_id == '' and args.output_folder == '':
OUT_PATH = create_experiment_folder(OUT_PATH, c.run_name, args.debug)
else:
OUT_PATH = os.path.join(OUT_PATH, args.output_folder)
OUT_PATH = create_experiment_folder(c.output_path, c.run_name, args.debug)
AUDIO_PATH = os.path.join(OUT_PATH, 'test_audios')