import math
import warnings
from copy import copy
from collections.abc import Sequence
from matplotlib import pyplot as plt
from rdkit import Chem
from rdkit.Chem.Scaffolds import MurckoScaffold
import torch
from torch_scatter import scatter_add, scatter_min
from torchdrug import utils
from torchdrug.data import constant, Graph, PackedGraph
from torchdrug.core import Registry as R
from torchdrug.data.rdkit import draw
from torchdrug.utils import pretty
plt.switch_backend("agg")
[docs]class Molecule(Graph):
"""
Molecules with predefined chemical features.
By nature, molecules are undirected graphs. Each bond is stored as two directed edges in this class.
.. warning::
This class doesn't enforce any order on edges.
Parameters:
edge_list (array_like, optional): list of edges of shape :math:`(|E|, 3)`.
Each tuple is (node_in, node_out, bond_type).
atom_type (array_like, optional): atom types of shape :math:`(|V|,)`
bond_type (array_like, optional): bond types of shape :math:`(|E|,)`
formal_charge (array_like, optional): formal charges of shape :math:`(|V|,)`
explicit_hs (array_like, optional): number of explicit hydrogens of shape :math:`(|V|,)`
chiral_tag (array_like, optional): chirality tags of shape :math:`(|V|,)`
radical_electrons (array_like, optional): number of radical electrons of shape :math:`(|V|,)`
atom_map (array_likeb optional): atom mappings of shape :math:`(|V|,)`
bond_stereo (array_like, optional): bond stereochem of shape :math:`(|E|,)`
stereo_atoms (array_like, optional): ids of stereo atoms of shape :math:`(|E|,)`
"""
bond2id = {"SINGLE": 0, "DOUBLE": 1, "TRIPLE": 2, "AROMATIC": 3}
atom2valence = {1: 1, 5: 3, 6: 4, 7: 3, 8: 2, 9: 1, 14: 4, 15: 5, 16: 6, 17: 1, 35: 1, 53: 7}
bond2valence = [1, 2, 3, 1.5]
id2bond = {v: k for k, v in bond2id.items()}
empty_mol = Chem.MolFromSmiles("")
dummy_mol = Chem.MolFromSmiles("CC")
def __init__(self, edge_list=None, atom_type=None, bond_type=None, atom_feature=None, bond_feature=None,
mol_feature=None, formal_charge=None, explicit_hs=None, chiral_tag=None, radical_electrons=None,
atom_map=None, bond_stereo=None, stereo_atoms=None, node_position=None, **kwargs):
if "num_relation" not in kwargs:
kwargs["num_relation"] = len(self.bond2id)
super(Molecule, self).__init__(edge_list=edge_list, **kwargs)
atom_type, bond_type = self._standarize_atom_bond(atom_type, bond_type)
formal_charge = self._standarize_attribute(formal_charge, self.num_node)
explicit_hs = self._standarize_attribute(explicit_hs, self.num_node)
chiral_tag = self._standarize_attribute(chiral_tag, self.num_node)
radical_electrons = self._standarize_attribute(radical_electrons, self.num_node)
atom_map = self._standarize_attribute(atom_map, self.num_node)
bond_stereo = self._standarize_attribute(bond_stereo, self.num_edge)
stereo_atoms = self._standarize_attribute(stereo_atoms, (self.num_edge, 2))
if node_position is not None:
node_position = torch.as_tensor(node_position, dtype=torch.float, device=self.device)
with self.atom():
if atom_feature is not None:
self.atom_feature = torch.as_tensor(atom_feature, device=self.device)
self.atom_type = atom_type
self.formal_charge = formal_charge
self.explicit_hs = explicit_hs
self.chiral_tag = chiral_tag
self.radical_electrons = radical_electrons
self.atom_map = atom_map
if node_position is not None:
self.node_position = node_position
with self.bond():
if bond_feature is not None:
self.bond_feature = torch.as_tensor(bond_feature, device=self.device)
self.bond_type = bond_type
self.bond_stereo = bond_stereo
self.stereo_atoms = stereo_atoms
with self.mol():
if mol_feature is not None:
self.mol_feature = torch.as_tensor(mol_feature, device=self.device)
def _standarize_atom_bond(self, atom_type, bond_type):
if atom_type is None:
raise ValueError("`atom_type` should be provided")
if bond_type is None:
raise ValueError("`bond_type` should be provided")
atom_type = torch.as_tensor(atom_type, dtype=torch.long, device=self.device)
bond_type = torch.as_tensor(bond_type, dtype=torch.long, device=self.device)
return atom_type, bond_type
def _standarize_attribute(self, attribute, size, dtype=torch.long, default=0):
if attribute is not None:
attribute = torch.as_tensor(attribute, dtype=dtype, device=self.device)
else:
if isinstance(size, torch.Tensor):
size = size.tolist()
if not isinstance(size, Sequence):
size = [size]
attribute = torch.full(size, default, dtype=dtype, device=self.device)
return attribute
@classmethod
def _standarize_option(cls, option):
if option is None:
option = []
elif isinstance(option, str):
option = [option]
return option
def _check_no_stereo(self):
if (self.bond_stereo > 0).any():
warnings.warn("Try to apply masks on molecules with stereo bonds. This may produce invalid molecules. "
"To discard stereo information, call `mol.bond_stereo[:] = 0` before applying masks.")
def _maybe_num_node(self, edge_list):
if len(edge_list):
return edge_list[:, :2].max().item() + 1
else:
return 0
[docs] @classmethod
@utils.deprecated_alias(node_feature="atom_feature", edge_feature="bond_feature", graph_feature="mol_feature")
def from_molecule(cls, mol, atom_feature="default", bond_feature="default", mol_feature=None,
with_hydrogen=False, kekulize=False):
"""
Create a molecule from an RDKit object.
Parameters:
mol (rdchem.Mol): molecule
atom_feature (str or list of str, optional): atom features to extract
bond_feature (str or list of str, optional): bond features to extract
mol_feature (str or list of str, optional): molecule features to extract
with_hydrogen (bool, optional): store hydrogens in the molecule graph.
By default, hydrogens are dropped
kekulize (bool, optional): convert aromatic bonds to single/double bonds.
Note this only affects the relation in ``edge_list``.
For ``bond_type``, aromatic bonds are always stored explicitly.
By default, aromatic bonds are stored.
"""
if mol is None:
mol = cls.empty_mol
# some RDKit operations are in-place
# copy the object to avoid undesired behavior in the caller
mol = copy(mol)
if with_hydrogen:
mol = Chem.AddHs(mol)
if kekulize:
Chem.Kekulize(mol)
atom_feature = cls._standarize_option(atom_feature)
bond_feature = cls._standarize_option(bond_feature)
mol_feature = cls._standarize_option(mol_feature)
atom_type = []
formal_charge = []
explicit_hs = []
chiral_tag = []
radical_electrons = []
atom_map = []
_atom_feature = []
dummy_atom = copy(cls.dummy_mol).GetAtomWithIdx(0)
atoms = [mol.GetAtomWithIdx(i) for i in range(mol.GetNumAtoms())] + [dummy_atom]
if mol.GetNumConformers() > 0:
node_position = torch.tensor(mol.GetConformer().GetPositions())
else:
node_position = None
for atom in atoms:
atom_type.append(atom.GetAtomicNum())
formal_charge.append(atom.GetFormalCharge())
explicit_hs.append(atom.GetNumExplicitHs())
chiral_tag.append(atom.GetChiralTag())
radical_electrons.append(atom.GetNumRadicalElectrons())
atom_map.append(atom.GetAtomMapNum())
feature = []
for name in atom_feature:
func = R.get("features.atom.%s" % name)
feature += func(atom)
_atom_feature.append(feature)
atom_type = torch.tensor(atom_type)[:-1]
atom_map = torch.tensor(atom_map)[:-1]
formal_charge = torch.tensor(formal_charge)[:-1]
explicit_hs = torch.tensor(explicit_hs)[:-1]
chiral_tag = torch.tensor(chiral_tag)[:-1]
radical_electrons = torch.tensor(radical_electrons)[:-1]
if len(atom_feature) > 0:
_atom_feature = torch.tensor(_atom_feature)[:-1]
else:
_atom_feature = None
edge_list = []
bond_type = []
bond_stereo = []
stereo_atoms = []
_bond_feature = []
dummy_bond = copy(cls.dummy_mol).GetBondWithIdx(0)
bonds = [mol.GetBondWithIdx(i) for i in range(mol.GetNumBonds())] + [dummy_bond]
for bond in bonds:
type = str(bond.GetBondType())
stereo = bond.GetStereo()
if stereo:
_atoms = [a for a in bond.GetStereoAtoms()]
else:
_atoms = [0, 0]
if type not in cls.bond2id:
continue
type = cls.bond2id[type]
h, t = bond.GetBeginAtomIdx(), bond.GetEndAtomIdx()
edge_list += [[h, t, type], [t, h, type]]
# always explicitly store aromatic bonds, no matter kekulize or not
if bond.GetIsAromatic():
type = cls.bond2id["AROMATIC"]
bond_type += [type, type]
bond_stereo += [stereo, stereo]
stereo_atoms += [_atoms, _atoms]
feature = []
for name in bond_feature:
func = R.get("features.bond.%s" % name)
feature += func(bond)
_bond_feature += [feature, feature]
edge_list = edge_list[:-2]
bond_type = torch.tensor(bond_type)[:-2]
bond_stereo = torch.tensor(bond_stereo)[:-2]
stereo_atoms = torch.tensor(stereo_atoms)[:-2]
if len(bond_feature) > 0:
_bond_feature = torch.tensor(_bond_feature)[:-2]
else:
_bond_feature = None
_mol_feature = []
for name in mol_feature:
func = R.get("features.molecule.%s" % name)
_mol_feature += func(mol)
if len(mol_feature) > 0:
_mol_feature = torch.tensor(_mol_feature)
else:
_mol_feature = None
num_relation = len(cls.bond2id) - 1 if kekulize else len(cls.bond2id)
return cls(edge_list, atom_type, bond_type,
formal_charge=formal_charge, explicit_hs=explicit_hs,
chiral_tag=chiral_tag, radical_electrons=radical_electrons, atom_map=atom_map,
bond_stereo=bond_stereo, stereo_atoms=stereo_atoms, node_position=node_position,
atom_feature=_atom_feature, bond_feature=_bond_feature, mol_feature=_mol_feature,
num_node=mol.GetNumAtoms(), num_relation=num_relation)
[docs] @classmethod
@utils.deprecated_alias(node_feature="atom_feature", edge_feature="bond_feature", graph_feature="mol_feature")
def from_smiles(cls, smiles, atom_feature="default", bond_feature="default", mol_feature=None,
with_hydrogen=False, kekulize=False):
"""
Create a molecule from a SMILES string.
Parameters:
smiles (str): SMILES string
atom_feature (str or list of str, optional): atom features to extract
bond_feature (str or list of str, optional): bond features to extract
mol_feature (str or list of str, optional): molecule features to extract
with_hydrogen (bool, optional): store hydrogens in the molecule graph.
By default, hydrogens are dropped
kekulize (bool, optional): convert aromatic bonds to single/double bonds.
Note this only affects the relation in ``edge_list``.
For ``bond_type``, aromatic bonds are always stored explicitly.
By default, aromatic bonds are stored.
"""
mol = Chem.MolFromSmiles(smiles)
if mol is None:
raise ValueError("Invalid SMILES `%s`" % smiles)
return cls.from_molecule(mol, atom_feature, bond_feature, mol_feature, with_hydrogen, kekulize)
[docs] def to_smiles(self, isomeric=True, atom_map=True, canonical=False):
"""
Return a SMILES string of this molecule.
Parameters:
isomeric (bool, optional): keep isomeric information or not
atom_map (bool, optional): keep atom mapping or not
canonical (bool, optional): if true, return the canonical form of smiles
Returns:
str
"""
mol = self.to_molecule()
if not atom_map:
for atom in mol.GetAtoms():
atom.SetAtomMapNum(0)
smiles = Chem.MolToSmiles(mol, isomericSmiles=isomeric)
if canonical:
smiles_set = set()
while smiles not in smiles_set:
smiles_set.add(smiles)
mol = Chem.MolFromSmiles(smiles)
smiles = Chem.MolToSmiles(mol, isomericSmiles=isomeric)
return smiles
[docs] def to_molecule(self, ignore_error=False):
"""
Return an RDKit object of this molecule.
Parameters:
ignore_error (bool, optional): if true, return ``None`` for illegal molecules.
Otherwise, raise an exception.
Returns:
rdchem.Mol
"""
mol = Chem.RWMol()
atom_type = self.atom_type.tolist()
bond_type = self.bond_type.tolist()
formal_charge = self.formal_charge.tolist()
explicit_hs = self.explicit_hs.tolist()
chiral_tag = self.chiral_tag.tolist()
radical_electrons = self.radical_electrons.tolist()
atom_map = self.atom_map.tolist()
bond_stereo = self.bond_stereo.tolist()
stereo_atoms = self.stereo_atoms.tolist()
if hasattr(self, "node_position"):
node_position = self.node_position.tolist()
conformer = Chem.Conformer()
else:
conformer = None
for i in range(self.num_node):
atom = Chem.Atom(atom_type[i])
atom.SetFormalCharge(formal_charge[i])
atom.SetNumExplicitHs(explicit_hs[i])
atom.SetChiralTag(Chem.ChiralType(chiral_tag[i]))
atom.SetNumRadicalElectrons(radical_electrons[i])
atom.SetNoImplicit(explicit_hs[i] > 0 or radical_electrons[i] > 0)
atom.SetAtomMapNum(atom_map[i])
if conformer:
conformer.SetAtomPosition(i, node_position[i])
mol.AddAtom(atom)
if conformer:
mol.AddConformer(conformer)
edge_list = self.edge_list.tolist()
for i in range(self.num_edge):
h, t, type = edge_list[i]
if h < t:
j = mol.AddBond(h, t, Chem.BondType.names[self.id2bond[type]])
bond = mol.GetBondWithIdx(j - 1)
bond.SetIsAromatic(bond_type[i] == self.bond2id["AROMATIC"])
bond.SetStereo(Chem.BondStereo(bond_stereo[i]))
j = 0
for i in range(self.num_edge):
h, t, type = edge_list[i]
if h < t:
if bond_stereo[i]:
bond = mol.GetBondWithIdx(j)
bond.SetStereoAtoms(*stereo_atoms[i])
j += 1
if ignore_error:
try:
with utils.no_rdkit_log():
mol.UpdatePropertyCache()
Chem.AssignStereochemistry(mol)
mol.ClearComputedProps()
mol.UpdatePropertyCache()
except:
mol = None
else:
mol.UpdatePropertyCache()
Chem.AssignStereochemistry(mol)
mol.ClearComputedProps()
mol.UpdatePropertyCache()
return mol
[docs] def ion_to_molecule(self):
"""
Convert ions to molecules by adjusting hydrogens and electrons.
Note [N+] will not be converted.
"""
data_dict = self.data_dict
formal_charge = data_dict.pop("formal_charge")
explicit_hs = data_dict.pop("explicit_hs")
radical_electrons = data_dict.pop("radical_electrons")
pos_nitrogen = (self.atom_type == 7) & (self.explicit_valence > 3)
formal_charge = pos_nitrogen.long()
explicit_hs = torch.zeros_like(explicit_hs)
radical_electrons = torch.zeros_like(radical_electrons)
return type(self)(self.edge_list, edge_weight=self.edge_weight,
num_node=self.num_node, num_relation=self.num_relation,
formal_charge=formal_charge, explicit_hs=explicit_hs, radical_electrons=radical_electrons,
meta_dict=self.meta_dict, **data_dict)
[docs] def to_scaffold(self, chirality=False):
"""
Return a scaffold SMILES string of this molecule.
Parameters:
chirality (bool, optional): consider chirality in the scaffold or not
Returns:
str
"""
smiles = self.to_smiles()
scaffold = MurckoScaffold.MurckoScaffoldSmiles(smiles, includeChirality=chirality)
return scaffold
[docs] def node_mask(self, index, compact=False):
self._check_no_stereo()
return super(Molecule, self).node_mask(index, compact)
[docs] def edge_mask(self, index):
self._check_no_stereo()
return super(Molecule, self).edge_mask(index)
[docs] def undirected(self, add_inverse=False):
if add_inverse:
raise ValueError("Bonds are undirected relations, but `add_inverse` is specified")
return super(Molecule, self).undirected(add_inverse)
[docs] def atom(self):
"""
Context manager for atom attributes.
"""
return self.node()
[docs] def bond(self):
"""
Context manager for bond attributes.
"""
return self.edge()
[docs] def mol(self):
"""
Context manager for molecule attributes.
"""
return self.graph()
[docs] def atom_reference(self):
"""
Context manager for atom references.
"""
return self.node_reference()
[docs] def bond_reference(self):
"""
Context manager for bond references.
"""
return self.edge_reference()
[docs] def mol_reference(self):
"""
Context mangaer for molecule references.
"""
return self.graph_reference()
@property
def num_node(self):
return self.num_atom
@num_node.setter
def num_node(self, value):
self.num_atom = value
@property
def num_edge(self):
return self.num_bond
@num_edge.setter
def num_edge(self, value):
self.num_bond = value
atom2graph = Graph.node2graph
bond2graph = Graph.edge2graph
@property
def node_feature(self):
return self.atom_feature
@node_feature.setter
def node_feature(self, value):
self.atom_feature = value
@property
def edge_feature(self):
return self.bond_feature
@edge_feature.setter
def edge_feature(self, value):
self.bond_feature = value
@property
def graph_feature(self):
return self.mol_feature
@graph_feature.setter
def graph_feature(self, value):
self.mol_feature = value
@utils.cached_property
def explicit_valence(self):
bond2valence = torch.tensor(self.bond2valence, device=self.device)
explicit_valence = scatter_add(bond2valence[self.edge_list[:, 2]], self.edge_list[:, 0], dim_size=self.num_node)
return explicit_valence.round().long()
@utils.cached_property
def is_valid(self):
"""A coarse implementation of valence check."""
# TODO: cross-check by any domain expert
atom2valence = torch.tensor(float("nan")).repeat(constant.NUM_ATOM)
for k, v in self.atom2valence:
atom2valence[k] = v
atom2valence = torch.as_tensor(atom2valence, device=self.device)
max_atom_valence = atom2valence[self.atom_type]
# special case for nitrogen
pos_nitrogen = (self.atom_type == 7) & (self.formal_charge == 1)
max_atom_valence[pos_nitrogen] = 4
if torch.isnan(max_atom_valence).any():
index = torch.isnan(max_atom_valence).nonzero()[0]
raise ValueError("Fail to check valence. Unknown atom type %d" % self.atom_type[index])
is_valid = (self.explicit_valence <= max_atom_valence).all()
return is_valid
@utils.cached_property
def is_valid_rdkit(self):
try:
with utils.no_rdkit_log():
mol = self.to_molecule()
Chem.SanitizeMol(mol, sanitizeOps=Chem.SanitizeFlags.SANITIZE_PROPERTIES)
is_valid = torch.ones(1, dtype=torch.bool, device=self.device)
except ValueError:
is_valid = torch.zeros(1, dtype=torch.bool, device=self.device)
return is_valid
def __repr__(self):
fields = ["num_atom=%d" % self.num_atom, "num_bond=%d" % self.num_bond]
if self.device.type != "cpu":
fields.append("device='%s'" % self.device)
return "%s(%s)" % (self.__class__.__name__, ", ".join(fields))
[docs] def visualize(self, title=None, save_file=None, figure_size=(3, 3), ax=None, atom_map=False):
"""
Visualize this molecule with matplotlib.
Parameters:
title (str, optional): title for this molecule
save_file (str, optional): ``png`` or ``pdf`` file to save visualization.
If not provided, show the figure in window.
figure_size (tuple of int, optional): width and height of the figure
ax (matplotlib.axes.Axes, optional): axis to plot the figure
atom_map (bool, optional): visualize atom mapping or not
"""
is_root = ax is None
if ax is None:
fig = plt.figure(figsize=figure_size)
if title is not None:
ax = plt.gca()
else:
ax = fig.add_axes([0, 0, 1, 1])
if title is not None:
ax.set_title(title)
mol = self.to_molecule()
if not atom_map:
for atom in mol.GetAtoms():
atom.SetAtomMapNum(0)
draw.MolToMPL(mol, ax=ax)
ax.set_frame_on(False)
if is_root:
if save_file:
fig.savefig(save_file)
else:
fig.show()
def __eq__(self, other):
smiles = self.to_smiles(isomeric=False, atom_map=False, canonical=True)
other_smiles = other.to_smiles(isomeric=False, atom_map=False, canonical=True)
return smiles == other_smiles
[docs]class PackedMolecule(PackedGraph, Molecule):
"""
Container for molecules with variadic sizes.
.. warning::
Edges of the same molecule are guaranteed to be consecutive in the edge list.
However, this class doesn't enforce any order on the edges.
Parameters:
edge_list (array_like, optional): list of edges of shape :math:`(|E|, 3)`.
Each tuple is (node_in, node_out, bond_type).
atom_type (array_like, optional): atom types of shape :math:`(|V|,)`
bond_type (array_like, optional): bond types of shape :math:`(|E|,)`
num_nodes (array_like, optional): number of nodes in each graph
By default, it will be inferred from the largest id in `edge_list`
num_edges (array_like, optional): number of edges in each graph
offsets (array_like, optional): node id offsets of shape :math:`(|E|,)`.
If not provided, nodes in `edge_list` should be relative index, i.e., the index in each graph.
If provided, nodes in `edge_list` should be absolute index, i.e., the index in the packed graph.
"""
unpacked_type = Molecule
atom2graph = PackedGraph.node2graph
bond2graph = PackedGraph.edge2graph
def __init__(self, edge_list=None, atom_type=None, bond_type=None, num_nodes=None, num_edges=None, offsets=None,
**kwargs):
if "num_relation" not in kwargs:
kwargs["num_relation"] = len(self.bond2id)
super(PackedMolecule, self).__init__(edge_list=edge_list, num_nodes=num_nodes, num_edges=num_edges,
offsets=offsets, atom_type=atom_type, bond_type=bond_type, **kwargs)
[docs] def ion_to_molecule(self):
"""
Convert ions to molecules by adjusting hydrogens and electrons.
Note [N+] will not be converted.
"""
data_dict = self.data_dict
formal_charge = data_dict.pop("formal_charge")
explicit_hs = data_dict.pop("explicit_hs")
radical_electrons = data_dict.pop("radical_electrons")
pos_nitrogen = (self.atom_type == 7) & (self.explicit_valence > 3)
formal_charge = pos_nitrogen.long()
explicit_hs = torch.zeros_like(explicit_hs)
radical_electrons = torch.zeros_like(radical_electrons)
return type(self)(self.edge_list, edge_weight=self.edge_weight,
num_nodes=self.num_nodes, num_edges=self.num_edges, num_relation=self.num_relation,
offsets=self._offsets, formal_charge=formal_charge, explicit_hs=explicit_hs,
radical_electrons=radical_electrons, meta_dict=self.meta_dict, **data_dict)
@utils.cached_property
def is_valid(self):
"""A coarse implementation of valence check."""
# TODO: cross-check by any domain expert
atom2valence = torch.tensor(float("nan")).repeat(constant.NUM_ATOM)
for k, v in self.atom2valence.items():
atom2valence[k] = v
atom2valence = torch.as_tensor(atom2valence, device=self.device)
max_atom_valence = atom2valence[self.atom_type]
# special case for nitrogen
pos_nitrogen = (self.atom_type == 7) & (self.formal_charge == 1)
max_atom_valence[pos_nitrogen] = 4
if torch.isnan(max_atom_valence).any():
index = torch.isnan(max_atom_valence).nonzero()[0]
raise ValueError("Fail to check valence. Unknown atom type %d" % self.atom_type[index])
is_valid = self.explicit_valence <= max_atom_valence
is_valid = scatter_min(is_valid.long(), self.node2graph, dim_size=self.batch_size)[0].bool()
return is_valid
@utils.cached_property
def is_valid_rdkit(self):
return torch.cat([mol.is_valid_rdkit for mol in self])
[docs] @classmethod
@utils.deprecated_alias(node_feature="atom_feature", edge_feature="bond_feature", graph_feature="mol_feature")
def from_molecule(cls, mols, atom_feature="default", bond_feature="default", mol_feature=None,
with_hydrogen=False, kekulize=False):
"""
Create a packed molecule from a list of RDKit objects.
Parameters:
mols (list of rdchem.Mol): molecules
atom_feature (str or list of str, optional): atom features to extract
bond_feature (str or list of str, optional): bond features to extract
mol_feature (str or list of str, optional): molecule features to extract
with_hydrogen (bool, optional): store hydrogens in the molecule graph.
By default, hydrogens are dropped
kekulize (bool, optional): convert aromatic bonds to single/double bonds.
Note this only affects the relation in ``edge_list``.
For ``bond_type``, aromatic bonds are always stored explicitly.
By default, aromatic bonds are stored.
"""
atom_feature = cls._standarize_option(atom_feature)
bond_feature = cls._standarize_option(bond_feature)
mol_feature = cls._standarize_option(mol_feature)
atom_type = []
formal_charge = []
explicit_hs = []
chiral_tag = []
radical_electrons = []
atom_map = []
edge_list = []
bond_type = []
bond_stereo = []
stereo_atoms = []
node_position = []
_atom_feature = []
_bond_feature = []
_mol_feature = []
num_nodes = []
num_edges = []
mols = mols + [cls.dummy_mol]
for mol in mols:
if mol is None:
mol = cls.empty_mol
# some RDKit operations are in-place
# copy the object to avoid undesired behavior in the caller
mol = copy(mol)
if with_hydrogen:
mol = Chem.AddHs(mol)
if kekulize:
Chem.Kekulize(mol)
if mol.GetNumConformers() > 0:
node_position += mol.GetConformer().GetPositions().tolist()
for atom in mol.GetAtoms():
atom_type.append(atom.GetAtomicNum())
formal_charge.append(atom.GetFormalCharge())
explicit_hs.append(atom.GetNumExplicitHs())
chiral_tag.append(atom.GetChiralTag())
radical_electrons.append(atom.GetNumRadicalElectrons())
atom_map.append(atom.GetAtomMapNum())
feature = []
for name in atom_feature:
func = R.get("features.atom.%s" % name)
feature += func(atom)
_atom_feature.append(feature)
for bond in mol.GetBonds():
type = str(bond.GetBondType())
stereo = bond.GetStereo()
if stereo:
_atoms = list(bond.GetStereoAtoms())
else:
_atoms = [0, 0]
if type not in cls.bond2id:
continue
type = cls.bond2id[type]
h, t = bond.GetBeginAtomIdx(), bond.GetEndAtomIdx()
feature = []
for name in bond_feature:
func = R.get("features.bond.%s" % name)
feature += func(bond)
edge_list += [[h, t, type], [t, h, type]]
# always explicitly store aromatic bonds
if bond.GetIsAromatic():
type = cls.bond2id["AROMATIC"]
bond_type += [type, type]
bond_stereo += [stereo, stereo]
stereo_atoms += [_atoms, _atoms]
_bond_feature += [feature, feature]
feature = []
for name in mol_feature:
func = R.get("features.molecule.%s" % name)
feature += func(mol)
_mol_feature.append(feature)
num_nodes.append(mol.GetNumAtoms())
num_edges.append(mol.GetNumBonds() * 2)
atom_type = torch.tensor(atom_type)[:-2]
atom_map = torch.tensor(atom_map)[:-2]
formal_charge = torch.tensor(formal_charge)[:-2]
explicit_hs = torch.tensor(explicit_hs)[:-2]
chiral_tag = torch.tensor(chiral_tag)[:-2]
radical_electrons = torch.tensor(radical_electrons)[:-2]
if len(node_position) > 0:
node_position = torch.tensor(node_position)
else:
node_position = None
if len(atom_feature) > 0:
_atom_feature = torch.tensor(_atom_feature)[:-2]
else:
_atom_feature = None
num_nodes = num_nodes[:-1]
num_edges = num_edges[:-1]
edge_list = torch.tensor(edge_list)[:-2]
bond_type = torch.tensor(bond_type)[:-2]
bond_stereo = torch.tensor(bond_stereo)[:-2]
stereo_atoms = torch.tensor(stereo_atoms)[:-2]
if len(bond_feature) > 0:
_bond_feature = torch.tensor(_bond_feature)[:-2]
else:
_bond_feature = None
if len(mol_feature) > 0:
_mol_feature = torch.tensor(_mol_feature)[:-1]
else:
_mol_feature = None
num_relation = len(cls.bond2id) - 1 if kekulize else len(cls.bond2id)
return cls(edge_list, atom_type, bond_type,
formal_charge=formal_charge, explicit_hs=explicit_hs,
chiral_tag=chiral_tag, radical_electrons=radical_electrons, atom_map=atom_map,
bond_stereo=bond_stereo, stereo_atoms=stereo_atoms, node_position=node_position,
atom_feature=_atom_feature, bond_feature=_bond_feature, mol_feature=_mol_feature,
num_nodes=num_nodes, num_edges=num_edges, num_relation=num_relation)
[docs] @classmethod
@utils.deprecated_alias(node_feature="atom_feature", edge_feature="bond_feature", graph_feature="mol_feature")
def from_smiles(cls, smiles_list, atom_feature="default", bond_feature="default", mol_feature=None,
with_hydrogen=False, kekulize=False):
"""
Create a packed molecule from a list of SMILES strings.
Parameters:
smiles_list (str): list of SMILES strings
atom_feature (str or list of str, optional): atom features to extract
bond_feature (str or list of str, optional): bond features to extract
mol_feature (str or list of str, optional): molecule features to extract
with_hydrogen (bool, optional): store hydrogens in the molecule graph.
By default, hydrogens are dropped
kekulize (bool, optional): convert aromatic bonds to single/double bonds.
Note this only affects the relation in ``edge_list``.
For ``bond_type``, aromatic bonds are always stored explicitly.
By default, aromatic bonds are stored.
"""
mols = []
for smiles in smiles_list:
mol = Chem.MolFromSmiles(smiles)
if mol is None:
raise ValueError("Invalid SMILES `%s`" % smiles)
mols.append(mol)
return cls.from_molecule(mols, atom_feature, bond_feature, mol_feature, with_hydrogen, kekulize)
[docs] def to_smiles(self, isomeric=True, atom_map=True, canonical=False):
"""
Return a list of SMILES strings.
Parameters:
isomeric (bool, optional): keep isomeric information or not
atom_map (bool, optional): keep atom mapping or not
canonical (bool, optional): if true, return the canonical form of smiles
Returns:
list of str
"""
mols = self.to_molecule()
smiles_list = []
for mol in mols:
if not atom_map:
for atom in mol.GetAtoms():
atom.SetAtomMapNum(0)
smiles = Chem.MolToSmiles(mol, isomericSmiles=isomeric)
if canonical:
smiles_set = set()
while smiles not in smiles_set:
smiles_set.add(smiles)
mol = Chem.MolFromSmiles(smiles)
smiles = Chem.MolToSmiles(mol, isomericSmiles=isomeric)
smiles_list.append(smiles)
return smiles_list
[docs] def to_molecule(self, ignore_error=False):
"""
Return a list of RDKit objects.
Parameters:
ignore_error (bool, optional): if true, return ``None`` for illegal molecules.
Otherwise, raise an exception.
Returns:
list of rdchem.Mol
"""
atom_type = self.atom_type.tolist()
bond_type = self.bond_type.tolist()
formal_charge = self.formal_charge.tolist()
explicit_hs = self.explicit_hs.tolist()
chiral_tag = self.chiral_tag.tolist()
radical_electrons = self.radical_electrons.tolist()
atom_map = self.atom_map.tolist()
bond_stereo = self.bond_stereo.tolist()
stereo_atoms = self.stereo_atoms.tolist()
if hasattr(self, "node_position"):
node_position = self.node_position.tolist()
else:
node_position = None
num_cum_nodes = [0] + self.num_cum_nodes.tolist()
num_cum_edges = [0] + self.num_cum_edges.tolist()
edge_list = self.edge_list.clone()
edge_list[:, :2] -= self._offsets.unsqueeze(-1)
edge_list = edge_list.tolist()
mols = []
for i in range(self.batch_size):
mol = Chem.RWMol()
if node_position:
conformer = Chem.Conformer()
else:
conformer = None
for j in range(num_cum_nodes[i], num_cum_nodes[i + 1]):
atom = Chem.Atom(atom_type[j])
atom.SetFormalCharge(formal_charge[j])
atom.SetNumExplicitHs(explicit_hs[j])
atom.SetChiralTag(Chem.ChiralType(chiral_tag[j]))
atom.SetNumRadicalElectrons(radical_electrons[j])
atom.SetNoImplicit(explicit_hs[j] > 0 or radical_electrons[j] > 0)
atom.SetAtomMapNum(atom_map[j])
if conformer:
conformer.SetAtomPosition(j - num_cum_nodes[i], node_position[j])
mol.AddAtom(atom)
if conformer:
mol.AddConformer(conformer)
for j in range(num_cum_edges[i], num_cum_edges[i + 1]):
h, t, type = edge_list[j]
if h < t:
k = mol.AddBond(h, t, Chem.BondType.names[self.id2bond[type]])
bond = mol.GetBondWithIdx(k - 1)
bond.SetIsAromatic(bond_type[j] == self.bond2id["AROMATIC"])
bond.SetStereo(Chem.BondStereo(bond_stereo[j]))
k = 0
for j in range(num_cum_edges[i], num_cum_edges[i + 1]):
h, t, type = edge_list[j]
if h < t:
if bond_stereo[j]:
bond = mol.GetBondWithIdx(k)
# These do not necessarily need to be the highest 'ranking' atoms like CIP stereo requires.
# They can be any arbitrary atoms neighboring the begin and end atoms of this bond respectively.
# STEREOCIS or STEREOTRANS is then set relative to only these atoms.
bond.SetStereoAtoms(*stereo_atoms[j])
k += 1
if ignore_error:
try:
with utils.no_rdkit_log():
mol.UpdatePropertyCache()
Chem.AssignStereochemistry(mol)
mol.ClearComputedProps()
mol.UpdatePropertyCache()
except:
mol = None
else:
mol.UpdatePropertyCache()
Chem.AssignStereochemistry(mol)
mol.ClearComputedProps()
mol.UpdatePropertyCache()
mols.append(mol)
return mols
[docs] def node_mask(self, index, compact=False):
self._check_no_stereo()
return super(PackedMolecule, self).node_mask(index, compact)
[docs] def edge_mask(self, index):
self._check_no_stereo()
return super(PackedMolecule, self).edge_mask(index)
[docs] def undirected(self, add_inverse=False):
if add_inverse:
raise ValueError("Bonds are undirected relations, but `add_inverse` is specified")
return super(PackedMolecule, self).undirected(add_inverse)
@property
def num_nodes(self):
return self.num_atoms
@num_nodes.setter
def num_nodes(self, value):
self.num_atoms = value
@property
def num_edges(self):
return self.num_bonds
@num_edges.setter
def num_edges(self, value):
self.num_bonds = value
def __repr__(self):
fields = ["batch_size=%d" % self.batch_size,
"num_atoms=%s" % pretty.long_array(self.num_atoms.tolist()),
"num_bonds=%s" % pretty.long_array(self.num_bonds.tolist())]
if self.device.type != "cpu":
fields.append("device='%s'" % self.device)
return "%s(%s)" % (self.__class__.__name__, ", ".join(fields))
[docs] def visualize(self, titles=None, save_file=None, figure_size=(3, 3), num_row=None, num_col=None, atom_map=False):
"""
Visualize the packed molecules with matplotlib.
Parameters:
titles (list of str, optional): title for each molecule. Default is the ID of each molecule.
save_file (str, optional): ``png`` or ``pdf`` file to save visualization.
If not provided, show the figure in window.
figure_size (tuple of int, optional): width and height of the figure
num_row (int, optional): number of rows in the figure
num_col (int, optional): number of columns in the figure
atom_map (bool, optional): visualize atom mapping or not
"""
if titles is None:
graph = self.get_item(0)
titles = ["%s %d" % (type(graph).__name__, i) for i in range(self.batch_size)]
if num_col is None:
if num_row is None:
num_col = math.ceil(self.batch_size ** 0.5)
else:
num_col = math.ceil(self.batch_size / num_row)
if num_row is None:
num_row = math.ceil(self.batch_size / num_col)
figure_size = (num_col * figure_size[0], num_row * figure_size[1])
fig = plt.figure(figsize=figure_size)
for i in range(self.batch_size):
graph = self.get_item(i)
ax = fig.add_subplot(num_row, num_col, i + 1)
graph.visualize(title=titles[i], ax=ax, atom_map=atom_map)
# remove the space of axis labels
fig.tight_layout()
if save_file:
fig.savefig(save_file)
else:
fig.show()
Molecule.packed_type = PackedMolecule