torchdrug.data

Data Structures

Graph

class Graph(edge_list=None, edge_weight=None, num_node=None, num_relation=None, node_feature=None, edge_feature=None, graph_feature=None, **kwargs)[source]

Basic container for sparse graphs.

To batch graphs with variadic sizes, use data.Graph.pack. This will return a PackedGraph object with the following block diagonal adjacency matrix.

\[\begin{split}\begin{bmatrix} A_1 & \cdots & 0 \\ \vdots & \ddots & \vdots \\ 0 & \cdots & A_n \end{bmatrix}\end{split}\]

where \(A_i\) is the adjacency of \(i\)-th graph.

You may register dynamic attributes for each graph. The registered attributes will be automatically processed during packing.

Example:

>>> graph = data.Graph(torch.randint(10, (30, 2)))
>>> with graph.node():
>>>     graph.my_node_attr = torch.rand(10, 5, 5)
Parameters
  • edge_list (array_like, optional) – list of edges of shape \((|E|, 2)\) or \((|E|, 3)\). Each tuple is (node_in, node_out) or (node_in, node_out, relation).

  • edge_weight (array_like, optional) – edge weights of shape \((|E|,)\)

  • num_node (int, optional) – number of nodes. By default, it will be inferred from the largest id in edge_list

  • num_relation (int, optional) – number of relations

  • node_feature (array_like, optional) – node features of shape \((|V|, ...)\)

  • edge_feature (array_like, optional) – edge features of shape \((|E|, ...)\)

  • graph_feature (array_like, optional) – graph feature of any shape

packed_type

alias of PackedGraph

clone()[source]

Clone this graph.

compact()[source]

Remove isolated nodes and compact node ids.

Returns

Graph

connected_components()[source]

Split this graph into connected components.

Returns

connected components, number of connected components per graph

Return type

(PackedGraph, LongTensor)

copy_(src)[source]

Copy data from src into self and return self.

The src graph must have the same set of attributes as self.

cpu()[source]

Return a copy of this graph in CPU memory.

This is a non-op if the graph is already in CPU memory.

cuda(*args, **kwargs)[source]

Return a copy of this graph in CUDA memory.

This is a non-op if the graph is already on the correct device.

detach()[source]

Detach this graph.

directed(order=None)[source]

Mask the edges to create a directed graph. Edges that go from a node index to a larger or equal node index will be kept.

Parameters

order (Tensor, optional) – topological order of the nodes

edge()[source]

Context manager for edge attributes.

edge_mask(index)[source]

Return a masked graph based on the specified edges.

This function can also be used to re-order the edges.

Parameters

index (array_like) – edge index

Returns

Graph

classmethod from_dense(adjacency, node_feature=None, edge_feature=None)[source]

Create a sparse graph from a dense adjacency matrix. For zero entries in the adjacency matrix, their edge features will be ignored.

Parameters
  • adjacency (array_like) – adjacency matrix of shape \((|V|, |V|)\) or \((|V|, |V|, |R|)\)

  • node_feature (array_like) – node features of shape \((|V|, ...)\)

  • edge_feature (array_like) – edge features of shape \((|V|, |V|, ...)\) or \((|V|, |V|, |R|, ...)\)

full()[source]

Return a fully connected graph over the nodes.

Returns

Graph

get_edge(edge)[source]

Get the weight of of an edge.

Parameters

edge (array_like) – index of shape \((2,)\) or \((3,)\)

Returns

weight of the edge

Return type

Tensor

graph()[source]

Context manager for graph attributes.

match(pattern)[source]

Return all matched indexes for each pattern. Support patterns with -1 as the wildcard.

Parameters

pattern (array_like) – index of shape \((N, 2)\) or \((N, 3)\)

Returns

matched indexes, number of matches per edge

Return type

(LongTensor, LongTensor)

Examples:

>>> graph = data.Graph([[0, 1], [1, 0], [1, 2], [2, 1], [2, 0], [0, 2]])
>>> index, num_match = graph.match([[0, -1], [1, 2]])
>>> assert (index == torch.tensor([0, 5, 2])).all()
>>> assert (num_match == torch.tensor([2, 1])).all()
node()[source]

Context manager for node attributes.

node_mask(index, compact=False)[source]

Return a masked graph based on the specified nodes.

This function can also be used to re-order the nodes.

Parameters
  • index (array_like) – node index

  • compact (bool, optional) – compact node ids or not

Returns

Graph

Examples:

