torchdrug.models

Knowledge Graph Embedding

TransE

class TransE(num_entity, num_relation, embedding_dim, max_score=12)[source]

TransE embedding proposed in Translating Embeddings for Modeling Multi-relational Data.

Parameters
  • num_entity (int) – number of entities

  • num_relation (int) – number of relations

  • embedding_dim (int) – dimension of embeddings

  • max_score (float, optional) – maximal score for triplets

forward(graph, h_index, t_index, r_index, all_loss=None, metric=None)[source]

Compute the score for each triplet.

Parameters
  • graph (Graph) – fact graph

  • h_index (Tensor) – indexes of head entities

  • t_index (Tensor) – indexes of tail entities

  • r_index (Tensor) – indexes of relations

  • all_loss (Tensor, optional) – if specified, add loss to this tensor

  • metric (dict, optional) – if specified, output metrics to this dict

DistMult

class DistMult(num_entity, num_relation, embedding_dim, l3_regularization=0)[source]

DistMult embedding proposed in Embedding Entities and Relations for Learning and Inference in Knowledge Bases.

Parameters
  • num_entity (int) – number of entities

  • num_relation (int) – number of relations

  • embedding_dim (int) – dimension of embeddings

  • l3_regularization (float, optional) – weight for l3 regularization

forward(graph, h_index, t_index, r_index, all_loss=None, metric=None)[source]

Compute the score for each triplet.

Parameters
  • graph (Graph) – fact graph

  • h_index (Tensor) – indexes of head entities

  • t_index (Tensor) – indexes of tail entities

  • r_index (Tensor) – indexes of relations

  • all_loss (Tensor, optional) – if specified, add loss to this tensor

  • metric (dict, optional) – if specified, output metrics to this dict

ComplEx

class ComplEx(num_entity, num_relation, embedding_dim, l3_regularization=0)[source]

ComplEx embedding proposed in Complex Embeddings for Simple Link Prediction.

Parameters
  • num_entity (int) – number of entities

  • num_relation (int) – number of relations

  • embedding_dim (int) – dimension of embeddings

  • l3_regularization (float, optional) – weight for l3 regularization

forward(graph, h_index, t_index, r_index, all_loss=None, metric=None)[source]

Compute the score for triplets.

Parameters
  • graph (Graph) – fact graph

  • h_index (Tensor) – indexes of head entities

  • t_index (Tensor) – indexes of tail entities

  • r_index (Tensor) – indexes of relations

  • all_loss (Tensor, optional) – if specified, add loss to this tensor

  • metric (dict, optional) – if specified, output metrics to this dict

SimplE

class SimplE(num_entity, num_relation, embedding_dim, l3_regularization=0)[source]

SimplE embedding proposed in SimplE Embedding for Link Prediction in Knowledge Graphs.

Parameters
  • num_entity (int) – number of entities

  • num_relation (int) – number of relations

  • embedding_dim (int) – dimension of embeddings

  • l3_regularization (float, optional) – maximal score for triplets

forward(graph, h_index, t_index, r_index, all_loss=None, metric=None)[source]

Compute the score for each triplet.

Parameters
  • graph (Graph) – fact graph

  • h_index (Tensor) – indexes of head entities

  • t_index (Tensor) – indexes of tail entities

  • r_index (Tensor) – indexes of relations

  • all_loss (Tensor, optional) – if specified, add loss to this tensor

  • metric (dict, optional) – if specified, output metrics to this dict

RotatE

class RotatE(num_entity, num_relation, embedding_dim, max_score=12)[source]

RotatE embedding proposed in RotatE: Knowledge Graph Embedding by Relational Rotation in Complex Space.

Parameters
  • num_entity (int) – number of entities

  • num_relation (int) – number of relations

  • embedding_dim (int) – dimension of embeddings

  • max_score (float, optional) – maximal score for triplets

forward(graph, h_index, t_index, r_index, all_loss=None, metric=None)[source]

Compute the score for each triplet.

Parameters
  • graph (Graph) – fact graph

  • h_index (Tensor) – indexes of head entities

  • t_index (Tensor) – indexes of tail entities

  • r_index (Tensor) – indexes of relations

  • all_loss (Tensor, optional) – if specified, add loss to this tensor

  • metric (dict, optional) – if specified, output metrics to this dict

NeuralLP

