torchdrug.layers#

Common Layers#

GaussianSmearing#

class GaussianSmearing(start=0, stop=5, num_kernel=100, centered=False, learnable=False)[source]#

Gaussian smearing from SchNet: A continuous-filter convolutional neural network for modeling quantum interactions.

There are two modes for Gaussian smearing.

Non-centered mode:

\[\mu = [0, 1, ..., n], \sigma = [1, 1, ..., 1]\]

Centered mode:

\[\mu = [0, 0, ..., 0], \sigma = [0, 1, ..., n]\]
Parameters
  • start (int, optional) – minimal input value

  • stop (int, optional) – maximal input value

  • num_kernel (int, optional) – number of RBF kernels

  • centered (bool, optional) – centered mode or not

  • learnable (bool, optional) – learnable gaussian parameters or not

forward(x, y)[source]#

Compute smeared gaussian features between data.

Parameters
  • x (Tensor) – data of shape \((..., d)\)

  • y (Tensor) – data of shape \((..., d)\)

Returns

features of shape \((..., num\_kernel)\)

Return type

Tensor

MultiLayerPerceptron#

class MultiLayerPerceptron(input_dim, hidden_dims, short_cut=False, batch_norm=False, activation='relu', dropout=0)[source]#

Multi-layer Perceptron. Note there is no batch normalization, activation or dropout in the last layer.

Parameters
  • input_dim (int) – input dimension

  • hidden_dim (list of int) – hidden dimensions

  • short_cut (bool, optional) – use short cut or not

  • batch_norm (bool, optional) – apply batch normalization or not

  • activation (str or function, optional) – activation function

  • dropout (float, optional) – dropout rate

MutualInformation#

class MutualInformation(input_dim, num_mlp_layer=2, activation='relu')[source]#

Mutual information estimator from Learning deep representations by mutual information estimation and maximization.

Parameters
  • input_dim (int) – input dimension

  • num_mlp_layer (int, optional) – number of MLP layers

  • activation (str or function, optional) – activation function

SinusoidalPositionEmbedding#

class SinusoidalPositionEmbedding(output_dim)[source]#

Positional embedding based on sine and cosine functions, proposed in Attention Is All You Need.

Parameters

output_dim (int) – output dimension

PairNorm#

class PairNorm(scale_individual=False)[source]#

Pair normalization layer proposed in PairNorm: Tackling Oversmoothing in GNNs.

Parameters

scale_individual (bool, optional) – additionally normalize each node representation to have the same L2-norm

Sequential#

class Sequential(*args, global_args=None, allow_unused=False)[source]#

Improved sequential container. Modules will be called in the order they are passed to the constructor.

Compared to the vanilla nn.Sequential, this layer additionally supports the following features.

  1. Multiple input / output arguments.

>>> # layer1 signature: (...) -> (a, b)
>>> # layer2 signature: (a, b) -> (...)
>>> layer = layers.Sequential(layer1, layer2)
  1. Global arguments.

>>> # layer1 signature: (graph, a) -> b
>>> # layer2 signature: (graph, b) -> c
>>> layer = layers.Sequential(layer1, layer2, global_args=("graph",))

Note the global arguments don’t need to be present in every layer.

>>> # layer1 signature: (graph, a) -> b
>>> # layer2 signature: b -> c
>>> # layer3 signature: (graph, c) -> d
>>> layer = layers.Sequential(layer1, layer2, global_args=("graph",))
  1. Dict outputs.

>>> # layer1 signature: a -> {"b": b, "c": c}
>>> # layer2 signature: b -> d
>>> layer = layers.Sequential(layer1, layer2, allow_unused=True)

When dict outputs are used with global arguments, the global arguments can be explicitly overwritten by any layer outputs.

>>> # layer1 signature: (graph, a) -> {"graph": graph, "b": b}
>>> # layer2 signature: (graph, b) -> c
>>> # layer2 takes in the graph output by layer1
>>> layer = layers.Sequential(layer1, layer2, global_args=("graph",))

Convolution Layers#

class MessagePassingBase[source]#

Base module for message passing.

Any custom message passing module should be derived from this class.

aggregate(graph, message)[source]#

Aggregate edge messages to nodes.

Parameters
  • graph (Graph) – graph(s)

  • message (Tensor) – edge messages of shape \((|E|, ...)\)

Returns

node updates of shape \((|V|, ...)\)

Return type

Tensor

combine(input, update)[source]#

Combine node input and node update.

Parameters
  • input (Tensor) – node representations of shape \((|V|, ...)\)

  • update (Tensor) – node updates of shape \((|V|, ...)\)

forward(graph, input)[source]#

Perform message passing over the graph(s).

Parameters
  • graph (Graph) – graph(s)

  • input (Tensor) – node representations of shape \((|V|, ...)\)

message(graph, input)[source]#

Compute edge messages for the graph.

Parameters
  • graph (Graph) – graph(s)

  • input (Tensor) – node representations of shape \((|V|, ...)\)

Returns

edge messages of shape \((|E|, ...)\)

Return type

Tensor

message_and_aggregate(graph, input)[source]#

Fused computation of message and aggregation over the graph. This may provide better time or memory complexity than separate calls of message and aggregate.

Parameters
  • graph (Graph) – graph(s)

  • input (Tensor) – node representations of shape \((|V|, ...)\)

Returns

node updates of shape \((|V|, ...)\)

Return type

Tensor

ChebyshevConv#

class ChebyshevConv(input_dim, output_dim, edge_input_dim=None, k=1, batch_norm=False, activation='relu')[source]#

Chebyshev spectral graph convolution operator from Convolutional Neural Networks on Graphs with Fast Localized Spectral Filtering.

