torchdrug.models#

Knowledge Graph Reasoning Models#

TransE#

class TransE(num_entity, num_relation, embedding_dim, max_score=12)[source]#

TransE embedding proposed in Translating Embeddings for Modeling Multi-relational Data.

Parameters
  • num_entity (int) – number of entities

  • num_relation (int) – number of relations

  • embedding_dim (int) – dimension of embeddings

  • max_score (float, optional) – maximal score for triplets

forward(graph, h_index, t_index, r_index, all_loss=None, metric=None)[source]#

Compute the score for each triplet.

Parameters
  • graph (Graph) – fact graph

  • h_index (Tensor) – indexes of head entities

  • t_index (Tensor) – indexes of tail entities

  • r_index (Tensor) – indexes of relations

  • all_loss (Tensor, optional) – if specified, add loss to this tensor

  • metric (dict, optional) – if specified, output metrics to this dict

DistMult#

class DistMult(num_entity, num_relation, embedding_dim, l3_regularization=0)[source]#

DistMult embedding proposed in Embedding Entities and Relations for Learning and Inference in Knowledge Bases.

Parameters
  • num_entity (int) – number of entities

  • num_relation (int) – number of relations

  • embedding_dim (int) – dimension of embeddings

  • l3_regularization (float, optional) – weight for l3 regularization

forward(graph, h_index, t_index, r_index, all_loss=None, metric=None)[source]#

Compute the score for each triplet.

Parameters
  • graph (Graph) – fact graph

  • h_index (Tensor) – indexes of head entities

  • t_index (Tensor) – indexes of tail entities

  • r_index (Tensor) – indexes of relations

  • all_loss (Tensor, optional) – if specified, add loss to this tensor

  • metric (dict, optional) – if specified, output metrics to this dict

ComplEx#

class ComplEx(num_entity, num_relation, embedding_dim, l3_regularization=0)[source]#

ComplEx embedding proposed in Complex Embeddings for Simple Link Prediction.

Parameters
  • num_entity (int) – number of entities

  • num_relation (int) – number of relations

  • embedding_dim (int) – dimension of embeddings

  • l3_regularization (float, optional) – weight for l3 regularization

forward(graph, h_index, t_index, r_index, all_loss=None, metric=None)[source]#

Compute the score for triplets.

Parameters
  • graph (Graph) – fact graph

  • h_index (Tensor) – indexes of head entities

  • t_index (Tensor) – indexes of tail entities

  • r_index (Tensor) – indexes of relations

  • all_loss (Tensor, optional) – if specified, add loss to this tensor

  • metric (dict, optional) – if specified, output metrics to this dict

SimplE#

class SimplE(num_entity, num_relation, embedding_dim, l3_regularization=0)[source]#

SimplE embedding proposed in SimplE Embedding for Link Prediction in Knowledge Graphs.

Parameters
  • num_entity (int) – number of entities

  • num_relation (int) – number of relations

  • embedding_dim (int) – dimension of embeddings

  • l3_regularization (float, optional) – maximal score for triplets

forward(graph, h_index, t_index, r_index, all_loss=None, metric=None)[source]#

Compute the score for each triplet.

Parameters
  • graph (Graph) – fact graph

  • h_index (Tensor) – indexes of head entities

  • t_index (Tensor) – indexes of tail entities

  • r_index (Tensor) – indexes of relations

  • all_loss (Tensor, optional) – if specified, add loss to this tensor

  • metric (dict, optional) – if specified, output metrics to this dict

RotatE#

class RotatE(num_entity, num_relation, embedding_dim, max_score=12)[source]#

RotatE embedding proposed in RotatE: Knowledge Graph Embedding by Relational Rotation in Complex Space.

Parameters
  • num_entity (int) – number of entities

  • num_relation (int) – number of relations

  • embedding_dim (int) – dimension of embeddings

  • max_score (float, optional) – maximal score for triplets

forward(graph, h_index, t_index, r_index, all_loss=None, metric=None)[source]#

Compute the score for each triplet.

