torchdrug.layers

Common Layers

GaussianSmearing

class GaussianSmearing(start=0, stop=5, num_kernel=100, centered=False, learnable=False)[source]

Gaussian smearing from SchNet: A continuous-filter convolutional neural network for modeling quantum interactions.

There are two modes for Gaussian smearing.

Non-centered mode:

\[\mu = [0, 1, ..., n], \sigma = [1, 1, ..., 1]\]

Centered mode:

\[\mu = [0, 0, ..., 0], \sigma = [0, 1, ..., n]\]
Parameters
  • start (int, optional) – minimal input value

  • stop (int, optional) – maximal input value

  • num_kernel (int, optional) – number of RBF kernels

  • centered (bool, optional) – centered mode or not

  • learnable (bool, optional) – learnable gaussian parameters or not

forward(x, y)[source]

Compute smeared gaussian features between data.

Parameters
  • x (Tensor) – data of shape \((..., d)\)

  • y (Tensor) – data of shape \((..., d)\)

Returns

features of shape \((..., num\_kernel)\)

Return type

Tensor

MultiLayerPerceptron

class MultiLayerPerceptron(input_dim, hidden_dims, short_cut=False, batch_norm=False, activation='relu', dropout=0)[source]

Multi-layer Perceptron.

Note there is no batch normalization, activation or dropout in the last layer.

Parameters
  • input_dim (int) – input dimension

  • hidden_dim (list of int) – hidden dimensions

  • short_cut (bool, optional) – use short cut or not

  • batch_norm (bool, optional) – apply batch normalization or not

  • activation (str or function, optional) – activation function

  • dropout (float, optional) – dropout rate

MutualInformation

class MutualInformation(input_dim, num_mlp_layer=2, activation='relu')[source]

Mutual information estimator from Learning deep representations by mutual information estimation and maximization.

Parameters
  • input_dim (int) – input dimension

  • num_mlp_layer (int, optional) – number of MLP layers

  • activation (str or function, optional) – activation function

PairNorm

class PairNorm(scale_individual=False)[source]

Pair normalization layer proposed in PairNorm: Tackling Oversmoothing in GNNs.

Parameters

scale_individual (bool, optional) – additionally normalize each node representation to have the same L2-norm

Sequential

class Sequential(*args, global_args=None, allow_unused=False)[source]

Improved sequential container. Modules will be called in the order they are passed to the constructor.

Compared to the vanilla nn.Sequential, this layer additionally supports the following features.

  1. Multiple input / output arguments.

>>> # layer1 signature: (...) -> (a, b)
>>> # layer2 signature: (a, b) -> (...)
>>> layer = layers.Sequential(layer1, layer2)
  1. Global arguments.

>>> # layer1 signature: (graph, a) -> b
>>> # layer2 signature: (graph, b) -> c
>>> layer = layers.Sequential(layer1, layer2, global_args=("graph",))

Note the global arguments don’t need to be present in every layer.

>>> # layer1 signature: (graph, a) -> b
>>> # layer2 signature: b -> c
>>> # layer3 signature: (graph, c) -> d
>>> layer = layers.Sequential(layer1, layer2, global_args=("graph",))
  1. Dict outputs.

>>> # layer1 signature: a -> {"b": b, "c": c}
>>> # layer2 signature: b -> d
>>> layer = layers.Sequential(layer1, layer2, allow_unused=True)

When dict outputs are used with global arguments, the global arguments can be explicitly overwritten by any layer outputs.

>>> # layer1 signature: (graph, a) -> {"graph": graph, "b": b}
>>> # layer2 signature: (graph, b) -> c
>>> # layer2 takes in the graph output by layer1
>>> layer = layers.Sequential(layer1, layer2, global_args=("graph",))

Convolution Layers

class MessagePassingBase[source]

Base module for message passing.

Any custom message passing module should be derived from this class.

aggregate(graph, message)[source]

Aggregate edge messages to nodes.

Parameters
  • graph (Graph) – graph(s)

  • message (Tensor) – edge messages of shape \((|E|, ...)\)

Returns

node updates of shape \((|V|, ...)\)

Return type

Tensor

combine(input, update)[source]

Combine node input and node update.

Parameters
  • input (Tensor) – node representations of shape \((|V|, ...)\)

  • update (Tensor) – node updates of shape \((|V|, ...)\)

forward(graph, input)[source]

Perform message passing over the graph(s).

Parameters
  • graph (Graph) – graph(s)

  • input (Tensor) – node representations of shape \((|V|, ...)\)