Parameters
  • input_dim (int) – input dimension

  • output_dim (int) – output dimension

  • edge_input_dim (int, optional) – dimension of edge features

  • k (int, optional) – number of Chebyshev polynomials. This also corresponds to the radius of the receptive field.

  • hidden_dims (list of int, optional) – hidden dims of edge network

  • batch_norm (bool, optional) – apply batch normalization on nodes or not

  • activation (str or function, optional) – activation function

aggregate(graph, message)[source]#

Aggregate edge messages to nodes.

Parameters
  • graph (Graph) – graph(s)

  • message (Tensor) – edge messages of shape \((|E|, ...)\)

Returns

node updates of shape \((|V|, ...)\)

Return type

Tensor

combine(input, update)[source]#

Combine node input and node update.

Parameters
  • input (Tensor) – node representations of shape \((|V|, ...)\)

  • update (Tensor) – node updates of shape \((|V|, ...)\)

forward(graph, input)[source]#

Perform message passing over the graph(s).

Parameters
  • graph (Graph) – graph(s)

  • input (Tensor) – node representations of shape \((|V|, ...)\)

message(graph, input)[source]#

Compute edge messages for the graph.

Parameters
  • graph (Graph) – graph(s)

  • input (Tensor) – node representations of shape \((|V|, ...)\)

Returns

edge messages of shape \((|E|, ...)\)

Return type

Tensor

message_and_aggregate(graph, input)[source]#

Fused computation of message and aggregation over the graph. This may provide better time or memory complexity than separate calls of message and aggregate.

Parameters
  • graph (Graph) – graph(s)

  • input (Tensor) – node representations of shape \((|V|, ...)\)

Returns

node updates of shape \((|V|, ...)\)

Return type

Tensor

ContinuousFilterConv#

class ContinuousFilterConv(input_dim, output_dim, edge_input_dim=None, hidden_dim=None, cutoff=5, num_gaussian=100, batch_norm=False, activation='shifted_softplus')[source]#

Continuous filter operator from SchNet: A continuous-filter convolutional neural network for modeling quantum interactions.

Parameters
  • input_dim (int) – input dimension

  • output_dim (int) – output dimension

  • edge_input_dim (int, optional) – dimension of edge features

  • hidden_dim (int, optional) – hidden dimension. By default, same as output_dim

  • cutoff (float, optional) – maximal scale for RBF kernels

  • num_gaussian (int, optional) – number of RBF kernels

  • batch_norm (bool, optional) – apply batch normalization on nodes or not

  • activation (str or function, optional) – activation function

aggregate(graph, message)[source]#

Aggregate edge messages to nodes.

Parameters
  • graph (Graph) – graph(s)

  • message (Tensor) – edge messages of shape \((|E|, ...)\)

Returns

node updates of shape \((|V|, ...)\)

Return type

Tensor

combine(input, update)[source]#

Combine node input and node update.

Parameters
  • input (Tensor) – node representations of shape \((|V|, ...)\)

  • update (Tensor) – node updates of shape \((|V|, ...)\)

message(graph, input)[source]#

Compute edge messages for the graph.

Parameters
  • graph (Graph) – graph(s)

  • input (Tensor) – node representations of shape \((|V|, ...)\)

Returns

edge messages of shape \((|E|, ...)\)

Return type

Tensor

message_and_aggregate(graph, input)[source]#

Fused computation of message and aggregation over the graph. This may provide better time or memory complexity than separate calls of message and aggregate.

Parameters
  • graph (Graph) – graph(s)

  • input (Tensor) – node representations of shape \((|V|, ...)\)

Returns

node updates of shape \((|V|, ...)\)

Return type

Tensor

GeometricRelationalGraphConv#

class GeometricRelationalGraphConv(input_dim, output_dim, num_relation, edge_input_dim=None, batch_norm=False, activation='relu')[source]#

Geometry-aware relational graph convolution operator from Protein Representation Learning by Geometric Structure Pretraining.

Parameters
  • input_dim (int) – input dimension

  • output_dim (int) – output dimension

  • num_relation (int) – number of relations

  • edge_input_dim (int, optional) – dimension of edge features

  • batch_norm (bool, optional) – apply batch normalization on nodes or not

  • activation (str or function, optional) – activation function

aggregate(graph, message)[source]#

Aggregate edge messages to nodes.

Parameters
  • graph (Graph) – graph(s)

  • message (Tensor) – edge messages of shape \((|E|, ...)\)

Returns

node updates of shape \((|V|, ...)\)

Return type

Tensor

message_and_aggregate(graph, input)[source]#

Fused computation of message and aggregation over the graph. This may provide better time or memory complexity than separate calls of message and aggregate.

Parameters
  • graph (Graph) – graph(s)

  • input (Tensor) – node representations of shape \((|V|, ...)\)

Returns

node updates of shape \((|V|, ...)\)

Return type

Tensor

GraphAttentionConv#

class GraphAttentionConv(input_dim, output_dim, edge_input_dim=None, num_head=1, negative_slope=0.2, concat=True, batch_norm=False, activation='relu')[source]#

Graph attentional convolution operator from Graph Attention Networks.

Parameters
  • input_dim (int) – input dimension

  • output_dim (int) – output dimension

  • edge_input_dim (int, optional) – dimension of edge features

  • num_head (int, optional) – number of attention heads

  • negative_slope (float, optional) – negative slope of leaky relu activation

  • batch_norm (bool, optional) – apply batch normalization on nodes or not

  • activation (str or function, optional) – activation function

aggregate(graph, message)[source]#

Aggregate edge messages to nodes.

Parameters
  • graph (Graph) – graph(s)

  • message (Tensor) – edge messages of shape \((|E|, ...)\)

Returns

node updates of shape \((|V|, ...)\)

Return type

Tensor

combine(input, update)[source]#

Combine node input and node update.

Parameters
  • input (Tensor) – node representations of shape \((|V|, ...)\)

  • update (Tensor) – node updates of shape \((|V|, ...)\)

message(graph, input)[source]#