Parameters
  • graph (Graph) – fact graph

  • h_index (Tensor) – indexes of head entities

  • t_index (Tensor) – indexes of tail entities

  • r_index (Tensor) – indexes of relations

  • all_loss (Tensor, optional) – if specified, add loss to this tensor

  • metric (dict, optional) – if specified, output metrics to this dict

NeuralLP#

class NeuralLogicProgramming(num_relation, hidden_dim, num_step, num_lstm_layer=1)[source]#

Neural Logic Programming proposed in Differentiable Learning of Logical Rules for Knowledge Base Reasoning.

Parameters
  • num_relation (int) – number of relations

  • hidden_dim (int) – dimension of hidden units in LSTM

  • num_step (int) – number of recurrent steps

  • num_lstm_layer (int, optional) – number of LSTM layers

forward(graph, h_index, t_index, r_index, all_loss=None, metric=None)[source]#

Compute the score for triplets.

Parameters
  • graph (Tensor) – fact graph

  • h_index (Tensor) – indexes of head entities

  • t_index (Tensor) – indexes of tail entities

  • r_index (Tensor) – indexes of relations

  • all_loss (Tensor, optional) – if specified, add loss to this tensor

  • metric (dict, optional) – if specified, output metrics to this dict

NeuralLP#

alias of torchdrug.models.neurallp.NeuralLogicProgramming

KBGAT#

class KnowledgeBaseGraphAttentionNetwork(num_entity, num_relation, embedding_dim, hidden_dims, max_score=12, edge_input_dim=None, num_head=1, negative_slope=0.2, short_cut=False, batch_norm=False, activation='relu', concat_hidden=False, readout='sum')[source]#

Knowledge Base Graph Attention Network proposed in Learning Attention-based Embeddings for Relation Prediction in Knowledge Graphs.

Parameters
  • num_entity (int) – number of entities

  • num_relation (int) – number of relations

  • embedding_dim (int) – dimension of embeddings

  • hidden_dims (list of int) – hidden dimensions

  • max_score (float, optional) – maximal score for triplets

  • edge_input_dim (int, optional) – dimension of edge features

  • num_head (int, optional) – number of attention heads

  • negative_slope (float, optional) – negative slope of leaky relu activation

  • short_cut (bool, optional) – use short cut or not

  • batch_norm (bool, optional) – apply batch normalization or not

  • activation (str or function, optional) – activation function

  • concat_hidden (bool, optional) – concat hidden representations from all layers as output

  • readout (str, optional) – readout function. Available functions are sum and mean.

forward(graph, h_index, t_index, r_index, all_loss=None, metric=None)[source]#

Compute the score for triplets.

Parameters
  • graph (Graph) – fact graph

  • h_index (Tensor) – indexes of head entities

  • t_index (Tensor) – indexes of tail entities

  • r_index (Tensor) – indexes of relations

  • all_loss (Tensor, optional) – if specified, add loss to this tensor

  • metric (dict, optional) – if specified, output metrics to this dict

KBGAT#

alias of torchdrug.models.kbgat.KnowledgeBaseGraphAttentionNetwork

Graph Neural Networks#

ChebNet#

class ChebyshevConvolutionalNetwork(input_dim, hidden_dims, edge_input_dim=None, k=1, short_cut=False, batch_norm=False, activation='relu', concat_hidden=False, readout='sum')[source]#

Chebyshev convolutional network proposed in Convolutional Neural Networks on Graphs with Fast Localized Spectral Filtering.

Parameters
  • input_dim (int) – input dimension

  • hidden_dims (list of int) – hidden dimensions

  • edge_input_dim (int, optional) – dimension of edge features

  • k (int, optional) – number of Chebyshev polynomials

  • short_cut (bool, optional) – use short cut or not

  • batch_norm (bool, optional) – apply batch normalization or not

  • activation (str or function, optional) – activation function

  • concat_hidden (bool, optional) – concat hidden representations from all layers as output

  • readout (str, optional) – readout function. Available functions are sum and mean.

forward(graph, input, all_loss=None, metric=None)[source]#

Compute the node representations and the graph representation(s).

