torchdrug.tasks

Property Prediction Tasks

PropertyPrediction

class PropertyPrediction(model, task=(), criterion='mse', metric='mae', 'rmse', verbose=0)[source]

Graph / molecule property prediction task.

This class is also compatible with semi-supervised learning.

Parameters
  • model (nn.Module) – graph representation model

  • task (str, list or dict, optional) – training task(s). For dict, the keys are tasks and the values are the corresponding weights.

  • criterion (str, list or dict, optional) – training criterion(s). For dict, the keys are criterions and the values are the corresponding weights. Available criterions are mse and bce.

  • metric (str or list of str, optional) – metric(s). Available metrics are mae, rmse, auprc and auroc.

  • verbose (int, optional) – output verbose level

preprocess(train_set, valid_set, test_set)[source]

Compute the mean and derivation for each task on the training set.

Pre-trained Molecular Representation Tasks

EdgePrediction

class EdgePrediction(model)[source]

Edge prediction task proposed in Inductive Representation Learning on Large Graphs.

Parameters

model (nn.Module) – node representation model

AttributeMasking

class AttributeMasking(model, mask_rate=0.15, num_mlp_layer=2)[source]

Attribute masking proposed in Strategies for Pre-training Graph Neural Networks.

Parameters
  • model (nn.Module) – node representation model

  • mask_rate (float, optional) – rate of masked nodes

  • num_mlp_layer (int, optional) – number of MLP layers

ContextPrediction

class ContextPrediction(model, context_model=None, k=5, r1=4, r2=7, readout='mean', num_negative=1)[source]

Context prediction task proposed in Strategies for Pre-training Graph Neural Networks.

For a given center node, the subgraph is defined as a k-hop neighborhood (inclusive) around the selected node. The context graph is defined as the surrounding graph structure between r1- (exclusive) and r2-hop (inclusive) from the center node. Nodes between k- and r1-hop are picked as anchor nodes for the context representation.

Parameters
  • model (nn.Module) – node representation model for subgraphs.

  • context_model (nn.Module, optional) – node representation model for context graphs. By default, use the same architecture as model without parameter sharing.

  • k (int, optional) – radius for subgraphs

  • r1 (int, optional) – inner radius for context graphs

  • r2 (int, optional) – outer radius for context graphs

  • readout (nn.Module, optional) – readout function over context anchor nodes

  • num_negative (int, optional) – number of negative samples per positive sample

Unsupervised

class Unsupervised(model)[source]

Wrapper task for unsupervised learning.

The unsupervised loss should be computed by the model.

Parameters

model (nn.Module) – any model

Molecule Generation Tasks

AutoregressiveGeneration

class AutoregressiveGeneration(node_model, edge_model, task=(), num_node_sample=- 1, num_edge_sample=- 1, max_edge_unroll=None, max_node=None, criterion='nll', agent_update_interval=5, gamma=0.9, reward_temperature=1, baseline_momentum=0.9)[source]

Autoregressive graph generation task.

This class can be used to implement GraphAF proposed in GraphAF: A Flow-based Autoregressive Model for Molecular Graph Generation. To do so, instantiate the node model and the edge model with two GraphAutoregressiveFlow models.

Parameters
  • node_model (nn.Module) – node likelihood model

  • edge_model (nn.Module) – edge likelihood model

  • task (str or list of str, optional) – property optimization task(s). Available tasks are plogp and qed.

  • num_node_sample (int, optional) – number of node samples per graph. -1 for all samples.

  • num_edge_sample (int, optional) – number of edge samples per graph. -1 for all samples.

  • max_edge_unroll (int, optional) – max node id difference. If not provided, use the statistics from the training set.

  • max_node (int, optional) – max number of node. If not provided, use the statistics from the training set.

  • criterion (str, list or dict, optional) – training criterion(s). For dict, the keys are criterions and the values are the corresponding weights. Available criterions are nll and ppo.

  • agent_update_interval (int, optional) – update agent every n batch

  • gamma (float, optional) – reward discount rate

  • reward_temperature (float, optional) – temperature for reward. Higher temperature encourages larger mean reward, while lower temperature encourages larger maximal reward.

  • baseline_momentum (float, optional) – momentum for value function baseline

preprocess(train_set, valid_set, test_set)[source]

Add atom id mapping and random BFS order to the training set.

Compute max_edge_unroll and max_node on the training set if not provided.

GCPNGeneration

class GCPNGeneration(model, atom_types, max_edge_unroll=None, max_node=None, task=(), criterion='nll', hidden_dim_mlp=128, agent_update_interval=10, gamma=0.9, reward_temperature=1, baseline_momentum=0.9)[source]

The graph generative model from Graph Convolutional Policy Network for Goal-Directed Molecular Graph Generation.