>>> graph = data.Graph.from_dense(torch.eye(3))
>>> assert graph.node_mask([1, 2]).adjacency.shape == (3, 3)
>>> assert graph.node_mask([1, 2], compact=True).adjacency.shape == (2, 2)
classmethod pack(graphs)[source]

Pack a list of graphs into a PackedGraph object.

Parameters

graphs (list of Graph) – list of graphs

Returns

PackedGraph

repeat(count)[source]

Repeat this graph.

Parameters

count (int) – number of repetitions

Returns

PackedGraph

split(node2graph)[source]

Split a graph into multiple disconnected graphs.

Parameters

node2graph (array_like) – ID of the graph each node belongs to

Returns

PackedGraph

subgraph(index)[source]

Return a subgraph based on the specified nodes. Equivalent to node_mask(index, compact=True).

Parameters

index (array_like) – node index

Returns

Graph

undirected(add_inverse=False)[source]

Flip all the edges to create an undirected graph.

For knowledge graphs, the flipped edges can either have the original relation or an inverse relation. The inverse relation for relation \(r\) is defined as \(|R| + r\).

Parameters

add_inverse (bool, optional) – whether to use inverse relations for flipped edges

visualize(title=None, save_file=None, figure_size=3, 3, ax=None, layout='spring')[source]

Visualize this graph with matplotlib.

Parameters
  • title (str, optional) – title for this graph

  • save_file (str, optional) – png or pdf file to save visualization. If not provided, show the figure in window.

  • figure_size (tuple of int, optional) – width and height of the figure

  • ax (matplotlib.axes.Axes, optional) – axis to plot the figure

  • layout (str, optional) – graph layout

property adjacency

Adjacency matrix of this graph.

If num_relation is specified, a sparse tensor of shape \((|V|, |V|, num\_relation)\) will be returned. Otherwise, a sparse tensor of shape \((|V|, |V|)\) will be returned.

property batch_size

Batch size.

property degree_in

Weighted number of edges containing each node as input.

Note this is the out-degree in graph theory.

property degree_out

Weighted number of edges containing each node as output.

Note this is the in-degree in graph theory.

property device

Device.

property edge2graph

Edge id to graph id mapping.

property edge_list

List of edges.

property edge_weight

Edge weights.

property node2graph

Node id to graph id mapping.

Molecule

class Molecule(edge_list=None, atom_type=None, bond_type=None, formal_charge=None, explicit_hs=None, chiral_tag=None, radical_electrons=None, atom_map=None, bond_stereo=None, stereo_atoms=None, **kwargs)[source]

Molecule graph with chemical features.

Parameters
  • edge_list (array_like, optional) – list of edges of shape \((|E|, 3)\). Each tuple is (node_in, node_out, bond_type).

  • atom_type (array_like, optional) – atom types of shape \((|V|,)\)

  • bond_type (array_like, optional) – bond types of shape \((|E|,)\)

  • formal_charge (array_like, optional) – formal charges of shape \((|V|,)\)

  • explicit_hs (array_like, optional) – number of explicit hydrogens of shape \((|V|,)\)

  • chiral_tag (array_like, optional) – chirality tags of shape \((|V|,)\)

  • radical_electrons (array_like, optional) – number of radical electrons of shape \((|V|,)\)

  • atom_map (array_likeb optional) – atom mappings of shape \((|V|,)\)

  • bond_stereo (array_like, optional) – bond stereochem of shape \((|E|,)\)

  • stereo_atoms (array_like, optional) – ids of stereo atoms of shape \((|E|,)\)

packed_type

alias of PackedMolecule

edge_mask(index)[source]

Return a masked graph based on the specified edges.

This function can also be used to re-order the edges.

Parameters

index (array_like) – edge index

Returns

Graph

classmethod from_molecule(mol, node_feature='default', edge_feature='default', graph_feature=None, with_hydrogen=False, kekulize=False)[source]

Create a molecule from a RDKit object.

Parameters
  • mol (rdchem.Mol) – molecule

  • node_feature (str or list of str, optional) – node features to extract

  • edge_feature (str or list of str, optional) – edge features to extract

  • graph_feature (str or list of str, optional) – graph features to extract

  • with_hydrogen (bool, optional) – store hydrogens in the molecule graph. By default, hydrogens are dropped

  • kekulize (bool, optional) – convert aromatic bonds to single/double bonds. Note this only affects the relation in edge_list. For bond_type, aromatic bonds are always stored explicitly. By default, aromatic bonds are stored.

classmethod from_smiles(smiles, node_feature='default', edge_feature='default', graph_feature=None, with_hydrogen=False, kekulize=False)[source]