Compute edge messages for the graph.

Parameters
  • graph (Graph) – graph(s)

  • input (Tensor) – node representations of shape \((|V|, ...)\)

Returns

edge messages of shape \((|E|, ...)\)

Return type

Tensor

GraphConv#

class GraphConv(input_dim, output_dim, edge_input_dim=None, batch_norm=False, activation='relu')[source]#

Graph convolution operator from Semi-Supervised Classification with Graph Convolutional Networks.

Parameters
  • input_dim (int) – input dimension

  • output_dim (int) – output dimension

  • edge_input_dim (int, optional) – dimension of edge features

  • batch_norm (bool, optional) – apply batch normalization on nodes or not

  • activation (str or function, optional) – activation function

aggregate(graph, message)[source]#

Aggregate edge messages to nodes.

Parameters
  • graph (Graph) – graph(s)

  • message (Tensor) – edge messages of shape \((|E|, ...)\)

Returns

node updates of shape \((|V|, ...)\)

Return type

Tensor

combine(input, update)[source]#

Combine node input and node update.

Parameters
  • input (Tensor) – node representations of shape \((|V|, ...)\)

  • update (Tensor) – node updates of shape \((|V|, ...)\)

message(graph, input)[source]#

Compute edge messages for the graph.

Parameters
  • graph (Graph) – graph(s)

  • input (Tensor) – node representations of shape \((|V|, ...)\)

Returns

edge messages of shape \((|E|, ...)\)

Return type

Tensor

message_and_aggregate(graph, input)[source]#

Fused computation of message and aggregation over the graph. This may provide better time or memory complexity than separate calls of message and aggregate.

Parameters
  • graph (Graph) – graph(s)

  • input (Tensor) – node representations of shape \((|V|, ...)\)

Returns

node updates of shape \((|V|, ...)\)

Return type

Tensor

GraphIsomorphismConv#

class GraphIsomorphismConv(input_dim, output_dim, edge_input_dim=None, hidden_dims=None, eps=0, learn_eps=False, batch_norm=False, activation='relu')[source]#

Graph isomorphism convolution operator from How Powerful are Graph Neural Networks?

Parameters
  • input_dim (int) – input dimension

  • output_dim (int) – output dimension

  • edge_input_dim (int, optional) – dimension of edge features

  • hidden_dims (list of int, optional) – hidden dimensions

  • eps (float, optional) – initial epsilon

  • learn_eps (bool, optional) – learn epsilon or not

  • batch_norm (bool, optional) – apply batch normalization on nodes or not

  • activation (str or function, optional) – activation function

aggregate(graph, message)[source]#

Aggregate edge messages to nodes.

Parameters
  • graph (Graph) – graph(s)

  • message (Tensor) – edge messages of shape \((|E|, ...)\)

Returns

node updates of shape \((|V|, ...)\)

Return type

Tensor

combine(input, update)[source]#

Combine node input and node update.

Parameters
  • input (Tensor) – node representations of shape \((|V|, ...)\)

  • update (Tensor) – node updates of shape \((|V|, ...)\)

message(graph, input)[source]#

Compute edge messages for the graph.

Parameters
  • graph (Graph) – graph(s)

  • input (Tensor) – node representations of shape \((|V|, ...)\)

Returns

edge messages of shape \((|E|, ...)\)

Return type

Tensor

message_and_aggregate(graph, input)[source]#

Fused computation of message and aggregation over the graph. This may provide better time or memory complexity than separate calls of message and aggregate.

Parameters
  • graph (Graph) – graph(s)

  • input (Tensor) – node representations of shape \((|V|, ...)\)

Returns

node updates of shape \((|V|, ...)\)

Return type

Tensor

MessagePassing#

class MessagePassing(input_dim, edge_input_dim, hidden_dims=None, batch_norm=False, activation='relu')[source]#

Message passing operator from Neural Message Passing for Quantum Chemistry.

This implements the edge network variant in the original paper.

Parameters
  • input_dim (int) – input dimension

  • edge_input_dim (int) – dimension of edge features

  • hidden_dims (list of int, optional) – hidden dims of edge network

  • batch_norm (bool, optional) – apply batch normalization on nodes or not

  • activation (str or function, optional) – activation function

aggregate(graph, message)[source]#

Aggregate edge messages to nodes.

Parameters
  • graph (Graph) – graph(s)

  • message (Tensor) – edge messages of shape \((|E|, ...)\)

Returns

node updates of shape \((|V|, ...)\)

Return type

Tensor

combine(input, update)[source]#

Combine node input and node update.

Parameters
  • input (Tensor) – node representations of shape \((|V|, ...)\)

  • update (Tensor) – node updates of shape \((|V|, ...)\)

message(graph, input)[source]#

Compute edge messages for the graph.

Parameters
  • graph (Graph) – graph(s)

  • input (Tensor) – node representations of shape \((|V|, ...)\)

Returns

edge messages of shape \((|E|, ...)\)

Return type

Tensor

NeuralFingerprintConv#

class NeuralFingerprintConv(input_dim, output_dim, edge_input_dim=None, batch_norm=False, activation='relu')[source]#

Graph neural network operator from Convolutional Networks on Graphs for Learning Molecular Fingerprints.

Note this operator doesn’t include the sparsifying step of the original paper.

Parameters
  • input_dim (int) – input dimension

  • output_dim (int) – output dimension

  • edge_input_dim (int, optional) – dimension of edge features

  • batch_norm (bool, optional) – apply batch normalization on nodes or not

  • activation (str or function, optional) – activation function

aggregate(graph, message)[source]#

Aggregate edge messages to nodes.

Parameters
  • graph (Graph) – graph(s)

  • message (Tensor) – edge messages of shape \((|E|, ...)\)

Returns

node updates of shape \((|V|, ...)\)

Return type

Tensor

combine(input, update)[source]#

Combine node input and node update.