message(graph, input)[source]

Compute edge messages for the graph.

Parameters
  • graph (Graph) – graph(s)

  • input (Tensor) – node representations of shape \((|V|, ...)\)

Returns

edge messages of shape \((|E|, ...)\)

Return type

Tensor

message_and_aggregate(graph, input)[source]

Fused computation of message and aggregation over the graph. This may provide better time or memory complexity than separate calls of message and aggregate.

Parameters
  • graph (Graph) – graph(s)

  • input (Tensor) – node representations of shape \((|V|, ...)\)

Returns

node updates of shape \((|V|, ...)\)

Return type

Tensor

ChebyshevConv

class ChebyshevConv(input_dim, output_dim, edge_input_dim=None, k=1, batch_norm=False, activation='relu')[source]

Chebyshev spectral graph convolution operator from Convolutional Neural Networks on Graphs with Fast Localized Spectral Filtering.

Parameters
  • input_dim (int) – input dimension

  • output_dim (int) – output dimension

  • edge_input_dim (int, optional) – dimension of edge features

  • k (int, optional) – number of Chebyshev polynomials. This also corresponds to the radius of the receptive field.

  • hidden_dims (list of int, optional) – hidden dims of edge network

  • batch_norm (bool, optional) – apply batch normalization on nodes or not

  • activation (str or function, optional) – activation function

aggregate(graph, message)[source]

Aggregate edge messages to nodes.

Parameters
  • graph (Graph) – graph(s)

  • message (Tensor) – edge messages of shape \((|E|, ...)\)

Returns

node updates of shape \((|V|, ...)\)

Return type

Tensor

combine(input, update)[source]

Combine node input and node update.

Parameters
  • input (Tensor) – node representations of shape \((|V|, ...)\)

  • update (Tensor) – node updates of shape \((|V|, ...)\)

forward(graph, input)[source]

Perform message passing over the graph(s).

Parameters
  • graph (Graph) – graph(s)

  • input (Tensor) – node representations of shape \((|V|, ...)\)

message(graph, input)[source]

Compute edge messages for the graph.

Parameters
  • graph (Graph) – graph(s)

  • input (Tensor) – node representations of shape \((|V|, ...)\)

Returns

edge messages of shape \((|E|, ...)\)

Return type

Tensor

message_and_aggregate(graph, input)[source]

Fused computation of message and aggregation over the graph. This may provide better time or memory complexity than separate calls of message and aggregate.

Parameters
  • graph (Graph) – graph(s)

  • input (Tensor) – node representations of shape \((|V|, ...)\)

Returns

node updates of shape \((|V|, ...)\)

Return type

Tensor

ContinuousFilterConv

class ContinuousFilterConv(input_dim, output_dim, edge_input_dim=None, hidden_dim=None, cutoff=5, num_gaussian=100, batch_norm=False, activation='shifted_softplus')[source]

Continuous filter operator from SchNet: A continuous-filter convolutional neural network for modeling quantum interactions.

Parameters
  • input_dim (int) – input dimension

  • output_dim (int) – output dimension

  • edge_input_dim (int, optional) – dimension of edge features

  • hidden_dim (int, optional) – hidden dimension. By default, same as output_dim

  • cutoff (float, optional) – maximal scale for RBF kernels

  • num_gaussian (int, optional) – number of RBF kernels

  • batch_norm (bool, optional) – apply batch normalization on nodes or not

  • activation (str or function, optional) – activation function

aggregate(graph, message)[source]

Aggregate edge messages to nodes.

Parameters
  • graph (Graph) – graph(s)

  • message (Tensor) – edge messages of shape \((|E|, ...)\)

Returns

node updates of shape \((|V|, ...)\)

Return type

Tensor

combine(input, update)[source]

Combine node input and node update.

Parameters
  • input (Tensor) – node representations of shape \((|V|, ...)\)

  • update (Tensor) – node updates of shape \((|V|, ...)\)

message(graph, input)[source]

Compute edge messages for the graph.

Parameters
  • graph (Graph) – graph(s)

  • input (Tensor) – node representations of shape \((|V|, ...)\)

Returns

edge messages of shape \((|E|, ...)\)

Return type

Tensor

message_and_aggregate(graph, input)[source]

Fused computation of message and aggregation over the graph. This may provide better time or memory complexity than separate calls of message and aggregate.

Parameters
  • graph (Graph) – graph(s)

  • input (Tensor) – node representations of shape \((|V|, ...)\)

Returns