Parameters
  • graph (Graph) – \(n\) graph(s)

  • input (Tensor) – input node representations

  • all_loss (Tensor, optional) – if specified, add loss to this tensor

  • metric (dict, optional) – if specified, output metrics to this dict

Returns

node representations of shape \((|V|, d)\), graph representations of shape \((n, d)\)

Return type

dict with node_feature and graph_feature fields

ChebNet#

alias of torchdrug.models.chebnet.ChebyshevConvolutionalNetwork

GCN#

class GraphConvolutionalNetwork(input_dim, hidden_dims, edge_input_dim=None, short_cut=False, batch_norm=False, activation='relu', concat_hidden=False, readout='sum')[source]#

Graph Convolutional Network proposed in Semi-Supervised Classification with Graph Convolutional Networks.

Parameters
  • input_dim (int) – input dimension

  • hidden_dims (list of int) – hidden dimensions

  • edge_input_dim (int, optional) – dimension of edge features

  • short_cut (bool, optional) – use short cut or not

  • batch_norm (bool, optional) – apply batch normalization or not

  • activation (str or function, optional) – activation function

  • concat_hidden (bool, optional) – concat hidden representations from all layers as output

  • readout (str, optional) – readout function. Available functions are sum and mean.

forward(graph, input, all_loss=None, metric=None)[source]#

Compute the node representations and the graph representation(s).

Parameters
  • graph (Graph) – \(n\) graph(s)

  • input (Tensor) – input node representations

  • all_loss (Tensor, optional) – if specified, add loss to this tensor

  • metric (dict, optional) – if specified, output metrics to this dict

Returns

node representations of shape \((|V|, d)\), graph representations of shape \((n, d)\)

Return type

dict with node_feature and graph_feature fields

GCN#

alias of torchdrug.models.gcn.GraphConvolutionalNetwork

GAT#

class GraphAttentionNetwork(input_dim, hidden_dims, edge_input_dim=None, num_head=1, negative_slope=0.2, short_cut=False, batch_norm=False, activation='relu', concat_hidden=False, readout='sum')[source]#

Graph Attention Network proposed in Graph Attention Networks.

Parameters
  • input_dim (int) – input dimension

  • hidden_dims (list of int) – hidden dimensions

  • edge_input_dim (int, optional) – dimension of edge features

  • num_head (int, optional) – number of attention heads

  • negative_slope (float, optional) – negative slope of leaky relu activation

  • short_cut (bool, optional) – use short cut or not

  • batch_norm (bool, optional) – apply batch normalization or not

  • activation (str or function, optional) – activation function

  • concat_hidden (bool, optional) – concat hidden representations from all layers as output

  • readout (str, optional) – readout function. Available functions are sum and mean.

forward(graph, input, all_loss=None, metric=None)[source]#

Compute the node representations and the graph representation(s).

Parameters
  • graph (Graph) – \(n\) graph(s)

  • input (Tensor) – input node representations

  • all_loss (Tensor, optional) – if specified, add loss to this tensor

  • metric (dict, optional) – if specified, output metrics to this dict

Returns

node representations of shape \((|V|, d)\), graph representations of shape \((n, d)\)

Return type

dict with node_feature and graph_feature fields

GAT#

alias of torchdrug.models.gat.GraphAttentionNetwork

GIN#

class GraphIsomorphismNetwork(input_dim, hidden_dims, edge_input_dim=None, num_mlp_layer=2, eps=0, learn_eps=False, short_cut=False, batch_norm=False, activation='relu', concat_hidden=False, readout='sum')[source]#

Graph Ismorphism Network proposed in How Powerful are Graph Neural Networks?

Parameters
  • input_dim (int) – input dimension

  • hidden_dims (list of int) – hidden dimensions

  • edge_input_dim (int, optional) – dimension of edge features

  • num_mlp_layer (int, optional) – number of MLP layers

  • eps (int, optional) – initial epsilon

  • learn_eps (bool, optional) – learn epsilon or not

  • short_cut (bool, optional) – use short cut or not

  • batch_norm (bool, optional) – apply batch normalization or not

  • activation (str or function, optional) – activation function

  • concat_hidden (bool, optional) – concat hidden representations from all layers as output

  • readout (str, optional) – readout function. Available functions are sum and mean.

