torchdrug.tasks#

Property Prediction Tasks#

PropertyPrediction#

class PropertyPrediction(model, task=(), criterion='mse', metric=('mae', 'rmse'), num_mlp_layer=1, normalization=True, num_class=None, graph_construction_model=None, verbose=0)[source]#

Graph / molecule / protein property prediction task.

This class is also compatible with semi-supervised learning.

Parameters
  • model (nn.Module) – graph representation model

  • task (str, list or dict, optional) – training task(s). For dict, the keys are tasks and the values are the corresponding weights.

  • criterion (str, list or dict, optional) – training criterion(s). For dict, the keys are criterions and the values are the corresponding weights. Available criterions are mse and bce.

  • metric (str or list of str, optional) – metric(s). Available metrics are mae, rmse, auprc and auroc.

  • num_mlp_layer (int, optional) – number of layers in mlp prediction head

  • normalization (bool, optional) – whether to normalize the target

  • num_class (int, optional) – number of classes

  • graph_construction_model (nn.Module, optional) – graph construction model

  • verbose (int, optional) – output verbose level

preprocess(train_set, valid_set, test_set)[source]#

Compute the mean and derivation for each task on the training set.

MultipleBinaryClassification#

class MultipleBinaryClassification(model, task=(), criterion='bce', metric=('auprc@micro', 'f1_max'), num_mlp_layer=1, normalization=True, reweight=False, graph_construction_model=None, verbose=0)[source]#

Multiple binary classification task for graphs / molecules / proteins.

Parameters
  • model (nn.Module) – graph representation model

  • task (list of int, optional) – training task id(s).

  • criterion (list or dict, optional) – training criterion(s). For dict, the keys are criterions and the values are the corresponding weights. Available criterions are bce.

  • metric (str or list of str, optional) – metric(s). Available metrics are auroc@macro, auprc@macro, auroc@micro, auprc@micro and f1_max.

  • num_mlp_layer (int, optional) – number of layers in the MLP prediction head

  • normalization (bool, optional) – whether to normalize the target

  • reweight (bool, optional) – whether to re-weight tasks according to the number of positive samples

  • graph_construction_model (nn.Module, optional) – graph construction model

  • verbose (int, optional) – output verbose level

forward(batch)[source]#

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

preprocess(train_set, valid_set, test_set)[source]#

Compute the weight for each task on the training set.

NodePropertyPrediction#

class NodePropertyPrediction(model, criterion='bce', metric=('macro_auprc', 'macro_auroc'), num_mlp_layer=1, normalization=True, num_class=None, verbose=0)[source]#

Node / atom / residue property prediction task.

Parameters
  • model (nn.Module) – graph representation model

  • criterion (str, list or dict, optional) – training criterion(s). For dict, the keys are criterions and the values are the corresponding weights. Available criterions are mse and bce.

  • metric (str or list of str, optional) – metric(s). Available metrics are mae, rmse, auprc and auroc.

  • num_mlp_layer (int, optional) – number of layers in mlp prediction head

  • normalization (bool, optional) – whether to normalize the target Available entities are node, atom and residue.

  • num_class (int, optional) – number of classes

  • verbose (int, optional) – output verbose level

preprocess(train_set, valid_set, test_set)[source]#

Compute the mean and derivation on the training set.

InteractionPrediction#

class InteractionPrediction(model, model2=None, task=(), criterion='mse', metric=('mae', 'rmse'), num_mlp_layer=1, normalization=True, num_class=None, verbose=0)[source]#

Predict the interaction property of graph pairs.

Parameters
  • model (nn.Module) – graph representation model

  • model2 (nn.Module, optional) – graph representation model for the second item. If None, use tied-weight model for the second item.

  • task (str, list or dict, optional) – training task(s). For dict, the keys are tasks and the values are the corresponding weights.

  • criterion (str, list or dict, optional) – training criterion(s). For dict, the keys are criterions and the values are the corresponding weights. Available criterions are mse and bce.

  • metric (str or list of str, optional) – metric(s). Available metrics are mae, rmse, auprc and auroc.

  • num_mlp_layer (int, optional) – number of layers in mlp prediction head

  • normalization (bool, optional) – whether to normalize the target

  • num_class (int, optional) – number of classes

  • verbose (int, optional) – output verbose level