node updates of shape \((|V|, ...)\)

Return type

Tensor

GraphAttentionConv

class GraphAttentionConv(input_dim, output_dim, edge_input_dim=None, num_head=1, negative_slope=0.2, concat=True, batch_norm=False, activation='relu')[source]

Graph attentional convolution operator from Graph Attention Networks.

Parameters
  • input_dim (int) – input dimension

  • output_dim (int) – output dimension

  • edge_input_dim (int, optional) – dimension of edge features

  • num_head (int, optional) – number of attention heads

  • negative_slope (float, optional) – negative slope of leaky relu activation

  • batch_norm (bool, optional) – apply batch normalization on nodes or not

  • activation (str or function, optional) – activation function

aggregate(graph, message)[source]

Aggregate edge messages to nodes.

Parameters
  • graph (Graph) – graph(s)

  • message (Tensor) – edge messages of shape \((|E|, ...)\)

Returns

node updates of shape \((|V|, ...)\)

Return type

Tensor

combine(input, update)[source]

Combine node input and node update.

Parameters
  • input (Tensor) – node representations of shape \((|V|, ...)\)

  • update (Tensor) – node updates of shape \((|V|, ...)\)

message(graph, input)[source]

Compute edge messages for the graph.

Parameters
  • graph (Graph) – graph(s)

  • input (Tensor) – node representations of shape \((|V|, ...)\)

Returns

edge messages of shape \((|E|, ...)\)

Return type

Tensor

GraphConv

class GraphConv(input_dim, output_dim, edge_input_dim=None, batch_norm=False, activation='relu')[source]

Graph convolution operator from Semi-Supervised Classification with Graph Convolutional Networks.

Parameters
  • input_dim (int) – input dimension

  • output_dim (int) – output dimension

  • edge_input_dim (int, optional) – dimension of edge features

  • batch_norm (bool, optional) – apply batch normalization on nodes or not

  • activation (str or function, optional) – activation function

aggregate(graph, message)[source]

Aggregate edge messages to nodes.

Parameters
  • graph (Graph) – graph(s)

  • message (Tensor) – edge messages of shape \((|E|, ...)\)

Returns

node updates of shape \((|V|, ...)\)

Return type

Tensor

combine(input, update)[source]

Combine node input and node update.

Parameters
  • input (Tensor) – node representations of shape \((|V|, ...)\)

  • update (Tensor) – node updates of shape \((|V|, ...)\)

message(graph, input)[source]

Compute edge messages for the graph.

Parameters
  • graph (Graph) – graph(s)

  • input (Tensor) – node representations of shape \((|V|, ...)\)

Returns

edge messages of shape \((|E|, ...)\)

Return type

Tensor

message_and_aggregate(graph, input)[source]

Fused computation of message and aggregation over the graph. This may provide better time or memory complexity than separate calls of message and aggregate.

Parameters
  • graph (Graph) – graph(s)

  • input (Tensor) – node representations of shape \((|V|, ...)\)

Returns

node updates of shape \((|V|, ...)\)

Return type

Tensor

GraphIsomorphismConv

class GraphIsomorphismConv(input_dim, output_dim, edge_input_dim=None, hidden_dims=None, eps=0, learn_eps=False, batch_norm=False, activation='relu')[source]

Graph isomorphism convolution operator from How Powerful are Graph Neural Networks?

Parameters
  • input_dim (int) – input dimension

  • output_dim (int) – output dimension

  • edge_input_dim (int, optional) – dimension of edge features

  • hidden_dims (list of int, optional) – hidden dimensions

  • eps (float, optional) – initial epsilon

  • learn_eps (bool, optional) – learn epsilon or not

  • batch_norm (bool, optional) – apply batch normalization on nodes or not

  • activation (str or function, optional) – activation function

aggregate(graph, message)[source]

Aggregate edge messages to nodes.

Parameters
  • graph (Graph) – graph(s)

  • message (Tensor) – edge messages of shape \((|E|, ...)\)

Returns

node updates of shape \((|V|, ...)\)

Return type

Tensor

combine(input, update)[source]

Combine node input and node update.

Parameters
  • input (Tensor) – node representations of shape \((|V|, ...)\)

  • update (Tensor) – node updates of shape \((|V|, ...)\)

message(graph, input)[source]

Compute edge messages for the graph.

Parameters
  • graph (Graph) – graph(s)

  • input (Tensor) – node representations of shape \((|V|, ...)\)

Returns

edge messages of shape \((|E|, ...)\)

Return type

Tensor