forward(graph, input, all_loss=None, metric=None)[source]#

Compute the node representations and the graph representation(s).

Parameters
  • graph (Graph) – \(n\) graph(s)

  • input (Tensor) – input node representations

  • all_loss (Tensor, optional) – if specified, add loss to this tensor

  • metric (dict, optional) – if specified, output metrics to this dict

Returns

node representations of shape \((|V|, d)\), graph representations of shape \((n, d)\)

Return type

dict with node_feature and graph_feature fields

GIN#

alias of torchdrug.models.gin.GraphIsomorphismNetwork

MPNN#

class MessagePassingNeuralNetwork(input_dim, hidden_dim, edge_input_dim, num_layer=1, num_gru_layer=1, num_mlp_layer=2, num_s2s_step=3, short_cut=False, batch_norm=False, activation='relu', concat_hidden=False)[source]#

Message Passing Neural Network proposed in Neural Message Passing for Quantum Chemistry.

This implements the enn-s2s variant in the original paper.

Parameters
  • input_dim (int) – input dimension

  • hidden_dim (int) – hidden dimension

  • edge_input_dim (int) – dimension of edge features

  • num_layer (int, optional) – number of hidden layers

  • num_gru_layer (int, optional) – number of GRU layers in each node update

  • num_mlp_layer (int, optional) – number of MLP layers in each message function

  • num_s2s_step (int, optional) – number of processing steps in set2set

  • short_cut (bool, optional) – use short cut or not

  • batch_norm (bool, optional) – apply batch normalization or not

  • activation (str or function, optional) – activation function

  • concat_hidden (bool, optional) – concat hidden representations from all layers as output

forward(graph, input, all_loss=None, metric=None)[source]#

Compute the node representations and the graph representation(s).

Parameters
  • graph (Graph) – \(n\) graph(s)

  • input (Tensor) – input node representations

  • all_loss (Tensor, optional) – if specified, add loss to this tensor

  • metric (dict, optional) – if specified, output metrics to this dict

Returns

node representations of shape \((|V|, d)\), graph representations of shape \((n, d)\)

Return type

dict with node_feature and graph_feature fields

MPNN#

alias of torchdrug.models.mpnn.MessagePassingNeuralNetwork

NFP#

class NeuralFingerprint(input_dim, output_dim, hidden_dims, edge_input_dim=None, short_cut=False, batch_norm=False, activation='relu', concat_hidden=False, readout='sum')[source]#

Neural Fingerprints from Convolutional Networks on Graphs for Learning Molecular Fingerprints.

Parameters
  • input_dim (int) – input dimension

  • output_dim (int) – fingerprint dimension

  • hidden_dims (list of int) – hidden dimensions

  • edge_input_dim (int, optional) – dimension of edge features

  • short_cut (bool, optional) – use short cut or not

  • batch_norm (bool, optional) – apply batch normalization or not

  • activation (str or function, optional) – activation function

  • concat_hidden (bool, optional) – concat hidden representations from all layers as output

  • readout (str, optional) – readout function. Available functions are sum and mean.

forward(graph, input, all_loss=None, metric=None)[source]#

Compute the node representations and the graph representation(s).

Parameters
  • graph (Graph) – \(n\) graph(s)

  • input (Tensor) – input node representations

  • all_loss (Tensor, optional) – if specified, add loss to this tensor

  • metric (dict, optional) – if specified, output metrics to this dict

Returns

node representations of shape \((|V|, d)\), graph representations of shape \((n, d)\)

Return type

dict with node_feature and graph_feature fields

NFP#

alias of torchdrug.models.neuralfp.NeuralFingerprint

RGCN#

class RelationalGraphConvolutionalNetwork(input_dim, hidden_dims, num_relation, edge_input_dim=None, short_cut=False, batch_norm=False, activation='relu', concat_hidden=False, readout='sum')[source]#

Relational Graph Convolutional Network proposed in Modeling Relational Data with Graph Convolutional Networks?.

