зеркало из https://github.com/mozilla/TTS.git
continuining training from where it is left off
This commit is contained in:
Родитель
8f53f9fc8f
Коммит
05f038880b
|
@ -1,6 +1,6 @@
|
|||
{
|
||||
"run_name": "ljspeech",
|
||||
"run_description": "t bidirectional decoder test train",
|
||||
"run_name": "ljspeech-w/o-bd",
|
||||
"run_description": "tacotron2 without bidirectional decoder",
|
||||
|
||||
"audio":{
|
||||
// Audio processing parameters
|
||||
|
@ -46,7 +46,7 @@
|
|||
"forward_attn_mask": false,
|
||||
"transition_agent": false, // enable/disable transition agent of forward attention.
|
||||
"location_attn": true, // enable_disable location sensitive attention. It is enabled for TACOTRON by default.
|
||||
"bidirectional_decoder": true, // use https://arxiv.org/abs/1907.09006. Use it, if attention does not work well with your dataset.
|
||||
"bidirectional_decoder": false, // use https://arxiv.org/abs/1907.09006. Use it, if attention does not work well with your dataset.
|
||||
"loss_masking": true, // enable / disable loss masking against the sequence padding.
|
||||
"enable_eos_bos_chars": false, // enable/disable beginning of sentence and end of sentence chars.
|
||||
"stopnet": true, // Train stopnet predicting the end of synthesis.
|
||||
|
|
|
@ -125,50 +125,45 @@ def main():
|
|||
Call train.py as a new process and pass command arguments
|
||||
"""
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
'--continue_path',
|
||||
type=str,
|
||||
help='Training output folder to conitnue training. Use to continue a training.',
|
||||
default='')
|
||||
parser.add_argument(
|
||||
'--restore_path',
|
||||
type=str,
|
||||
help='Folder path to checkpoints',
|
||||
help='Model file to be restored. Use to finetune a model.',
|
||||
default='')
|
||||
parser.add_argument(
|
||||
'--config_path',
|
||||
type=str,
|
||||
help='path to config file for training',
|
||||
)
|
||||
parser.add_argument(
|
||||
'--data_path', type=str, help='dataset path.', default='')
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
CONFIG = load_config(args.config_path)
|
||||
OUT_PATH = create_experiment_folder(CONFIG.output_path, CONFIG.run_name,
|
||||
True)
|
||||
stdout_path = os.path.join(OUT_PATH, "process_stdout/")
|
||||
# OUT_PATH = create_experiment_folder(CONFIG.output_path, CONFIG.run_name,
|
||||
# True)
|
||||
# stdout_path = os.path.join(OUT_PATH, "process_stdout/")
|
||||
|
||||
num_gpus = torch.cuda.device_count()
|
||||
group_id = time.strftime("%Y_%m_%d-%H%M%S")
|
||||
|
||||
# set arguments for train.py
|
||||
command = ['train.py']
|
||||
command.append('--continue_path={}'.format(args.continue_path))
|
||||
command.append('--restore_path={}'.format(args.restore_path))
|
||||
command.append('--config_path={}'.format(args.config_path))
|
||||
command.append('--group_id=group_{}'.format(group_id))
|
||||
command.append('--data_path={}'.format(args.data_path))
|
||||
command.append('--output_path={}'.format(OUT_PATH))
|
||||
command.append('')
|
||||
|
||||
if not os.path.isdir(stdout_path):
|
||||
os.makedirs(stdout_path)
|
||||
os.chmod(stdout_path, 0o775)
|
||||
|
||||
# run processes
|
||||
processes = []
|
||||
for i in range(num_gpus):
|
||||
my_env = os.environ.copy()
|
||||
my_env["PYTHON_EGG_CACHE"] = "/tmp/tmp{}".format(i)
|
||||
command[6] = '--rank={}'.format(i)
|
||||
stdout = None if i == 0 else open(
|
||||
os.path.join(stdout_path, "process_{}.log".format(i)), "w")
|
||||
command[-1] = '--rank={}'.format(i)
|
||||
stdout = None if i == 0 else open(os.devnull, 'w')
|
||||
p = subprocess.Popen(['python3'] + command, stdout=stdout, env=my_env)
|
||||
processes.append(p)
|
||||
print(command)
|
||||
|
|
46
train.py
46
train.py
|
@ -1,6 +1,7 @@
|
|||
import argparse
|
||||
import os
|
||||
import sys
|
||||
import glob
|
||||
import time
|
||||
import traceback
|
||||
|
||||
|
@ -643,11 +644,16 @@ def main(args): # pylint: disable=redefined-outer-name
|
|||
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
'--continue_path',
|
||||
type=str,
|
||||
help='Training output folder to conitnue training. Use to continue a training.',
|
||||
default='')
|
||||
parser.add_argument(
|
||||
'--restore_path',
|
||||
type=str,
|
||||
help='Path to model outputs (checkpoint, tensorboard etc.).',
|
||||
default=0)
|
||||
help='Model file to be restored. Use to finetune a model.',
|
||||
default='')
|
||||
parser.add_argument(
|
||||
'--config_path',
|
||||
type=str,
|
||||
|
@ -657,19 +663,6 @@ if __name__ == '__main__':
|
|||
type=bool,
|
||||
default=True,
|
||||
help='Do not verify commit integrity to run training.')
|
||||
parser.add_argument(
|
||||
'--data_path',
|
||||
type=str,
|
||||
default='',
|
||||
help='Defines the data path. It overwrites config.json.')
|
||||
parser.add_argument('--output_path',
|
||||
type=str,
|
||||
help='path for training outputs.',
|
||||
default='')
|
||||
parser.add_argument('--output_folder',
|
||||
type=str,
|
||||
default='',
|
||||
help='folder name for training outputs.')
|
||||
|
||||
# DISTRUBUTED
|
||||
parser.add_argument(
|
||||
|
@ -683,21 +676,22 @@ if __name__ == '__main__':
|
|||
help='DISTRIBUTED: process group id.')
|
||||
args = parser.parse_args()
|
||||
|
||||
# setup output paths and read configs
|
||||
if args.continue_path != '':
|
||||
args.output_path = args.continue_path
|
||||
args.config_path = os.path.join(args.continue_path, 'config.json')
|
||||
list_of_files = glob.glob(args.continue_path + "/*.pth.tar") # * means all if need specific format then *.csv
|
||||
latest_model_file = max(list_of_files, key=os.path.getctime)
|
||||
args.restore_path = latest_model_file
|
||||
print(f" > Training continues for {args.restore_path}")
|
||||
|
||||
# setup output paths and read configs
|
||||
c = load_config(args.config_path)
|
||||
_ = os.path.dirname(os.path.realpath(__file__))
|
||||
if args.data_path != '':
|
||||
c.data_path = args.data_path
|
||||
|
||||
if args.output_path == '':
|
||||
OUT_PATH = os.path.join(_, c.output_path)
|
||||
if args.continue_path != '':
|
||||
OUT_PATH = create_experiment_folder(args.continue_path, c.run_name, args.debug)
|
||||
else:
|
||||
OUT_PATH = args.output_path
|
||||
|
||||
if args.group_id == '' and args.output_folder == '':
|
||||
OUT_PATH = create_experiment_folder(OUT_PATH, c.run_name, args.debug)
|
||||
else:
|
||||
OUT_PATH = os.path.join(OUT_PATH, args.output_folder)
|
||||
OUT_PATH = create_experiment_folder(c.output_path, c.run_name, args.debug)
|
||||
|
||||
AUDIO_PATH = os.path.join(OUT_PATH, 'test_audios')
|
||||
|
||||
|
|
Загрузка…
Ссылка в новой задаче