message_and_aggregate(graph, input)[source]

Fused computation of message and aggregation over the graph. This may provide better time or memory complexity than separate calls of message and aggregate.

Parameters
  • graph (Graph) – graph(s)

  • input (Tensor) – node representations of shape \((|V|, ...)\)

Returns

node updates of shape \((|V|, ...)\)

Return type

Tensor

MessagePassing

class MessagePassing(input_dim, edge_input_dim, hidden_dims=None, batch_norm=False, activation='relu')[source]

Message passing operator from Neural Message Passing for Quantum Chemistry.

This implements the edge network variant in the original paper.

Parameters
  • input_dim (int) – input dimension

  • edge_input_dim (int) – dimension of edge features

  • hidden_dims (list of int, optional) – hidden dims of edge network

  • batch_norm (bool, optional) – apply batch normalization on nodes or not

  • activation (str or function, optional) – activation function

aggregate(graph, message)[source]

Aggregate edge messages to nodes.

Parameters
  • graph (Graph) – graph(s)

  • message (Tensor) – edge messages of shape \((|E|, ...)\)

Returns

node updates of shape \((|V|, ...)\)

Return type

Tensor

combine(input, update)[source]

Combine node input and node update.

Parameters
  • input (Tensor) – node representations of shape \((|V|, ...)\)

  • update (Tensor) – node updates of shape \((|V|, ...)\)

message(graph, input)[source]

Compute edge messages for the graph.

Parameters
  • graph (Graph) – graph(s)

  • input (Tensor) – node representations of shape \((|V|, ...)\)

Returns

edge messages of shape \((|E|, ...)\)

Return type

Tensor

NeuralFingerprintConv

class NeuralFingerprintConv(input_dim, output_dim, edge_input_dim=None, batch_norm=False, activation='relu')[source]

Graph neural network operator from Convolutional Networks on Graphs for Learning Molecular Fingerprints.

Note this operator doesn’t include the sparsifying step of the original paper.

Parameters
  • input_dim (int) – input dimension

  • output_dim (int) – output dimension

  • edge_input_dim (int, optional) – dimension of edge features

  • batch_norm (bool, optional) – apply batch normalization on nodes or not

  • activation (str or function, optional) – activation function

aggregate(graph, message)[source]

Aggregate edge messages to nodes.

Parameters
  • graph (Graph) – graph(s)

  • message (Tensor) – edge messages of shape \((|E|, ...)\)

Returns

node updates of shape \((|V|, ...)\)

Return type

Tensor

combine(input, update)[source]

Combine node input and node update.

Parameters
  • input (Tensor) – node representations of shape \((|V|, ...)\)

  • update (Tensor) – node updates of shape \((|V|, ...)\)

message(graph, input)[source]

Compute edge messages for the graph.

Parameters
  • graph (Graph) – graph(s)

  • input (Tensor) – node representations of shape \((|V|, ...)\)

Returns

edge messages of shape \((|E|, ...)\)

Return type

Tensor

message_and_aggregate(graph, input)[source]

Fused computation of message and aggregation over the graph. This may provide better time or memory complexity than separate calls of message and aggregate.

Parameters
  • graph (Graph) – graph(s)

  • input (Tensor) – node representations of shape \((|V|, ...)\)

Returns

node updates of shape \((|V|, ...)\)

Return type

Tensor

RelationalGraphConv

class RelationalGraphConv(input_dim, output_dim, num_relation, edge_input_dim=None, batch_norm=False, activation='relu')[source]

Relational graph convolution operator from Modeling Relational Data with Graph Convolutional Networks.

Parameters
  • input_dim (int) – input dimension

  • output_dim (int) – output dimension

  • num_relation (int) – number of relations

  • edge_input_dim (int, optional) – dimension of edge features

  • batch_norm (bool, optional) – apply batch normalization on nodes or not

  • activation (str or function, optional) – activation function

aggregate(graph, message)[source]

Aggregate edge messages to nodes.

Parameters
  • graph (Graph) – graph(s)

  • message (Tensor) – edge messages of shape \((|E|, ...)\)

Returns

node updates of shape \((|V|, ...)\)

Return type

Tensor

combine(input, update)[source]

Combine node input and node update.

Parameters
  • input (Tensor) – node representations of shape \((|V|, ...)\)

  • update (Tensor) – node updates of shape \((|V|, ...)\)

message(graph, input)[source]

Compute edge messages for the graph.

Parameters
  • graph (Graph) – graph(s)

  • input (Tensor) – node representations of shape \((|V|, ...)\)