Parameters
  • input_dim (int) – input dimension

  • hidden_dims (list of int) – hidden dimensions

  • num_relation (int) – number of relations

  • edge_input_dim (int, optional) – dimension of edge features

  • short_cut (bool, optional) – use short cut or not

  • batch_norm (bool, optional) – apply batch normalization or not

  • activation (str or function, optional) – activation function

  • concat_hidden (bool, optional) – concat hidden representations from all layers as output

  • readout (str, optional) – readout function. Available functions are sum and mean.

forward(graph, input, all_loss=None, metric=None)[source]#

Compute the node representations and the graph representation(s).

Require the graph(s) to have the same number of relations as this module.

Parameters
  • graph (Graph) – \(n\) graph(s)

  • input (Tensor) – input node representations

  • all_loss (Tensor, optional) – if specified, add loss to this tensor

  • metric (dict, optional) – if specified, output metrics to this dict

Returns

node representations of shape \((|V|, d)\), graph representations of shape \((n, d)\)

Return type

dict with node_feature and graph_feature fields

RGCN#

alias of torchdrug.models.gcn.RelationalGraphConvolutionalNetwork

GearNet#

class GeometryAwareRelationalGraphNeuralNetwork(input_dim, hidden_dims, num_relation, edge_input_dim=None, num_angle_bin=None, short_cut=False, batch_norm=False, activation='relu', concat_hidden=False, readout='sum')[source]#

Geometry Aware Relational Graph Neural Network proposed in Protein Representation Learning by Geometric Structure Pretraining.

Parameters
  • input_dim (int) – input dimension

  • hidden_dims (list of int) – hidden dimensions

  • num_relation (int) – number of relations

  • edge_input_dim (int, optional) – dimension of edge features

  • num_angle_bin (int, optional) – number of bins to discretize angles between edges. The discretized angles are used as relations in edge message passing. If not provided, edge message passing is disabled.

  • short_cut (bool, optional) – use short cut or not

  • batch_norm (bool, optional) – apply batch normalization or not

  • activation (str or function, optional) – activation function

  • concat_hidden (bool, optional) – concat hidden representations from all layers as output

  • readout (str, optional) – readout function. Available functions are sum and mean.

forward(graph, input, all_loss=None, metric=None)[source]#

Compute the node representations and the graph representation(s).

Parameters
  • graph (Graph) – \(n\) graph(s)

  • input (Tensor) – input node representations

  • all_loss (Tensor, optional) – if specified, add loss to this tensor

  • metric (dict, optional) – if specified, output metrics to this dict

Returns

node representations of shape \((|V|, d)\), graph representations of shape \((n, d)\)

Return type

dict with node_feature and graph_feature fields

GearNet#

alias of torchdrug.models.gearnet.GeometryAwareRelationalGraphNeuralNetwork

SchNet#

class SchNet(input_dim, hidden_dims, edge_input_dim=None, cutoff=5, num_gaussian=100, short_cut=True, batch_norm=False, activation='shifted_softplus', concat_hidden=False)[source]#

SchNet from SchNet: A continuous-filter convolutional neural network for modeling quantum interactions.

Parameters
  • input_dim (int) – input dimension

  • hidden_dims (list of int) – hidden dimensions

  • edge_input_dim (int, optional) – dimension of edge features

  • cutoff (float, optional) – maximal scale for RBF kernels

  • num_gaussian (int, optional) – number of RBF kernels

  • short_cut (bool, optional) – use short cut or not

  • batch_norm (bool, optional) – apply batch normalization or not

  • activation (str or function, optional) – activation function

  • concat_hidden (bool, optional) – concat hidden representations from all layers as output

forward(graph, input, all_loss=None, metric=None)[source]#

Compute the node representations and the graph representation(s).

Require the graph(s) to have node attribute node_position.

Parameters
  • graph (Graph) – \(n\) graph(s)

  • input (Tensor) – input node representations

  • all_loss (Tensor, optional) – if specified, add loss to this tensor

  • metric (dict, optional) – if specified, output metrics to this dict

Returns

node representations of shape \((|V|, d)\), graph representations of shape \((n, d)\)

Return type

dict with node_feature and graph_feature fields

Protein Sequence Encoders#