Create a molecule from a SMILES string.

Parameters
  • smiles (str) – SMILES string

  • node_feature (str or list of str, optional) – node features to extract

  • edge_feature (str or list of str, optional) – edge features to extract

  • graph_feature (str or list of str, optional) – graph features to extract

  • with_hydrogen (bool, optional) – store hydrogens in the molecule graph. By default, hydrogens are dropped

  • kekulize (bool, optional) – convert aromatic bonds to single/double bonds. Note this only affects the relation in edge_list. For bond_type, aromatic bonds are always stored explicitly. By default, aromatic bonds are stored.

ion_to_molecule()[source]

Convert ions to molecules by adjusting hydrogens and electrons.

Note [N+] will not be converted.

node_mask(index, compact=False)[source]

Return a masked graph based on the specified nodes.

This function can also be used to re-order the nodes.

Parameters
  • index (array_like) – node index

  • compact (bool, optional) – compact node ids or not

Returns

Graph

Examples:

>>> graph = data.Graph.from_dense(torch.eye(3))
>>> assert graph.node_mask([1, 2]).adjacency.shape == (3, 3)
>>> assert graph.node_mask([1, 2], compact=True).adjacency.shape == (2, 2)
to_molecule(ignore_error=False)[source]

Return a RDKit object of this molecule.

Parameters

ignore_error (bool, optional) – if true, return None for illegal molecules. Otherwise, raise an exception.

Returns

rdchem.Mol

to_scaffold(chirality=False)[source]

Return a scaffold SMILES string of this molecule.

Parameters

chirality (bool, optional) – consider chirality in the scaffold or not

Returns

str

to_smiles(isomeric=True, atom_map=True, canonical=False)[source]

Return a SMILES string of this molecule.

Parameters
  • isomeric (bool, optional) – keep isomeric information or not

  • atom_map (bool, optional) – keep atom mapping or not

  • canonical (bool, optional) – if true, return the canonical form of smiles

Returns

str

undirected(add_inverse=False)[source]

Flip all the edges to create an undirected graph.

For knowledge graphs, the flipped edges can either have the original relation or an inverse relation. The inverse relation for relation \(r\) is defined as \(|R| + r\).

Parameters

add_inverse (bool, optional) – whether to use inverse relations for flipped edges

visualize(title=None, save_file=None, figure_size=3, 3, ax=None, atom_map=False)[source]

Visualize this molecule with matplotlib.

Parameters
  • title (str, optional) – title for this molecule

  • save_file (str, optional) – png or pdf file to save visualization. If not provided, show the figure in window.

  • figure_size (tuple of int, optional) – width and height of the figure

  • ax (matplotlib.axes.Axes, optional) – axis to plot the figure

  • atom_map (bool, optional) – visualize atom mapping or not

property is_valid

A coarse implementation of valence check.

property num_atom

Number of atoms.

property num_bond

Number of bonds.

PackedGraph

class PackedGraph(edge_list=None, edge_weight=None, num_nodes=None, num_edges=None, num_relation=None, offsets=None, **kwargs)[source]

Container for sparse graphs with variadic sizes.

To create a PackedGraph from Graph objects

>>> batch = data.Graph.pack(graphs)

To retrieve Graph objects from a PackedGraph

>>> graphs = batch.unpack()
Parameters
  • edge_list (array_like, optional) – list of edges of shape \((|E|, 2)\) or \((|E|, 3)\). Each tuple is (node_in, node_out) or (node_in, node_out, relation).

  • edge_weight (array_like, optional) – edge weights of shape \((|E|,)\)

  • num_nodes (array_like, optional) – number of nodes in each graph By default, it will be inferred from the largest id in edge_list

  • num_edges (array_like, optional) – number of edges in each graph

  • num_relation (int, optional) – number of relations

  • node_feature (array_like, optional) – node features of shape \((|V|, ...)\)

  • edge_feature (array_like, optional) – edge features of shape \((|E|, ...)\)

  • offsets (array_like, optional) – node id offsets of shape \((|E|,)\). If not provided, nodes in edge_list should be relative index, i.e., the index in each graph. If provided, nodes in edge_list should be absolute index, i.e., the index in the packed graph.

unpacked_type

alias of Graph

clone()[source]

Clone this packed graph.

cpu()[source]

Return a copy of this packed graph in CPU memory.

This is a non-op if the graph is already in CPU memory.

cuda(*args, **kwargs)[source]

Return a copy of this packed graph in CUDA memory.

This is a non-op if the graph is already on the correct device.

detach()[source]

