feat(RGIN): Add relational generalisation of GIN

This commit is contained in:
Marc Brockschmidt 2019-10-16 11:00:48 +00:00
Родитель 9dc36b8cfe
Коммит 9764626607
5 изменённых файлов: 187 добавлений и 2 удалений

Просмотреть файл

@ -3,4 +3,5 @@ from .gnn_edge_mlp import sparse_gnn_edge_mlp_layer
from .gnn_film import sparse_gnn_film_layer
from .rgat import sparse_rgat_layer
from .rgcn import sparse_rgcn_layer
from .rgdcn import sparse_rgdcn_layer
from .rgdcn import sparse_rgdcn_layer
from .rgin import sparse_rgin_layer

136
gnns/rgin.py Normal file
Просмотреть файл

@ -0,0 +1,136 @@
from typing import List, Optional, Callable
import tensorflow as tf
from utils import get_activation
class MLP(object):
def __init__(self, out_size: int, num_hidden_layers: int = 1, activation_fun: Optional[Callable] = None):
if activation_fun is None:
self.__activation_fun = tf.nn.relu
else:
self.__activation_fun = activation_fun
self.__layers = [] # type: List[tf.layers.Dense]
for _ in range(num_hidden_layers + 1): # all hidden layers + one linear output:
self.__layers.append(tf.layers.Dense(units=out_size,
use_bias=False,
activation=None))
def __call__(self, input: tf.Tensor) -> tf.Tensor:
activations = input
for linear_layer in self.__layers[:-1]:
activations = linear_layer(activations)
activations = self.__activation_fun(activations)
return self.__layers[-1](activations)
def sparse_rgin_layer(
node_embeddings: tf.Tensor,
adjacency_lists: List[tf.Tensor],
state_dim: Optional[int],
num_timesteps: int = 1,
activation_function: Optional[str] = "ReLU",
num_MLP_hidden_layers: int = 1,
learn_epsilon: bool = True,
) -> tf.Tensor:
"""
Compute new graph states by neural message passing using MLPs for state updates
and message computation.
For this, we assume existing node states h^t_v and a list of per-edge-type adjacency
matrices A_\ell.
We compute new states as follows:
h^{t+1}_v := MLP_{out}((1 + \epsilon) * MLP_{self}(h^t_v)
+ \sum_\ell \sum_{(u, v) \in A_\ell} MLP_\ell(h^t_u))
The learnable parameters of this are the MLPs and (if enabled) epsilon.
This is derived from Cor. 6 of arXiv:1810.00826, instantiating the functions f, \phi
with _separate_ MLPs. This is more powerful than the GIN formulation in Eq. (4.1) of
arXiv:1810.00826, as we want to be able to distinguish graphs of the form
G_1 = (V={1, 2, 3}, E_1={(1, 2)}, E_2={(3, 2)})
and
G_2 = (V={1, 2, 3}, E_1={(3, 2)}, E_2={(1, 2)})
from each other. If we would treat all edges the same,
G_1.E_1 \cup G_1.E_2 == G_2.E_1 \cup G_2.E_2 would imply that the two graphs
become indistuingishable.
Hence, we introduce per-edge-type MLPs, which also means that we have to drop
the optimisation of modelling f \circ \phi by a single MLP used in the original
GIN formulation.
We use the following abbreviations in shape descriptions:
* V: number of nodes
* D: state dimension
* L: number of different edge types
* E: number of edges of a given edge type
Arguments:
node_embeddings: float32 tensor of shape [V, D], the original representation of
each node in the graph.
adjacency_lists: List of L adjacency lists, represented as int32 tensors of shape
[E, 2]. Concretely, adjacency_lists[l][k,:] == [v, u] means that the k-th edge
of type l connects node v to node u.
state_dim: Optional size of output dimension of the GNN layer. If not set, defaults
to D, the dimensionality of the input. If different from the input dimension,
parameter num_timesteps has to be 1.
num_timesteps: Number of repeated applications of this message passing layer.
activation_function: Type of activation function used.
num_MLP_hidden_layers: Number of hidden layers of the MLPs.
learn_epsilon: Flag indicating if the value of epsilon should be learned. If
False, epsilon defaults to 0.
Returns:
float32 tensor of shape [V, state_dim]
"""
num_nodes = tf.shape(node_embeddings, out_type=tf.int32)[0]
if state_dim is None:
state_dim = tf.shape(node_embeddings, out_type=tf.int32)[1]
# === Prepare things we need across all timesteps:
activation_fn = get_activation(activation_function)
aggregation_MLP = MLP(out_size=state_dim,
num_hidden_layers=num_MLP_hidden_layers,
activation_fun=activation_fn)
edge_type_to_edge_mlp = [] # MLPs to compute the edge messages
edge_type_to_message_targets = [] # List of tensors of message targets
for edge_type_idx, adjacency_list_for_edge_type in enumerate(adjacency_lists):
with tf.variable_scope("Edge_%i_MLP" % edge_type_idx):
edge_type_to_edge_mlp.append(
MLP(out_size=state_dim,
num_hidden_layers=num_MLP_hidden_layers,
activation_fun=activation_fn))
edge_type_to_message_targets.append(adjacency_list_for_edge_type[:, 1])
# Initialize epsilon: Note that we merge the 1+epsilon here:
if learn_epsilon:
epsilon = tf.get_variable("epsilon", shape=(), dtype=tf.float32, initializer=tf.ones_initializer, trainable=True)
else:
epsilon = 1
self_loop_MLP = MLP(out_size=state_dim,
num_hidden_layers=num_MLP_hidden_layers,
activation_fun=activation_fn)
# Let M be the number of messages (sum of all E):
message_targets = tf.concat(edge_type_to_message_targets, axis=0) # Shape [M]
cur_node_states = node_embeddings
for _ in range(num_timesteps):
messages_per_type = [] # list of tensors of messages of shape [E, D]
# Collect incoming messages per edge type
for edge_type_idx, adjacency_list_for_edge_type in enumerate(adjacency_lists):
edge_sources = adjacency_list_for_edge_type[:, 0]
edge_source_states = \
tf.nn.embedding_lookup(params=cur_node_states,
ids=edge_sources) # Shape [E, D]
messages = edge_type_to_edge_mlp[edge_type_idx](edge_source_states) # Shape [E, D]
messages_per_type.append(messages)
all_messages = tf.concat(messages_per_type, axis=0) # Shape [M, D]
aggregated_messages = \
tf.unsorted_segment_sum(data=all_messages,
segment_ids=message_targets,
num_segments=num_nodes) # Shape [V, D]
cur_node_states = aggregation_MLP(epsilon * self_loop_MLP(cur_node_states) + aggregated_messages)
cur_node_states = activation_fn(cur_node_states) # Note that the final MLP layer has no activation, so we do that here explicitly
return cur_node_states