ESM#

class EvolutionaryScaleModeling(path, model='ESM-1b', readout='mean')[source]#

The protein language model, Evolutionary Scale Modeling (ESM) proposed in Biological Structure and Function Emerge from Scaling Unsupervised Learning to 250 Million Protein Sequences.

Parameters
  • path (str) – path to store ESM model weights

  • model (str, optional) – model name. Available model names are ESM-1b, ESM-1v and ESM-1b-regression.

  • readout (str, optional) – readout function. Available functions are sum and mean.

forward(graph, input, all_loss=None, metric=None)[source]#

Compute the residue representations and the graph representation(s).

Parameters
  • graph (Protein) – \(n\) protein(s)

  • input (Tensor) – input node representations

  • all_loss (Tensor, optional) – if specified, add loss to this tensor

  • metric (dict, optional) – if specified, output metrics to this dict

Returns

residue representations of shape \((|V_{res}|, d)\), graph representations of shape \((n, d)\)

Return type

dict with residue_feature and graph_feature fields

ProteinCNN#

class ProteinConvolutionalNetwork(input_dim, hidden_dims, kernel_size=3, stride=1, padding=1, activation='relu', short_cut=False, concat_hidden=False, readout='max')[source]#

Protein Shallow CNN proposed in Is Transfer Learning Necessary for Protein Landscape Prediction?.

Parameters
  • input_dim (int) – input dimension

  • hidden_dims (list of int) – hidden dimensions

  • kernel_size (int, optional) – size of convolutional kernel

  • stride (int, optional) – stride of convolution

  • padding (int, optional) – padding added to both sides of the input

  • activation (str or function, optional) – activation function

  • short_cut (bool, optional) – use short cut or not

  • concat_hidden (bool, optional) – concat hidden representations from all layers as output

  • readout (str, optional) – readout function. Available functions are sum, mean, max and attention.

forward(graph, input, all_loss=None, metric=None)[source]#

Compute the residue representations and the graph representation(s).

Parameters
  • graph (Protein) – \(n\) protein(s)

  • input (Tensor) – input node representations

  • all_loss (Tensor, optional) – if specified, add loss to this tensor

  • metric (dict, optional) – if specified, output metrics to this dict

Returns

residue representations of shape \((|V_{res}|, d)\), graph representations of shape \((n, d)\)

Return type

dict with residue_feature and graph_feature fields

ProteinResNet#

class ProteinResNet(input_dim, hidden_dims, kernel_size=3, stride=1, padding=1, activation='gelu', short_cut=False, concat_hidden=False, layer_norm=False, dropout=0, readout='attention')[source]#

Protein ResNet proposed in Evaluating Protein Transfer Learning with TAPE.

Parameters
  • input_dim (int) – input dimension

  • hidden_dims (list of int) – hidden dimensions

  • kernel_size (int, optional) – size of convolutional kernel

  • stride (int, optional) – stride of convolution

  • padding (int, optional) – padding added to both sides of the input

  • activation (str or function, optional) – activation function

  • short_cut (bool, optional) – use short cut or not

  • concat_hidden (bool, optional) – concat hidden representations from all layers as output

  • layer_norm (bool, optional) – apply layer normalization or not

  • dropout (float, optional) – dropout ratio of input features

  • readout (str, optional) – readout function. Available functions are sum, mean and attention.

forward(graph, input, all_loss=None, metric=None)[source]#

Compute the residue representations and the graph representation(s).

Parameters
  • graph (Protein) – \(n\) protein(s)

  • input (Tensor) – input node representations

  • all_loss (Tensor, optional) – if specified, add loss to this tensor

  • metric (dict, optional) – if specified, output metrics to this dict

Returns

residue representations of shape \((|V_{res}|, d)\), graph representations of shape \((n, d)\)

Return type

dict with residue_feature and graph_feature fields

ProteinLSTM#

class ProteinLSTM(input_dim, hidden_dim, num_layers, activation='tanh', layer_norm=False, dropout=0)[source]#

Protein LSTM proposed in Evaluating Protein Transfer Learning with TAPE.