Detach this packed graph.

edge_mask(index)[source]

Return a masked packed graph based on the specified edges.

Parameters

index (array_like) – edge index

Returns

PackedGraph

full()[source]

Return a pack of fully connected graphs.

This is useful for computing node-pair-wise features. The computation can be implemented as message passing over a fully connected graph.

Returns

PackedGraph

get_item(index)[source]

Get the i-th graph from this packed graph.

Parameters

index (int) – graph index

Returns

Graph

graph_mask(index, compact=False)[source]

Return a masked packed graph based on the specified graphs.

This function can also be used to re-order the graphs.

Parameters
  • index (array_like) – graph index

  • compact (bool, optional) – compact graph ids or not

Returns

PackedGraph

merge(graph2graph)[source]

Merge multiple graphs into a single graph.

Parameters

graph2graph (array_like) – ID of the new graph each graph belongs to

node_mask(index, compact=False)[source]

Return a masked packed graph based on the specified nodes.

Note the compact option is only applied to node ids but not graph ids. To generate compact graph ids, use subbatch().

Parameters
  • index (array_like) – node index

  • compact (bool, optional) – compact node ids or not

Returns

PackedGraph

repeat(count)[source]

Repeat this packed graph. This function behaves similarly to torch.Tensor.repeat.

Parameters

count (int) – number of repetitions

Returns

PackedGraph

repeat_interleave(repeats)[source]

Repeat this packed graph. This function behaves similarly to torch.repeat_interleave.

Parameters

repeats (Tensor or int) – number of repetitions for each graph

Returns

PackedGraph

subbatch(index)[source]

Return a subbatch based on the specified graphs. Equivalent to graph_mask(index, compact=True).

Parameters

index (array_like) – graph index

Returns

PackedGraph

undirected(add_inverse=False)[source]

Flip all the edges to create undirected graphs.

For knowledge graphs, the flipped edges can either have the original relation or an inverse relation. The inverse relation for relation \(r\) is defined as \(|R| + r\).

Parameters

add_inverse (bool, optional) – whether to use inverse relations for flipped edges

unpack()[source]

Unpack this packed graph into a list of graphs.

Returns

list of Graph

unpack_data(data, type='auto')[source]

Unpack node or edge data according to the packed graph.

Parameters
  • data (Tensor) – data to unpack

  • type (str, optional) – data type. Can be auto, node, or edge.

Returns

list of Tensor

visualize(titles=None, save_file=None, figure_size=3, 3, layout='spring', num_row=None, num_col=None)[source]

Visualize the packed graphs with matplotlib.

Parameters
  • titles (list of str, optional) – title for each graph. Default is the ID of each graph.

  • save_file (str, optional) – png or pdf file to save visualization. If not provided, show the figure in window.

  • figure_size (tuple of int, optional) – width and height of the figure

  • layout (str, optional) – graph layout

  • num_row (int, optional) – number of rows in the figure

  • num_col (int, optional) – number of columns in the figure

property batch_size

Batch size.

property edge2graph

Edge id to graph id mapping.

property node2graph

Node id to graph id mapping.

PackedMolecule

class PackedMolecule(edge_list=None, atom_type=None, bond_type=None, num_nodes=None, num_edges=None, offsets=None, **kwargs)[source]

Container for molecules with variadic sizes.

Parameters
  • edge_list (array_like, optional) – list of edges of shape \((|E|, 3)\). Each tuple is (node_in, node_out, bond_type).

  • atom_type (array_like, optional) – atom types of shape \((|V|,)\)

  • bond_type (array_like, optional) – bond types of shape \((|E|,)\)

  • num_nodes (array_like, optional) – number of nodes in each graph By default, it will be inferred from the largest id in edge_list

  • num_edges (array_like, optional) – number of edges in each graph

  • num_relation (int, optional) – number of relations

  • offsets (array_like, optional) – node id offsets of shape \((|E|,)\). If not provided, nodes in edge_list should be relative index, i.e., the index in each graph. If provided, nodes in edge_list should be absolute index, i.e., the index in the packed graph.

unpacked_type

alias of Molecule

edge_mask(index)[source]

Return a masked packed graph based on the specified edges.

Parameters

index (array_like) – edge index

Returns

PackedGraph

classmethod from_molecule(mols, node_feature='default', edge_feature='default', graph_feature=None, with_hydrogen=False, kekulize=False)[source]

Create a packed molecule from a list of RDKit objects.