Parameters
  • input (Tensor) – node representations of shape \((|V|, ...)\)

  • update (Tensor) – node updates of shape \((|V|, ...)\)

message(graph, input)[source]#

Compute edge messages for the graph.

Parameters
  • graph (Graph) – graph(s)

  • input (Tensor) – node representations of shape \((|V|, ...)\)

Returns

edge messages of shape \((|E|, ...)\)

Return type

Tensor

message_and_aggregate(graph, input)[source]#

Fused computation of message and aggregation over the graph. This may provide better time or memory complexity than separate calls of message and aggregate.

Parameters
  • graph (Graph) – graph(s)

  • input (Tensor) – node representations of shape \((|V|, ...)\)

Returns

node updates of shape \((|V|, ...)\)

Return type

Tensor

RelationalGraphConv#

class RelationalGraphConv(input_dim, output_dim, num_relation, edge_input_dim=None, batch_norm=False, activation='relu')[source]#

Relational graph convolution operator from Modeling Relational Data with Graph Convolutional Networks.

Parameters
  • input_dim (int) – input dimension

  • output_dim (int) – output dimension

  • num_relation (int) – number of relations

  • edge_input_dim (int, optional) – dimension of edge features

  • batch_norm (bool, optional) – apply batch normalization on nodes or not

  • activation (str or function, optional) – activation function

aggregate(graph, message)[source]#

Aggregate edge messages to nodes.

Parameters
  • graph (Graph) – graph(s)

  • message (Tensor) – edge messages of shape \((|E|, ...)\)

Returns

node updates of shape \((|V|, ...)\)

Return type

Tensor

combine(input, update)[source]#

Combine node input and node update.

Parameters
  • input (Tensor) – node representations of shape \((|V|, ...)\)

  • update (Tensor) – node updates of shape \((|V|, ...)\)

message(graph, input)[source]#

Compute edge messages for the graph.

Parameters
  • graph (Graph) – graph(s)

  • input (Tensor) – node representations of shape \((|V|, ...)\)

Returns

edge messages of shape \((|E|, ...)\)

Return type

Tensor

message_and_aggregate(graph, input)[source]#

Fused computation of message and aggregation over the graph. This may provide better time or memory complexity than separate calls of message and aggregate.

Parameters
  • graph (Graph) – graph(s)

  • input (Tensor) – node representations of shape \((|V|, ...)\)

Returns

node updates of shape \((|V|, ...)\)

Return type

Tensor

Readout Layers#

MeanReadout#

class MeanReadout(type='node')[source]#

Mean readout operator over graphs with variadic sizes.

forward(graph, input)[source]#

Perform readout over the graph(s).

Parameters
  • graph (Graph) – graph(s)

  • input (Tensor) – node representations

Returns

graph representations

Return type

Tensor

SumReadout#

class SumReadout(type='node')[source]#

Sum readout operator over graphs with variadic sizes.

forward(graph, input)[source]#

Perform readout over the graph(s).

Parameters
  • graph (Graph) – graph(s)

  • input (Tensor) – node representations

Returns

graph representations

Return type

Tensor

MaxReadout#

class MaxReadout(type='node')[source]#

Max readout operator over graphs with variadic sizes.

forward(graph, input)[source]#

Perform readout over the graph(s).

Parameters
  • graph (Graph) – graph(s)

  • input (Tensor) – node representations

Returns

graph representations

Return type

Tensor

Set2Set#

class Set2Set(input_dim, type='node', num_step=3, num_lstm_layer=1)[source]#

Set2Set operator from Order Matters: Sequence to sequence for sets.

Parameters
  • input_dim (int) – input dimension

  • num_step (int, optional) – number of process steps

  • num_lstm_layer (int, optional) – number of LSTM layers

forward(graph, input)[source]#

Perform Set2Set readout over graph(s).

Parameters
  • graph (Graph) – graph(s)

  • input (Tensor) – node representations

Returns

graph representations

Return type

Tensor

Softmax#

class Softmax(type='node')[source]#

Softmax operator over graphs with variadic sizes.

forward(graph, input)[source]#

Perform softmax over the graph(s).

Parameters
  • graph (Graph) – graph(s)

  • input (Tensor) – node logits

Returns

node probabilities

Return type

Tensor

Sort#

class Sort(type='node', descending=False)[source]#

Sort operator over graphs with variadic sizes.

Parameters

descending (bool, optional) – use descending sort order or not

forward(graph, input)[source]#

Perform sort over graph(s).

Parameters
  • graph (Graph) – graph(s)

  • input (Tensor) – node values

Returns

sorted values, sorted indices

Return type

(Tensor, LongTensor)

Pooling Layers#

DiffPool#

class DiffPool(input_dim, output_node, feature_layer=None, pool_layer=None, loss_weight=1, zero_diagonal=False, sparse=False)[source]#

Differentiable pooling operator from Hierarchical Graph Representation Learning with Differentiable Pooling

Parameter

input_dim (int): input dimension output_node (int): number of nodes after pooling feature_layer (Module, optional): graph convolution layer for embedding pool_layer (Module, optional): graph convolution layer for pooling assignment loss_weight (float, optional): weight of entropy regularization zero_diagonal (bool, optional): remove self loops in the pooled graph or not sparse (bool, optional): use sparse assignment or not

forward(graph, input, all_loss=None, metric=None)[source]#

Compute the node cluster assignment and pool the nodes.

Parameters
  • graph (Graph) – graph(s)

  • input (Tensor) – input node representations

  • all_loss (Tensor, optional) – if specified, add loss to this tensor

  • metric (dict, optional) – if specified, output metrics to this dict

Returns

pooled graph, output node representations, node-to-cluster assignment

Return type

(PackedGraph, Tensor, Tensor)

MinCutPool#

class MinCutPool(input_dim, output_node, feature_layer=None, pool_layer=None, loss_weight=1, zero_diagonal=True, sparse=False)[source]#

