зеркало из https://github.com/mozilla/TTS.git
continuining training from where it is left off
This commit is contained in:
Родитель
8f53f9fc8f
Коммит
05f038880b
|
@ -1,6 +1,6 @@
|
||||||
{
|
{
|
||||||
"run_name": "ljspeech",
|
"run_name": "ljspeech-w/o-bd",
|
||||||
"run_description": "t bidirectional decoder test train",
|
"run_description": "tacotron2 without bidirectional decoder",
|
||||||
|
|
||||||
"audio":{
|
"audio":{
|
||||||
// Audio processing parameters
|
// Audio processing parameters
|
||||||
|
@ -46,7 +46,7 @@
|
||||||
"forward_attn_mask": false,
|
"forward_attn_mask": false,
|
||||||
"transition_agent": false, // enable/disable transition agent of forward attention.
|
"transition_agent": false, // enable/disable transition agent of forward attention.
|
||||||
"location_attn": true, // enable_disable location sensitive attention. It is enabled for TACOTRON by default.
|
"location_attn": true, // enable_disable location sensitive attention. It is enabled for TACOTRON by default.
|
||||||
"bidirectional_decoder": true, // use https://arxiv.org/abs/1907.09006. Use it, if attention does not work well with your dataset.
|
"bidirectional_decoder": false, // use https://arxiv.org/abs/1907.09006. Use it, if attention does not work well with your dataset.
|
||||||
"loss_masking": true, // enable / disable loss masking against the sequence padding.
|
"loss_masking": true, // enable / disable loss masking against the sequence padding.
|
||||||
"enable_eos_bos_chars": false, // enable/disable beginning of sentence and end of sentence chars.
|
"enable_eos_bos_chars": false, // enable/disable beginning of sentence and end of sentence chars.
|
||||||
"stopnet": true, // Train stopnet predicting the end of synthesis.
|
"stopnet": true, // Train stopnet predicting the end of synthesis.
|
||||||
|
|
|
@ -125,50 +125,45 @@ def main():
|
||||||
Call train.py as a new process and pass command arguments
|
Call train.py as a new process and pass command arguments
|
||||||
"""
|
"""
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.add_argument(
|
||||||
|
'--continue_path',
|
||||||
|
type=str,
|
||||||
|
help='Training output folder to conitnue training. Use to continue a training.',
|
||||||
|
default='')
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
'--restore_path',
|
'--restore_path',
|
||||||
type=str,
|
type=str,
|
||||||
help='Folder path to checkpoints',
|
help='Model file to be restored. Use to finetune a model.',
|
||||||
default='')
|
default='')
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
'--config_path',
|
'--config_path',
|
||||||
type=str,
|
type=str,
|
||||||
help='path to config file for training',
|
help='path to config file for training',
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
|
||||||
'--data_path', type=str, help='dataset path.', default='')
|
|
||||||
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
CONFIG = load_config(args.config_path)
|
# OUT_PATH = create_experiment_folder(CONFIG.output_path, CONFIG.run_name,
|
||||||
OUT_PATH = create_experiment_folder(CONFIG.output_path, CONFIG.run_name,
|
# True)
|
||||||
True)
|
# stdout_path = os.path.join(OUT_PATH, "process_stdout/")
|
||||||
stdout_path = os.path.join(OUT_PATH, "process_stdout/")
|
|
||||||
|
|
||||||
num_gpus = torch.cuda.device_count()
|
num_gpus = torch.cuda.device_count()
|
||||||
group_id = time.strftime("%Y_%m_%d-%H%M%S")
|
group_id = time.strftime("%Y_%m_%d-%H%M%S")
|
||||||
|
|
||||||
# set arguments for train.py
|
# set arguments for train.py
|
||||||
command = ['train.py']
|
command = ['train.py']
|
||||||
|
command.append('--continue_path={}'.format(args.continue_path))
|
||||||
command.append('--restore_path={}'.format(args.restore_path))
|
command.append('--restore_path={}'.format(args.restore_path))
|
||||||
command.append('--config_path={}'.format(args.config_path))
|
command.append('--config_path={}'.format(args.config_path))
|
||||||
command.append('--group_id=group_{}'.format(group_id))
|
command.append('--group_id=group_{}'.format(group_id))
|
||||||
command.append('--data_path={}'.format(args.data_path))
|
|
||||||
command.append('--output_path={}'.format(OUT_PATH))
|
|
||||||
command.append('')
|
command.append('')
|
||||||
|
|
||||||
if not os.path.isdir(stdout_path):
|
|
||||||
os.makedirs(stdout_path)
|
|
||||||
os.chmod(stdout_path, 0o775)
|
|
||||||
|
|
||||||
# run processes
|
# run processes
|
||||||
processes = []
|
processes = []
|
||||||
for i in range(num_gpus):
|
for i in range(num_gpus):
|
||||||
my_env = os.environ.copy()
|
my_env = os.environ.copy()
|
||||||
my_env["PYTHON_EGG_CACHE"] = "/tmp/tmp{}".format(i)
|
my_env["PYTHON_EGG_CACHE"] = "/tmp/tmp{}".format(i)
|
||||||
command[6] = '--rank={}'.format(i)
|
command[-1] = '--rank={}'.format(i)
|
||||||
stdout = None if i == 0 else open(
|
stdout = None if i == 0 else open(os.devnull, 'w')
|
||||||
os.path.join(stdout_path, "process_{}.log".format(i)), "w")
|
|
||||||
p = subprocess.Popen(['python3'] + command, stdout=stdout, env=my_env)
|
p = subprocess.Popen(['python3'] + command, stdout=stdout, env=my_env)
|
||||||
processes.append(p)
|
processes.append(p)
|
||||||
print(command)
|
print(command)
|
||||||
|
|
44
train.py
44
train.py
|
@ -1,6 +1,7 @@
|
||||||
import argparse
|
import argparse
|
||||||
import os
|
import os
|
||||||
import sys
|
import sys
|
||||||
|
import glob
|
||||||
import time
|
import time
|
||||||
import traceback
|
import traceback
|
||||||
|
|
||||||
|
@ -643,11 +644,16 @@ def main(args): # pylint: disable=redefined-outer-name
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.add_argument(
|
||||||
|
'--continue_path',
|
||||||
|
type=str,
|
||||||
|
help='Training output folder to conitnue training. Use to continue a training.',
|
||||||
|
default='')
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
'--restore_path',
|
'--restore_path',
|
||||||
type=str,
|
type=str,
|
||||||
help='Path to model outputs (checkpoint, tensorboard etc.).',
|
help='Model file to be restored. Use to finetune a model.',
|
||||||
default=0)
|
default='')
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
'--config_path',
|
'--config_path',
|
||||||
type=str,
|
type=str,
|
||||||
|
@ -657,19 +663,6 @@ if __name__ == '__main__':
|
||||||
type=bool,
|
type=bool,
|
||||||
default=True,
|
default=True,
|
||||||
help='Do not verify commit integrity to run training.')
|
help='Do not verify commit integrity to run training.')
|
||||||
parser.add_argument(
|
|
||||||
'--data_path',
|
|
||||||
type=str,
|
|
||||||
default='',
|
|
||||||
help='Defines the data path. It overwrites config.json.')
|
|
||||||
parser.add_argument('--output_path',
|
|
||||||
type=str,
|
|
||||||
help='path for training outputs.',
|
|
||||||
default='')
|
|
||||||
parser.add_argument('--output_folder',
|
|
||||||
type=str,
|
|
||||||
default='',
|
|
||||||
help='folder name for training outputs.')
|
|
||||||
|
|
||||||
# DISTRUBUTED
|
# DISTRUBUTED
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
|
@ -683,21 +676,22 @@ if __name__ == '__main__':
|
||||||
help='DISTRIBUTED: process group id.')
|
help='DISTRIBUTED: process group id.')
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
if args.continue_path != '':
|
||||||
|
args.output_path = args.continue_path
|
||||||
|
args.config_path = os.path.join(args.continue_path, 'config.json')
|
||||||
|
list_of_files = glob.glob(args.continue_path + "/*.pth.tar") # * means all if need specific format then *.csv
|
||||||
|
latest_model_file = max(list_of_files, key=os.path.getctime)
|
||||||
|
args.restore_path = latest_model_file
|
||||||
|
print(f" > Training continues for {args.restore_path}")
|
||||||
|
|
||||||
# setup output paths and read configs
|
# setup output paths and read configs
|
||||||
c = load_config(args.config_path)
|
c = load_config(args.config_path)
|
||||||
_ = os.path.dirname(os.path.realpath(__file__))
|
_ = os.path.dirname(os.path.realpath(__file__))
|
||||||
if args.data_path != '':
|
|
||||||
c.data_path = args.data_path
|
|
||||||
|
|
||||||
if args.output_path == '':
|
if args.continue_path != '':
|
||||||
OUT_PATH = os.path.join(_, c.output_path)
|
OUT_PATH = create_experiment_folder(args.continue_path, c.run_name, args.debug)
|
||||||
else:
|
else:
|
||||||
OUT_PATH = args.output_path
|
OUT_PATH = create_experiment_folder(c.output_path, c.run_name, args.debug)
|
||||||
|
|
||||||
if args.group_id == '' and args.output_folder == '':
|
|
||||||
OUT_PATH = create_experiment_folder(OUT_PATH, c.run_name, args.debug)
|
|
||||||
else:
|
|
||||||
OUT_PATH = os.path.join(OUT_PATH, args.output_folder)
|
|
||||||
|
|
||||||
AUDIO_PATH = os.path.join(OUT_PATH, 'test_audios')
|
AUDIO_PATH = os.path.join(OUT_PATH, 'test_audios')
|
||||||
|
|
||||||
|
|
Загрузка…
Ссылка в новой задаче