diff --git a/ptgnn/neuralmodels/gnn/messagepassing/abstractmessagepassing.py b/ptgnn/neuralmodels/gnn/messagepassing/abstractmessagepassing.py index ada96ac..a68e536 100644 --- a/ptgnn/neuralmodels/gnn/messagepassing/abstractmessagepassing.py +++ b/ptgnn/neuralmodels/gnn/messagepassing/abstractmessagepassing.py @@ -20,13 +20,19 @@ class AbstractMessagePassingLayer(nn.Module): ) -> torch.Tensor: """ :param node_states: A [num_nodes, D] matrix containing the states of all nodes. - :param adjacency_lists: A list of [num_edges, 2] adjacency lists, one per edge type. - :param node_to_graph_idx: - :param reference_node_ids: - :param reference_node_graph_idx: + :param adjacency_lists: A list with as many elements as edge types. Each element in the list + is a pair (tuple) of [num_edges_for_edge_type]-sized tensors indicating the left/right + hand-side of the elements in the adjacency list. + :param node_to_graph_idx: a vector of shape [num_nodes] indicating the graph that each node + belongs to. + :param reference_node_ids: a dictionary that maps each reference (key) to + the indices of the reference nodes in `node_states`. + :param reference_node_graph_idx: a dictionary that maps each reference (key) to + the graph it belongs to. For each reference `ref_name`, + len(reference_node_ids[ref_name])==len(reference_node_graph_idx[ref_name]) :param edge_features: A list of [num_edges, H] with edge features. - Has the size of `adjacency_lists`. - :return: the next node states in a [num_nodes, D'] matrix. + Has the same length as `adjacency_lists`, ie. len(adjacency_lists) == len(edge_features) + :return: the output node states in a [num_nodes, D'] matrix. """ def _aggregate_messages( diff --git a/ptgnn/neuralmodels/gnn/messagepassing/mlpmessagepassing.py b/ptgnn/neuralmodels/gnn/messagepassing/mlpmessagepassing.py index ad12bc1..bc777cd 100644 --- a/ptgnn/neuralmodels/gnn/messagepassing/mlpmessagepassing.py +++ b/ptgnn/neuralmodels/gnn/messagepassing/mlpmessagepassing.py @@ -74,7 +74,9 @@ class MlpMessagePassingLayer(AbstractMessagePassingLayer): reference_node_graph_idx: Dict[str, torch.Tensor], edge_features: List[torch.Tensor], ) -> torch.Tensor: - assert len(adjacency_lists) == len(self.__edge_message_transformation_layers) + assert len(adjacency_lists) == len( + self.__edge_message_transformation_layers + ), "The number of adjacency lists must be equal to the number of edge types." all_message_targets, all_messages = [], [] for edge_type_idx, (adj_list, features, edge_transformation_layer) in enumerate(