torchdrug.tasks#
Property Prediction Tasks#
PropertyPrediction#
- class PropertyPrediction(model, task=(), criterion='mse', metric=('mae', 'rmse'), num_mlp_layer=1, normalization=True, num_class=None, mlp_batch_norm=False, mlp_dropout=0, graph_construction_model=None, verbose=0)[source]#
Graph / molecule / protein property prediction task.
This class is also compatible with semi-supervised learning.
- Parameters
model (nn.Module) – graph representation model
task (str, list or dict, optional) – training task(s). For dict, the keys are tasks and the values are the corresponding weights.
criterion (str, list or dict, optional) – training criterion(s). For dict, the keys are criterions and the values are the corresponding weights. Available criterions are
mse
,bce
andce
.metric (str or list of str, optional) – metric(s). Available metrics are
mae
,rmse
,auprc
andauroc
.num_mlp_layer (int, optional) – number of layers in mlp prediction head
normalization (bool, optional) – whether to normalize the target
num_class (int, optional) – number of classes
mlp_batch_norm (bool, optional) – apply batch normalization in mlp or not
mlp_dropout (float, optional) – dropout in mlp
graph_construction_model (nn.Module, optional) – graph construction model
verbose (int, optional) – output verbose level
MultipleBinaryClassification#
- class MultipleBinaryClassification(model, task=(), criterion='bce', metric=('auprc@micro', 'f1_max'), num_mlp_layer=1, normalization=True, reweight=False, graph_construction_model=None, verbose=0)[source]#
Multiple binary classification task for graphs / molecules / proteins.
- Parameters
model (nn.Module) – graph representation model
task (list of int, optional) – training task id(s).
criterion (list or dict, optional) – training criterion(s). For dict, the keys are criterions and the values are the corresponding weights. Available criterions are
bce
.metric (str or list of str, optional) – metric(s). Available metrics are
auroc@macro
,auprc@macro
,auroc@micro
,auprc@micro
andf1_max
.num_mlp_layer (int, optional) – number of layers in the MLP prediction head
normalization (bool, optional) – whether to normalize the target
reweight (bool, optional) – whether to re-weight tasks according to the number of positive samples
graph_construction_model (nn.Module, optional) – graph construction model
verbose (int, optional) – output verbose level
- forward(batch)[source]#
Defines the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Module
instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
NodePropertyPrediction#
- class NodePropertyPrediction(model, criterion='bce', metric=('macro_auprc', 'macro_auroc'), num_mlp_layer=1, normalization=True, num_class=None, verbose=0)[source]#
Node / atom / residue property prediction task.
- Parameters
model (nn.Module) – graph representation model
criterion (str, list or dict, optional) – training criterion(s). For dict, the keys are criterions and the values are the corresponding weights. Available criterions are
mse
,bce
andce
.metric (str or list of str, optional) – metric(s). Available metrics are
mae
,rmse
,auprc
andauroc
.num_mlp_layer (int, optional) – number of layers in mlp prediction head
normalization (bool, optional) – whether to normalize the target Available entities are
node
,atom
andresidue
.num_class (int, optional) – number of classes
verbose (int, optional) – output verbose level
InteractionPrediction#
- class InteractionPrediction(model, model2=None, task=(), criterion='mse', metric=('mae', 'rmse'), num_mlp_layer=1, normalization=True, num_class=None, mlp_batch_norm=False, mlp_dropout=0, verbose=0)[source]#
Predict the interaction property of graph pairs.
- Parameters
model (nn.Module) – graph representation model
model2 (nn.Module, optional) – graph representation model for the second item. If
None
, use tied-weight model for the second item.task (str, list or dict, optional) – training task(s). For dict, the keys are tasks and the values are the corresponding weights.
criterion (str, list or dict, optional) – training criterion(s). For dict, the keys are criterions and the values are the corresponding weights. Available criterions are
mse
,bce
andce
.metric (str or list of str, optional) – metric(s). Available metrics are
mae
,rmse
,auprc
andauroc
.num_mlp_layer (int, optional) – number of layers in mlp prediction head
normalization (bool, optional) – whether to normalize the target
num_class (int, optional) – number of classes
mlp_batch_norm (bool, optional) – apply batch normalization in mlp or not
mlp_dropout (float, optional) – dropout in mlp
verbose (int, optional) – output verbose level
Pre-trained Molecular Representation Tasks#
EdgePrediction#
- class EdgePrediction(model)[source]#
Edge prediction task proposed in Inductive Representation Learning on Large Graphs.
- Parameters
model (nn.Module) – node representation model
AttributeMasking#
- class AttributeMasking(model, mask_rate=0.15, num_mlp_layer=2, graph_construction_model=None)[source]#
Attribute masking proposed in Strategies for Pre-training Graph Neural Networks.
- Parameters
model (nn.Module) – node representation model
mask_rate (float, optional) – rate of masked nodes
num_mlp_layer (int, optional) – number of MLP layers
ContextPrediction#
- class ContextPrediction(model, context_model=None, k=5, r1=4, r2=7, readout='mean', num_negative=1)[source]#
Context prediction task proposed in Strategies for Pre-training Graph Neural Networks.
For a given center node, the subgraph is defined as a k-hop neighborhood (inclusive) around the selected node. The context graph is defined as the surrounding graph structure between r1- (exclusive) and r2-hop (inclusive) from the center node. Nodes between k- and r1-hop are picked as anchor nodes for the context representation.
- Parameters
model (nn.Module) – node representation model for subgraphs.
context_model (nn.Module, optional) – node representation model for context graphs. By default, use the same architecture as
model
without parameter sharing.k (int, optional) – radius for subgraphs
r1 (int, optional) – inner radius for context graphs
r2 (int, optional) – outer radius for context graphs
readout (nn.Module, optional) – readout function over context anchor nodes
num_negative (int, optional) – number of negative samples per positive sample
DistancePrediction#
- class DistancePrediction(model, num_sample=256, num_mlp_layer=2, graph_construction_model=None)[source]#
Pairwise spatial distance prediction task proposed in Protein Representation Learning by Geometric Structure Pretraining.
Randomly select some edges and predict the lengths of the edges using the representations of two nodes. The selected edges are removed from the input graph to prevent trivial solutions.
- Parameters
model (nn.Module) – node representation model
num_sample (int, optional) – number of edges selected from each graph
num_mlp_layer (int, optional) – number of MLP layers in distance predictor
graph_construction_model (nn.Module, optional) – graph construction model
- forward(batch)[source]#
Defines the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Module
instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
AnglePrediction#
- class AnglePrediction(model, num_sample=256, num_class=8, num_mlp_layer=2, graph_construction_model=None)[source]#
Angle prediction task proposed in Protein Representation Learning by Geometric Structure Pretraining.
Randomly select pairs of adjacent edges and predict the angles between them using the representations of three nodes. The selected edges are removed from the input graph to prevent trivial solutions.
- Parameters
model (nn.Module) – node representation model
num_sample (int, optional) – number of edge pairs selected from each graph
num_class (int, optional) – number of classes to discretize the angles
num_mlp_layer (int, optional) – number of MLP layers in angle predictor
graph_construction_model (nn.Module, optional) – graph construction model
- forward(batch)[source]#
Defines the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Module
instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
DihedralPrediction#
- class DihedralPrediction(model, num_sample=256, num_class=8, num_mlp_layer=2, graph_construction_model=None)[source]#
Dihedral prediction task proposed in Protein Representation Learning by Geometric Structure Pretraining.
Randomly select three consecutive edges and predict the dihedrals among them using the representations of four nodes. The selected edges are removed from the input graph to prevent trivial solutions.
- Parameters
model (nn.Module) – node representation model
num_sample (int, optional) – number of edge triplets selected from each graph
num_class (int, optional) – number of classes for discretizing the dihedrals
num_mlp_layer (int, optional) – number of MLP layers in dihedral angle predictor
graph_construction_model (nn.Module, optional) – graph construction model
- forward(batch)[source]#
Defines the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Module
instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
Unsupervised#
Molecule Generation Tasks#
AutoregressiveGeneration#
- class AutoregressiveGeneration(node_model, edge_model, task=(), num_node_sample=- 1, num_edge_sample=- 1, max_edge_unroll=None, max_node=None, criterion='nll', agent_update_interval=5, gamma=0.9, reward_temperature=1, baseline_momentum=0.9)[source]#
Autoregressive graph generation task.
This class can be used to implement GraphAF proposed in GraphAF: A Flow-based Autoregressive Model for Molecular Graph Generation. To do so, instantiate the node model and the edge model with two
GraphAutoregressiveFlow
models.- Parameters
node_model (nn.Module) – node likelihood model
edge_model (nn.Module) – edge likelihood model
task (str or list of str, optional) – property optimization task(s). Available tasks are
plogp
andqed
.num_node_sample (int, optional) – number of node samples per graph. -1 for all samples.
num_edge_sample (int, optional) – number of edge samples per graph. -1 for all samples.
max_edge_unroll (int, optional) – max node id difference. If not provided, use the statistics from the training set.
max_node (int, optional) – max number of node. If not provided, use the statistics from the training set.
criterion (str, list or dict, optional) – training criterion(s). For dict, the keys are criterions and the values are the corresponding weights. Available criterions are
nll
andppo
.agent_update_interval (int, optional) – update agent every n batch
gamma (float, optional) – reward discount rate
reward_temperature (float, optional) – temperature for reward. Higher temperature encourages larger mean reward, while lower temperature encourages larger maximal reward.
baseline_momentum (float, optional) – momentum for value function baseline
GCPNGeneration#
- class GCPNGeneration(model, atom_types, max_edge_unroll=None, max_node=None, task=(), criterion='nll', hidden_dim_mlp=128, agent_update_interval=10, gamma=0.9, reward_temperature=1, baseline_momentum=0.9)[source]#
The graph generative model from Graph Convolutional Policy Network for Goal-Directed Molecular Graph Generation.
- Parameters
model (nn.Module) – graph representation model
atom_types (list or set) – set of all possible atom types
task (str or list of str, optional) – property optimization task(s)
max_edge_unroll (int, optional) – max node id difference. If not provided, use the statistics from the training set.
max_node (int, optional) – max number of node. If not provided, use the statistics from the training set.
criterion (str, list or dict, optional) – training criterion(s). For dict, the keys are criterions and the values are the corresponding weights. Available criterions are
nll
andppo
.agent_update_interval (int, optional) – update the agent every n batch
gamma (float, optional) – reward discount rate
reward_temperature (float, optional) – temperature for reward. Higher temperature encourages larger mean reward, while lower temperature encourages larger maximal reward.
baseline_momentum (float, optional) – momentum for value function baseline
Retrosynthesis Tasks#
CenterIdentification#
- class CenterIdentification(model, feature=('reaction', 'graph', 'atom', 'bond'), num_mlp_layer=2)[source]#
Reaction center identification task.
This class is a part of retrosynthesis prediction.
- Parameters
model (nn.Module) – graph representation model
feature (str or list of str, optional) – additional features for prediction. Available features are reaction: type of the reaction graph: graph representation of the product atom: original atom feature bond: original bond feature
num_mlp_layer (int, optional) – number of MLP layers
- predict_synthon(batch, k=1)[source]#
Predict top-k synthons from target molecules.
- Parameters
batch (dict) – batch of target molecules
k (int, optional) – return top-k results
- Returns
- top k records.
Each record is a batch dict of keys
synthon
,num_synthon
,reaction_center
,log_likelihood
andreaction
.
- Return type
list of dict
SynthonCompletion#
- class SynthonCompletion(model, feature=('reaction', 'graph', 'atom'), num_mlp_layer=2)[source]#
Synthon completion task.
This class is a part of retrosynthesis prediction.
- Parameters
model (nn.Module) – graph representation model
feature (str or list of str, optional) – additional features for prediction. Available features are reaction: type of the reaction graph: graph representation of the synthon atom: original atom feature
num_mlp_layer (int, optional) – number of MLP layers
Retrosynthesis#
- class Retrosynthesis(center_identification, synthon_completion, center_topk=2, num_synthon_beam=10, max_prediction=20, metric=('top-1', 'top-3', 'top-5', 'top-10'))[source]#
Retrosynthesis task.
This class wraps pretrained center identification and synthon completion modeules into a pipeline.
- Parameters
center_identification (CenterIdentification) – sub task of center identification
synthon_completion (SynthonCompletion) – sub task of synthon completion
center_topk (int, optional) – number of reaction centers to predict for each product
num_synthon_beam (int, optional) – size of beam search for each synthon
max_prediction (int, optional) – max number of final predictions for each product
metric (str or list of str, optional) – metric(s). Available metrics are
top-K
.
- load_state_dict(state_dict, strict=True)[source]#
Copies parameters and buffers from
state_dict
into this module and its descendants. Ifstrict
isTrue
, then the keys ofstate_dict
must exactly match the keys returned by this module’sstate_dict()
function.- Parameters
state_dict (dict) – a dict containing parameters and persistent buffers.
strict (bool, optional) – whether to strictly enforce that the keys in
state_dict
match the keys returned by this module’sstate_dict()
function. Default:True
- Returns
missing_keys is a list of str containing the missing keys
unexpected_keys is a list of str containing the unexpected keys
- Return type
NamedTuple
withmissing_keys
andunexpected_keys
fields
Knowledge Graph Reasoning Tasks#
KnowledgeGraphCompletion#
- class KnowledgeGraphCompletion(model, criterion='bce', metric=('mr', 'mrr', 'hits@1', 'hits@3', 'hits@10'), num_negative=128, margin=6, adversarial_temperature=0, strict_negative=True, fact_ratio=None, sample_weight=True, filtered_ranking=True, full_batch_eval=False)[source]#
Knowledge graph completion task.
This class provides routines for the family of knowledge graph embedding models.
- Parameters
model (nn.Module) – knowledge graph completion model
criterion (str, list or dict, optional) – training criterion(s). For dict, the keys are criterions and the values are the corresponding weights. Available criterions are
bce
,ce
andranking
.metric (str or list of str, optional) – metric(s). Available metrics are
mr
,mrr
andhits@K
.num_negative (int, optional) – number of negative samples per positive sample
margin (float, optional) – margin in ranking criterion
adversarial_temperature (float, optional) – temperature for self-adversarial negative sampling. Set
0
to disable self-adversarial negative sampling.strict_negative (bool, optional) – use strict negative sampling or not
fact_ratio (float, optional) – split the training set into facts and labels. Set
None
to use the whole training set as both facts and labels.sample_weight (bool, optional) – whether to down-weight triplets from entities of large degrees
filtered_ranking (bool, optional) – use filtered or unfiltered ranking for evaluation
full_batch_eval (bool, optional) – whether to feed test negative samples by full batch or mini batch. Full batch speeds up evaluation significantly, but may cause OOM problems for some models and datasets.
Protein-Protein Interaction Prediction Tasks#
ContactPrediction#
- class ContactPrediction(model, max_length=500, random_truncate=True, threshold=8.0, gap=6, criterion='bce', metric=('accuracy', 'prec@L5'), num_mlp_layer=1, verbose=0)[source]#
Predict whether each amino acid pair contact or not in the folding structure.
- Parameters
model (nn.Module) – protein sequence representation model
max_length (int, optional) – maximal length of sequence. Truncate the sequence if it exceeds this limit.
random_truncate (bool, optional) – truncate the sequence at a random position. If not, truncate the suffix of the sequence.
threshold (float, optional) – distance threshold for contact
gap (int, optional) – sequential distance cutoff for evaluation
criterion (str or dict, optional) – training criterion. For dict, the key is criterion and the value is the corresponding weight. Available criterion is
bce
.metric (str or list of str, optional) – metric(s). Available metrics are
accuracy
,prec@Lk
andprec@k
.num_mlp_layer (int, optional) – number of layers in mlp prediction head
verbose (int, optional) – output verbose level