Parameters
  • mols (list of rdchem.Mol) – molecules

  • node_feature (str or list of str, optional) – node features to extract

  • edge_feature (str or list of str, optional) – edge features to extract

  • graph_feature (str or list of str, optional) – graph features to extract

  • with_hydrogen (bool, optional) – store hydrogens in the molecule graph. By default, hydrogens are dropped

  • kekulize (bool, optional) – convert aromatic bonds to single/double bonds. Note this only affects the relation in edge_list. For bond_type, aromatic bonds are always stored explicitly. By default, aromatic bonds are stored.

classmethod from_smiles(smiles_list, node_feature='default', edge_feature='default', graph_feature=None, with_hydrogen=False, kekulize=False)[source]

Create a packed molecule from a list of SMILES strings.

Parameters
  • smiles_list (str) – list of SMILES strings

  • node_feature (str or list of str, optional) – node features to extract

  • edge_feature (str or list of str, optional) – edge features to extract

  • graph_feature (str or list of str, optional) – graph features to extract

  • with_hydrogen (bool, optional) – store hydrogens in the molecule graph. By default, hydrogens are dropped

  • kekulize (bool, optional) – convert aromatic bonds to single/double bonds. Note this only affects the relation in edge_list. For bond_type, aromatic bonds are always stored explicitly. By default, aromatic bonds are stored.

ion_to_molecule()[source]

Convert ions to molecules by adjusting hydrogens and electrons.

Note [N+] will not be converted.

node_mask(index, compact=False)[source]

Return a masked packed graph based on the specified nodes.

Note the compact option is only applied to node ids but not graph ids. To generate compact graph ids, use subbatch().

Parameters
  • index (array_like) – node index

  • compact (bool, optional) – compact node ids or not

Returns

PackedGraph

to_molecule(ignore_error=False)[source]

Return a list of RDKit objects.

Parameters

ignore_error (bool, optional) – if true, return None for illegal molecules. Otherwise, raise an exception.

Returns

list of rdchem.Mol

to_smiles(isomeric=True, atom_map=True, canonical=False)[source]

Return a list of SMILES strings.

Parameters
  • isomeric (bool, optional) – keep isomeric information or not

  • atom_map (bool, optional) – keep atom mapping or not

  • canonical (bool, optional) – if true, return the canonical form of smiles

Returns

list of str

undirected(add_inverse=False)[source]

Flip all the edges to create undirected graphs.

For knowledge graphs, the flipped edges can either have the original relation or an inverse relation. The inverse relation for relation \(r\) is defined as \(|R| + r\).

Parameters

add_inverse (bool, optional) – whether to use inverse relations for flipped edges

visualize(titles=None, save_file=None, figure_size=3, 3, num_row=None, num_col=None, atom_map=False)[source]

Visualize the packed molecules with matplotlib.

Parameters
  • titles (list of str, optional) – title for each molecule. Default is the ID of each molecule.

  • save_file (str, optional) – png or pdf file to save visualization. If not provided, show the figure in window.

  • figure_size (tuple of int, optional) – width and height of the figure

  • num_row (int, optional) – number of rows in the figure

  • num_col (int, optional) – number of columns in the figure

  • atom_map (bool, optional) – visualize atom mapping or not

property is_valid

A coarse implementation of valence check.

Dictionary

class Dictionary(keys, values, hash=None)[source]

Dictionary for mapping keys to values.

This class has the same behavior as the built-in dict, except it operates on tensors and support batching.

Example:

>>> keys = torch.tensor([[0, 0], [1, 1], [2, 2]])
>>> values = torch.tensor([[0, 1], [1, 2], [2, 3]])
>>> d = data.Dictionary(keys, values)
>>> assert (d[[[0, 0], [2, 2]]] == values[[0, 2]]).all()
>>> assert (d.has_key([[0, 1], [1, 2]]) == torch.tensor([False, False])).all()
Parameters
  • keys (LongTensor) – keys of shape \((N,)\) or \((N, D)\)

  • values (Tensor) – values of shape \((N, ...)\)

  • hash (PerfectHash, optional) – hash function for keys

cpu()[source]

Return a copy of this dictionary in CPU memory.

This is a non-op if the dictionary is already in CPU memory.

cuda(*args, **kwargs)[source]

Return a copy of this dictionary in CUDA memory.

This is a non-op if the dictionary is already in CUDA memory.

get(keys, default=None)[source]

Return the value for each key if the key is in the dictionary, otherwise the default value is returned.

Parameters
  • keys (LongTensor) – keys of arbitrary shape

  • default (int or Tensor, optional) – default return value. By default, 0 is used.

has_key(keys)[source]