Parameters
  • input_dim (int) – input dimension

  • hidden_dim (int, optional) – hidden dimension

  • num_layers (int, optional) – number of LSTM layers

  • activation (str or function, optional) – activation function

  • layer_norm (bool, optional) – apply layer normalization or not

  • dropout (float, optional) – dropout ratio of input features

forward(graph, input, all_loss=None, metric=None)[source]#

Compute the residue representations and the graph representation(s).

Parameters
  • graph (Protein) – \(n\) protein(s)

  • input (Tensor) – input node representations

  • all_loss (Tensor, optional) – if specified, add loss to this tensor

  • metric (dict, optional) – if specified, output metrics to this dict

Returns

residue representations of shape \((|V_{res}|, d)\), graph representations of shape \((n, d)\)

Return type

dict with residue_feature and graph_feature fields

ProteinBERT#

class ProteinBERT(input_dim, hidden_dim=768, num_layers=12, num_heads=12, intermediate_dim=3072, activation='gelu', hidden_dropout=0.1, attention_dropout=0.1, max_position=8192)[source]#

Protein BERT proposed in Evaluating Protein Transfer Learning with TAPE.

Parameters
  • input_dim (int) – input dimension

  • hidden_dim (int, optional) – hidden dimension

  • num_layers (int, optional) – number of Transformer blocks

  • num_heads (int, optional) – number of attention heads

  • intermediate_dim (int, optional) – intermediate hidden dimension of Transformer block

  • activation (str or function, optional) – activation function

  • hidden_dropout (float, optional) – dropout ratio of hidden features

  • attention_dropout (float, optional) – dropout ratio of attention maps

  • max_position (int, optional) – maximum number of positions

forward(graph, input, all_loss=None, metric=None)[source]#

Compute the residue representations and the graph representation(s).

Parameters
  • graph (Protein) – \(n\) protein(s)

  • input (Tensor) – input node representations

  • all_loss (Tensor, optional) – if specified, add loss to this tensor

  • metric (dict, optional) – if specified, output metrics to this dict

Returns

residue representations of shape \((|V_{res}|, d)\), graph representations of shape \((n, d)\)

Return type

dict with residue_feature and graph_feature fields

Statistic Feature Engineering#

class Statistic(type='DDE', hidden_dims=(512,))[source]#

The statistic feature engineering for protein sequence proposed in Harnessing Computational Biology for Exact Linear B-cell Epitope Prediction.

Parameters
  • type (str, optional) – statistic feature. Available feature is DDE.

  • hidden_dims (list of int, optional) – hidden dimensions

forward(graph, input, all_loss=None, metric=None)[source]#

Compute the residue representations and the graph representation(s).

Parameters
  • graph (Protein) – \(n\) protein(s)

  • input (Tensor) – input node representations

  • all_loss (Tensor, optional) – if specified, add loss to this tensor

  • metric (dict, optional) – if specified, output metrics to this dict

Returns

graph representations of shape \((n, d)\)

Return type

dict with graph_feature field

Physicochemical Feature Engineering#

class Physicochemical(path, type='moran', nlag=30, hidden_dims=(512,))[source]#

The physicochemical feature engineering for protein sequence proposed in Prediction of Membrane Protein Types based on the Hydrophobic Index of Amino Acids.

Parameters
  • path (str) – path to store feature file

  • type (str, optional) – physicochemical feature. Available features are moran, geary and nmbroto.

  • nlag (int, optional) – maximum position interval to compute features

  • hidden_dims (list of int, optional) – hidden dimensions

forward(graph, input, all_loss=None, metric=None)[source]#

Compute the residue representations and the graph representation(s).

Parameters
  • graph (Protein) – \(n\) protein(s)

  • input (Tensor) – input node representations

  • all_loss (Tensor, optional) – if specified, add loss to this tensor

  • metric (dict, optional) – if specified, output metrics to this dict

Returns

graph representations of shape \((n, d)\)

Return type

dict with graph_feature field

Normalizing Flows#

GraphAutoregressiveFlow#

class GraphAutoregressiveFlow(model, prior, use_edge=False, num_layer=6, num_mlp_layer=2, dequantization_noise=0.9)[source]#