class NeuralLogicProgramming(num_entity, num_relation, hidden_dim, num_step, num_lstm_layer=1)[source]

Neural Logic Programming proposed in Differentiable Learning of Logical Rules for Knowledge Base Reasoning.

Parameters
  • num_entity (int) – number of entities

  • num_relation (int) – number of relations

  • hidden_dim (int) – dimension of hidden units in LSTM

  • num_step (int) – number of recurrent steps

  • num_lstm_layer (int, optional) – number of LSTM layers

forward(graph, h_index, t_index, r_index, all_loss=None, metric=None)[source]

Compute the score for triplets.

Parameters
  • graph (Tensor) – fact graph

  • h_index (Tensor) – indexes of head entities

  • t_index (Tensor) – indexes of tail entities

  • r_index (Tensor) – indexes of relations

  • all_loss (Tensor, optional) – if specified, add loss to this tensor

  • metric (dict, optional) – if specified, output metrics to this dict

NeuralLP

alias of torchdrug.models.neurallp.NeuralLogicProgramming

KBGAT

class KnowledgeBaseGraphAttentionNetwork(num_entity, num_relation, embedding_dim, hidden_dims, max_score=12, edge_input_dim=None, num_head=1, negative_slope=0.2, short_cut=False, batch_norm=False, activation='relu', concat_hidden=False, readout='sum')[source]

Knowledge Base Graph Attention Network proposed in Learning Attention-based Embeddings for Relation Prediction in Knowledge Graphs.

Parameters
  • num_entity (int) – number of entities

  • num_relation (int) – number of relations

  • embedding_dim (int) – dimension of embeddings

  • hidden_dims (list of int) – hidden dimensions

  • max_score (float, optional) – maximal score for triplets

  • edge_input_dim (int, optional) – dimension of edge features

  • num_head (int, optional) – number of attention heads

  • negative_slope (float, optional) – negative slope of leaky relu activation

  • short_cut (bool, optional) – use short cut or not

  • batch_norm (bool, optional) – apply batch normalization or not

  • activation (str or function, optional) – activation function

  • concat_hidden (bool, optional) – concat hidden representations from all layers as output

  • readout (str, optional) – readout function. Available functions are sum and mean.

forward(graph, h_index, t_index, r_index, all_loss=None, metric=None)[source]

Compute the score for triplets.

Parameters
  • graph (Graph) – fact graph

  • h_index (Tensor) – indexes of head entities

  • t_index (Tensor) – indexes of tail entities

  • r_index (Tensor) – indexes of relations

  • all_loss (Tensor, optional) – if specified, add loss to this tensor

  • metric (dict, optional) – if specified, output metrics to this dict

KBGAT

alias of torchdrug.models.kbgat.KnowledgeBaseGraphAttentionNetwork

Graph Neural Networks

ChebNet

class ChebyshevConvolutionalNetwork(input_dim, hidden_dims, edge_input_dim=None, k=1, short_cut=False, batch_norm=False, activation='relu', concat_hidden=False, readout='sum')[source]

Chebyshev convolutional network proposed in Convolutional Neural Networks on Graphs with Fast Localized Spectral Filtering.

Parameters
  • input_dim (int) – input dimension

  • hidden_dims (list of int) – hidden dimensions

  • edge_input_dim (int, optional) – dimension of edge features

  • k (int, optional) – number of Chebyshev polynomials

  • short_cut (bool, optional) – use short cut or not

  • batch_norm (bool, optional) – apply batch normalization or not

  • activation (str or function, optional) – activation function

  • concat_hidden (bool, optional) – concat hidden representations from all layers as output

  • readout (str, optional) – readout function. Available functions are sum and mean.

forward(graph, input, all_loss=None, metric=None)[source]

Compute the node representations and the graph representation(s).

Parameters
  • graph (Graph) – \(n\) graph(s)

  • input (Tensor) – input node representations

  • all_loss (Tensor, optional) – if specified, add loss to this tensor

  • metric (dict, optional) – if specified, output metrics to this dict

Returns

node representations of shape \((|V|, d)\), graph representations of shape \((n, d)\)

Return type

dict with node_feature and graph_feature fields

ChebNet

alias of torchdrug.models.chebnet.ChebyshevConvolutionalNetwork

GCN

class GraphConvolutionalNetwork(input_dim, hidden_dims, edge_input_dim=None, short_cut=False, batch_norm=False, activation='relu', concat_hidden=False, readout='sum')[source]