Returns

edge messages of shape \((|E|, ...)\)

Return type

Tensor

message_and_aggregate(graph, input)[source]

Fused computation of message and aggregation over the graph. This may provide better time or memory complexity than separate calls of message and aggregate.

Parameters
  • graph (Graph) – graph(s)

  • input (Tensor) – node representations of shape \((|V|, ...)\)

Returns

node updates of shape \((|V|, ...)\)

Return type

Tensor

Readout Layers

MeanReadout

class MeanReadout[source]

Mean readout operator over graphs with variadic sizes.

forward(graph, input)[source]

Perform readout over the graph(s).

Parameters
  • graph (Graph) – graph(s)

  • input (Tensor) – node representations

Returns

graph representations

Return type

Tensor

SumReadout

class SumReadout[source]

Sum readout operator over graphs with variadic sizes.

forward(graph, input)[source]

Perform readout over the graph(s).

Parameters
  • graph (Graph) – graph(s)

  • input (Tensor) – node representations

Returns

graph representations

Return type

Tensor

MaxReadout

class MaxReadout[source]

Max readout operator over graphs with variadic sizes.

forward(graph, input)[source]

Perform readout over the graph(s).

Parameters
  • graph (Graph) – graph(s)

  • input (Tensor) – node representations

Returns

graph representations

Return type

Tensor

Set2Set

class Set2Set(input_dim, num_step=3, num_lstm_layer=1)[source]

Set2Set operator from Order Matters: Sequence to sequence for sets.

Parameters
  • input_dim (int) – input dimension

  • num_step (int, optional) – number of process steps

  • num_lstm_layer (int, optional) – number of LSTM layers

forward(graph, input)[source]

Perform Set2Set readout over graph(s).

Parameters
  • graph (Graph) – graph(s)

  • input (Tensor) – node representations

Returns

graph representations

Return type

Tensor

Softmax

class Softmax[source]

Softmax operator over graphs with variadic sizes.

forward(graph, input)[source]

Perform softmax over the graph(s).

Parameters
  • graph (Graph) – graph(s)

  • input (Tensor) – node logits

Returns

node probabilities

Return type

Tensor

Sort

class Sort(descending=False)[source]

Sort operator over graphs with variadic sizes.

Parameters

descending (bool, optional) – use descending sort order or not

forward(graph, input)[source]

Perform sort over graph(s).

Parameters
  • graph (Graph) – graph(s)

  • input (Tensor) – node values

Returns

sorted values, sorted indices

Return type

(Tensor, LongTensor)

Pooling Layers

DiffPool

class DiffPool(input_dim, output_node, feature_layer=None, pool_layer=None, loss_weight=1, zero_diagonal=False, sparse=False)[source]

Differentiable pooling operator from Hierarchical Graph Representation Learning with Differentiable Pooling

Parameter

input_dim (int): input dimension output_node (int): number of nodes after pooling feature_layer (Module, optional): graph convolution layer for embedding pool_layer (Module, optional): graph convolution layer for pooling assignment loss_weight (float, optional): weight of entropy regularization zero_diagonal (bool, optional): remove self loops in the pooled graph or not sparse (bool, optional): use sparse assignment or not

forward(graph, input, all_loss=None, metric=None)[source]

Compute the node cluster assignment and pool the nodes.

Parameters
  • graph (Graph) – graph(s)

  • input (Tensor) – input node representations

  • all_loss (Tensor, optional) – if specified, add loss to this tensor

  • metric (dict, optional) – if specified, output metrics to this dict

Returns

pooled graph, output node representations, node-to-cluster assignment

Return type

(PackedGraph, Tensor, Tensor)

MinCutPool

class MinCutPool(input_dim, output_node, feature_layer=None, pool_layer=None, loss_weight=1, zero_diagonal=True, sparse=False)[source]

Min cut pooling operator from Spectral Clustering with Graph Neural Networks for Graph Pooling

Parameters
  • input_dim (int) – input dimension

  • output_node (int) – number of nodes after pooling

  • feature_layer (Module, optional) – graph convolution layer for embedding

  • pool_layer (Module, optional) – graph convolution layer for pooling assignment

  • loss_weight (float, optional) – weight of entropy regularization

  • zero_diagonal (bool, optional) – remove self loops in the pooled graph or not

  • sparse (bool, optional) – use sparse assignment or not

forward(graph, input, all_loss=None, metric=None)[source]

Compute the node cluster assignment and pool the nodes.

