106 строки
4.6 KiB
Python
106 строки
4.6 KiB
Python
#!/usr/bin/env python
|
|
"""
|
|
Usage:
|
|
train.py [options] MODEL_NAME TASK_NAME
|
|
|
|
MODEL_NAME has to be one of the supported models, which currently are
|
|
GGNN, GNN-Edge-MLP, GNN-FiLM, RGAT, RGCN, RGDCN
|
|
|
|
Options:
|
|
-h --help Show this screen.
|
|
--data-path PATH Path to load data from, has task-specific defaults under data/.
|
|
--result-dir DIR Directory to store logfiles and trained models. [default: trained_models]
|
|
--run-test Indicate if the task's test should be run.
|
|
--model-param-overrides PARAMS Parameter settings overriding model defaults (in JSON format).
|
|
--task-param-overrides PARAMS Parameter settings overriding task defaults (in JSON format).
|
|
--quiet Show less output.
|
|
--tensorboard DIR Dump tensorboard event files to DIR.
|
|
--azure-info=<path> Azure authentication information file (JSON). [default: azure_auth.json]
|
|
--debug Turn on debugger.
|
|
"""
|
|
import json
|
|
import os
|
|
import sys
|
|
import time
|
|
|
|
from docopt import docopt
|
|
from dpu_utils.utils import run_and_debug, RichPath, git_tag_run
|
|
|
|
from utils.model_utils import name_to_model_class, name_to_task_class
|
|
from test import test
|
|
|
|
|
|
def run(args):
|
|
azure_info_path = args.get('--azure-info', None)
|
|
model_cls, additional_model_params = name_to_model_class(args['MODEL_NAME'])
|
|
task_cls, additional_task_params = name_to_task_class(args['TASK_NAME'])
|
|
|
|
# Collect parameters from first the class defaults, potential task defaults, and then CLI:
|
|
task_params = task_cls.default_params()
|
|
task_params.update(additional_task_params)
|
|
model_params = model_cls.default_params()
|
|
model_params.update(additional_model_params)
|
|
|
|
# Load potential task-specific defaults:
|
|
task_model_default_hypers_file = \
|
|
os.path.join(os.path.dirname(__file__),
|
|
"tasks",
|
|
"default_hypers",
|
|
"%s_%s.json" % (task_cls.name(), model_cls.name(model_params)))
|
|
if os.path.exists(task_model_default_hypers_file):
|
|
print("Loading task/model-specific default parameters from %s." % task_model_default_hypers_file)
|
|
with open(task_model_default_hypers_file, "rt") as f:
|
|
default_task_model_hypers = json.load(f)
|
|
task_params.update(default_task_model_hypers['task_params'])
|
|
model_params.update(default_task_model_hypers['model_params'])
|
|
|
|
# Load overrides from command line:
|
|
task_params.update(json.loads(args.get('--task-param-overrides') or '{}'))
|
|
model_params.update(json.loads(args.get('--model-param-overrides') or '{}'))
|
|
|
|
# Finally, upgrade every parameters that's a path to a RichPath:
|
|
task_params_orig = dict(task_params)
|
|
for (param_name, param_value) in task_params.items():
|
|
if param_name.endswith("_path"):
|
|
task_params[param_name] = RichPath.create(param_value, azure_info_path)
|
|
|
|
# Now prepare to actually run by setting up directories, creating object instances and running:
|
|
result_dir = args.get('--result-dir', 'trained_models')
|
|
os.makedirs(result_dir, exist_ok=True)
|
|
task = task_cls(task_params)
|
|
data_path = args.get('--data-path') or task.default_data_path()
|
|
data_path = RichPath.create(data_path, azure_info_path)
|
|
task.load_data(data_path)
|
|
|
|
random_seeds = model_params['random_seed']
|
|
if not isinstance(random_seeds, list):
|
|
random_seeds = [random_seeds]
|
|
|
|
for random_seed in random_seeds:
|
|
model_params['random_seed'] = random_seed
|
|
run_id = "_".join([task_cls.name(), model_cls.name(model_params), time.strftime("%Y-%m-%d-%H-%M-%S"), str(os.getpid())])
|
|
|
|
model = model_cls(model_params, task, run_id, result_dir)
|
|
model.log_line("Run %s starting." % run_id)
|
|
model.log_line(" Using the following task params: %s" % json.dumps(task_params_orig))
|
|
model.log_line(" Using the following model params: %s" % json.dumps(model_params))
|
|
|
|
if sys.stdin.isatty():
|
|
try:
|
|
git_sha = git_tag_run(run_id)
|
|
model.log_line(" git tagged as %s" % git_sha)
|
|
except:
|
|
print(" Tried tagging run in git, but failed.")
|
|
pass
|
|
|
|
model.initialize_model()
|
|
model.train(quiet=args.get('--quiet'), tf_summary_path=args.get('--tensorboard'))
|
|
|
|
if args.get('--run-test'):
|
|
test(model.best_model_file, data_path, result_dir, quiet=args.get('--quiet'), run_id=run_id)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
args = docopt(__doc__)
|
|
run_and_debug(lambda: run(args), enable_debugging=args['--debug'])
|