Check whether each key exists in the dictionary.

to_dict()[source]

Return a built-in dict object of this dictionary.

property device

Device.

Datasets

KnowledgeGraphDataset

class KnowledgeGraphDataset[source]

Knowledge graph dataset.

The whole dataset contains one knowledge graph.

load_triplet(triplets, entity_vocab=None, relation_vocab=None, inv_entity_vocab=None, inv_relation_vocab=None)[source]

Load the dataset from triplets. The mapping between indexes and tokens is specified through either vocabularies or inverse vocabularies.

Parameters
  • triplets (array_like) – triplets of shape \((n, 3)\)

  • entity_vocab (dict of str, optional) – maps entity indexes to tokens

  • relation_vocab (dict of str, optional) – maps relation indexes to tokens

  • inv_entity_vocab (dict of str, optional) – maps tokens to entity indexes

  • inv_relation_vocab (dict of str, optional) – maps tokens to relation indexes

load_tsv(tsv_file, verbose=0)[source]

Load the dataset from a tsv file.

Parameters
  • tsv_file (str) – file name

  • verbose (int, optional) – output verbose level

load_tsvs(tsv_files, verbose=0)[source]

Load the dataset from multiple tsv files.

Parameters
  • tsv_files (list of str) – list of file names

  • verbose (int, optional) – output verbose level

property num_entity

Number of entities.

property num_relation

Number of relations.

property num_triplet

Number of triplets.

MoleculeDataset

class MoleculeDataset[source]

Molecule dataset.

Each sample contains a molecule graph, and any number of prediction targets.

load_csv(csv_file, smiles_field='smiles', target_fields=None, verbose=0, transform=None, lazy=False, node_feature='default', edge_feature='default', graph_feature=None, with_hydrogen=False, kekulize=False)[source]

Load the dataset from a csv file.

Parameters
  • csv_file (str) – file name

  • smiles_field (str, optional) – name of SMILES column in the table. Use None if there is no SMILES column.

  • target_fields (list of str, optional) – name of target columns in the table. Default is all columns other than the SMILES column.

  • verbose (int, optional) – output verbose level

  • transform (Callable, optional) – data transformation function

  • lazy (bool, optional) – if lazy mode is used, the molecules are processed in the dataloader. This may slow down the data loading process, but save a lot of CPU memory and dataset loading time.

  • node_feature (str or list of str, optional) – node features to extract

  • edge_feature (str or list of str, optional) – edge features to extract

  • graph_feature (str or list of str, optional) – graph features to extract

  • with_hydrogen (bool, optional) – store hydrogens in the molecule graph. By default, hydrogens are dropped

  • kekulize (bool, optional) – convert aromatic bonds to single/double bonds. Note this only affects the relation in edge_list. For bond_type, aromatic bonds are always stored explicitly. By default, aromatic bonds are stored.

load_smiles(smiles_list, targets, transform=None, lazy=False, verbose=0, node_feature='default', edge_feature='default', graph_feature=None, with_hydrogen=False, kekulize=False)[source]

Load the dataset from SMILES and targets.

Parameters
  • smiles_list (list of str) – SMILES strings

  • targets (dict of list) – prediction targets

  • transform (Callable, optional) – data transformation function

  • lazy (bool, optional) – if lazy mode is used, the molecules are processed in the dataloader. This may slow down the data loading process, but save a lot of CPU memory and dataset loading time.

  • verbose (int, optional) – output verbose level

  • node_feature (str or list of str, optional) – node features to extract

  • edge_feature (str or list of str, optional) – edge features to extract

  • graph_feature (str or list of str, optional) – graph features to extract

  • with_hydrogen (bool, optional) – store hydrogens in the molecule graph. By default, hydrogens are dropped

  • kekulize (bool, optional) – convert aromatic bonds to single/double bonds. Note this only affects the relation in edge_list. For bond_type, aromatic bonds are always stored explicitly. By default, aromatic bonds are stored.

property atom_types

All atom types.

property bond_types

All bond types.

property edge_feature_dim

Dimension of edge features.

property node_feature_dim

Dimension of node features.

property num_atom_type

Number of different atom types.

property num_bond_type

Number of different bond types.

property tasks

List of tasks.

NodeClassificationDataset

class NodeClassificationDataset[source]

Node classification dataset.

The whole dataset contains one graph, where each node has its own node feature and label.

load_tsv(node_file, edge_file, verbose=0)[source]

Load the edge list from a tsv file.

Parameters
  • node_file (str) – node feature and label file

  • edge_file (str) – edge list file

  • verbose (int, optional) – output verbose level