Graph Convolutional Network proposed in Semi-Supervised Classification with Graph Convolutional Networks.

Parameters
  • input_dim (int) – input dimension

  • hidden_dims (list of int) – hidden dimensions

  • edge_input_dim (int, optional) – dimension of edge features

  • short_cut (bool, optional) – use short cut or not

  • batch_norm (bool, optional) – apply batch normalization or not

  • activation (str or function, optional) – activation function

  • concat_hidden (bool, optional) – concat hidden representations from all layers as output

  • readout (str, optional) – readout function. Available functions are sum and mean.

forward(graph, input, all_loss=None, metric=None)[source]

Compute the node representations and the graph representation(s).

Parameters
  • graph (Graph) – \(n\) graph(s)

  • input (Tensor) – input node representations

  • all_loss (Tensor, optional) – if specified, add loss to this tensor

  • metric (dict, optional) – if specified, output metrics to this dict

Returns

node representations of shape \((|V|, d)\), graph representations of shape \((n, d)\)

Return type

dict with node_feature and graph_feature fields

GCN

alias of torchdrug.models.gcn.GraphConvolutionalNetwork

GAT

class GraphAttentionNetwork(input_dim, hidden_dims, edge_input_dim=None, num_head=1, negative_slope=0.2, short_cut=False, batch_norm=False, activation='relu', concat_hidden=False, readout='sum')[source]

Graph Attention Network proposed in Graph Attention Networks.

Parameters
  • input_dim (int) – input dimension

  • hidden_dims (list of int) – hidden dimensions

  • edge_input_dim (int, optional) – dimension of edge features

  • num_head (int, optional) – number of attention heads

  • negative_slope (float, optional) – negative slope of leaky relu activation

  • short_cut (bool, optional) – use short cut or not

  • batch_norm (bool, optional) – apply batch normalization or not

  • activation (str or function, optional) – activation function

  • concat_hidden (bool, optional) – concat hidden representations from all layers as output

  • readout (str, optional) – readout function. Available functions are sum and mean.

forward(graph, input, all_loss=None, metric=None)[source]

Compute the node representations and the graph representation(s).

Parameters
  • graph (Graph) – \(n\) graph(s)

  • input (Tensor) – input node representations

  • all_loss (Tensor, optional) – if specified, add loss to this tensor

  • metric (dict, optional) – if specified, output metrics to this dict

Returns

node representations of shape \((|V|, d)\), graph representations of shape \((n, d)\)

Return type

dict with node_feature and graph_feature fields

GAT

alias of torchdrug.models.gat.GraphAttentionNetwork

GIN

class GraphIsomorphismNetwork(input_dim=None, hidden_dims=None, edge_input_dim=None, num_mlp_layer=2, eps=0, learn_eps=False, short_cut=False, batch_norm=False, activation='relu', concat_hidden=False, readout='sum')[source]

Graph Ismorphism Network proposed in How Powerful are Graph Neural Networks?

Parameters
  • input_dim (int) – input dimension

  • hidden_dims (list of int) – hidden dimensions

  • edge_input_dim (int, optional) – dimension of edge features

  • num_mlp_layer (int, optional) – number of MLP layers

  • eps (int, optional) – initial epsilon

  • learn_eps (bool, optional) – learn epsilon or not

  • short_cut (bool, optional) – use short cut or not

  • batch_norm (bool, optional) – apply batch normalization or not

  • activation (str or function, optional) – activation function

  • concat_hidden (bool, optional) – concat hidden representations from all layers as output

  • readout (str, optional) – readout function. Available functions are sum and mean.

forward(graph, input, all_loss=None, metric=None)[source]

Compute the node representations and the graph representation(s).

Parameters
  • graph (Graph) – \(n\) graph(s)

  • input (Tensor) – input node representations

  • all_loss (Tensor, optional) – if specified, add loss to this tensor

  • metric (dict, optional) – if specified, output metrics to this dict

Returns

node representations of shape \((|V|, d)\), graph representations of shape \((n, d)\)

Return type

dict with node_feature and graph_feature fields

GIN

alias of torchdrug.models.gin.GraphIsomorphismNetwork

MPNN

class MessagePassingNeuralNetwork(input_dim, hidden_dim, edge_input_dim, num_layer=1, num_gru_layer=1, num_mlp_layer=2, num_s2s_step=3, short_cut=False, batch_norm=False, activation='relu', concat_hidden=False)[source]

