feat(*): Support GNN-Edge-MLP0/MLP1 model names

This commit is contained in:
Marc Brockschmidt 2019-10-08 17:46:35 +00:00
Родитель b0ffd07f0e
Коммит 1b1530e4dd
16 изменённых файлов: 100 добавлений и 38 удалений

Просмотреть файл

@ -87,7 +87,7 @@ Currently, five model types are implemented:
* `GGNN`: Gated Graph Neural Networks ([Li et al., 2015](#li-et-al-2015)).
* `RGCN`: Relational Graph Convolutional Networks ([Schlichtkrull et al., 2017](#schlichtkrull-et-al-2017)).
* `RGAT`: Relational Graph Attention Networks ([Veličković et al., 2018](#veličković-et-al-2018)).
* `GNN-Edge-MLP`: Graph Neural Network with Edge MLPs - a variant of RGCN in which messages on edges are computed using full MLPs, not just a single layer.
* `GNN-Edge-MLP`: Graph Neural Network with Edge MLPs - a variant of RGCN in which messages on edges are computed using full MLPs, not just a single layer applied to the source state.
* `RGDCN`: Relational Graph Dynamic Convolution Networks - a new variant of RGCN in which the weights of convolutional layers are dynamically computed.
* `GNN-FiLM`: Graph Neural Networks with Feature-wise Linear Modulation - a new extension of RGCN with FiLM layers.
@ -130,13 +130,14 @@ by using `--data-path "SOME/OTHER/DIR"`.
Running `python run_ppi_benchs.py ppi_results/` should yield results looking
like this (on an NVidia V100):
| Model | Avg. MicroF1 | Avg. Time |
|--------------|-------------------|------------|
| GGNN | 0.990 (+/- 0.001) | 432.6 |
| RGCN | 0.989 (+/- 0.000) | 759.0 |
| GAT | 0.989 (+/- 0.001) | 782.3 |
| GNN-Edge-MLP | 0.992 (+/- 0.001) | 479.2 |
| GNN-FiLM | 0.992 (+/- 0.000) | 308.1 |
| Model | Avg. MicroF1 | Avg. Time |
|---------------|-------------------|------------|
| GGNN | 0.990 (+/- 0.001) | 432.6 |
| RGCN | 0.989 (+/- 0.000) | 759.0 |
| GAT | 0.989 (+/- 0.001) | 782.3 |
| GNN-Edge-MLP0 | 0.992 (+/- 0.000) | 556.9 |
| GNN-Edge-MLP1 | 0.992 (+/- 0.001) | 479.2 |
| GNN-FiLM | 0.992 (+/- 0.000) | 308.1 |
## QM9
The `QM9` task (implemented in `tasks/qm9_task.py`) handles the quantum chemistry

Просмотреть файл

@ -23,7 +23,7 @@ class GGNN_Model(Sparse_Graph_Model):
return params
@staticmethod
def name() -> str:
def name(params: Dict[str, Any]) -> str:
return "GGNN"
def __init__(self, params: Dict[str, Any], task: Sparse_Graph_Task, run_id: str, result_dir: str) -> None:

Просмотреть файл

@ -24,8 +24,8 @@ class GNN_Edge_MLP_Model(Sparse_Graph_Model):
return params
@staticmethod
def name() -> str:
return "GNN-Edge-MLP"
def name(params: Dict[str, Any]) -> str:
return "GNN-Edge-MLP%i" % (params['num_edge_hidden_layers'])
def __init__(self, params: Dict[str, Any], task: Sparse_Graph_Task, run_id: str, result_dir: str) -> None:
super().__init__(params, task, run_id, result_dir)

Просмотреть файл

@ -20,7 +20,7 @@ class GNN_FiLM_Model(Sparse_Graph_Model):
return params
@staticmethod
def name() -> str:
def name(params: Dict[str, Any]) -> str:
return "GNN-FiLM"
def __init__(self, params: Dict[str, Any], task: Sparse_Graph_Task, run_id: str, result_dir: str) -> None:

Просмотреть файл

@ -22,7 +22,7 @@ class RGAT_Model(Sparse_Graph_Model):
return params
@staticmethod
def name() -> str:
def name(params: Dict[str, Any]) -> str:
return "RGAT"
def __init__(self, params: Dict[str, Any], task: Sparse_Graph_Task, run_id: str, result_dir: str) -> None:

Просмотреть файл

@ -22,7 +22,7 @@ class RGCN_Model(Sparse_Graph_Model):
return params
@staticmethod
def name() -> str:
def name(params: Dict[str, Any]) -> str:
return "RGCN"
def __init__(self, params: Dict[str, Any], task: Sparse_Graph_Task, run_id: str, result_dir: str) -> None:

Просмотреть файл

@ -24,7 +24,7 @@ class RGDCN_Model(Sparse_Graph_Model):
return params
@staticmethod
def name() -> str:
def name(params: Dict[str, Any]) -> str:
return "RGDCN"
def __init__(self, params: Dict[str, Any], task: Sparse_Graph_Task, run_id: str, result_dir: str) -> None:

Просмотреть файл

@ -46,7 +46,7 @@ class Sparse_Graph_Model(ABC):
@staticmethod
@abstractmethod
def name() -> str:
def name(params: Dict[str, Any]) -> str:
raise NotImplementedError()
def __init__(self,
@ -96,7 +96,7 @@ class Sparse_Graph_Model(ABC):
weights_to_save = self.sess.run(vars_to_retrieve)
data_to_save = {
"model_class": self.name(),
"model_class": self.name(self.params),
"task_class": self.task.name(),
"model_params": self.params,
"task_params": self.task.params,

Просмотреть файл

@ -16,7 +16,7 @@ import numpy as np
from docopt import docopt
from dpu_utils.utils import run_and_debug
MODEL_TYPES = ["GGNN", "RGCN", "RGAT", "GNN-Edge-MLP", "GNN_FiLM"]
MODEL_TYPES = ["GGNN", "RGCN", "RGAT", "GNN-Edge-MLP0", "GNN-Edge-MLP1", "GNN_FiLM"]
TEST_RES_RE = re.compile('^Metrics: Avg MicroF1: (0.\d+)')
TIME_RE = re.compile('^Training took (\d+)s')

Просмотреть файл

@ -0,0 +1,7 @@
{"task_params": {},
"model_params": {"graph_num_layers": 5,
"hidden_size": 256,
"max_nodes_in_batch": 6000,
"graph_layer_input_dropout_keep_prob": 0.8
}
}

Просмотреть файл

@ -1,7 +0,0 @@
{
"task_params": {},
"model_params": {
"max_nodes_in_batch": 50000,
"num_edge_hidden_layers": 0
}
}

Просмотреть файл

@ -0,0 +1,27 @@
{
"task_params": {},
"model_params": {
"max_nodes_in_batch": 50000,
"graph_num_layers": 8,
"graph_num_timesteps_per_layer": 1,
"graph_layer_input_dropout_keep_prob": 0.9,
"graph_dense_between_every_num_gnn_layers": 32,
"graph_model_activation_function": "tanh",
"graph_residual_connection_every_num_layers": 2,
"graph_inter_layer_norm": true,
"max_epochs": 10000,
"patience": 25,
"optimizer": "RMSProp",
"learning_rate": 0.0005072060718321982,
"learning_rate_decay": 0.98,
"lr_for_num_graphs_per_batch": null,
"momentum": 0.85,
"clamp_gradient_norm": 1.0,
"hidden_size": 128,
"graph_activation_function": "relu",
"message_aggregation_function": "sum",
"graph_message_weights_dropout_ratio": 0.0,
"use_target_state_as_input": true,
"num_edge_hidden_layers": 0
}
}

Просмотреть файл

@ -0,0 +1,27 @@
{
"task_params": {},
"model_params": {
"max_nodes_in_batch": 50000,
"graph_num_layers": 8,
"graph_num_timesteps_per_layer": 1,
"graph_layer_input_dropout_keep_prob": 0.9,
"graph_dense_between_every_num_gnn_layers": 32,
"graph_model_activation_function": "tanh",
"graph_residual_connection_every_num_layers": 2,
"graph_inter_layer_norm": false,
"max_epochs": 10000,
"patience": 25,
"optimizer": "Adam",
"learning_rate": 0.0006482335154980316,
"learning_rate_decay": 0.98,
"lr_for_num_graphs_per_batch": null,
"momentum": 0.85,
"clamp_gradient_norm": 1.0,
"hidden_size": 128,
"graph_activation_function": "gelu",
"message_aggregation_function": "sum",
"graph_message_weights_dropout_ratio": 0.0,
"use_target_state_as_input": true,
"num_edge_hidden_layers": 1
}
}

Просмотреть файл

@ -32,20 +32,21 @@ from test import test
def run(args):
azure_info_path = args.get('--azure-info', None)
model_cls = name_to_model_class(args['MODEL_NAME'])
model_cls, additional_model_params = name_to_model_class(args['MODEL_NAME'])
task_cls, additional_task_params = name_to_task_class(args['TASK_NAME'])
# Collect parameters from first the class defaults, potential task defaults, and then CLI:
task_params = task_cls.default_params()
task_params.update(additional_task_params)
model_params = model_cls.default_params()
model_params.update(additional_model_params)
# Load potential task-specific defaults:
task_model_default_hypers_file = \
os.path.join(os.path.dirname(__file__),
"tasks",
"default_hypers",
"%s_%s.json" % (task_cls.name(), model_cls.name()))
"%s_%s.json" % (task_cls.name(), model_cls.name(model_params)))
if os.path.exists(task_model_default_hypers_file):
print("Loading task/model-specific default parameters from %s." % task_model_default_hypers_file)
with open(task_model_default_hypers_file, "rt") as f:
@ -77,7 +78,7 @@ def run(args):
for random_seed in random_seeds:
model_params['random_seed'] = random_seed
run_id = "_".join([task_cls.name(), model_cls.name(), time.strftime("%Y-%m-%d-%H-%M-%S"), str(os.getpid())])
run_id = "_".join([task_cls.name(), model_cls.name(model_params), time.strftime("%Y-%m-%d-%H-%M-%S"), str(os.getpid())])
model = model_cls(model_params, task, run_id, result_dir)
model.log_line("Run %s starting." % run_id)

Просмотреть файл

@ -29,34 +29,40 @@ def name_to_task_class(name: str) -> Tuple[Type[Sparse_Graph_Task], Dict[str, An
raise ValueError("Unknown task type '%s'" % name)
def name_to_model_class(name: str) -> Type[Sparse_Graph_Model]:
def name_to_model_class(name: str) -> Tuple[Type[Sparse_Graph_Model], Dict[str, Any]]:
name = name.lower()
if name in ["ggnn", "ggnn_model"]:
return GGNN_Model
return GGNN_Model, {}
if name in ["gnn_edge_mlp", "gnn-edge-mlp", "gnn_edge_mlp_model"]:
return GNN_Edge_MLP_Model
return GNN_Edge_MLP_Model, {}
if name in ["gnn_edge_mlp0", "gnn-edge-mlp0", "gnn_edge_mlp0_model"]:
return GNN_Edge_MLP_Model, {'num_edge_hidden_layers': 0}
if name in ["gnn_edge_mlp1", "gnn-edge-mlp1", "gnn_edge_mlp1_model"]:
return GNN_Edge_MLP_Model, {'num_edge_hidden_layers': 1}
if name in ["gnn_edge_mlp", "gnn-edge-mlp"]:
return GNN_Edge_MLP_Model, {}
if name in ["gnn_film", "gnn-film", "gnn_film_model"]:
return GNN_FiLM_Model
return GNN_FiLM_Model, {}
if name in ["rgat", "rgat_model"]:
return RGAT_Model
return RGAT_Model, {}
if name in ["rgcn", "rgcn_model"]:
return RGCN_Model
return RGCN_Model, {}
if name in ["rgdcn", "rgdcn_model"]:
return RGDCN_Model
return RGDCN_Model, {}
raise ValueError("Unknown model type '%s'" % name)
def restore(saved_model_path: str, result_dir: str, run_id: str = None) -> None:
def restore(saved_model_path: str, result_dir: str, run_id: str = None) -> Sparse_Graph_Model:
print("Loading model from file %s." % saved_model_path)
with open(saved_model_path, 'rb') as in_file:
data_to_load = pickle.load(in_file)
model_cls = name_to_model_class(data_to_load['model_class'])
model_cls, _ = name_to_model_class(data_to_load['model_class'])
task_cls, additional_task_params = name_to_task_class(data_to_load['task_class'])
if run_id is None:
run_id = "_".join([task_cls.name(), model_cls.name(), time.strftime("%Y-%m-%d-%H-%M-%S"), str(os.getpid())])
run_id = "_".join([task_cls.name(), model_cls.name(data_to_load['model_params']), time.strftime("%Y-%m-%d-%H-%M-%S"), str(os.getpid())])
task = task_cls(data_to_load['task_params'])
task.restore_from_metadata(data_to_load['task_metadata'])