Min cut pooling operator from Spectral Clustering with Graph Neural Networks for Graph Pooling

Parameters
  • input_dim (int) – input dimension

  • output_node (int) – number of nodes after pooling

  • feature_layer (Module, optional) – graph convolution layer for embedding

  • pool_layer (Module, optional) – graph convolution layer for pooling assignment

  • loss_weight (float, optional) – weight of entropy regularization

  • zero_diagonal (bool, optional) – remove self loops in the pooled graph or not

  • sparse (bool, optional) – use sparse assignment or not

forward(graph, input, all_loss=None, metric=None)[source]#

Compute the node cluster assignment and pool the nodes.

Parameters
  • graph (Graph) – graph(s)

  • input (Tensor) – input node representations

  • all_loss (Tensor, optional) – if specified, add loss to this tensor

  • metric (dict, optional) – if specified, output metrics to this dict

Returns

pooled graph, output node representations, node-to-cluster assignment

Return type

(PackedGraph, Tensor, Tensor)

Sampler Layers#

EdgeSampler#

class EdgeSampler(budget=None, ratio=None)[source]#

Edge sampler from GraphSAINT: Graph Sampling Based Inductive Learning Method.

Parameters
  • budget (int, optional) – number of node to keep

  • ratio (int, optional) – ratio of node to keep

forward(graph)[source]#

Sample a subgraph from the graph.

Parameters

graph (Graph) – graph(s)

NodeSampler#

class NodeSampler(budget=None, ratio=None)[source]#

Node sampler from GraphSAINT: Graph Sampling Based Inductive Learning Method.

Parameters
  • budget (int, optional) – number of node to keep

  • ratio (int, optional) – ratio of node to keep

forward(graph)[source]#

Sample a subgraph from the graph.

Parameters

graph (Graph) – graph(s)

Flow Layers#

ConditionalFlow#

class ConditionalFlow(input_dim, condition_dim, hidden_dims=None, activation='relu')[source]#

Conditional flow transformation from Masked Autoregressive Flow for Density Estimation.

Parameters
  • input_dim (int) – input & output dimension

  • condition_dim (int) – condition dimension

  • hidden_dims (list of int, optional) – hidden dimensions

  • activation (str or function, optional) – activation function

forward(input, condition)[source]#

Transform data into latent representations.

Parameters
  • input (Tensor) – input representations

  • condition (Tensor) – conditional representations

Returns

latent representations, log-likelihood of the transformation

Return type

(Tensor, Tensor)

reverse(latent, condition)[source]#

Transform latent representations into data.

Parameters
  • latent (Tensor) – latent representations

  • condition (Tensor) – conditional representations

Returns

input representations, log-likelihood of the transformation

Return type

(Tensor, Tensor)

Sequence Encoder Blocks#

ProteinResNetBlock#

class ProteinResNetBlock(input_dim, output_dim, kernel_size=3, stride=1, padding=1, activation='gelu')[source]#

Convolutional block with residual connection from Deep Residual Learning for Image Recognition.

Parameters
  • input_dim (int) – input dimension

  • output_dim (int) – output dimension

  • kernel_size (int, optional) – size of convolutional kernel

  • stride (int, optional) – stride of convolution

  • padding (int, optional) – padding added to both sides of the input

  • activation (str or function, optional) – activation function

forward(input, mask)[source]#

Perform 1D convolutions over the input.

Parameters
  • input (Tensor) – input representations of shape (…, length, dim)

  • mask (Tensor) – bool mask of shape (…, length, dim)

SelfAttentionBlock#

class SelfAttentionBlock(hidden_dim, num_heads, dropout=0.0)[source]#

Multi-head self-attention block from Attention Is All You Need.

Parameters
  • hidden_dim (int) – hidden dimension

  • num_heads (int) – number of attention heads

  • dropout (float, optional) – dropout ratio of attention maps

forward(input, mask)[source]#

Perform self attention over the input.

Parameters
  • input (Tensor) – input representations of shape (…, length, dim)

  • mask (Tensor) – bool mask of shape (…, length)

ProteinBERTBlock#

class ProteinBERTBlock(input_dim, hidden_dim, num_heads, attention_dropout=0, hidden_dropout=0, activation='relu')[source]#

Transformer encoding block from Attention Is All You Need.

Parameters
  • input_dim (int) – input dimension

  • hidden_dim (int) – hidden dimension

  • num_heads (int) – number of attention heads

  • attention_dropout (float, optional) – dropout ratio of attention maps

  • hidden_dropout (float, optional) – dropout ratio of hidden features

  • activation (str or function, optional) – activation function

forward(input, mask)[source]#

Perform a BERT-block transformation over the input.

Parameters
  • input (Tensor) – input representations of shape (…, length, dim)

  • mask (Tensor) – bool mask of shape (…, length)

Distribution Layers#

These layers belong to torchdrug.layers.distribution.

IndependentGaussian#

class IndependentGaussian(mu, sigma2, learnable=False)[source]#

Independent Gaussian distribution.

Parameters
  • mu (Tensor) – mean of shape \((N,)\)

  • sigma2 (Tensor) – variance of shape \((N,)\)

  • learnable (bool, optional) – learnable parameters or not

forward(input)[source]#

Compute the likelihood of input data.

Parameters

input (Tensor) – input data of shape \((..., N)\)

sample(*size)[source]#

Draw samples from the distribution.

Parameters

size (tuple of int) – shape of the samples

Graph Construction Layers#

These layers belong to torchdrug.layers.geometry.

GraphConstruction#

class GraphConstruction(node_layers=None, edge_layers=None, edge_feature='residue_type')[source]#

Construct a new graph from an existing graph.

See torchdrug.layers.geometry for a full list of available node and edge layers.

