Improve some code comments about message-passing layers.

This commit is contained in:
Miltos Allamanis 2021-07-20 17:21:38 +01:00 коммит произвёл Miltos
Родитель 27b3102184
Коммит 5020bb7f6f
2 изменённых файлов: 15 добавлений и 7 удалений

Просмотреть файл

@ -20,13 +20,19 @@ class AbstractMessagePassingLayer(nn.Module):
) -> torch.Tensor: ) -> torch.Tensor:
""" """
:param node_states: A [num_nodes, D] matrix containing the states of all nodes. :param node_states: A [num_nodes, D] matrix containing the states of all nodes.
:param adjacency_lists: A list of [num_edges, 2] adjacency lists, one per edge type. :param adjacency_lists: A list with as many elements as edge types. Each element in the list
:param node_to_graph_idx: is a pair (tuple) of [num_edges_for_edge_type]-sized tensors indicating the left/right
:param reference_node_ids: hand-side of the elements in the adjacency list.
:param reference_node_graph_idx: :param node_to_graph_idx: a vector of shape [num_nodes] indicating the graph that each node
belongs to.
:param reference_node_ids: a dictionary that maps each reference (key) to
the indices of the reference nodes in `node_states`.
:param reference_node_graph_idx: a dictionary that maps each reference (key) to
the graph it belongs to. For each reference `ref_name`,
len(reference_node_ids[ref_name])==len(reference_node_graph_idx[ref_name])
:param edge_features: A list of [num_edges, H] with edge features. :param edge_features: A list of [num_edges, H] with edge features.
Has the size of `adjacency_lists`. Has the same length as `adjacency_lists`, ie. len(adjacency_lists) == len(edge_features)
:return: the next node states in a [num_nodes, D'] matrix. :return: the output node states in a [num_nodes, D'] matrix.
""" """
def _aggregate_messages( def _aggregate_messages(

Просмотреть файл

@ -74,7 +74,9 @@ class MlpMessagePassingLayer(AbstractMessagePassingLayer):
reference_node_graph_idx: Dict[str, torch.Tensor], reference_node_graph_idx: Dict[str, torch.Tensor],
edge_features: List[torch.Tensor], edge_features: List[torch.Tensor],
) -> torch.Tensor: ) -> torch.Tensor:
assert len(adjacency_lists) == len(self.__edge_message_transformation_layers) assert len(adjacency_lists) == len(
self.__edge_message_transformation_layers
), "The number of adjacency lists must be equal to the number of edge types."
all_message_targets, all_messages = [], [] all_message_targets, all_messages = [], []
for edge_type_idx, (adj_list, features, edge_transformation_layer) in enumerate( for edge_type_idx, (adj_list, features, edge_transformation_layer) in enumerate(