Source code for torchdrug.models.mpnn

import torch
from torch import nn

from torchdrug import core, layers
from torchdrug.core import Registry as R


[docs]@R.register("models.MPNN") class MessagePassingNeuralNetwork(nn.Module, core.Configurable): """ Message Passing Neural Network proposed in `Neural Message Passing for Quantum Chemistry`_. This implements the enn-s2s variant in the original paper. .. _Neural Message Passing for Quantum Chemistry: https://arxiv.org/pdf/1704.01212.pdf Parameters: input_dim (int): input dimension hidden_dim (int): hidden dimension edge_input_dim (int): dimension of edge features num_layer (int, optional): number of hidden layers num_gru_layer (int, optional): number of GRU layers in each node update num_mlp_layer (int, optional): number of MLP layers in each message function num_s2s_step (int, optional): number of processing steps in set2set short_cut (bool, optional): use short cut or not batch_norm (bool, optional): apply batch normalization or not activation (str or function, optional): activation function concat_hidden (bool, optional): concat hidden representations from all layers as output """ def __init__(self, input_dim, hidden_dim, edge_input_dim, num_layer=1, num_gru_layer=1, num_mlp_layer=2, num_s2s_step=3, short_cut=False, batch_norm=False, activation="relu", concat_hidden=False): super(MessagePassingNeuralNetwork, self).__init__() self.input_dim = input_dim self.edge_input_dim = edge_input_dim if concat_hidden: feature_dim = hidden_dim * num_layer else: feature_dim = hidden_dim self.output_dim = feature_dim * 2 self.num_layer = num_layer self.short_cut = short_cut self.concat_hidden = concat_hidden self.linear = nn.Linear(input_dim, hidden_dim) self.layer = layers.MessagePassing(hidden_dim, edge_input_dim, [hidden_dim] * (num_mlp_layer - 1), batch_norm, activation) self.gru = nn.GRU(hidden_dim, hidden_dim, num_gru_layer) self.readout = layers.Set2Set(feature_dim, num_s2s_step)
[docs] def forward(self, graph, input, all_loss=None, metric=None): """ Compute the node representations and the graph representation(s). Parameters: graph (Graph): :math:`n` graph(s) input (Tensor): input node representations all_loss (Tensor, optional): if specified, add loss to this tensor metric (dict, optional): if specified, output metrics to this dict Returns: dict with ``node_feature`` and ``graph_feature`` fields: node representations of shape :math:`(|V|, d)`, graph representations of shape :math:`(n, d)` """ hiddens = [] layer_input = self.linear(input) hx = layer_input.repeat(self.gru.num_layers, 1, 1) for i in range(self.num_layer): x = self.layer(graph, layer_input) hidden, hx = self.gru(x.unsqueeze(0), hx) hidden = hidden.squeeze(0) if self.short_cut and hidden.shape == layer_input.shape: hidden = hidden + layer_input hiddens.append(hidden) layer_input = hidden if self.concat_hidden: node_feature = torch.cat(hiddens, dim=-1) else: node_feature = hiddens[-1] graph_feature = self.readout(graph, node_feature) return { "graph_feature": graph_feature, "node_feature": node_feature }