Parameters
  • graph (Graph) – graph(s)

  • input (Tensor) – input node representations

  • all_loss (Tensor, optional) – if specified, add loss to this tensor

  • metric (dict, optional) – if specified, output metrics to this dict

Returns

pooled graph, output node representations, node-to-cluster assignment

Return type

(PackedGraph, Tensor, Tensor)

Sampler Layers

EdgeSampler

class EdgeSampler(budget=None, ratio=None)[source]

Edge sampler from GraphSAINT: Graph Sampling Based Inductive Learning Method.

Parameters
  • budget (int, optional) – number of node to keep

  • ratio (int, optional) – ratio of node to keep

forward(graph)[source]

Sample a subgraph from the graph.

Parameters

graph (Graph) – graph(s)

NodeSampler

class NodeSampler(budget=None, ratio=None)[source]

Node sampler from GraphSAINT: Graph Sampling Based Inductive Learning Method.

Parameters
  • budget (int, optional) – number of node to keep

  • ratio (int, optional) – ratio of node to keep

forward(graph)[source]

Sample a subgraph from the graph.

Parameters

graph (Graph) – graph(s)

Flow Layers

ConditionalFlow

class ConditionalFlow(input_dim, condition_dim, hidden_dims=None, activation='relu')[source]

Conditional flow transformation from Masked Autoregressive Flow for Density Estimation.

Parameters
  • input_dim (int) – input & output dimension

  • condition_dim (int) – condition dimension

  • hidden_dims (list of int, optional) – hidden dimensions

  • activation (str or function, optional) – activation function

forward(input, condition)[source]

Transform data into latent representations.

Parameters
  • input (Tensor) – input representations

  • condition (Tensor) – conditional representations

Returns

latent representations, log-likelihood of the transformation

Return type

(Tensor, Tensor)

reverse(latent, condition)[source]

Transform latent representations into data.

Parameters
  • latent (Tensor) – latent representations

  • condition (Tensor) – conditional representations

Returns

input representations, log-likelihood of the transformation

Return type

(Tensor, Tensor)

Distribution Layers

These layers belong to torchdrug.layers.distribution.

IndependentGaussian

class IndependentGaussian(mu, sigma2, learnable=False)[source]

Independent Gaussian distribution.

Parameters
  • mu (Tensor) – mean of shape \((N,)\)

  • sigma2 (Tensor) – variance of shape \((N,)\)

  • learnable (bool, optional) – learnable parameters or not

forward(input)[source]

Compute the likelihood of input data.

Parameters

input (Tensor) – input data of shape \((..., N)\)

sample(*size)[source]

Draw samples from the distribution.

Parameters

size (tuple of int) – shape of the samples

Functional Layers

These layers belong to torchdrug.layers.functional.

Embedding Score Functions

transe_score(entity, relation, h_index, t_index, r_index)[source]

TransE score function from Translating Embeddings for Modeling Multi-relational Data.

Parameters
  • entity (Tensor) – entity embeddings of shape \((|V|, d)\)

  • relation (Tensor) – relation embeddings of shape \((|R|, d)\)

  • h_index (LongTensor) – index of head entities

  • t_index (LongTensor) – index of tail entities

  • r_index (LongTensor) – index of relations

distmult_score(entity, relation, h_index, t_index, r_index)[source]

DistMult score function from Embedding Entities and Relations for Learning and Inference in Knowledge Bases.

Parameters
  • entity (Tensor) – entity embeddings of shape \((|V|, d)\)

  • relation (Tensor) – relation embeddings of shape \((|R|, d)\)

  • h_index (LongTensor) – index of head entities

  • t_index (LongTensor) – index of tail entities

  • r_index (LongTensor) – index of relations

complex_score(entity, relation, h_index, t_index, r_index)[source]

ComplEx score function from Complex Embeddings for Simple Link Prediction.

Parameters
  • entity (Tensor) – entity embeddings of shape \((|V|, 2d)\)

  • relation (Tensor) – relation embeddings of shape \((|R|, 2d)\)

  • h_index (LongTensor) – index of head entities

  • t_index (LongTensor) – index of tail entities

  • r_index (LongTensor) – index of relations

simple_score(entity, relation, h_index, t_index, r_index)[source]

SimplE score function from SimplE Embedding for Link Prediction in Knowledge Graphs.

Parameters
  • entity (Tensor) – entity embeddings of shape \((|V|, 2d)\)

  • relation (Tensor) – relation embeddings of shape \((|R|, d)\)

  • h_index (LongTensor) – index of head entities

  • t_index (LongTensor) – index of tail entities

  • r_index (LongTensor) – index of relations