preprocess(train_set, valid_set, test_set)[source]#

Compute the mean and derivation for each task on the training set.

Pre-trained Molecular Representation Tasks#

EdgePrediction#

class EdgePrediction(model)[source]#

Edge prediction task proposed in Inductive Representation Learning on Large Graphs.

Parameters

model (nn.Module) – node representation model

AttributeMasking#

class AttributeMasking(model, mask_rate=0.15, num_mlp_layer=2, graph_construction_model=None)[source]#

Attribute masking proposed in Strategies for Pre-training Graph Neural Networks.

Parameters
  • model (nn.Module) – node representation model

  • mask_rate (float, optional) – rate of masked nodes

  • num_mlp_layer (int, optional) – number of MLP layers

ContextPrediction#

class ContextPrediction(model, context_model=None, k=5, r1=4, r2=7, readout='mean', num_negative=1)[source]#

Context prediction task proposed in Strategies for Pre-training Graph Neural Networks.

For a given center node, the subgraph is defined as a k-hop neighborhood (inclusive) around the selected node. The context graph is defined as the surrounding graph structure between r1- (exclusive) and r2-hop (inclusive) from the center node. Nodes between k- and r1-hop are picked as anchor nodes for the context representation.

Parameters
  • model (nn.Module) – node representation model for subgraphs.

  • context_model (nn.Module, optional) – node representation model for context graphs. By default, use the same architecture as model without parameter sharing.

  • k (int, optional) – radius for subgraphs

  • r1 (int, optional) – inner radius for context graphs

  • r2 (int, optional) – outer radius for context graphs

  • readout (nn.Module, optional) – readout function over context anchor nodes

  • num_negative (int, optional) – number of negative samples per positive sample

DistancePrediction#

class DistancePrediction(model, num_sample=256, num_mlp_layer=2, graph_construction_model=None)[source]#

Pairwise spatial distance prediction task proposed in Protein Representation Learning by Geometric Structure Pretraining.

Randomly select some edges and predict the lengths of the edges using the representations of two nodes. The selected edges are removed from the input graph to prevent trivial solutions.

Parameters
  • model (nn.Module) – node representation model

  • num_sample (int, optional) – number of edges selected from each graph

  • num_mlp_layer (int, optional) – number of MLP layers in distance predictor

  • graph_construction_model (nn.Module, optional) – graph construction model

forward(batch)[source]#

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

AnglePrediction#

class AnglePrediction(model, num_sample=256, num_class=8, num_mlp_layer=2, graph_construction_model=None)[source]#

Angle prediction task proposed in Protein Representation Learning by Geometric Structure Pretraining.

Randomly select pairs of adjacent edges and predict the angles between them using the representations of three nodes. The selected edges are removed from the input graph to prevent trivial solutions.

Parameters
  • model (nn.Module) – node representation model

  • num_sample (int, optional) – number of edge pairs selected from each graph

  • num_class (int, optional) – number of classes to discretize the angles

  • num_mlp_layer (int, optional) – number of MLP layers in angle predictor

  • graph_construction_model (nn.Module, optional) – graph construction model

forward(batch)[source]#

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

DihedralPrediction#

class DihedralPrediction(model, num_sample=256, num_class=8, num_mlp_layer=2, graph_construction_model=None)[source]#

Dihedral prediction task proposed in Protein Representation Learning by Geometric Structure Pretraining.

Randomly select three consecutive edges and predict the dihedrals among them using the representations of four nodes. The selected edges are removed from the input graph to prevent trivial solutions.

Parameters
  • model (nn.Module) – node representation model

  • num_sample (int, optional) – number of edge triplets selected from each graph

  • num_class (int, optional) – number of classes for discretizing the dihedrals

  • num_mlp_layer (int, optional) – number of MLP layers in dihedral angle predictor

  • graph_construction_model (nn.Module, optional) – graph construction model

forward(batch)[source]#

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

Unsupervised#

class Unsupervised(model, graph_construction_model=None)[source]#

Wrapper task for unsupervised learning.

The unsupervised loss should be computed by the model.

Parameters

model (nn.Module) – any model

