torchdrug.tasks#
Property Prediction Tasks#
PropertyPrediction#
- class PropertyPrediction(model, task=(), criterion='mse', metric=('mae', 'rmse'), num_mlp_layer=1, normalization=True, entity_level='residue', num_class=None, verbose=0)[source]#
Graph / molecule property prediction task.
This class is also compatible with semi-supervised learning.
- Parameters
model (nn.Module) – graph representation model
task (str, list or dict, optional) – training task(s). For dict, the keys are tasks and the values are the corresponding weights.
criterion (str, list or dict, optional) – training criterion(s). For dict, the keys are criterions and the values are the corresponding weights. Available criterions are
mse
andbce
.metric (str or list of str, optional) – metric(s). Available metrics are
mae
,rmse
,auprc
andauroc
.verbose (int, optional) – output verbose level
Pre-trained Molecular Representation Tasks#
EdgePrediction#
- class EdgePrediction(model)[source]#
Edge prediction task proposed in Inductive Representation Learning on Large Graphs.
- Parameters
model (nn.Module) – node representation model
AttributeMasking#
- class AttributeMasking(model, mask_rate=0.15, num_mlp_layer=2)[source]#
Attribute masking proposed in Strategies for Pre-training Graph Neural Networks.
- Parameters
model (nn.Module) – node representation model
mask_rate (float, optional) – rate of masked nodes
num_mlp_layer (int, optional) – number of MLP layers
ContextPrediction#
- class ContextPrediction(model, context_model=None, k=5, r1=4, r2=7, readout='mean', num_negative=1)[source]#
Context prediction task proposed in Strategies for Pre-training Graph Neural Networks.
For a given center node, the subgraph is defined as a k-hop neighborhood (inclusive) around the selected node. The context graph is defined as the surrounding graph structure between r1- (exclusive) and r2-hop (inclusive) from the center node. Nodes between k- and r1-hop are picked as anchor nodes for the context representation.
- Parameters
model (nn.Module) – node representation model for subgraphs.
context_model (nn.Module, optional) – node representation model for context graphs. By default, use the same architecture as
model
without parameter sharing.k (int, optional) – radius for subgraphs
r1 (int, optional) – inner radius for context graphs
r2 (int, optional) – outer radius for context graphs
readout (nn.Module, optional) – readout function over context anchor nodes
num_negative (int, optional) – number of negative samples per positive sample
Unsupervised#
Molecule Generation Tasks#
AutoregressiveGeneration#
- class AutoregressiveGeneration(node_model, edge_model, task=(), num_node_sample=- 1, num_edge_sample=- 1, max_edge_unroll=None, max_node=None, criterion='nll', agent_update_interval=5, gamma=0.9, reward_temperature=1, baseline_momentum=0.9)[source]#
Autoregressive graph generation task.
This class can be used to implement GraphAF proposed in GraphAF: A Flow-based Autoregressive Model for Molecular Graph Generation. To do so, instantiate the node model and the edge model with two
GraphAutoregressiveFlow
models.- Parameters
node_model (nn.Module) – node likelihood model
edge_model (nn.Module) – edge likelihood model
task (str or list of str, optional) – property optimization task(s). Available tasks are
plogp
andqed
.num_node_sample (int, optional) – number of node samples per graph. -1 for all samples.
num_edge_sample (int, optional) – number of edge samples per graph. -1 for all samples.
max_edge_unroll (int, optional) – max node id difference. If not provided, use the statistics from the training set.
max_node (int, optional) – max number of node. If not provided, use the statistics from the training set.
criterion (str, list or dict, optional) – training criterion(s). For dict, the keys are criterions and the values are the corresponding weights. Available criterions are
nll
andppo
.agent_update_interval (int, optional) – update agent every n batch
gamma (float, optional) – reward discount rate
reward_temperature (float, optional) – temperature for reward. Higher temperature encourages larger mean reward, while lower temperature encourages larger maximal reward.
baseline_momentum (float, optional) – momentum for value function baseline
GCPNGeneration#
- class GCPNGeneration(model, atom_types, max_edge_unroll=None, max_node=None, task=(), criterion='nll', hidden_dim_mlp=128, agent_update_interval=10, gamma=0.9, reward_temperature=1, baseline_momentum=0.9)[source]#
The graph generative model from Graph Convolutional Policy Network for Goal-Directed Molecular Graph Generation.
- Parameters
model (nn.Module) – graph representation model
atom_types (list or set) – set of all possible atom types
task (str or list of str, optional) – property optimization task(s)
max_edge_unroll (int, optional) – max node id difference. If not provided, use the statistics from the training set.
max_node (int, optional) – max number of node. If not provided, use the statistics from the training set.
criterion (str, list or dict, optional) – training criterion(s). For dict, the keys are criterions and the values are the corresponding weights. Available criterions are
nll
andppo
.agent_update_interval (int, optional) – update the agent every n batch
gamma (float, optional) – reward discount rate
reward_temperature (float, optional) – temperature for reward. Higher temperature encourages larger mean reward, while lower temperature encourages larger maximal reward.
baseline_momentum (float, optional) – momentum for value function baseline
Retrosynthesis Tasks#
CenterIdentification#
- class CenterIdentification(model, feature=('reaction', 'graph', 'atom', 'bond'), num_mlp_layer=2)[source]#
Reaction center identification task.
This class is a part of retrosynthesis prediction.
- Parameters
model (nn.Module) – graph representation model
feature (str or list of str, optional) – additional features for prediction. Available features are reaction: type of the reaction graph: graph representation of the product atom: original atom feature bond: original bond feature
num_mlp_layer (int, optional) – number of MLP layers
- predict_synthon(batch, k=1)[source]#
Predict top-k synthons from target molecules.
- Parameters
batch (dict) – batch of target molecules
k (int, optional) – return top-k results
- Returns
- top k records.
Each record is a batch dict of keys
synthon
,num_synthon
,reaction_center
,log_likelihood
andreaction
.
- Return type
list of dict
SynthonCompletion#
- class SynthonCompletion(model, feature=('reaction', 'graph', 'atom'), num_mlp_layer=2)[source]#
Synthon completion task.
This class is a part of retrosynthesis prediction.
- Parameters
model (nn.Module) – graph representation model
feature (str or list of str, optional) – additional features for prediction. Available features are reaction: type of the reaction graph: graph representation of the synthon atom: original atom feature
num_mlp_layer (int, optional) – number of MLP layers
Retrosynthesis#
- class Retrosynthesis(center_identification, synthon_completion, center_topk=2, num_synthon_beam=10, max_prediction=20, metric=('top-1', 'top-3', 'top-5', 'top-10'))[source]#
Retrosynthesis task.
This class wraps pretrained center identification and synthon completion modeules into a pipeline.
- Parameters
center_identification (CenterIdentification) – sub task of center identification
synthon_completion (SynthonCompletion) – sub task of synthon completion
center_topk (int, optional) – number of reaction centers to predict for each product
num_synthon_beam (int, optional) – size of beam search for each synthon
max_prediction (int, optional) – max number of final predictions for each product
metric (str or list of str, optional) – metric(s). Available metrics are
top-K
.
- load_state_dict(state_dict, strict=True)[source]#
Copies parameters and buffers from
state_dict
into this module and its descendants. Ifstrict
isTrue
, then the keys ofstate_dict
must exactly match the keys returned by this module’sstate_dict()
function.- Parameters
state_dict (dict) – a dict containing parameters and persistent buffers.
strict (bool, optional) – whether to strictly enforce that the keys in
state_dict
match the keys returned by this module’sstate_dict()
function. Default:True
- Returns
missing_keys is a list of str containing the missing keys
unexpected_keys is a list of str containing the unexpected keys
- Return type
NamedTuple
withmissing_keys
andunexpected_keys
fields
Knowledge Graph Reasoning Tasks#
KnowledgeGraphCompletion#
- class KnowledgeGraphCompletion(model, criterion='bce', metric=('mr', 'mrr', 'hits@1', 'hits@3', 'hits@10'), num_negative=128, margin=6, adversarial_temperature=0, strict_negative=True, filtered_ranking=True, fact_ratio=None, sample_weight=True)[source]#
Knowledge graph completion task.
This class provides routines for the family of knowledge graph embedding models.
- Parameters
model (nn.Module) – knowledge graph embedding model
criterion (str, list or dict, optional) – training criterion(s). For dict, the keys are criterions and the values are the corresponding weights. Available criterions are
bce
,ce
andranking
.metric (str or list of str, optional) – metric(s). Available metrics are
mr
,mrr
andhits@K
.num_negative (int, optional) – number of negative samples per positive sample
margin (float, optional) – margin in ranking criterion
adversarial_temperature (float, optional) – temperature for self-adversarial negative sampling. Set
0
to disable self-adversarial negative sampling.strict_negative (bool, optional) – use strict negative sampling or not
filtered_ranking (bool, optional) – use filtered or unfiltered ranking for evaluation
fact_ratio (float, optional) – split the training set into facts and labels. Set
None
to use the whole training set as both facts and labels.sample_weight (bool, optional) – whether to down-weight triplets from entities of large degrees