rotate_score(entity, relation, h_index, t_index, r_index)[source]

RotatE score function from RotatE: Knowledge Graph Embedding by Relational Rotation in Complex Space.

Parameters
  • entity (Tensor) – entity embeddings of shape \((|V|, 2d)\)

  • relation (Tensor) – relation embeddings of shape \((|R|, d)\)

  • h_index (LongTensor) – index of head entities

  • t_index (LongTensor) – index of tail entities

  • r_index (LongTensor) – index of relations

Sparse Matrix Multiplication

generalized_spmm(sparse, input, sum='add', mul='mul')[source]

Generalized sparse-dense matrix multiplication.

This function computes the matrix multiplication of a sparse matrix and a dense input matrix. The output dense matrix satisfies

\[output_{i,k} = \bigoplus_{j: sparse_{i,j} \neq 0} sparse_{i,j} \otimes input_{j,k}\]

where \(\oplus\) and \(\otimes\) are the summation and the multiplication operators respectively.

Warning

Gradient w.r.t. the sparse matrix is only computed for non-zero entries of the sparse matrix. This behaves differently from dense-dense matrix multiplication with zero entries.

Parameters
  • sparse (SparseTensor) – 2D sparse tensor

  • input (Tensor) – 2D dense tensor

  • sum (str, optional) – generalized summation operator. Available operators are add, min and max.

  • mul (str, optional) – generalized multiplication operator. Available operators are add and mul.

generalized_rspmm(sparse, relation, input, sum='add', mul='mul')[source]

Generalized relational sparse-dense matrix multiplication.

This function computes the matrix multiplication of a sparse matrix, a dense relation matrix and a dense input matrix. The output dense matrix satisfies

\[output_{i,l} = \bigoplus_{j,k: sparse_{i,j,k} \neq 0} sparse_{i, j, k} \times (relation_{k,l} \otimes input_{j,l})\]

where \(\oplus\) and \(\otimes\) are the summation and the multiplication operators respectively.

Warning

Gradient w.r.t. the sparse matrix is only computed for non-zero entries of the sparse matrix. This behaves differently from dense-dense matrix multiplication with zero entries.

Parameters
  • sparse (SparseTensor) – 3D sparse tensor

  • relation (Tensor) – 2D dense tensor

  • input (Tensor) – 2D dense tensor

  • sum (str, optional) – generalized summation operator. Available operators are add, min and max.

  • mul (str, optional) – generalized multiplication operator. Available operators are add and mul.

Variadic

variadic_sum(input, size)[source]

Compute sum over sets with variadic sizes.

Suppose there are \(N\) sets, and the sizes of all sets are summed to \(B\).

Parameters
  • input (Tensor) – input of shape \((B, ...)\)

  • size (LongTensor) – size of sets of shape \((N,)\)

variadic_mean(input, size)[source]

Compute mean over sets with variadic sizes.

Suppose there are \(N\) sets, and the sizes of all sets are summed to \(B\).

Parameters
  • input (Tensor) – input of shape \((B, ...)\)

  • size (LongTensor) – size of sets of shape \((N,)\)

variadic_max(input, size)[source]

Compute max over sets with variadic sizes.

Suppose there are \(N\) sets, and the sizes of all sets are summed to \(B\).

Parameters
  • input (Tensor) – input of shape \((B, ...)\)

  • size (LongTensor) – size of sets of shape \((N,)\)

Returns

(Tensor, LongTensor): max values and indexes

variadic_cross_entropy(input, target, size, reduction='mean')[source]

Compute cross entropy loss over categories with variadic sizes.

Suppose there are \(N\) samples, and the numbers of categories in all samples are summed to \(B\).

Parameters
  • input (Tensor) – prediction of shape \((B, ...)\)

  • target (Tensor) – target of shape \((N, ...)\). Each target is a relative index in a sample.

  • size (LongTensor) – number of categories of shape \((N,)\)

  • reduction (string, optional) – reduction to apply to the output. Available reductions are none, sum and mean.

variadic_log_softmax(input, size)[source]

Compute log softmax over categories with variadic sizes.

Suppose there are \(N\) samples, and the numbers of categories in all samples are summed to \(B\).

Parameters
  • input (Tensor) – input of shape \((B, ...)\)

  • size (LongTensor) – number of categories of shape \((N,)\)

variadic_softmax(input, size)[source]