Parameters
  • node_layers (list of nn.Module, optional) – modules to construct nodes of the new graph

  • edge_layers (list of nn.Module, optional) – modules to construct edges of the new graph

  • edge_feature (str, optional) –

    edge features in the new graph. Available features are residue_type, gearnet.

    1. For residue_type, the feature of the edge \(e_{ij}\) between residue \(i\) and residue

      \(j\) is the concatenation [residue_type(i), residue_type(j)].

    2. For gearnet, the feature of the edge \(e_{ij}\) between residue \(i\) and residue \(j\)

      is the concatenation [residue_type(i), residue_type(j), edge_type(e_ij), sequential_distance(i,j), spatial_distance(i,j)].

Note

You may customize your own edge features by inheriting this class and define a member function for your features. Use edge_feature="my_feature" to call the following feature function.

def edge_my_feature(self, graph, edge_list, num_relation):
    ...
    return feature # the first dimension must be ``graph.num_edge``
forward(graph)[source]#

Generate a new graph based on the input graph and pre-defined node and edge layers.

Parameters

graph (Graph) – \(n\) graph(s)

Returns

new graph(s)

Return type

graph (Graph)

SpatialLineGraph#

class SpatialLineGraph(num_angle_bin=8)[source]#

Spatial line graph construction module from Protein Representation Learning by Geometric Structure Pretraining.

Parameters

num_angle_bin (int, optional) – number of bins to discretize angles between edges

forward(graph)[source]#

Generate the spatial line graph of the input graph. The edge types are decided by the angles between two adjacent edges in the input graph.

Parameters

graph (PackedGraph) – \(n\) graph(s)

Returns

the spatial line graph

Return type

graph (PackedGraph)

BondEdge#

class BondEdge[source]#

Construct all bond edges.

forward(graph)[source]#

Return bond edges from the input graph. Edge types are inherited from the input graph.

Parameters

graph (Graph) – \(n\) graph(s)

Returns

edge list of shape \((|E|, 3)\), number of relations

Return type

(Tensor, int)

KNNEdge#

class KNNEdge(k=10, min_distance=5, max_distance=None)[source]#

Construct edges between each node and its nearest neighbors.

Parameters
  • k (int, optional) – number of neighbors

  • min_distance (int, optional) – minimum distance between the residues of two nodes

forward(graph)[source]#

Return KNN edges constructed from the input graph.

Parameters

graph (Graph) – \(n\) graph(s)

Returns

edge list of shape \((|E|, 3)\), number of relations

Return type

(Tensor, int)

SpatialEdge#

class SpatialEdge(radius=5, min_distance=5, max_distance=None, max_num_neighbors=32)[source]#

Construct edges between nodes within a specified radius.

Parameters
  • radius (float, optional) – spatial radius

  • min_distance (int, optional) – minimum distance between the residues of two nodes

forward(graph)[source]#

Return spatial radius edges constructed based on the input graph.

Parameters

graph (Graph) – \(n\) graph(s)

Returns

edge list of shape \((|E|, 3)\), number of relations

Return type

(Tensor, int)

SequentialEdge#

class SequentialEdge(max_distance=2, only_backbone=False)[source]#

Construct edges between atoms within close residues.

Parameters

max_distance (int, optional) – maximum distance between two residues in the sequence

forward(graph)[source]#

Return sequential edges constructed based on the input graph. Edge types are defined by the relative distance between two residues in the sequence

Parameters

graph (Graph) – \(n\) graph(s)

Returns

edge list of shape \((|E|, 3)\), number of relations

Return type

(Tensor, int)

AlphaCarbonNode#

class AlphaCarbonNode[source]#

Construct only alpha carbon atoms.

forward(graph)[source]#

Return a subgraph that only consists of alpha carbon nodes.

Parameters

graph (Graph) – \(n\) graph(s)

IdentityNode#

class IdentityNode[source]#

Construct all nodes as the input.

forward(graph)[source]#

Return the input graph as is.

Parameters

graph (Graph) – \(n\) graph(s)

RandomEdgeMask#

class RandomEdgeMask(mask_rate=0.15)[source]#

Construct nodes by random edge masking.

Parameters

mask_rate (float, optional) – rate of masked edges

forward(graph)[source]#

Return a graph with some edges masked out.

Parameters

graph (Graph) – \(n\) graph(s)

SubsequenceNode#

class SubsequenceNode(max_length=100)[source]#

Construct nodes by taking a random subsequence of the original graph.

Parameters

max_length (int, optional) – maximal length of the sequence after cropping

forward(graph)[source]#

Randomly take a subsequence of the specified length. Return the full sequence if the sequence is shorter than the specified length.

Parameters

graph (Graph) – \(n\) graph(s)

SubspaceNode#

class SubspaceNode(entity_level='node', min_radius=15.0, min_neighbor=50)[source]#

Construct nodes by taking a spatial ball of the original graph.

Parameters
  • entity_level (str, optional) – level to perform cropping. Available options are node, atom and residue.

  • min_radius (float, optional) – minimum radius of the spatial ball

  • min_neighbor (int, optional) – minimum number of nodes in the spatial ball

forward(graph)[source]#

Randomly pick a node as the center, and crop a spatial ball that is at least radius large and contain at least k nodes.

Parameters

graph (Graph) – \(n\) graph(s)

Functional Layers#

These layers belong to torchdrug.layers.functional.

Embedding Score Functions#

transe_score(entity, relation, h_index, t_index, r_index)[source]#

TransE score function from Translating Embeddings for Modeling Multi-relational Data.

Parameters
  • entity (Tensor) – entity embeddings of shape \((|V|, d)\)

  • relation (Tensor) – relation embeddings of shape \((|R|, d)\)

  • h_index (LongTensor) – index of head entities

  • t_index (LongTensor) – index of tail entities

  • r_index (LongTensor) – index of relations

distmult_score(entity, relation, h_index, t_index, r_index)[source]#

