Improve some code comments about message-passing layers.

This commit is contained in:
Miltos Allamanis 2021-07-20 17:21:38 +01:00 коммит произвёл Miltos
Родитель 27b3102184
Коммит 5020bb7f6f
2 изменённых файлов: 15 добавлений и 7 удалений

Просмотреть файл

@ -20,13 +20,19 @@ class AbstractMessagePassingLayer(nn.Module):
) -> torch.Tensor:
"""
:param node_states: A [num_nodes, D] matrix containing the states of all nodes.
:param adjacency_lists: A list of [num_edges, 2] adjacency lists, one per edge type.
:param node_to_graph_idx:
:param reference_node_ids:
:param reference_node_graph_idx:
:param adjacency_lists: A list with as many elements as edge types. Each element in the list
is a pair (tuple) of [num_edges_for_edge_type]-sized tensors indicating the left/right
hand-side of the elements in the adjacency list.
:param node_to_graph_idx: a vector of shape [num_nodes] indicating the graph that each node
belongs to.
:param reference_node_ids: a dictionary that maps each reference (key) to
the indices of the reference nodes in `node_states`.
:param reference_node_graph_idx: a dictionary that maps each reference (key) to
the graph it belongs to. For each reference `ref_name`,
len(reference_node_ids[ref_name])==len(reference_node_graph_idx[ref_name])
:param edge_features: A list of [num_edges, H] with edge features.
Has the size of `adjacency_lists`.
:return: the next node states in a [num_nodes, D'] matrix.
Has the same length as `adjacency_lists`, ie. len(adjacency_lists) == len(edge_features)
:return: the output node states in a [num_nodes, D'] matrix.
"""
def _aggregate_messages(

Просмотреть файл

@ -74,7 +74,9 @@ class MlpMessagePassingLayer(AbstractMessagePassingLayer):
reference_node_graph_idx: Dict[str, torch.Tensor],
edge_features: List[torch.Tensor],
) -> torch.Tensor:
assert len(adjacency_lists) == len(self.__edge_message_transformation_layers)
assert len(adjacency_lists) == len(
self.__edge_message_transformation_layers
), "The number of adjacency lists must be equal to the number of edge types."
all_message_targets, all_messages = [], []
for edge_type_idx, (adj_list, features, edge_transformation_layer) in enumerate(