property node_feature_dim

Dimension of node features.

property num_edge

Number of edges.

property num_node

Number of nodes.

ReactionDataset

class ReactionDataset[source]

Chemical reaction dataset.

Each sample contains two molecule graphs, and any number of prediction targets.

load_smiles(smiles_list, targets, transform=None, verbose=0, node_feature='default', edge_feature='default', graph_feature=None, with_hydrogen=False, kekulize=False)[source]

Load the dataset from SMILES and targets.

Parameters
  • smiles_list (list of str) – SMILES strings

  • targets (dict of list) – prediction targets

  • transform (Callable, optional) – data transformation function

  • verbose (int, optional) – output verbose level

  • node_feature (str or list of str, optional) – node features to extract

  • edge_feature (str or list of str, optional) – edge features to extract

  • graph_feature (str or list of str, optional) – graph features to extract

  • with_hydrogen (bool, optional) – store hydrogens in the molecule graph. By default, hydrogens are dropped

  • kekulize (bool, optional) – convert aromatic bonds to single/double bonds. Note this only affects the relation in edge_list. For bond_type, aromatic bonds are always stored explicitly. By default, aromatic bonds are stored.

property atom_types

All atom types.

property bond_types

All bond types.

property edge_feature_dim

Dimension of edge features.

property node_feature_dim

Dimension of node features.

property num_atom_type

Number of different atom types.

property num_bond_type

Number of different bond types.

SemiSupervised

class SemiSupervised(dataset, indices)[source]

Semi-supervised dataset.

Parameters
  • dataset (Dataset) – supervised dataset

  • indices (list of int) – sample indices to keep supervision

Data Processing

DataLoader

class DataLoader(dataset, batch_size=1, shuffle=False, sampler=None, batch_sampler=None, num_workers=0, collate_fn=<function graph_collate>, **kwargs)[source]

Extended data loader for batching graph structured data.

See torch.utils.data.DataLoader for more details.

Parameters
  • dataset (Dataset) – dataset from which to load the data

  • batch_size (int, optional) – how many samples per batch to load

  • shuffle (bool, optional) – set to True to have the data reshuffled at every epoch

  • sampler (Sampler, optional) – sampler that draws single sample from the dataset

  • batch_sampler (Sampler, optional) – sampler that draws a mini-batch of data from the dataset

  • num_workers (int, optional) – how many subprocesses to use for data loading

  • collate_fn (callable, optional) – merge a list of samples into a mini-batch

  • kwargs – keyword arguments for torch.utils.data.DataLoader

Dataset Split Methods

graph_collate(batch)[source]

Convert any list of same nested container into a container of tensors.

For instances of data.Graph, they are collated by data.Graph.pack.

Parameters

batch (list) – list of samples with the same nested container

key_split(dataset, keys, lengths=None, key_lengths=None)[source]
ordered_scaffold_split(dataset, lengths, chirality=True)[source]

Split a dataset into new datasets with non-overlapping scaffolds and sorted w.r.t. number of each scaffold.

Parameters
  • dataset (Dataset) – dataset to split

  • lengths (list of int) – expected length for each split. Note the results may be different in length due to rounding.

scaffold_split(dataset, lengths)[source]

Randomly split a dataset into new datasets with non-overlapping scaffolds.

Parameters
  • dataset (Dataset) – dataset to split

  • lengths (list of int) – expected length for each split. Note the results may be different in length due to rounding.

semisupervised(dataset, length)[source]

Randomly construct a semi-supervised dataset based on the given length.

Parameters
  • dataset (Dataset) – supervised dataset

  • length (int) – length of supervised data to keep

Feature Functions

Atom Features

atom_default(atom)[source]

Default atom feature.

Features:

GetSymbol(): one-hot embedding for the atomic symbol

GetChiralTag(): one-hot embedding for atomic chiral tag

GetTotalDegree(): one-hot embedding for the degree of the atom in the molecule including Hs

GetFormalCharge(): one-hot embedding for the number of formal charges in the molecule

GetTotalNumHs(): one-hot embedding for the total number of Hs (explicit and implicit) on the atom

GetNumRadicalElectrons(): one-hot embedding for the number of radical electrons on the atom

GetHybridization(): one-hot embedding for the atom’s hybridization

GetIsAromatic(): whether the atom is aromatic

IsInRing(): whether the atom is in a ring

atom_position(): the 3D position of the atom

atom_symbol(atom)[source]

Symbol atom feature.

Features:

GetSymbol(): one-hot embedding for the atomic symbol