Molecule Generation Tasks#

AutoregressiveGeneration#

class AutoregressiveGeneration(node_model, edge_model, task=(), num_node_sample=- 1, num_edge_sample=- 1, max_edge_unroll=None, max_node=None, criterion='nll', agent_update_interval=5, gamma=0.9, reward_temperature=1, baseline_momentum=0.9)[source]#

Autoregressive graph generation task.

This class can be used to implement GraphAF proposed in GraphAF: A Flow-based Autoregressive Model for Molecular Graph Generation. To do so, instantiate the node model and the edge model with two GraphAutoregressiveFlow models.

Parameters
  • node_model (nn.Module) – node likelihood model

  • edge_model (nn.Module) – edge likelihood model

  • task (str or list of str, optional) – property optimization task(s). Available tasks are plogp and qed.

  • num_node_sample (int, optional) – number of node samples per graph. -1 for all samples.

  • num_edge_sample (int, optional) – number of edge samples per graph. -1 for all samples.

  • max_edge_unroll (int, optional) – max node id difference. If not provided, use the statistics from the training set.

  • max_node (int, optional) – max number of node. If not provided, use the statistics from the training set.

  • criterion (str, list or dict, optional) – training criterion(s). For dict, the keys are criterions and the values are the corresponding weights. Available criterions are nll and ppo.

  • agent_update_interval (int, optional) – update agent every n batch

  • gamma (float, optional) – reward discount rate

  • reward_temperature (float, optional) – temperature for reward. Higher temperature encourages larger mean reward, while lower temperature encourages larger maximal reward.

  • baseline_momentum (float, optional) – momentum for value function baseline

preprocess(train_set, valid_set, test_set)[source]#

Add atom id mapping and random BFS order to the training set.

Compute max_edge_unroll and max_node on the training set if not provided.

GCPNGeneration#

class GCPNGeneration(model, atom_types, max_edge_unroll=None, max_node=None, task=(), criterion='nll', hidden_dim_mlp=128, agent_update_interval=10, gamma=0.9, reward_temperature=1, baseline_momentum=0.9)[source]#

The graph generative model from Graph Convolutional Policy Network for Goal-Directed Molecular Graph Generation.

Parameters
  • model (nn.Module) – graph representation model

  • atom_types (list or set) – set of all possible atom types

  • task (str or list of str, optional) – property optimization task(s)

  • max_edge_unroll (int, optional) – max node id difference. If not provided, use the statistics from the training set.

  • max_node (int, optional) – max number of node. If not provided, use the statistics from the training set.

  • criterion (str, list or dict, optional) – training criterion(s). For dict, the keys are criterions and the values are the corresponding weights. Available criterions are nll and ppo.

  • agent_update_interval (int, optional) – update the agent every n batch

  • gamma (float, optional) – reward discount rate

  • reward_temperature (float, optional) – temperature for reward. Higher temperature encourages larger mean reward, while lower temperature encourages larger maximal reward.

  • baseline_momentum (float, optional) – momentum for value function baseline

preprocess(train_set, valid_set, test_set)[source]#

Add atom id mapping and random BFS order to the training set.

Compute max_edge_unroll and max_node on the training set if not provided.

Retrosynthesis Tasks#

CenterIdentification#

class CenterIdentification(model, feature=('reaction', 'graph', 'atom', 'bond'), num_mlp_layer=2)[source]#

Reaction center identification task.

This class is a part of retrosynthesis prediction.

Parameters
  • model (nn.Module) – graph representation model

  • feature (str or list of str, optional) – additional features for prediction. Available features are reaction: type of the reaction graph: graph representation of the product atom: original atom feature bond: original bond feature

  • num_mlp_layer (int, optional) – number of MLP layers

predict_synthon(batch, k=1)[source]#

Predict top-k synthons from target molecules.

Parameters
  • batch (dict) – batch of target molecules

  • k (int, optional) – return top-k results

Returns

top k records.

Each record is a batch dict of keys synthon, num_synthon, reaction_center, log_likelihood and reaction.

Return type

list of dict

SynthonCompletion#

class SynthonCompletion(model, feature=('reaction', 'graph', 'atom'), num_mlp_layer=2)[source]#

Synthon completion task.