DistMult score function from Embedding Entities and Relations for Learning and Inference in Knowledge Bases.

Parameters
  • entity (Tensor) – entity embeddings of shape \((|V|, d)\)

  • relation (Tensor) – relation embeddings of shape \((|R|, d)\)

  • h_index (LongTensor) – index of head entities

  • t_index (LongTensor) – index of tail entities

  • r_index (LongTensor) – index of relations

complex_score(entity, relation, h_index, t_index, r_index)[source]#

ComplEx score function from Complex Embeddings for Simple Link Prediction.

Parameters
  • entity (Tensor) – entity embeddings of shape \((|V|, 2d)\)

  • relation (Tensor) – relation embeddings of shape \((|R|, 2d)\)

  • h_index (LongTensor) – index of head entities

  • t_index (LongTensor) – index of tail entities

  • r_index (LongTensor) – index of relations

simple_score(entity, relation, h_index, t_index, r_index)[source]#

SimplE score function from SimplE Embedding for Link Prediction in Knowledge Graphs.

Parameters
  • entity (Tensor) – entity embeddings of shape \((|V|, 2d)\)

  • relation (Tensor) – relation embeddings of shape \((|R|, d)\)

  • h_index (LongTensor) – index of head entities

  • t_index (LongTensor) – index of tail entities

  • r_index (LongTensor) – index of relations

rotate_score(entity, relation, h_index, t_index, r_index)[source]#

RotatE score function from RotatE: Knowledge Graph Embedding by Relational Rotation in Complex Space.

Parameters
  • entity (Tensor) – entity embeddings of shape \((|V|, 2d)\)

  • relation (Tensor) – relation embeddings of shape \((|R|, d)\)

  • h_index (LongTensor) – index of head entities

  • t_index (LongTensor) – index of tail entities

  • r_index (LongTensor) – index of relations

Sparse Matrix Multiplication#

generalized_spmm(sparse, input, sum='add', mul='mul')[source]#

Generalized sparse-dense matrix multiplication.

This function computes the matrix multiplication of a sparse matrix and a dense input matrix. The output dense matrix satisfies

\[output_{i,k} = \bigoplus_{j: sparse_{i,j} \neq 0} sparse_{i,j} \otimes input_{j,k}\]

where \(\oplus\) and \(\otimes\) are the summation and the multiplication operators respectively.

Warning

Gradient w.r.t. the sparse matrix is only computed for non-zero entries of the sparse matrix. This behaves differently from dense-dense matrix multiplication with zero entries.

Parameters
  • sparse (SparseTensor) – 2D sparse tensor

  • input (Tensor) – 2D dense tensor

  • sum (str, optional) – generalized summation operator. Available operators are add, min and max.

  • mul (str, optional) – generalized multiplication operator. Available operators are add and mul.

generalized_rspmm(sparse, relation, input, sum='add', mul='mul')[source]#

Generalized relational sparse-dense matrix multiplication.

This function computes the matrix multiplication of a sparse matrix, a dense relation matrix and a dense input matrix. The output dense matrix satisfies

\[output_{i,l} = \bigoplus_{j,k: sparse_{i,j,k} \neq 0} sparse_{i, j, k} \times (relation_{k,l} \otimes input_{j,l})\]

where \(\oplus\) and \(\otimes\) are the summation and the multiplication operators respectively.

Warning

Gradient w.r.t. the sparse matrix is only computed for non-zero entries of the sparse matrix. This behaves differently from dense-dense matrix multiplication with zero entries.

Parameters
  • sparse (SparseTensor) – 3D sparse tensor

  • relation (Tensor) – 2D dense tensor

  • input (Tensor) – 2D dense tensor

  • sum (str, optional) – generalized summation operator. Available operators are add, min and max.

  • mul (str, optional) – generalized multiplication operator. Available operators are add and mul.

Variadic#

variadic_sum(input, size)[source]#

Compute sum over sets with variadic sizes.

Suppose there are \(N\) sets, and the sizes of all sets are summed to \(B\).

Parameters
  • input (Tensor) – input of shape \((B, ...)\)

  • size (LongTensor) – size of sets of shape \((N,)\)

variadic_mean(input, size)[source]#

Compute mean over sets with variadic sizes.

Suppose there are \(N\) sets, and the sizes of all sets are summed to \(B\).

Parameters
  • input (Tensor) – input of shape \((B, ...)\)

  • size (LongTensor) – size of sets of shape \((N,)\)

variadic_max(input, size)[source]#

Compute max over sets with variadic sizes.

Suppose there are \(N\) sets, and the sizes of all sets are summed to \(B\).

Parameters
  • input (Tensor) – input of shape \((B, ...)\)

  • size (LongTensor) – size of sets of shape \((N,)\)

Returns

(Tensor, LongTensor): max values and indexes

variadic_cross_entropy(input, target, size, reduction='mean')[source]#

Compute cross entropy loss over categories with variadic sizes.

Suppose there are \(N\) samples, and the numbers of categories in all samples are summed to \(B\).

Parameters
  • input (Tensor) – prediction of shape \((B, ...)\)

  • target (Tensor) – target of shape \((N, ...)\). Each target is a relative index in a sample.

  • size (LongTensor) – number of categories of shape \((N,)\)

  • reduction (string, optional) – reduction to apply to the output. Available reductions are none, sum and mean.

variadic_log_softmax(input, size)[source]#

Compute log softmax over categories with variadic sizes.

Suppose there are \(N\) samples, and the numbers of categories in all samples are summed to \(B\).

Parameters
  • input (Tensor) – input of shape \((B, ...)\)

  • size (LongTensor) – number of categories of shape \((N,)\)

variadic_softmax(input, size)[source]#

Compute softmax over categories with variadic sizes.

Suppose there are \(N\) samples, and the numbers of categories in all samples are summed to \(B\).