atom_position(atom)[source]

Atom position. Return 3D position if available, otherwise 2D position is returned.

atom_property_prediction(atom)[source]

Property prediction atom feature.

Features:

GetSymbol(): one-hot embedding for the atomic symbol

GetDegree(): one-hot embedding for the degree of the atom in the molecule

GetTotalNumHs(): one-hot embedding for the total number of Hs (explicit and implicit) on the atom

GetTotalValence(): one-hot embedding for the total valence (explicit + implicit) of the atom

GetFormalCharge(): one-hot embedding for the number of formal charges in the molecule

GetIsAromatic(): whether the atom is aromatic

atom_explicit_property_prediction(atom)[source]

Explicit property prediction atom feature.

Features:

GetSymbol(): one-hot embedding for the atomic symbol

GetDegree(): one-hot embedding for the degree of the atom in the molecule

GetTotalValence(): one-hot embedding for the total valence (explicit + implicit) of the atom

GetFormalCharge(): one-hot embedding for the number of formal charges in the molecule

GetIsAromatic(): whether the atom is aromatic

atom_pretrain(atom)[source]

Atom feature for pretraining.

Features:

GetSymbol(): one-hot embedding for the atomic symbol

GetChiralTag(): one-hot embedding for atomic chiral tag

atom_center_identification(atom)[source]

Reaction center identification atom feature.

Features:

GetSymbol(): one-hot embedding for the atomic symbol

GetTotalNumHs(): one-hot embedding for the total number of Hs (explicit and implicit) on the atom

GetTotalDegree(): one-hot embedding for the degree of the atom in the molecule including Hs

GetTotalValence(): one-hot embedding for the total valence (explicit + implicit) of the atom

GetIsAromatic(): whether the atom is aromatic

IsInRing(): whether the atom is in a ring

atom_synthon_completion(atom)[source]

Synthon completion atom feature.

Features:

GetSymbol(): one-hot embedding for the atomic symbol

GetTotalNumHs(): one-hot embedding for the total number of Hs (explicit and implicit) on the atom

GetTotalDegree(): one-hot embedding for the degree of the atom in the molecule including Hs

IsInRing(): whether the atom is in a ring

IsInRingSize(3, 4, 5, 6): whether the atom is in a ring of a particular size

IsInRing() and not IsInRingSize(3, 4, 5, 6): whether the atom is in a ring and not in a ring of 3, 4, 5, 6

Bond Features

bond_default(bond)[source]

Default bond feature.

Features:

GetBondType(): one-hot embedding for the type of the bond

GetBondDir(): one-hot embedding for the direction of the bond

GetStereo(): one-hot embedding for the stereo configuration of the bond

GetIsConjugated(): whether the bond is considered to be conjugated

bond_length: the length of the bond

bond_length(bond)[source]

Bond length

bond_property_prediction(bond)[source]

Property prediction bond feature.

Features:

GetBondType(): one-hot embedding for the type of the bond

GetIsConjugated(): whether the bond is considered to be conjugated

IsInRing(): whether the bond is in a ring

bond_pretrain(bond)[source]

Bond feature for pretraining.

Features:

GetBondType(): one-hot embedding for the type of the bond

GetBondDir(): one-hot embedding for the direction of the bond

Molecule Features

molecule_default(mol)[source]

Default molecule feature.

ExtendedConnectivityFingerprint(mol, radius=2, length=1024)[source]

Extended Connectivity Fingerprint molecule feature.

Features:

GetMorganFingerprintAsBitVect(): a Morgan fingerprint for a molecule as a bit vector

ECFP()

alias of torchdrug.data.feature.ExtendedConnectivityFingerprint

Element Constants

Element constants are provided for convenient manipulating of atom types. The atomic numbers can be accessed by uppercased element names at the root of the package. For example, we can get the carbon scaffold of a molecule with the following code.

import torchdrug as td
from torchdrug import data

smiles = "CC1=C(C=C(C=C1[N+](=O)[O-])[N+](=O)[O-])[N+](=O)[O-]"
mol = data.Molecule.from_smiles(smiles)
scaffold = mol.subgraph(mol.atom_type == td.CARBON)
mol.visualize()
scaffold.visualize()
../_images/tnt.png ../_images/tnt_carbon_scaffold.png

There are also 2 constant arrays that map atomic numbers to element names. td.ATOM_NAME[i] returns the full name, while td.ATOM_SYMBOL[i] returns the abbreviated chemical symbol for atomic number i.

For a full list of elements, please refer to the perodic table.