Message Passing Neural Network proposed in Neural Message Passing for Quantum Chemistry.

This implements the enn-s2s variant in the original paper.

Parameters
  • input_dim (int) – input dimension

  • hidden_dim (int) – hidden dimension

  • edge_input_dim (int) – dimension of edge features

  • num_layer (int, optional) – number of hidden layers

  • num_gru_layer (int, optional) – number of GRU layers in each node update

  • num_mlp_layer (int, optional) – number of MLP layers in each message function

  • num_s2s_step (int, optional) – number of processing steps in set2set

  • short_cut (bool, optional) – use short cut or not

  • batch_norm (bool, optional) – apply batch normalization or not

  • activation (str or function, optional) – activation function

  • concat_hidden (bool, optional) – concat hidden representations from all layers as output

forward(graph, input, all_loss=None, metric=None)[source]

Compute the node representations and the graph representation(s).

Parameters
  • graph (Graph) – \(n\) graph(s)

  • input (Tensor) – input node representations

  • all_loss (Tensor, optional) – if specified, add loss to this tensor

  • metric (dict, optional) – if specified, output metrics to this dict

Returns

node representations of shape \((|V|, d)\), graph representations of shape \((n, d)\)

Return type

dict with node_feature and graph_feature fields

MPNN

alias of torchdrug.models.mpnn.MessagePassingNeuralNetwork

NFP

class NeuralFingerprint(input_dim, output_dim, hidden_dims, edge_input_dim=None, short_cut=False, batch_norm=False, activation='relu', concat_hidden=False, readout='sum')[source]

Neural Fingerprints from Convolutional Networks on Graphs for Learning Molecular Fingerprints.

Parameters
  • input_dim (int) – input dimension

  • output_dim (int) – fingerprint dimension

  • hidden_dims (list of int) – hidden dimensions

  • edge_input_dim (int, optional) – dimension of edge features

  • short_cut (bool, optional) – use short cut or not

  • batch_norm (bool, optional) – apply batch normalization or not

  • activation (str or function, optional) – activation function

  • concat_hidden (bool, optional) – concat hidden representations from all layers as output

  • readout (str, optional) – readout function. Available functions are sum and mean.

forward(graph, input, all_loss=None, metric=None)[source]

Compute the node representations and the graph representation(s).

Parameters
  • graph (Graph) – \(n\) graph(s)

  • input (Tensor) – input node representations

  • all_loss (Tensor, optional) – if specified, add loss to this tensor

  • metric (dict, optional) – if specified, output metrics to this dict

Returns

node representations of shape \((|V|, d)\), graph representations of shape \((n, d)\)

Return type

dict with node_feature and graph_feature fields

NFP

alias of torchdrug.models.neuralfp.NeuralFingerprint

RGCN

class RelationalGraphConvolutionalNetwork(input_dim, hidden_dims, num_relation, edge_input_dim=None, short_cut=False, batch_norm=False, activation='relu', concat_hidden=False, readout='sum')[source]

Relational Graph Convolutional Network proposed in Modeling Relational Data with Graph Convolutional Networks?.

Parameters
  • input_dim (int) – input dimension

  • hidden_dims (list of int) – hidden dimensions

  • num_relation (int) – number of relations

  • edge_input_dim (int, optional) – dimension of edge features

  • short_cut (bool, optional) – use short cut or not

  • batch_norm (bool, optional) – apply batch normalization or not

  • activation (str or function, optional) – activation function

  • concat_hidden (bool, optional) – concat hidden representations from all layers as output

  • readout (str, optional) – readout function. Available functions are sum and mean.

forward(graph, input, all_loss=None, metric=None)[source]

Compute the node representations and the graph representation(s).

Require the graph(s) to have the same number of relations as this module.

Parameters
  • graph (Graph) – \(n\) graph(s)

  • input (Tensor) – input node representations

  • all_loss (Tensor, optional) – if specified, add loss to this tensor

  • metric (dict, optional) – if specified, output metrics to this dict

Returns

node representations of shape \((|V|, d)\), graph representations of shape \((n, d)\)

Return type

dict with node_feature and graph_feature fields

RGCN

alias of torchdrug.models.gcn.RelationalGraphConvolutionalNetwork

