feat(RGIN): Add relational generalisation of GIN
This commit is contained in:
Родитель
9dc36b8cfe
Коммит
9764626607
|
@ -3,4 +3,5 @@ from .gnn_edge_mlp import sparse_gnn_edge_mlp_layer
|
||||||
from .gnn_film import sparse_gnn_film_layer
|
from .gnn_film import sparse_gnn_film_layer
|
||||||
from .rgat import sparse_rgat_layer
|
from .rgat import sparse_rgat_layer
|
||||||
from .rgcn import sparse_rgcn_layer
|
from .rgcn import sparse_rgcn_layer
|
||||||
from .rgdcn import sparse_rgdcn_layer
|
from .rgdcn import sparse_rgdcn_layer
|
||||||
|
from .rgin import sparse_rgin_layer
|
||||||
|
|
|
@ -0,0 +1,136 @@
|
||||||
|
from typing import List, Optional, Callable
|
||||||
|
import tensorflow as tf
|
||||||
|
|
||||||
|
|
||||||
|
from utils import get_activation
|
||||||
|
|
||||||
|
|
||||||
|
class MLP(object):
|
||||||
|
def __init__(self, out_size: int, num_hidden_layers: int = 1, activation_fun: Optional[Callable] = None):
|
||||||
|
if activation_fun is None:
|
||||||
|
self.__activation_fun = tf.nn.relu
|
||||||
|
else:
|
||||||
|
self.__activation_fun = activation_fun
|
||||||
|
|
||||||
|
self.__layers = [] # type: List[tf.layers.Dense]
|
||||||
|
for _ in range(num_hidden_layers + 1): # all hidden layers + one linear output:
|
||||||
|
self.__layers.append(tf.layers.Dense(units=out_size,
|
||||||
|
use_bias=False,
|
||||||
|
activation=None))
|
||||||
|
|
||||||
|
def __call__(self, input: tf.Tensor) -> tf.Tensor:
|
||||||
|
activations = input
|
||||||
|
for linear_layer in self.__layers[:-1]:
|
||||||
|
activations = linear_layer(activations)
|
||||||
|
activations = self.__activation_fun(activations)
|
||||||
|
return self.__layers[-1](activations)
|
||||||
|
|
||||||
|
|
||||||
|
def sparse_rgin_layer(
|
||||||
|
node_embeddings: tf.Tensor,
|
||||||
|
adjacency_lists: List[tf.Tensor],
|
||||||
|
state_dim: Optional[int],
|
||||||
|
num_timesteps: int = 1,
|
||||||
|
activation_function: Optional[str] = "ReLU",
|
||||||
|
num_MLP_hidden_layers: int = 1,
|
||||||
|
learn_epsilon: bool = True,
|
||||||
|
) -> tf.Tensor:
|
||||||
|
"""
|
||||||
|
Compute new graph states by neural message passing using MLPs for state updates
|
||||||
|
and message computation.
|
||||||
|
For this, we assume existing node states h^t_v and a list of per-edge-type adjacency
|
||||||
|
matrices A_\ell.
|
||||||
|
|
||||||
|
We compute new states as follows:
|
||||||
|
h^{t+1}_v := MLP_{out}((1 + \epsilon) * MLP_{self}(h^t_v)
|
||||||
|
+ \sum_\ell \sum_{(u, v) \in A_\ell} MLP_\ell(h^t_u))
|
||||||
|
The learnable parameters of this are the MLPs and (if enabled) epsilon.
|
||||||
|
This is derived from Cor. 6 of arXiv:1810.00826, instantiating the functions f, \phi
|
||||||
|
with _separate_ MLPs. This is more powerful than the GIN formulation in Eq. (4.1) of
|
||||||
|
arXiv:1810.00826, as we want to be able to distinguish graphs of the form
|
||||||
|
G_1 = (V={1, 2, 3}, E_1={(1, 2)}, E_2={(3, 2)})
|
||||||
|
and
|
||||||
|
G_2 = (V={1, 2, 3}, E_1={(3, 2)}, E_2={(1, 2)})
|
||||||
|
from each other. If we would treat all edges the same,
|
||||||
|
G_1.E_1 \cup G_1.E_2 == G_2.E_1 \cup G_2.E_2 would imply that the two graphs
|
||||||
|
become indistuingishable.
|
||||||
|
Hence, we introduce per-edge-type MLPs, which also means that we have to drop
|
||||||
|
the optimisation of modelling f \circ \phi by a single MLP used in the original
|
||||||
|
GIN formulation.
|
||||||
|
|
||||||
|
We use the following abbreviations in shape descriptions:
|
||||||
|
* V: number of nodes
|
||||||
|
* D: state dimension
|
||||||
|
* L: number of different edge types
|
||||||
|
* E: number of edges of a given edge type
|
||||||
|
|
||||||
|
Arguments:
|
||||||
|
node_embeddings: float32 tensor of shape [V, D], the original representation of
|
||||||
|
each node in the graph.
|
||||||
|
adjacency_lists: List of L adjacency lists, represented as int32 tensors of shape
|
||||||
|
[E, 2]. Concretely, adjacency_lists[l][k,:] == [v, u] means that the k-th edge
|
||||||
|
of type l connects node v to node u.
|
||||||
|
state_dim: Optional size of output dimension of the GNN layer. If not set, defaults
|
||||||
|
to D, the dimensionality of the input. If different from the input dimension,
|
||||||
|
parameter num_timesteps has to be 1.
|
||||||
|
num_timesteps: Number of repeated applications of this message passing layer.
|
||||||
|
activation_function: Type of activation function used.
|
||||||
|
num_MLP_hidden_layers: Number of hidden layers of the MLPs.
|
||||||
|
learn_epsilon: Flag indicating if the value of epsilon should be learned. If
|
||||||
|
False, epsilon defaults to 0.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
float32 tensor of shape [V, state_dim]
|
||||||
|
"""
|
||||||
|
num_nodes = tf.shape(node_embeddings, out_type=tf.int32)[0]
|
||||||
|
if state_dim is None:
|
||||||
|
state_dim = tf.shape(node_embeddings, out_type=tf.int32)[1]
|
||||||
|
|
||||||
|
# === Prepare things we need across all timesteps:
|
||||||
|
activation_fn = get_activation(activation_function)
|
||||||
|
aggregation_MLP = MLP(out_size=state_dim,
|
||||||
|
num_hidden_layers=num_MLP_hidden_layers,
|
||||||
|
activation_fun=activation_fn)
|
||||||
|
edge_type_to_edge_mlp = [] # MLPs to compute the edge messages
|
||||||
|
edge_type_to_message_targets = [] # List of tensors of message targets
|
||||||
|
for edge_type_idx, adjacency_list_for_edge_type in enumerate(adjacency_lists):
|
||||||
|
with tf.variable_scope("Edge_%i_MLP" % edge_type_idx):
|
||||||
|
edge_type_to_edge_mlp.append(
|
||||||
|
MLP(out_size=state_dim,
|
||||||
|
num_hidden_layers=num_MLP_hidden_layers,
|
||||||
|
activation_fun=activation_fn))
|
||||||
|
edge_type_to_message_targets.append(adjacency_list_for_edge_type[:, 1])
|
||||||
|
# Initialize epsilon: Note that we merge the 1+epsilon here:
|
||||||
|
if learn_epsilon:
|
||||||
|
epsilon = tf.get_variable("epsilon", shape=(), dtype=tf.float32, initializer=tf.ones_initializer, trainable=True)
|
||||||
|
else:
|
||||||
|
epsilon = 1
|
||||||
|
self_loop_MLP = MLP(out_size=state_dim,
|
||||||
|
num_hidden_layers=num_MLP_hidden_layers,
|
||||||
|
activation_fun=activation_fn)
|
||||||
|
|
||||||
|
# Let M be the number of messages (sum of all E):
|
||||||
|
message_targets = tf.concat(edge_type_to_message_targets, axis=0) # Shape [M]
|
||||||
|
|
||||||
|
cur_node_states = node_embeddings
|
||||||
|
for _ in range(num_timesteps):
|
||||||
|
messages_per_type = [] # list of tensors of messages of shape [E, D]
|
||||||
|
# Collect incoming messages per edge type
|
||||||
|
for edge_type_idx, adjacency_list_for_edge_type in enumerate(adjacency_lists):
|
||||||
|
edge_sources = adjacency_list_for_edge_type[:, 0]
|
||||||
|
edge_source_states = \
|
||||||
|
tf.nn.embedding_lookup(params=cur_node_states,
|
||||||
|
ids=edge_sources) # Shape [E, D]
|
||||||
|
messages = edge_type_to_edge_mlp[edge_type_idx](edge_source_states) # Shape [E, D]
|
||||||
|
messages_per_type.append(messages)
|
||||||
|
|
||||||
|
all_messages = tf.concat(messages_per_type, axis=0) # Shape [M, D]
|
||||||
|
aggregated_messages = \
|
||||||
|
tf.unsorted_segment_sum(data=all_messages,
|
||||||
|
segment_ids=message_targets,
|
||||||
|
num_segments=num_nodes) # Shape [V, D]
|
||||||
|
|
||||||
|
cur_node_states = aggregation_MLP(epsilon * self_loop_MLP(cur_node_states) + aggregated_messages)
|
||||||
|
cur_node_states = activation_fn(cur_node_states) # Note that the final MLP layer has no activation, so we do that here explicitly
|
||||||
|
|
||||||
|
return cur_node_states
|
|
@ -5,3 +5,4 @@ from .gnn_film_model import GNN_FiLM_Model
|
||||||
from .rgat_model import RGAT_Model
|
from .rgat_model import RGAT_Model
|
||||||
from .rgcn_model import RGCN_Model
|
from .rgcn_model import RGCN_Model
|
||||||
from .rgdcn_model import RGDCN_Model
|
from .rgdcn_model import RGDCN_Model
|
||||||
|
from .rgin_model import RGIN_Model
|
||||||
|
|
|
@ -0,0 +1,45 @@
|
||||||
|
from typing import Dict, Any, List
|
||||||
|
|
||||||
|
import tensorflow as tf
|
||||||
|
|
||||||
|
from .sparse_graph_model import Sparse_Graph_Model
|
||||||
|
from tasks import Sparse_Graph_Task
|
||||||
|
from gnns import sparse_rgin_layer
|
||||||
|
|
||||||
|
|
||||||
|
class RGIN_Model(Sparse_Graph_Model):
|
||||||
|
@classmethod
|
||||||
|
def default_params(cls):
|
||||||
|
params = super().default_params()
|
||||||
|
params.update({
|
||||||
|
'hidden_size': 128,
|
||||||
|
"graph_activation_function": "ReLU",
|
||||||
|
'graph_layer_input_dropout_keep_prob': 1.0,
|
||||||
|
'graph_dense_between_every_num_gnn_layers': 10000,
|
||||||
|
'graph_residual_connection_every_num_layers': 10000,
|
||||||
|
'graph_num_MLP_hidden_layers': 1,
|
||||||
|
'graph_learn_epsilon': False,
|
||||||
|
})
|
||||||
|
return params
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def name(params: Dict[str, Any]) -> str:
|
||||||
|
return "RGIN"
|
||||||
|
|
||||||
|
def __init__(self, params: Dict[str, Any], task: Sparse_Graph_Task, run_id: str, result_dir: str) -> None:
|
||||||
|
super().__init__(params, task, run_id, result_dir)
|
||||||
|
|
||||||
|
def _apply_gnn_layer(self,
|
||||||
|
node_representations: tf.Tensor,
|
||||||
|
adjacency_lists: List[tf.Tensor],
|
||||||
|
type_to_num_incoming_edges: tf.Tensor,
|
||||||
|
num_timesteps: int) -> tf.Tensor:
|
||||||
|
return sparse_rgin_layer(
|
||||||
|
node_embeddings=node_representations,
|
||||||
|
adjacency_lists=adjacency_lists,
|
||||||
|
state_dim=self.params['hidden_size'],
|
||||||
|
num_timesteps=num_timesteps,
|
||||||
|
activation_function=self.params['graph_activation_function'],
|
||||||
|
num_MLP_hidden_layers=self.params['graph_num_MLP_hidden_layers'],
|
||||||
|
learn_epsilon=self.params['graph_learn_epsilon'],
|
||||||
|
)
|
|
@ -5,7 +5,7 @@ from typing import Tuple, Type, Dict, Any
|
||||||
import pickle
|
import pickle
|
||||||
|
|
||||||
from models import (Sparse_Graph_Model, GGNN_Model, GNN_FiLM_Model, GNN_Edge_MLP_Model,
|
from models import (Sparse_Graph_Model, GGNN_Model, GNN_FiLM_Model, GNN_Edge_MLP_Model,
|
||||||
RGAT_Model, RGCN_Model, RGDCN_Model)
|
RGAT_Model, RGCN_Model, RGDCN_Model, RGIN_Model)
|
||||||
from tasks import Sparse_Graph_Task, QM9_Task, Citation_Network_Task, PPI_Task, VarMisuse_Task
|
from tasks import Sparse_Graph_Task, QM9_Task, Citation_Network_Task, PPI_Task, VarMisuse_Task
|
||||||
|
|
||||||
|
|
||||||
|
@ -49,6 +49,8 @@ def name_to_model_class(name: str) -> Tuple[Type[Sparse_Graph_Model], Dict[str,
|
||||||
return RGCN_Model, {}
|
return RGCN_Model, {}
|
||||||
if name in ["rgdcn", "rgdcn_model"]:
|
if name in ["rgdcn", "rgdcn_model"]:
|
||||||
return RGDCN_Model, {}
|
return RGDCN_Model, {}
|
||||||
|
if name in ["rgin", "rgin_model"]:
|
||||||
|
return RGIN_Model, {}
|
||||||
|
|
||||||
raise ValueError("Unknown model type '%s'" % name)
|
raise ValueError("Unknown model type '%s'" % name)
|
||||||
|
|
||||||
|
|
Загрузка…
Ссылка в новой задаче