From 6a00b4a65c064bd31d8c434bffb064b0d1287728 Mon Sep 17 00:00:00 2001 From: Miltos Allamanis Date: Thu, 11 Mar 2021 17:12:06 +0000 Subject: [PATCH] Simplify PNA aggregation/layers. --- .../gnn/messagepassing/__init__.py | 2 +- .../messagepassing/abstractmessagepassing.py | 10 ++ .../gnn/messagepassing/mlpmessagepassing.py | 37 +++-- .../gnn/messagepassing/pna_aggregation.py | 55 +++++++ .../gnn/messagepassing/pnamessagepassing.py | 142 ------------------ 5 files changed, 92 insertions(+), 154 deletions(-) create mode 100644 ptgnn/neuralmodels/gnn/messagepassing/pna_aggregation.py delete mode 100644 ptgnn/neuralmodels/gnn/messagepassing/pnamessagepassing.py diff --git a/ptgnn/neuralmodels/gnn/messagepassing/__init__.py b/ptgnn/neuralmodels/gnn/messagepassing/__init__.py index c121b0f..7445fa1 100644 --- a/ptgnn/neuralmodels/gnn/messagepassing/__init__.py +++ b/ptgnn/neuralmodels/gnn/messagepassing/__init__.py @@ -2,5 +2,5 @@ from .abstractmessagepassing import AbstractMessagePassingLayer from .gatedmessagepassing import GatedMessagePassingLayer from .globalgraphexchange import AbstractGlobalGraphExchange, GruGlobalStateUpdate from .mlpmessagepassing import MlpMessagePassingLayer -from .pnamessagepassing import PnaMessagePassingLayer +from .pna_aggregation import PnaMessageAggregation from .residuallayers import LinearResidualLayer, MeanResidualLayer diff --git a/ptgnn/neuralmodels/gnn/messagepassing/abstractmessagepassing.py b/ptgnn/neuralmodels/gnn/messagepassing/abstractmessagepassing.py index 5ce698c..ada96ac 100644 --- a/ptgnn/neuralmodels/gnn/messagepassing/abstractmessagepassing.py +++ b/ptgnn/neuralmodels/gnn/messagepassing/abstractmessagepassing.py @@ -52,3 +52,13 @@ class AbstractMessagePassingLayer(nn.Module): @abstractmethod def output_state_dimension(self) -> int: pass + + +class AbstractMessageAggregation(nn.Module): + @abstractmethod + def forward(self, messages: torch.Tensor, message_targets: torch.Tensor, num_nodes): + pass + + @abstractmethod + def output_state_size(self, message_input_size: int) -> int: + pass diff --git a/ptgnn/neuralmodels/gnn/messagepassing/mlpmessagepassing.py b/ptgnn/neuralmodels/gnn/messagepassing/mlpmessagepassing.py index f5fca5a..ad12bc1 100644 --- a/ptgnn/neuralmodels/gnn/messagepassing/mlpmessagepassing.py +++ b/ptgnn/neuralmodels/gnn/messagepassing/mlpmessagepassing.py @@ -1,8 +1,11 @@ import torch from torch import nn -from typing import Dict, List, Optional, Tuple, Union +from typing import Callable, Dict, List, Optional, Tuple, Union -from ptgnn.neuralmodels.gnn.messagepassing.abstractmessagepassing import AbstractMessagePassingLayer +from ptgnn.neuralmodels.gnn.messagepassing.abstractmessagepassing import ( + AbstractMessageAggregation, + AbstractMessagePassingLayer, +) from ptgnn.neuralmodels.mlp import MLP @@ -13,7 +16,7 @@ class MlpMessagePassingLayer(AbstractMessagePassingLayer): output_state_dimension: int, message_dimension: int, num_edge_types: int, - message_aggregation_function: str, + message_aggregation_function: Union[str, AbstractMessageAggregation], message_activation: Optional[nn.Module] = nn.GELU(), use_target_state_as_message_input: bool = True, mlp_hidden_layers: Union[List[int], int] = 0, @@ -43,13 +46,18 @@ class MlpMessagePassingLayer(AbstractMessagePassingLayer): ] ) self.__aggregation_fn = message_aggregation_function + if isinstance(message_aggregation_function, str): + aggregated_state_size = message_dimension + else: + aggregated_state_size = self.__aggregation_fn.output_state_size(message_dimension) + self.__message_activation = message_activation state_update_layers: List[nn.Module] = [] if use_layer_norm: - state_update_layers.append(nn.LayerNorm(message_dimension)) + state_update_layers.append(nn.LayerNorm(aggregated_state_size)) if use_dense_layer: - state_update_layers.append(nn.Linear(message_dimension, output_state_dimension)) + state_update_layers.append(nn.Linear(aggregated_state_size, output_state_dimension)) nn.init.xavier_uniform_(state_update_layers[-1].weight) if dense_activation is not None: state_update_layers.append(dense_activation) @@ -87,12 +95,19 @@ class MlpMessagePassingLayer(AbstractMessagePassingLayer): edge_transformation_layer(torch.cat([message_input, features], dim=-1)) ) - aggregated_messages = self._aggregate_messages( - messages=torch.cat(all_messages, dim=0), - message_targets=torch.cat(all_message_targets, dim=0), - num_nodes=node_states.shape[0], - aggregation_fn=self.__aggregation_fn, - ) + if isinstance(self.__aggregation_fn, AbstractMessageAggregation): + aggregated_messages = self.__aggregation_fn( + messages=torch.cat(all_messages, dim=0), + message_targets=torch.cat(all_message_targets, dim=0), + num_nodes=node_states.shape[0], + ) + else: + aggregated_messages = self._aggregate_messages( + messages=torch.cat(all_messages, dim=0), + message_targets=torch.cat(all_message_targets, dim=0), + num_nodes=node_states.shape[0], + aggregation_fn=self.__aggregation_fn, + ) if self.__message_activation is not None: aggregated_messages = self.__message_activation(aggregated_messages) diff --git a/ptgnn/neuralmodels/gnn/messagepassing/pna_aggregation.py b/ptgnn/neuralmodels/gnn/messagepassing/pna_aggregation.py new file mode 100644 index 0000000..7ad0e75 --- /dev/null +++ b/ptgnn/neuralmodels/gnn/messagepassing/pna_aggregation.py @@ -0,0 +1,55 @@ +import torch +from torch import nn +from torch_scatter import scatter +from typing import Dict, List, Optional, Tuple, Union + +from ptgnn.neuralmodels.gnn.messagepassing.abstractmessagepassing import ( + AbstractMessageAggregation, + AbstractMessagePassingLayer, +) +from ptgnn.neuralmodels.mlp import MLP + + +class PnaMessageAggregation(AbstractMessageAggregation): + """ + Principal Neighbourhood Aggregation for Graph Nets + + https://arxiv.org/abs/2004.05718 + """ + + def __init__( + self, + delta: float = 1, + ): + super().__init__() + self._delta = delta # See Eq 5 of paper + + def forward(self, messages: torch.Tensor, message_targets: torch.Tensor, num_nodes): + degree = scatter( + torch.ones_like(message_targets), + index=message_targets, + dim_size=num_nodes, + reduce="sum", + ) + + sum_agg = scatter(messages, index=message_targets, dim=0, dim_size=num_nodes, reduce="sum") + mean_agg = sum_agg / (degree.unsqueeze(-1) + 1e-5) + max_agg = scatter(messages, index=message_targets, dim=0, dim_size=num_nodes, reduce="max") + min_agg = scatter(messages, index=message_targets, dim=0, dim_size=num_nodes, reduce="min") + + std_components = torch.relu(messages.pow(2) - mean_agg[message_targets].pow(2)) + 1e-10 + std = torch.sqrt( + scatter(std_components, index=message_targets, dim=0, dim_size=num_nodes, reduce="sum") + ) + + all_aggregations = torch.cat([sum_agg, mean_agg, max_agg, min_agg, std], dim=-1) + + scaler_p1 = torch.log(degree.float() + 1).unsqueeze(-1) / self._delta + scaler_m1 = 1 / (scaler_p1 + 1e-3) + + return torch.cat( + [all_aggregations, all_aggregations * scaler_p1, all_aggregations * scaler_m1], dim=-1 + ) + + def output_state_size(self, message_input_size: int) -> int: + return message_input_size * 5 * 3 diff --git a/ptgnn/neuralmodels/gnn/messagepassing/pnamessagepassing.py b/ptgnn/neuralmodels/gnn/messagepassing/pnamessagepassing.py deleted file mode 100644 index 2aa40e0..0000000 --- a/ptgnn/neuralmodels/gnn/messagepassing/pnamessagepassing.py +++ /dev/null @@ -1,142 +0,0 @@ -import torch -from torch import nn -from torch_scatter import scatter -from typing import Dict, List, Optional, Tuple, Union - -from ptgnn.neuralmodels.gnn.messagepassing.abstractmessagepassing import AbstractMessagePassingLayer -from ptgnn.neuralmodels.mlp import MLP - - -class PnaMessagePassingLayer(AbstractMessagePassingLayer): - """ - Principal Neighbourhood Aggregation for Graph Nets - - https://arxiv.org/abs/2004.05718 - """ - - def __init__( - self, - input_state_dimension: int, - output_state_dimension: int, - message_dimension: int, - num_edge_types: int, - message_aggregation_function: str, - message_activation: Optional[nn.Module] = nn.GELU(), - use_target_state_as_message_input: bool = True, - mlp_hidden_layers: Union[List[int], int] = 0, - use_layer_norm: bool = True, - use_dense_layer: bool = True, - dropout_rate: float = 0.0, - dense_activation: Optional[nn.Module] = nn.Tanh(), - delta: float = 1, - ): - super().__init__() - self.__input_state_dim = input_state_dimension - self.__use_target_state_as_message_input = use_target_state_as_message_input - self.__output_state_dim = output_state_dimension - - if use_target_state_as_message_input: - message_input_size = 2 * input_state_dimension - else: - message_input_size = input_state_dimension - self.__edge_message_transformation_layers = nn.ModuleList( - [ - MLP( - input_dimension=message_input_size, - output_dimension=message_dimension, - hidden_layers=mlp_hidden_layers, - ) - for _ in range(num_edge_types) - ] - ) - self.__aggregation_fn = message_aggregation_function - self.__message_activation = message_activation - - state_update_layers: List[nn.Module] = [] - if use_layer_norm: - state_update_layers.append(nn.LayerNorm(message_dimension * 5 * 3)) - if use_dense_layer: - state_update_layers.append(nn.Linear(message_dimension * 5 * 3, output_state_dimension)) - nn.init.xavier_uniform_(state_update_layers[-1].weight) - if dense_activation is not None: - state_update_layers.append(dense_activation) - state_update_layers.append(nn.Dropout(p=dropout_rate)) - - self.__state_update = nn.Sequential(*state_update_layers) - self._delta = delta # See Eq 5 of paper - - def __pna_aggregation_and_scaling( - self, messages: torch.Tensor, message_targets: torch.Tensor, num_nodes - ): - degree = scatter( - torch.ones_like(message_targets), - index=message_targets, - dim_size=num_nodes, - reduce="sum", - ) - - sum_agg = scatter(messages, index=message_targets, dim=0, dim_size=num_nodes, reduce="sum") - mean_agg = sum_agg / (degree.unsqueeze(-1) + 1e-5) - max_agg = scatter(messages, index=message_targets, dim=0, dim_size=num_nodes, reduce="max") - min_agg = scatter(messages, index=message_targets, dim=0, dim_size=num_nodes, reduce="min") - - std_components = torch.relu(messages.pow(2) - mean_agg[message_targets].pow(2)) + 1e-10 - std = torch.sqrt( - scatter(std_components, index=message_targets, dim=0, dim_size=num_nodes, reduce="sum") - ) - - all_aggregations = torch.cat([sum_agg, mean_agg, max_agg, min_agg, std], dim=-1) - - scaler_p1 = torch.log(degree.float() + 1).unsqueeze(-1) / self._delta - scaler_m1 = 1 / (scaler_p1 + 1e-3) - - return torch.cat( - [all_aggregations, all_aggregations * scaler_p1, all_aggregations * scaler_m1], dim=-1 - ) - - def forward( - self, - node_states: torch.Tensor, - adjacency_lists: List[Tuple[torch.Tensor, torch.Tensor]], - node_to_graph_idx: torch.Tensor, - reference_node_ids: Dict[str, torch.Tensor], - reference_node_graph_idx: Dict[str, torch.Tensor], - ) -> torch.Tensor: - assert len(adjacency_lists) == len(self.__edge_message_transformation_layers) - - all_message_targets, all_messages = [], [] - for edge_type_idx, (adj_list, edge_transformation_layer) in enumerate( - zip(adjacency_lists, self.__edge_message_transformation_layers) - ): - edge_sources_idxs, edge_target_idxs = adj_list - all_message_targets.append(edge_target_idxs) - - edge_source_states = nn.functional.embedding(edge_sources_idxs, node_states) - - if self.__use_target_state_as_message_input: - edge_target_states = nn.functional.embedding(edge_target_idxs, node_states) - message_input = torch.cat([edge_source_states, edge_target_states], dim=-1) - else: - message_input = edge_source_states - - all_messages.append(edge_transformation_layer(message_input)) - - all_messages = torch.cat(all_messages, dim=0) - if self.__message_activation is not None: - all_messages = self.__message_activation(all_messages) - - aggregated_messages = self.__pna_aggregation_and_scaling( - messages=all_messages, - message_targets=torch.cat(all_message_targets, dim=0), - num_nodes=node_states.shape[0], - ) - - return self.__state_update(aggregated_messages) # num_nodes x H - - @property - def input_state_dimension(self) -> int: - return self.__input_state_dim - - @property - def output_state_dimension(self) -> int: - return self.__output_state_dim