Graph autoregressive flow proposed in GraphAF: a Flow-based Autoregressive Model for Molecular Graph Generation.

Parameters
  • model (nn.Module) – graph representation model

  • prior (nn.Module) – prior distribution

  • use_edge (bool, optional) – use edge or not

  • num_flow_layer (int, optional) – number of conditional flow layers

  • num_mlp_layer (int, optional) – number of MLP layers in each conditional flow

  • dequantization_noise (float, optional) – scale of dequantization noise

forward(graph, input, edge=None, all_loss=None, metric=None)[source]#

Compute the log-likelihood for the input given the graph(s).

Parameters
  • graph (Graph) – \(n\) graph(s)

  • input (Tensor) – discrete data of shape \((n,)\)

  • edge (Tensor, optional) – edge list of shape \((n, 2)\). If specified, additionally condition on the edge for each input.

  • all_loss (Tensor, optional) – if specified, add loss to this tensor

  • metric (dict, optional) – if specified, output metrics to this dict

sample(graph, edge=None, all_loss=None, metric=None)[source]#

Sample discrete data based on the given graph(s).

Parameters
  • graph (Graph) – \(n\) graph(s)

  • edge (Tensor, optional) – edge list of shape \((n, 2)\). If specified, additionally condition on the edge for each input.

  • all_loss (Tensor, optional) – if specified, add loss to this tensor

  • metric (dict, optional) – if specified, output metrics to this dict

GraphAF#

alias of torchdrug.models.flow.GraphAutoregressiveFlow

Self-supervised Models#

InfoGraph#

class InfoGraph(model, num_mlp_layer=2, activation='relu', loss_weight=1, separate_model=False)[source]#

InfoGraph proposed in InfoGraph: Unsupervised and Semi-supervised Graph-Level Representation Learning via Mutual Information Maximization.

Parameters
  • model (nn.Module) – node & graph representation model

  • num_mlp_layer (int, optional) – number of MLP layers in mutual information estimators

  • activation (str or function, optional) – activation function

  • loss_weight (float, optional) – weight of both unsupervised & transfer losses

  • separate_model (bool, optional) – separate supervised and unsupervised encoders. If true, the unsupervised loss will be applied on a separate encoder, and a transfer loss is applied between the two encoders.

forward(graph, input, all_loss=None, metric=None)[source]#

Compute the node representations and the graph representation(s). Add the mutual information between graph and nodes to the loss.

Parameters
  • graph (Graph) – \(n\) graph(s)

  • input (Tensor) – input node representations

  • all_loss (Tensor, optional) – if specified, add loss to this tensor

  • metric (dict, optional) – if specified, output metrics to this dict

Returns

node representations of shape \((|V|, d)\), graph representations of shape \((n, d)\)

Return type

dict with node_feature and graph_feature fields

MultiviewContrast#

class MultiviewContrast(model, crop_funcs, noise_funcs, num_mlp_layer=2, activation='relu', tau=0.07)[source]#

Multiview Contrast proposed in Protein Representation Learning by Geometric Structure Pretraining.

Parameters
  • model (nn.Module) – node & graph representation model

  • crop_funcs (list of nn.Module) – list of cropping functions

  • noise_funcs (list of nn.Module) – list of noise functions

  • num_mlp_layer (int, optional) – number of MLP layers in mutual information estimators

  • activation (str or function, optional) – activation function

  • tau (float, optional) – temperature in InfoNCE loss

forward(graph, input, all_loss=None, metric=None)[source]#

Compute the graph representations of two augmented views. Each view is generated by randomly picking a cropping function and a noise function. Add the mutual information between two augmented views to the loss.

Parameters
  • graph (Graph) – \(n\) graph(s)

  • input (Tensor) – input node representations

  • all_loss (Tensor, optional) – if specified, add loss to this tensor

  • metric (dict, optional) – if specified, output metrics to this dict

Returns

node representations of shape \((|V|, d)\), graph representations of shape \((n, d)\) for two augmented views respectively

Return type

dict with node_feature1, node_feature2, graph_feature1 and graph_feature2 fields