Compute softmax over categories with variadic sizes.

Suppose there are \(N\) samples, and the numbers of categories in all samples are summed to \(B\).

Parameters
  • input (Tensor) – input of shape \((B, ...)\)

  • size (LongTensor) – number of categories of shape \((N,)\)

variadic_sort(input, size, descending=False)[source]

Sort elements in sets with variadic sizes.

Suppose there are \(N\) sets, and the sizes of all sets are summed to \(B\).

Parameters
  • input (Tensor) – input of shape \((B, ...)\)

  • size (LongTensor) – size of sets of shape \((N,)\)

  • descending (bool, optional) – return ascending or descending order

variadic_topk(input, size, k, largest=True)[source]

Compute the \(k\) largest elements over sets with variadic sizes.

Suppose there are \(N\) sets, and the sizes of all sets are summed to \(B\).

If any set has less than than \(k\) elements, the size-th largest element will be repeated to pad the output to \(k\).

Parameters
  • input (Tensor) – input of shape \((B, ...)\)

  • size (LongTensor) – size of sets of shape \((N,)\)

  • k (int) – the k in “top-k”

  • largest (bool, optional) – return largest or smallest elements

Returns

(Tensor, LongTensor): top-k values and indexes

variadic_arange(size)[source]

Return a 1-D tensor that contains integer intervals of variadic sizes. This is a variadic variant of torch.arange(stop).expand(batch_size, -1).

Suppose there are \(N\) intervals.

Parameters

size (LongTensor) – size of intervals of shape \((N,)\)

variadic_randperm(size)[source]

Return random permutations for sets with variadic sizes. The i-th permutation contains integers from 0 to size[i] - 1.

Suppose there are \(N\) sets.

Parameters
  • size (LongTensor) – size of sets of shape \((N,)\)

  • device (torch.device, optional) – device of the tensor

variadic_sample(input, size, num_sample)[source]

Draw samples with replacement from sets with variadic sizes.

Suppose there are \(N\) sets, and the sizes of all sets are summed to \(B\).

Parameters
  • input (Tensor) – input of shape \((B, ...)\)

  • size (LongTensor) – size of sets of shape \((N,)\)

  • num_sample (int) – number of samples to draw from each set

Tensor Reduction

masked_mean(input, mask, dim=None, keepdim=False)[source]

Masked mean of a tensor.

Parameters
  • input (Tensor) – input tensor

  • mask (BoolTensor) – mask tensor

  • dim (int or tuple of int, optional) – dimension to reduce

  • keepdim (bool, optional) – whether retain dim or not

mean_with_nan(input, dim=None, keepdim=False)[source]

Mean of a tensor. Ignore all nan values.

Parameters
  • input (Tensor) – input tensor

  • dim (int or tuple of int, optional) – dimension to reduce

  • keepdim (bool, optional) – whether retain dim or not

Tensor Construction

as_mask(indexes, length)[source]

Convert indexes into a binary mask.

Parameters
  • indexes (LongTensor) – positive indexes

  • length (int) – maximal possible value of indexes

one_hot(index, size)[source]

Expand indexes into one-hot vectors.

Parameters
  • index (Tensor) – index

  • size (int) – size of the one-hot dimension

multi_slice(starts, ends)[source]

Compute the union of indexes in multiple slices.

Example:

>>> mask = multi_slice(torch.tensor([0, 1, 4]), torch.tensor([2, 3, 6]), 6)
>>> assert (mask == torch.tensor([0, 1, 2, 4, 5]).all()
Parameters
  • starts (LongTensor) – start indexes of slices

  • ends (LongTensor) – end indexes of slices

multi_slice_mask(starts, ends, length)[source]

Compute the union of multiple slices into a binary mask.

Example:

>>> mask = multi_slice_mask(torch.tensor([0, 1, 4]), torch.tensor([2, 3, 6]), 6)
>>> assert (mask == torch.tensor([1, 1, 1, 0, 1, 1])).all()
Parameters
  • starts (LongTensor) – start indexes of slices

  • ends (LongTensor) – end indexes of slices

  • length (int) – length of mask

Sampling

multinomial(input, num_sample, replacement=False)[source]

Fast multinomial sampling. This is the default implementation in PyTorch v1.6.0+.

Parameters
  • input (Tensor) – unnormalized distribution

  • num_sample (int) – number of samples

  • replacement (bool, optional) – sample with replacement or not

Activation

shifted_softplus(input)[source]

Shifted softplus function.

Parameters

input (Tensor) – input tensor