SchNet

class SchNet(input_dim, hidden_dims, edge_input_dim=None, cutoff=5, num_gaussian=100, short_cut=True, batch_norm=False, activation='shifted_softplus', concat_hidden=False)[source]

SchNet from SchNet: A continuous-filter convolutional neural network for modeling quantum interactions.

Parameters
  • input_dim (int) – input dimension

  • hidden_dims (list of int) – hidden dimensions

  • edge_input_dim (int, optional) – dimension of edge features

  • cutoff (float, optional) – maximal scale for RBF kernels

  • num_gaussian (int, optional) – number of RBF kernels

  • short_cut (bool, optional) – use short cut or not

  • batch_norm (bool, optional) – apply batch normalization or not

  • activation (str or function, optional) – activation function

  • concat_hidden (bool, optional) – concat hidden representations from all layers as output

forward(graph, input, all_loss=None, metric=None)[source]

Compute the node representations and the graph representation(s).

Require the graph(s) to have node attribute node_position.

Parameters
  • graph (Graph) – \(n\) graph(s)

  • input (Tensor) – input node representations

  • all_loss (Tensor, optional) – if specified, add loss to this tensor

  • metric (dict, optional) – if specified, output metrics to this dict

Returns

node representations of shape \((|V|, d)\), graph representations of shape \((n, d)\)

Return type

dict with node_feature and graph_feature fields

Normalizing Flow

GraphAutoregressiveFlow

class GraphAutoregressiveFlow(model, prior, use_edge=False, num_layer=6, num_mlp_layer=2, dequantization_noise=0.9)[source]

Graph autoregressive flow proposed in GraphAF: a Flow-based Autoregressive Model for Molecular Graph Generation.

Parameters
  • model (nn.Module) – graph representation model

  • prior (nn.Module) – prior distribution

  • use_edge (bool, optional) – use edge or not

  • num_flow_layer (int, optional) – number of conditional flow layers

  • num_mlp_layer (int, optional) – number of MLP layers in each conditional flow

  • dequantization_noise (float, optional) – scale of dequantization noise

forward(graph, input, edge=None, all_loss=None, metric=None)[source]

Compute the log-likelihood for the input given the graph(s).

Parameters
  • graph (Graph) – \(n\) graph(s)

  • input (Tensor) – discrete data of shape \((n,)\)

  • edge (Tensor, optional) – edge list of shape \((n, 2)\). If specified, additionally condition on the edge for each input.

  • all_loss (Tensor, optional) – if specified, add loss to this tensor

  • metric (dict, optional) – if specified, output metrics to this dict

sample(graph, edge=None, all_loss=None, metric=None)[source]

Sample discrete data based on the given graph(s).

Parameters
  • graph (Graph) – \(n\) graph(s)

  • edge (Tensor, optional) – edge list of shape \((n, 2)\). If specified, additionally condition on the edge for each input.

  • all_loss (Tensor, optional) – if specified, add loss to this tensor

  • metric (dict, optional) – if specified, output metrics to this dict

GraphAF

alias of torchdrug.models.flow.GraphAutoregressiveFlow

Self-supervised Models

InfoGraph

class InfoGraph(model, num_mlp_layer=2, activation='relu', loss_weight=1, separate_model=False)[source]

InfoGraph proposed in InfoGraph: Unsupervised and Semi-supervised Graph-Level Representation Learning via Mutual Information Maximization.

Parameters
  • model (nn.Module) – node & graph representation model

  • num_mlp_layer (int, optional) – number of MLP layers in mutual information estimators

  • activation (str or function, optional) – activation function

  • loss_weight (float, optional) – weight of both unsupervised & transfer losses

  • separate_model (bool, optional) – separate supervised and unsupervised encoders. If true, the unsupervised loss will be applied on a separate encoder, and a transfer loss is applied between the two encoders.

forward(graph, input, all_loss=None, metric=None)[source]

Compute the node representations and the graph representation(s). Add the mutual information between graph and nodes to the loss.

Parameters
  • graph (Graph) – \(n\) graph(s)

  • input (Tensor) – input node representations

  • all_loss (Tensor, optional) – if specified, add loss to this tensor

  • metric (dict, optional) – if specified, output metrics to this dict

Returns

node representations of shape \((|V|, d)\), graph representations of shape \((n, d)\)

Return type

dict with node_feature and graph_feature fields