Parameters
  • input (Tensor) – input of shape \((B, ...)\)

  • size (LongTensor) – number of categories of shape \((N,)\)

variadic_sort(input, size, descending=False)[source]#

Sort elements in sets with variadic sizes.

Suppose there are \(N\) sets, and the sizes of all sets are summed to \(B\).

Parameters
  • input (Tensor) – input of shape \((B, ...)\)

  • size (LongTensor) – size of sets of shape \((N,)\)

  • descending (bool, optional) – return ascending or descending order

Returns

(Tensor, LongTensor): sorted values and indexes

variadic_topk(input, size, k, largest=True)[source]#

Compute the \(k\) largest elements over sets with variadic sizes.

Suppose there are \(N\) sets, and the sizes of all sets are summed to \(B\).

If any set has less than than \(k\) elements, the size-th largest element will be repeated to pad the output to \(k\).

Parameters
  • input (Tensor) – input of shape \((B, ...)\)

  • size (LongTensor) – size of sets of shape \((N,)\)

  • k (int or LongTensor) – the k in “top-k”. Can be a fixed value for all sets, or different values for different sets of shape \((N,)\).

  • largest (bool, optional) – return largest or smallest elements

Returns

(Tensor, LongTensor): top-k values and indexes

variadic_arange(size)[source]#

Return a 1-D tensor that contains integer intervals of variadic sizes. This is a variadic variant of torch.arange(stop).expand(batch_size, -1).

Suppose there are \(N\) intervals.

Parameters

size (LongTensor) – size of intervals of shape \((N,)\)

variadic_randperm(size)[source]#

Return random permutations for sets with variadic sizes. The i-th permutation contains integers from 0 to size[i] - 1.

Suppose there are \(N\) sets.

Parameters
  • size (LongTensor) – size of sets of shape \((N,)\)

  • device (torch.device, optional) – device of the tensor

variadic_sample(input, size, num_sample)[source]#

Draw samples with replacement from sets with variadic sizes.

Suppose there are \(N\) sets, and the sizes of all sets are summed to \(B\).

Parameters
  • input (Tensor) – input of shape \((B, ...)\)

  • size (LongTensor) – size of sets of shape \((N,)\)

  • num_sample (int) – number of samples to draw from each set

variadic_meshgrid(input1, size1, input2, size2)[source]#

Compute the Cartesian product for two batches of sets with variadic sizes.

Suppose there are \(N\) sets in each input, and the sizes of all sets are summed to \(B_1\) and \(B_2\) respectively.

Parameters
  • input1 (Tensor) – input of shape \((B_1, ...)\)

  • size1 (LongTensor) – size of input1 of shape \((N,)\)

  • input2 (Tensor) – input of shape \((B_2, ...)\)

  • size2 (LongTensor) – size of input2 of shape \((N,)\)

Returns

(Tensor, Tensor): the first and the second elements in the Cartesian product

variadic_to_padded(input, size, value=0)[source]#

Convert a variadic tensor to a padded tensor.

Suppose there are \(N\) sets, and the sizes of all sets are summed to \(B\).

Parameters
  • input (Tensor) – input of shape \((B, ...)\)

  • size (LongTensor) – size of sets of shape \((N,)\)

  • value (scalar) – fill value for padding

Returns

padded tensor and mask

Return type

(Tensor, BoolTensor)

padded_to_variadic(padded, size)[source]#

Convert a padded tensor to a variadic tensor.

Parameters
  • padded (Tensor) – padded tensor of shape \((N, ...)\)

  • size (LongTensor) – size of sets of shape \((N,)\)

Tensor Reduction#

masked_mean(input, mask, dim=None, keepdim=False)[source]#

Masked mean of a tensor.

Parameters
  • input (Tensor) – input tensor

  • mask (BoolTensor) – mask tensor

  • dim (int or tuple of int, optional) – dimension to reduce

  • keepdim (bool, optional) – whether retain dim or not

mean_with_nan(input, dim=None, keepdim=False)[source]#

Mean of a tensor. Ignore all nan values.

Parameters
  • input (Tensor) – input tensor

  • dim (int or tuple of int, optional) – dimension to reduce

  • keepdim (bool, optional) – whether retain dim or not

Tensor Construction#

as_mask(indexes, length)[source]#

Convert indexes into a binary mask.

Parameters
  • indexes (LongTensor) – positive indexes

  • length (int) – maximal possible value of indexes

one_hot(index, size)[source]#

Expand indexes into one-hot vectors.

Parameters
  • index (Tensor) – index

  • size (int) – size of the one-hot dimension

multi_slice(starts, ends)[source]#

Compute the union of indexes in multiple slices.

Example:

>>> mask = multi_slice(torch.tensor([0, 1, 4]), torch.tensor([2, 3, 6]), 6)
>>> assert (mask == torch.tensor([0, 1, 2, 4, 5]).all()
Parameters
  • starts (LongTensor) – start indexes of slices

  • ends (LongTensor) – end indexes of slices

multi_slice_mask(starts, ends, length)[source]#

Compute the union of multiple slices into a binary mask.

Example:

>>> mask = multi_slice_mask(torch.tensor([0, 1, 4]), torch.tensor([2, 3, 6]), 6)
>>> assert (mask == torch.tensor([1, 1, 1, 0, 1, 1])).all()
Parameters
  • starts (LongTensor) – start indexes of slices

  • ends (LongTensor) – end indexes of slices

  • length (int) – length of mask

Sampling#

multinomial(input, num_sample, replacement=False)[source]#

Fast multinomial sampling. This is the default implementation in PyTorch v1.6.0+.

Parameters
  • input (Tensor) – unnormalized distribution

  • num_sample (int) – number of samples

  • replacement (bool, optional) – sample with replacement or not

Activation#

shifted_softplus(input)[source]#

Shifted softplus function.

Parameters

input (Tensor) – input tensor