Source code for torchdrug.layers.conv

import functools

import torch
from torch import nn
from torch.nn import functional as F
from torch.utils import checkpoint
from torch_scatter import scatter_mean, scatter_add, scatter_max

from torchdrug import data, layers, utils
from torchdrug.layers import functional


[docs]class MessagePassingBase(nn.Module): """ Base module for message passing. Any custom message passing module should be derived from this class. """ gradient_checkpoint = False
[docs] def message(self, graph, input): """ Compute edge messages for the graph. Parameters: graph (Graph): graph(s) input (Tensor): node representations of shape :math:`(|V|, ...)` Returns: Tensor: edge messages of shape :math:`(|E|, ...)` """ raise NotImplementedError
[docs] def aggregate(self, graph, message): """ Aggregate edge messages to nodes. Parameters: graph (Graph): graph(s) message (Tensor): edge messages of shape :math:`(|E|, ...)` Returns: Tensor: node updates of shape :math:`(|V|, ...)` """ raise NotImplementedError
[docs] def message_and_aggregate(self, graph, input): """ Fused computation of message and aggregation over the graph. This may provide better time or memory complexity than separate calls of :meth:`message <MessagePassingBase.message>` and :meth:`aggregate <MessagePassingBase.aggregate>`. Parameters: graph (Graph): graph(s) input (Tensor): node representations of shape :math:`(|V|, ...)` Returns: Tensor: node updates of shape :math:`(|V|, ...)` """ message = self.message(graph, input) update = self.aggregate(graph, message) return update
def _message_and_aggregate(self, *tensors): graph = data.Graph.from_tensors(tensors[:-1]) input = tensors[-1] update = self.message_and_aggregate(graph, input) return update
[docs] def combine(self, input, update): """ Combine node input and node update. Parameters: input (Tensor): node representations of shape :math:`(|V|, ...)` update (Tensor): node updates of shape :math:`(|V|, ...)` """ raise NotImplementedError
[docs] def forward(self, graph, input): """ Perform message passing over the graph(s). Parameters: graph (Graph): graph(s) input (Tensor): node representations of shape :math:`(|V|, ...)` """ if self.gradient_checkpoint: update = checkpoint.checkpoint(self._message_and_aggregate, *graph.to_tensors(), input) else: update = self.message_and_aggregate(graph, input) output = self.combine(input, update) return output
[docs]class GraphConv(MessagePassingBase): """ Graph convolution operator from `Semi-Supervised Classification with Graph Convolutional Networks`_. .. _Semi-Supervised Classification with Graph Convolutional Networks: https://arxiv.org/pdf/1609.02907.pdf Parameters: input_dim (int): input dimension output_dim (int): output dimension edge_input_dim (int, optional): dimension of edge features batch_norm (bool, optional): apply batch normalization on nodes or not activation (str or function, optional): activation function """ def __init__(self, input_dim, output_dim, edge_input_dim=None, batch_norm=False, activation="relu"): super(GraphConv, self).__init__() self.input_dim = input_dim self.output_dim = output_dim self.edge_input_dim = edge_input_dim if batch_norm: self.batch_norm = nn.BatchNorm1d(output_dim) else: self.batch_norm = None if isinstance(activation, str): self.activation = getattr(F, activation) else: self.activation = activation self.linear = nn.Linear(input_dim, output_dim) if edge_input_dim: self.edge_linear = nn.Linear(edge_input_dim, input_dim) else: self.edge_linear = None
[docs] def message(self, graph, input): # add self loop node_in = torch.cat([graph.edge_list[:, 0], torch.arange(graph.num_node, device=graph.device)]) degree_in = graph.degree_in.unsqueeze(-1) + 1 message = input[node_in] if self.edge_linear: edge_input = self.edge_linear(graph.edge_feature.float()) edge_input = torch.cat([edge_input, torch.zeros(graph.num_node, self.input_dim, device=graph.device)]) message += edge_input message /= (degree_in[node_in].sqrt() + 1e-10) return message
[docs] def aggregate(self, graph, message): # add self loop node_out = torch.cat([graph.edge_list[:, 1], torch.arange(graph.num_node, device=graph.device)]) edge_weight = torch.cat([graph.edge_weight, torch.ones(graph.num_node, device=graph.device)]) edge_weight = edge_weight.unsqueeze(-1) degree_out = graph.degree_out.unsqueeze(-1) + 1 update = scatter_add(message * edge_weight, node_out, dim=0, dim_size=graph.num_node) update = update / (degree_out.sqrt() + 1e-10) return update
[docs] def message_and_aggregate(self, graph, input): node_in, node_out = graph.edge_list.t()[:2] node_in = torch.cat([node_in, torch.arange(graph.num_node, device=graph.device)]) node_out = torch.cat([node_out, torch.arange(graph.num_node, device=graph.device)]) edge_weight = torch.cat([graph.edge_weight, torch.ones(graph.num_node, device=graph.device)]) degree_in = graph.degree_in + 1 degree_out = graph.degree_out + 1 edge_weight = edge_weight / ((degree_in[node_in] * degree_out[node_out]).sqrt() + 1e-10) adjacency = utils.sparse_coo_tensor(torch.stack([node_in, node_out]), edge_weight, (graph.num_node, graph.num_node)) update = torch.sparse.mm(adjacency.t(), input) if self.edge_linear: edge_input = graph.edge_feature.float() edge_input = torch.cat([self.edge_linear(edge_input), torch.zeros(graph.num_node, self.input_dim, device=graph.device)]) edge_weight = edge_weight.unsqueeze(-1) node_out = torch.cat([graph.edge_list[:, 1], torch.arange(graph.num_node, device=graph.device)]) edge_update = scatter_add(edge_input * edge_weight, node_out, dim=0, dim_size=graph.num_node) update += edge_update return update
[docs] def combine(self, input, update): output = self.linear(update) if self.batch_norm: output = self.batch_norm(output) if self.activation: output = self.activation(output) return output
[docs]class GraphAttentionConv(MessagePassingBase): """ Graph attentional convolution operator from `Graph Attention Networks`_. .. _Graph Attention Networks: https://arxiv.org/pdf/1710.10903.pdf Parameters: input_dim (int): input dimension output_dim (int): output dimension edge_input_dim (int, optional): dimension of edge features num_head (int, optional): number of attention heads negative_slope (float, optional): negative slope of leaky relu activation batch_norm (bool, optional): apply batch normalization on nodes or not activation (str or function, optional): activation function """ eps = 1e-10 def __init__(self, input_dim, output_dim, edge_input_dim=None, num_head=1, negative_slope=0.2, concat=True, batch_norm=False, activation="relu"): super(GraphAttentionConv, self).__init__() self.input_dim = input_dim self.output_dim = output_dim self.edge_input_dim = edge_input_dim self.num_head = num_head self.concat = concat self.leaky_relu = functools.partial(F.leaky_relu, negative_slope=negative_slope) if batch_norm: self.batch_norm = nn.BatchNorm1d(output_dim) else: self.batch_norm = None if isinstance(activation, str): self.activation = getattr(F, activation) else: self.activation = activation if output_dim % num_head != 0: raise ValueError("Expect output_dim to be a multiplier of num_head, but found `%d` and `%d`" % (output_dim, num_head)) self.linear = nn.Linear(input_dim, output_dim) if edge_input_dim: self.edge_linear = nn.Linear(edge_input_dim, output_dim) else: self.edge_linear = None self.query = nn.Parameter(torch.zeros(num_head, output_dim * 2 // num_head)) nn.init.kaiming_uniform_(self.query, negative_slope, mode="fan_in")
[docs] def message(self, graph, input): # add self loop node_in = torch.cat([graph.edge_list[:, 0], torch.arange(graph.num_node, device=graph.device)]) node_out = torch.cat([graph.edge_list[:, 1], torch.arange(graph.num_node, device=graph.device)]) edge_weight = torch.cat([graph.edge_weight, torch.ones(graph.num_node, device=graph.device)]) edge_weight = edge_weight.unsqueeze(-1) hidden = self.linear(input) key = torch.stack([hidden[node_in], hidden[node_out]], dim=-1) if self.edge_linear: edge_input = self.edge_linear(graph.edge_feature.float()) edge_input = torch.cat([edge_input, torch.zeros(graph.num_node, self.output_dim, device=graph.device)]) key += edge_input.unsqueeze(-1) key = key.view(-1, *self.query.shape) weight = torch.einsum("hd, nhd -> nh", self.query, key) weight = self.leaky_relu(weight) weight = weight - scatter_max(weight, node_out, dim=0, dim_size=graph.num_node)[0][node_out] attention = weight.exp() * edge_weight # why mean? because with mean we have normalized message scale across different node degrees normalizer = scatter_mean(attention, node_out, dim=0, dim_size=graph.num_node)[node_out] attention = attention / (normalizer + self.eps) value = hidden[node_in].view(-1, self.num_head, self.query.shape[-1] // 2) attention = attention.unsqueeze(-1).expand_as(value) message = (attention * value).flatten(1) return message
[docs] def aggregate(self, graph, message): # add self loop node_out = torch.cat([graph.edge_list[:, 1], torch.arange(graph.num_node, device=graph.device)]) update = scatter_mean(message, node_out, dim=0, dim_size=graph.num_node) return update
[docs] def combine(self, input, update): output = update if self.batch_norm: output = self.batch_norm(output) if self.activation: output = self.activation(output) return output
[docs]class GraphIsomorphismConv(MessagePassingBase): """ Graph isomorphism convolution operator from `How Powerful are Graph Neural Networks?`_ .. _How Powerful are Graph Neural Networks?: https://arxiv.org/pdf/1810.00826.pdf Parameters: input_dim (int): input dimension output_dim (int): output dimension edge_input_dim (int, optional): dimension of edge features hidden_dims (list of int, optional): hidden dimensions eps (float, optional): initial epsilon learn_eps (bool, optional): learn epsilon or not batch_norm (bool, optional): apply batch normalization on nodes or not activation (str or function, optional): activation function """ def __init__(self, input_dim, output_dim, edge_input_dim=None, hidden_dims=None, eps=0, learn_eps=False, batch_norm=False, activation="relu"): super(GraphIsomorphismConv, self).__init__() self.input_dim = input_dim self.output_dim = output_dim self.edge_input_dim = edge_input_dim eps = torch.tensor([eps], dtype=torch.float32) if learn_eps: self.eps = nn.Parameter(eps) else: self.register_buffer("eps", eps) if batch_norm: self.batch_norm = nn.BatchNorm1d(output_dim) else: self.batch_norm = None if isinstance(activation, str): self.activation = getattr(F, activation) else: self.activation = activation if hidden_dims is None: hidden_dims = [] self.mlp = layers.MLP(input_dim, list(hidden_dims) + [output_dim], activation) if edge_input_dim: self.edge_linear = nn.Linear(edge_input_dim, input_dim) else: self.edge_linear = None
[docs] def message(self, graph, input): node_in = graph.edge_list[:, 0] message = input[node_in] if self.edge_linear: message += self.edge_linear(graph.edge_feature.float()) return message
[docs] def aggregate(self, graph, message): node_out = graph.edge_list[:, 1] edge_weight = graph.edge_weight.unsqueeze(-1) update = scatter_add(message * edge_weight, node_out, dim=0, dim_size=graph.num_node) return update
[docs] def message_and_aggregate(self, graph, input): adjacency = utils.sparse_coo_tensor(graph.edge_list.t()[:2], graph.edge_weight, (graph.num_node, graph.num_node)) update = torch.sparse.mm(adjacency.t(), input) if self.edge_linear: edge_input = graph.edge_feature.float() edge_weight = graph.edge_weight.unsqueeze(-1) edge_input = self.edge_linear(edge_input) edge_update = scatter_add(edge_input * edge_weight, graph.edge_list[:, 1], dim=0, dim_size=graph.num_node) update += edge_update return update
[docs] def combine(self, input, update): output = self.mlp((1 + self.eps) * input + update) if self.batch_norm: output = self.batch_norm(output) if self.activation: output = self.activation(output) return output
[docs]class RelationalGraphConv(MessagePassingBase): """ Relational graph convolution operator from `Modeling Relational Data with Graph Convolutional Networks`_. .. _Modeling Relational Data with Graph Convolutional Networks: https://arxiv.org/pdf/1703.06103.pdf Parameters: input_dim (int): input dimension output_dim (int): output dimension num_relation (int): number of relations edge_input_dim (int, optional): dimension of edge features batch_norm (bool, optional): apply batch normalization on nodes or not activation (str or function, optional): activation function """ eps = 1e-10 def __init__(self, input_dim, output_dim, num_relation, edge_input_dim=None, batch_norm=False, activation="relu"): super(RelationalGraphConv, self).__init__() self.input_dim = input_dim self.output_dim = output_dim self.num_relation = num_relation self.edge_input_dim = edge_input_dim if batch_norm: self.batch_norm = nn.BatchNorm1d(output_dim) else: self.batch_norm = None if isinstance(activation, str): self.activation = getattr(F, activation) else: self.activation = activation self.self_loop = nn.Linear(input_dim, output_dim) self.linear = nn.Linear(num_relation * input_dim, output_dim) if edge_input_dim: self.edge_linear = nn.Linear(edge_input_dim, input_dim) else: self.edge_linear = None
[docs] def message(self, graph, input): node_in = graph.edge_list[:, 0] message = input[node_in] if self.edge_linear: message += self.edge_linear(graph.edge_feature.float()) return message
[docs] def aggregate(self, graph, message): assert graph.num_relation == self.num_relation node_out = graph.edge_list[:, 1] * self.num_relation + graph.edge_list[:, 2] edge_weight = graph.edge_weight.unsqueeze(-1) update = scatter_add(message * edge_weight, node_out, dim=0, dim_size=graph.num_node * self.num_relation) / \ (scatter_add(edge_weight, node_out, dim=0, dim_size=graph.num_node * self.num_relation) + self.eps) return update.view(graph.num_node, self.num_relation * self.input_dim)
[docs] def message_and_aggregate(self, graph, input): assert graph.num_relation == self.num_relation node_in, node_out, relation = graph.edge_list.t() node_out = node_out * self.num_relation + relation degree_out = scatter_add(graph.edge_weight, node_out, dim_size=graph.num_node * graph.num_relation) edge_weight = graph.edge_weight / degree_out[node_out] adjacency = utils.sparse_coo_tensor(torch.stack([node_in, node_out]), edge_weight, (graph.num_node, graph.num_node * graph.num_relation)) update = torch.sparse.mm(adjacency.t(), input) if self.edge_linear: edge_input = graph.edge_feature.float() edge_input = self.edge_linear(edge_input) edge_weight = edge_weight.unsqueeze(-1) edge_update = scatter_add(edge_input * edge_weight, node_out, dim=0, dim_size=graph.num_node * graph.num_relation) update += edge_update return update.view(graph.num_node, self.num_relation * self.input_dim)
[docs] def combine(self, input, update): output = self.linear(update) + self.self_loop(input) if self.batch_norm: output = self.batch_norm(output) if self.activation: output = self.activation(output) return output
[docs]class NeuralFingerprintConv(MessagePassingBase): """ Graph neural network operator from `Convolutional Networks on Graphs for Learning Molecular Fingerprints`_. Note this operator doesn't include the sparsifying step of the original paper. .. _Convolutional Networks on Graphs for Learning Molecular Fingerprints: https://arxiv.org/pdf/1509.09292.pdf Parameters: input_dim (int): input dimension output_dim (int): output dimension edge_input_dim (int, optional): dimension of edge features batch_norm (bool, optional): apply batch normalization on nodes or not activation (str or function, optional): activation function """ def __init__(self, input_dim, output_dim, edge_input_dim=None, batch_norm=False, activation="relu"): super(NeuralFingerprintConv, self).__init__() self.input_dim = input_dim self.output_dim = output_dim self.edge_input_dim = edge_input_dim if batch_norm: self.batch_norm = nn.BatchNorm1d(output_dim) else: self.batch_norm = None if isinstance(activation, str): self.activation = getattr(F, activation) else: self.activation = activation self.linear = nn.Linear(input_dim, output_dim) if edge_input_dim: self.edge_linear = nn.Linear(edge_input_dim, input_dim) else: self.edge_linear = None
[docs] def message(self, graph, input): node_in = graph.edge_list[:, 0] message = input[node_in] if self.edge_linear: message += self.edge_linear(graph.edge_feature.float()) return message
[docs] def aggregate(self, graph, message): node_out = graph.edge_list[:, 1] edge_weight = graph.edge_weight.unsqueeze(-1) update = scatter_add(message * edge_weight, node_out, dim=0, dim_size=graph.num_node) return update
[docs] def message_and_aggregate(self, graph, input): adjacency = utils.sparse_coo_tensor(graph.edge_list.t()[:2], graph.edge_weight, (graph.num_node, graph.num_node)) update = torch.sparse.mm(adjacency.t(), input) if self.edge_linear: edge_input = graph.edge_feature.float() edge_weight = graph.edge_weight.unsqueeze(-1) edge_input = self.edge_linear(edge_input) edge_update = scatter_add(edge_input * edge_weight, graph.edge_list[:, 1], dim=0, dim_size=graph.num_node) update += edge_update return update
[docs] def combine(self, input, update): output = self.linear(input + update) if self.batch_norm: output = self.batch_norm(output) if self.activation: output = self.activation(output) return output
[docs]class ContinuousFilterConv(MessagePassingBase): """ Continuous filter operator from `SchNet: A continuous-filter convolutional neural network for modeling quantum interactions`_. .. _SchNet\: A continuous-filter convolutional neural network for modeling quantum interactions: https://arxiv.org/pdf/1706.08566.pdf Parameters: input_dim (int): input dimension output_dim (int): output dimension edge_input_dim (int, optional): dimension of edge features hidden_dim (int, optional): hidden dimension. By default, same as :attr:`output_dim` cutoff (float, optional): maximal scale for RBF kernels num_gaussian (int, optional): number of RBF kernels batch_norm (bool, optional): apply batch normalization on nodes or not activation (str or function, optional): activation function """ def __init__(self, input_dim, output_dim, edge_input_dim=None, hidden_dim=None, cutoff=5, num_gaussian=100, batch_norm=False, activation="shifted_softplus"): super(ContinuousFilterConv, self).__init__() self.input_dim = input_dim self.output_dim = output_dim self.edge_input_dim = edge_input_dim if hidden_dim is None: hidden_dim = output_dim self.hidden_dim = hidden_dim self.rbf = layers.RBF(stop=cutoff, num_kernel=num_gaussian) if batch_norm: self.batch_norm = nn.BatchNorm1d(output_dim) else: self.batch_norm = None if activation == "shifted_softplus": self.activation = functional.shifted_softplus elif isinstance(activation, str): self.activation = getattr(F, activation) else: self.activation = activation self.input_layer = nn.Linear(input_dim, hidden_dim) self.rbf_layer = nn.Linear(num_gaussian, hidden_dim) self.output_layer = nn.Linear(hidden_dim, output_dim) if edge_input_dim: self.edge_linear = nn.Linear(edge_input_dim, hidden_dim) else: self.edge_linear = None
[docs] def message(self, graph, input): node_in, node_out = graph.edge_list.t()[:2] position = graph.node_position message = self.input_layer(input)[node_in] if self.edge_linear: message += self.edge_linear(graph.edge_feature.float()) weight = self.rbf_layer(self.rbf(position[node_in], position[node_out])) message *= weight return message
[docs] def aggregate(self, graph, message): node_out = graph.edge_list[:, 1] edge_weight = graph.edge_weight.unsqueeze(-1) update = scatter_add(message * edge_weight, node_out, dim=0, dim_size=graph.num_node) return update
[docs] def message_and_aggregate(self, graph, input): node_in, node_out = graph.edge_list.t()[:2] position = graph.node_position rbf_weight = self.rbf_layer(self.rbf(position[node_in], position[node_out])) indices = torch.stack([node_out, node_in, torch.arange(graph.num_edge, device=graph.device)]) adjacency = utils.sparse_coo_tensor(indices, graph.edge_weight, (graph.num_node, graph.num_node, graph.num_edge)) update = functional.generalized_rspmm(adjacency, rbf_weight, self.input_layer(input)) if self.edge_linear: edge_input = graph.edge_feature.float() edge_input = self.edge_linear(edge_input) edge_weight = graph.edge_weight.unsqueeze(-1) * rbf_weight edge_update = scatter_add(edge_input * edge_weight, graph.edge_list[:, 1], dim=0, dim_size=graph.num_node) update += edge_update return update
[docs] def combine(self, input, update): output = self.output_layer(update) if self.batch_norm: output = self.batch_norm(output) if self.activation: output = self.activation(output) return output
[docs]class MessagePassing(MessagePassingBase): """ Message passing operator from `Neural Message Passing for Quantum Chemistry`_. This implements the edge network variant in the original paper. .. _Neural Message Passing for Quantum Chemistry: https://arxiv.org/pdf/1704.01212.pdf Parameters: input_dim (int): input dimension edge_input_dim (int): dimension of edge features hidden_dims (list of int, optional): hidden dims of edge network batch_norm (bool, optional): apply batch normalization on nodes or not activation (str or function, optional): activation function """ def __init__(self, input_dim, edge_input_dim, hidden_dims=None, batch_norm=False, activation="relu"): super(MessagePassing, self).__init__() self.input_dim = input_dim self.output_dim = input_dim self.edge_input_dim = edge_input_dim if hidden_dims is None: hidden_dims = [] if batch_norm: self.batch_norm = nn.BatchNorm1d(input_dim) else: self.batch_norm = None if isinstance(activation, str): self.activation = getattr(F, activation) else: self.activation = activation self.edge_mlp = layers.MLP(edge_input_dim, list(hidden_dims) + [input_dim * input_dim], activation)
[docs] def message(self, graph, input): node_in = graph.edge_list[:, 0] transform = self.edge_mlp(graph.edge_feature.float()).view(-1, self.input_dim, self.input_dim) if graph.num_edge: message = torch.einsum("bed, bd -> be", transform, input[node_in]) else: message = torch.zeros(0, self.input_dim, device=graph.device) return message
[docs] def aggregate(self, graph, message): node_out = graph.edge_list[:, 1] edge_weight = graph.edge_weight.unsqueeze(-1) update = scatter_add(message * edge_weight, node_out, dim=0, dim_size=graph.num_node) return update
[docs] def combine(self, input, update): output = update if self.batch_norm: output = self.batch_norm(output) if self.activation: output = self.activation(output) return output
[docs]class ChebyshevConv(MessagePassingBase): """ Chebyshev spectral graph convolution operator from `Convolutional Neural Networks on Graphs with Fast Localized Spectral Filtering`_. .. _Convolutional Neural Networks on Graphs with Fast Localized Spectral Filtering: https://arxiv.org/pdf/1606.09375.pdf Parameters: input_dim (int): input dimension output_dim (int): output dimension edge_input_dim (int, optional): dimension of edge features k (int, optional): number of Chebyshev polynomials. This also corresponds to the radius of the receptive field. hidden_dims (list of int, optional): hidden dims of edge network batch_norm (bool, optional): apply batch normalization on nodes or not activation (str or function, optional): activation function """ def __init__(self, input_dim, output_dim, edge_input_dim=None, k=1, batch_norm=False, activation="relu"): super(ChebyshevConv, self).__init__() self.input_dim = input_dim self.output_dim = output_dim self.k = k self.edge_input_dim = edge_input_dim if batch_norm: self.batch_norm = nn.BatchNorm1d(output_dim) else: self.batch_norm = None if isinstance(activation, str): self.activation = getattr(F, activation) else: self.activation = activation self.linear = nn.Linear((k + 1) * input_dim, output_dim) if edge_input_dim: self.edge_linear = nn.Linear(edge_input_dim, input_dim) else: self.edge_linear = None
[docs] def message(self, graph, input): node_in = graph.edge_list[:, 0] degree_in = graph.degree_in.unsqueeze(-1) # because self-loop messages have a different scale, they are processed in combine() message = input[node_in] if self.edge_linear: message += self.edge_linear(graph.edge_feature.float()) message /= (degree_in[node_in].sqrt() + 1e-10) return message
[docs] def aggregate(self, graph, message): node_out = graph.edge_list[:, 1] edge_weight = graph.edge_weight.unsqueeze(-1) degree_out = graph.degree_out.unsqueeze(-1) # because self-loop messages have a different scale, they are processed in combine() update = -scatter_add(message * edge_weight, node_out, dim=0, dim_size=graph.num_node) update = update / (degree_out.sqrt() + 1e-10) return update
[docs] def message_and_aggregate(self, graph, input): node_in, node_out = graph.edge_list.t()[:2] edge_weight = -graph.edge_weight / ((graph.degree_in[node_in] * graph.degree_out[node_out]).sqrt() + 1e-10) adjacency = utils.sparse_coo_tensor(graph.edge_list.t()[:2], edge_weight, (graph.num_node, graph.num_node)) update = torch.sparse.mm(adjacency.t(), input) if self.edge_linear: edge_input = graph.edge_feature.float() edge_input = self.edge_linear(edge_input) edge_weight = edge_weight.unsqueeze(-1) edge_update = scatter_add(edge_input * edge_weight, graph.edge_list[:, 1], dim=0, dim_size=graph.num_node) update += edge_update return update
[docs] def forward(self, graph, input): # Chebyshev polynomial bases bases = [input] for i in range(self.k): x = super(ChebyshevConv, self).forward(graph, bases[-1]) if i > 0: x = 2 * x - bases[-2] bases.append(x) bases = torch.cat(bases, dim=-1) output = self.linear(bases) if self.batch_norm: x = self.batch_norm(output) if self.activation: output = self.activation(output) return output
[docs] def combine(self, input, update): output = input + update return output
[docs]class GeometricRelationalGraphConv(RelationalGraphConv): """ Geometry-aware relational graph convolution operator from `Protein Representation Learning by Geometric Structure Pretraining`_. .. _Protein Representation Learning by Geometric Structure Pretraining: https://arxiv.org/pdf/2203.06125.pdf Parameters: input_dim (int): input dimension output_dim (int): output dimension num_relation (int): number of relations edge_input_dim (int, optional): dimension of edge features batch_norm (bool, optional): apply batch normalization on nodes or not activation (str or function, optional): activation function """ def __init__(self, input_dim, output_dim, num_relation, edge_input_dim=None, batch_norm=False, activation="relu"): super(GeometricRelationalGraphConv, self).__init__(input_dim, output_dim, num_relation, edge_input_dim, batch_norm, activation)
[docs] def aggregate(self, graph, message): assert graph.num_relation == self.num_relation node_out = graph.edge_list[:, 1] * self.num_relation + graph.edge_list[:, 2] edge_weight = graph.edge_weight.unsqueeze(-1) update = scatter_add(message * edge_weight, node_out, dim=0, dim_size=graph.num_node * self.num_relation) update = update.view(graph.num_node, self.num_relation * self.input_dim) return update
[docs] def message_and_aggregate(self, graph, input): assert graph.num_relation == self.num_relation node_in, node_out, relation = graph.edge_list.t() node_out = node_out * self.num_relation + relation adjacency = utils.sparse_coo_tensor(torch.stack([node_in, node_out]), graph.edge_weight, (graph.num_node, graph.num_node * graph.num_relation)) update = torch.sparse.mm(adjacency.t(), input) if self.edge_linear: edge_input = graph.edge_feature.float() edge_input = self.edge_linear(edge_input) edge_weight = graph.edge_weight.unsqueeze(-1) edge_update = scatter_add(edge_input * edge_weight, node_out, dim=0, dim_size=graph.num_node * graph.num_relation) update += edge_update return update.view(graph.num_node, self.num_relation * self.input_dim)