Просмотреть файл

@ -5,3 +5,4 @@ from .gnn_film_model import GNN_FiLM_Model
from .rgat_model import RGAT_Model
from .rgcn_model import RGCN_Model
from .rgdcn_model import RGDCN_Model
from .rgin_model import RGIN_Model

45
models/rgin_model.py Normal file
Просмотреть файл

@ -0,0 +1,45 @@
from typing import Dict, Any, List
import tensorflow as tf
from .sparse_graph_model import Sparse_Graph_Model
from tasks import Sparse_Graph_Task
from gnns import sparse_rgin_layer
class RGIN_Model(Sparse_Graph_Model):
@classmethod
def default_params(cls):
params = super().default_params()
params.update({
'hidden_size': 128,
"graph_activation_function": "ReLU",
'graph_layer_input_dropout_keep_prob': 1.0,
'graph_dense_between_every_num_gnn_layers': 10000,
'graph_residual_connection_every_num_layers': 10000,
'graph_num_MLP_hidden_layers': 1,
'graph_learn_epsilon': False,
})
return params
@staticmethod
def name(params: Dict[str, Any]) -> str:
return "RGIN"
def __init__(self, params: Dict[str, Any], task: Sparse_Graph_Task, run_id: str, result_dir: str) -> None:
super().__init__(params, task, run_id, result_dir)
def _apply_gnn_layer(self,
node_representations: tf.Tensor,
adjacency_lists: List[tf.Tensor],
type_to_num_incoming_edges: tf.Tensor,
num_timesteps: int) -> tf.Tensor:
return sparse_rgin_layer(
node_embeddings=node_representations,
adjacency_lists=adjacency_lists,
state_dim=self.params['hidden_size'],
num_timesteps=num_timesteps,
activation_function=self.params['graph_activation_function'],
num_MLP_hidden_layers=self.params['graph_num_MLP_hidden_layers'],
learn_epsilon=self.params['graph_learn_epsilon'],
)

Просмотреть файл

@ -5,7 +5,7 @@ from typing import Tuple, Type, Dict, Any
import pickle
from models import (Sparse_Graph_Model, GGNN_Model, GNN_FiLM_Model, GNN_Edge_MLP_Model,
RGAT_Model, RGCN_Model, RGDCN_Model)
RGAT_Model, RGCN_Model, RGDCN_Model, RGIN_Model)
from tasks import Sparse_Graph_Task, QM9_Task, Citation_Network_Task, PPI_Task, VarMisuse_Task
@ -49,6 +49,8 @@ def name_to_model_class(name: str) -> Tuple[Type[Sparse_Graph_Model], Dict[str,
return RGCN_Model, {}
if name in ["rgdcn", "rgdcn_model"]:
return RGDCN_Model, {}
if name in ["rgin", "rgin_model"]:
return RGIN_Model, {}
raise ValueError("Unknown model type '%s'" % name)