This class is a part of retrosynthesis prediction.

Parameters
  • model (nn.Module) – graph representation model

  • feature (str or list of str, optional) – additional features for prediction. Available features are reaction: type of the reaction graph: graph representation of the synthon atom: original atom feature

  • num_mlp_layer (int, optional) – number of MLP layers

Retrosynthesis#

class Retrosynthesis(center_identification, synthon_completion, center_topk=2, num_synthon_beam=10, max_prediction=20, metric=('top-1', 'top-3', 'top-5', 'top-10'))[source]#

Retrosynthesis task.

This class wraps pretrained center identification and synthon completion modeules into a pipeline.

Parameters
  • center_identification (CenterIdentification) – sub task of center identification

  • synthon_completion (SynthonCompletion) – sub task of synthon completion

  • center_topk (int, optional) – number of reaction centers to predict for each product

  • num_synthon_beam (int, optional) – size of beam search for each synthon

  • max_prediction (int, optional) – max number of final predictions for each product

  • metric (str or list of str, optional) – metric(s). Available metrics are top-K.

load_state_dict(state_dict, strict=True)[source]#

Copies parameters and buffers from state_dict into this module and its descendants. If strict is True, then the keys of state_dict must exactly match the keys returned by this module’s state_dict() function.

Parameters
  • state_dict (dict) – a dict containing parameters and persistent buffers.

  • strict (bool, optional) – whether to strictly enforce that the keys in state_dict match the keys returned by this module’s state_dict() function. Default: True

Returns

  • missing_keys is a list of str containing the missing keys

  • unexpected_keys is a list of str containing the unexpected keys

Return type

NamedTuple with missing_keys and unexpected_keys fields

Knowledge Graph Reasoning Tasks#

KnowledgeGraphCompletion#

class KnowledgeGraphCompletion(model, criterion='bce', metric=('mr', 'mrr', 'hits@1', 'hits@3', 'hits@10'), num_negative=128, margin=6, adversarial_temperature=0, strict_negative=True, filtered_ranking=True, fact_ratio=None, sample_weight=True)[source]#

Knowledge graph completion task.

This class provides routines for the family of knowledge graph embedding models.

Parameters
  • model (nn.Module) – knowledge graph embedding model

  • criterion (str, list or dict, optional) – training criterion(s). For dict, the keys are criterions and the values are the corresponding weights. Available criterions are bce, ce and ranking.

  • metric (str or list of str, optional) – metric(s). Available metrics are mr, mrr and hits@K.

  • num_negative (int, optional) – number of negative samples per positive sample

  • margin (float, optional) – margin in ranking criterion

  • adversarial_temperature (float, optional) – temperature for self-adversarial negative sampling. Set 0 to disable self-adversarial negative sampling.

  • strict_negative (bool, optional) – use strict negative sampling or not

  • filtered_ranking (bool, optional) – use filtered or unfiltered ranking for evaluation

  • fact_ratio (float, optional) – split the training set into facts and labels. Set None to use the whole training set as both facts and labels.

  • sample_weight (bool, optional) – whether to down-weight triplets from entities of large degrees

Protein-Protein Interaction Prediction Tasks#

ContactPrediction#

class ContactPrediction(model, max_length=500, random_truncate=True, threshold=8.0, gap=6, criterion='bce', metric=('accuracy', 'prec@L5'), num_mlp_layer=1, verbose=0)[source]#

Predict whether each amino acid pair contact or not in the folding structure.

Parameters
  • model (nn.Module) – protein sequence representation model

  • max_length (int, optional) – maximal length of sequence. Truncate the sequence if it exceeds this limit.

  • random_truncate (bool, optional) – truncate the sequence at a random position. If not, truncate the suffix of the sequence.

  • threshold (float, optional) – distance threshold for contact

  • gap (int, optional) – sequential distance cutoff for evaluation

  • criterion (str or dict, optional) – training criterion. For dict, the key is criterion and the value is the corresponding weight. Available criterion is bce.

  • metric (str or list of str, optional) – metric(s). Available metrics are accuracy, prec@Lk and prec@k.

  • num_mlp_layer (int, optional) – number of layers in mlp prediction head

  • verbose (int, optional) – output verbose level