зеркало из https://github.com/microsoft/ptgnn.git
Improve some code comments about message-passing layers.
This commit is contained in:
Родитель
27b3102184
Коммит
5020bb7f6f
|
@ -20,13 +20,19 @@ class AbstractMessagePassingLayer(nn.Module):
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
"""
|
"""
|
||||||
:param node_states: A [num_nodes, D] matrix containing the states of all nodes.
|
:param node_states: A [num_nodes, D] matrix containing the states of all nodes.
|
||||||
:param adjacency_lists: A list of [num_edges, 2] adjacency lists, one per edge type.
|
:param adjacency_lists: A list with as many elements as edge types. Each element in the list
|
||||||
:param node_to_graph_idx:
|
is a pair (tuple) of [num_edges_for_edge_type]-sized tensors indicating the left/right
|
||||||
:param reference_node_ids:
|
hand-side of the elements in the adjacency list.
|
||||||
:param reference_node_graph_idx:
|
:param node_to_graph_idx: a vector of shape [num_nodes] indicating the graph that each node
|
||||||
|
belongs to.
|
||||||
|
:param reference_node_ids: a dictionary that maps each reference (key) to
|
||||||
|
the indices of the reference nodes in `node_states`.
|
||||||
|
:param reference_node_graph_idx: a dictionary that maps each reference (key) to
|
||||||
|
the graph it belongs to. For each reference `ref_name`,
|
||||||
|
len(reference_node_ids[ref_name])==len(reference_node_graph_idx[ref_name])
|
||||||
:param edge_features: A list of [num_edges, H] with edge features.
|
:param edge_features: A list of [num_edges, H] with edge features.
|
||||||
Has the size of `adjacency_lists`.
|
Has the same length as `adjacency_lists`, ie. len(adjacency_lists) == len(edge_features)
|
||||||
:return: the next node states in a [num_nodes, D'] matrix.
|
:return: the output node states in a [num_nodes, D'] matrix.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def _aggregate_messages(
|
def _aggregate_messages(
|
||||||
|
|
|
@ -74,7 +74,9 @@ class MlpMessagePassingLayer(AbstractMessagePassingLayer):
|
||||||
reference_node_graph_idx: Dict[str, torch.Tensor],
|
reference_node_graph_idx: Dict[str, torch.Tensor],
|
||||||
edge_features: List[torch.Tensor],
|
edge_features: List[torch.Tensor],
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
assert len(adjacency_lists) == len(self.__edge_message_transformation_layers)
|
assert len(adjacency_lists) == len(
|
||||||
|
self.__edge_message_transformation_layers
|
||||||
|
), "The number of adjacency lists must be equal to the number of edge types."
|
||||||
|
|
||||||
all_message_targets, all_messages = [], []
|
all_message_targets, all_messages = [], []
|
||||||
for edge_type_idx, (adj_list, features, edge_transformation_layer) in enumerate(
|
for edge_type_idx, (adj_list, features, edge_transformation_layer) in enumerate(
|
||||||
|
|
Загрузка…
Ссылка в новой задаче