Parameters
  • model (nn.Module) – graph representation model

  • atom_types (list or set) – set of all possible atom types

  • task (str or list of str, optional) – property optimization task(s)

  • max_edge_unroll (int, optional) – max node id difference. If not provided, use the statistics from the training set.

  • max_node (int, optional) – max number of node. If not provided, use the statistics from the training set.

  • criterion (str, list or dict, optional) – training criterion(s). For dict, the keys are criterions and the values are the corresponding weights. Available criterions are nll and ppo.

  • agent_update_interval (int, optional) – update the agent every n batch

  • gamma (float, optional) – reward discount rate

  • reward_temperature (float, optional) – temperature for reward. Higher temperature encourages larger mean reward, while lower temperature encourages larger maximal reward.

  • baseline_momentum (float, optional) – momentum for value function baseline

preprocess(train_set, valid_set, test_set)[source]

Add atom id mapping and random BFS order to the training set.

Compute max_edge_unroll and max_node on the training set if not provided.

Retrosynthesis Tasks

CenterIdentification

class CenterIdentification(model, feature='reaction', 'graph', 'atom', 'bond', num_mlp_layer=2)[source]

Reaction center identification task.

This class is a part of retrosynthesis prediction.

Parameters
  • model (nn.Module) – graph representation model

  • feature (str or list of str, optional) – additional features for prediction. Available features are reaction: type of the reaction graph: graph representation of the product atom: original atom feature bond: original bond feature

  • num_mlp_layer (int, optional) – number of MLP layers

predict_synthon(batch, k=1)[source]

Predict top-k synthons from target molecules.

Parameters
  • batch (dict) – batch of target molecules

  • k (int, optional) – return top-k results

Returns

top k records.

Each record is a batch dict of keys synthon, num_synthon, reaction_center, log_likelihood and reaction.

Return type

list of dict

SynthonCompletion

class SynthonCompletion(model, feature='reaction', 'graph', 'atom', num_mlp_layer=2)[source]

Synthon completion task.

This class is a part of retrosynthesis prediction.

Parameters
  • model (nn.Module) – graph representation model

  • feature (str or list of str, optional) – additional features for prediction. Available features are reaction: type of the reaction graph: graph representation of the synthon atom: original atom feature

  • num_mlp_layer (int, optional) – number of MLP layers

Retrosynthesis

class Retrosynthesis(center_identification, synthon_completion, center_topk=2, num_synthon_beam=10, max_prediction=20, metric='top-1', 'top-3', 'top-5', 'top-10')[source]

Retrosynthesis task.

This class wraps pretrained center identification and synthon completion modeules into a pipeline.

Parameters
  • center_identification (CenterIdentification) – sub task of center identification

  • synthon_completion (SynthonCompletion) – sub task of synthon completion

  • center_topk (int, optional) – number of reaction centers to predict for each product

  • num_synthon_beam (int, optional) – size of beam search for each synthon

  • max_prediction (int, optional) – max number of final predictions for each product

  • metric (str or list of str, optional) – metric(s). Available metrics are top-K.

load_state_dict(state_dict, strict=True)[source]

Copies parameters and buffers from state_dict into this module and its descendants. If strict is True, then the keys of state_dict must exactly match the keys returned by this module’s state_dict() function.

Parameters
  • state_dict (dict) – a dict containing parameters and persistent buffers.

  • strict (bool, optional) – whether to strictly enforce that the keys in state_dict match the keys returned by this module’s state_dict() function. Default: True

Returns

  • missing_keys is a list of str containing the missing keys

  • unexpected_keys is a list of str containing the unexpected keys

Return type

NamedTuple with missing_keys and unexpected_keys fields

Knowledge Graph Reasoning Tasks

KnowledgeGraphCompletion

class KnowledgeGraphCompletion(model, criterion='bce', metric='mr', 'mrr', 'hits@1', 'hits@3', 'hits@10', num_negative=128, margin=6, adversarial_temperature=0, strict_negative=True, filtered_ranking=True, fact_ratio=None, sample_weight=True)[source]

Knowledge graph completion task.

This class provides routines for the family of knowledge graph embedding models.

Parameters
  • model (nn.Module) – knowledge graph embedding model

  • criterion (str, list or dict, optional) – training criterion(s). For dict, the keys are criterions and the values are the corresponding weights. Available criterions are bce, ce and ranking.

  • metric (str or list of str, optional) – metric(s). Available metrics are mr, mrr and hits@K.

  • num_negative (int, optional) – number of negative samples per positive sample

  • margin (float, optional) – margin in ranking criterion

  • adversarial_temperature (float, optional) – temperature for self-adversarial negative sampling. Set 0 to disable self-adversarial negative sampling.

  • strict_negative (bool, optional) – use strict negative sampling or not

  • filtered_ranking (bool, optional) – use filtered or unfiltered ranking for evaluation

  • fact_ratio (float, optional) – split the training set into facts and labels. Set None to use the whole training set as both facts and labels.

  • sample_weight (bool, optional) – whether to down-weight triplets from entities of large degrees