import warnings

from rdkit import Chem
from rdkit.Chem import AllChem

from torchdrug.core import Registry as R

# orderd by perodic table
atom_vocab = ["H", "B", "C", "N", "O", "F", "Mg", "Si", "P", "S", "Cl", "Cu", "Zn", "Se", "Br", "Sn", "I"]
atom_vocab = {a: i for i, a in enumerate(atom_vocab)}
degree_vocab = range(7)
num_hs_vocab = range(7)
formal_charge_vocab = range(-5, 6)
chiral_tag_vocab = range(4)
total_valence_vocab = range(8)
num_radical_vocab = range(8)
hybridization_vocab = range(len(Chem.rdchem.HybridizationType.values))

bond_type_vocab = [Chem.rdchem.BondType.SINGLE, Chem.rdchem.BondType.DOUBLE,
                   Chem.rdchem.BondType.TRIPLE, Chem.rdchem.BondType.AROMATIC]
bond_type_vocab = {b: i for i, b in enumerate(bond_type_vocab)}
bond_dir_vocab = range(len(Chem.rdchem.BondDir.values))
bond_stereo_vocab = range(len(Chem.rdchem.BondStereo.values))

# orderd by molecular mass
residue_vocab = ["GLY", "ALA", "SER", "PRO", "VAL", "THR", "CYS", "ILE", "LEU", "ASN",
                 "ASP", "GLN", "LYS", "GLU", "MET", "HIS", "PHE", "ARG", "TYR", "TRP"]

def onehot(x, vocab, allow_unknown=False):
    if x in vocab:
        if isinstance(vocab, dict):
            index = vocab[x]
            index = vocab.index(x)
        index = -1
    if allow_unknown:
        feature = [0] * (len(vocab) + 1)
        if index == -1:
            warnings.warn("Unknown value `%s`" % x)
        feature[index] = 1
        feature = [0] * len(vocab)
        if index == -1:
            raise ValueError("Unknown value `%s`. Available vocabulary is `%s`" % (x, vocab))
        feature[index] = 1

    return feature

# TODO: this one is too slow
[docs]@R.register("features.atom.default") def atom_default(atom): """Default atom feature. Features: GetSymbol(): one-hot embedding for the atomic symbol GetChiralTag(): one-hot embedding for atomic chiral tag GetTotalDegree(): one-hot embedding for the degree of the atom in the molecule including Hs GetFormalCharge(): one-hot embedding for the number of formal charges in the molecule GetTotalNumHs(): one-hot embedding for the total number of Hs (explicit and implicit) on the atom GetNumRadicalElectrons(): one-hot embedding for the number of radical electrons on the atom GetHybridization(): one-hot embedding for the atom's hybridization GetIsAromatic(): whether the atom is aromatic IsInRing(): whether the atom is in a ring """ return onehot(atom.GetSymbol(), atom_vocab, allow_unknown=True) + \ onehot(atom.GetChiralTag(), chiral_tag_vocab) + \ onehot(atom.GetTotalDegree(), degree_vocab, allow_unknown=True) + \ onehot(atom.GetFormalCharge(), formal_charge_vocab) + \ onehot(atom.GetTotalNumHs(), num_hs_vocab) + \ onehot(atom.GetNumRadicalElectrons(), num_radical_vocab) + \ onehot(atom.GetHybridization(), hybridization_vocab) + \ [atom.GetIsAromatic(), atom.IsInRing()]
[docs]@R.register("features.atom.center_identification") def atom_center_identification(atom): """Reaction center identification atom feature. Features: GetSymbol(): one-hot embedding for the atomic symbol GetTotalNumHs(): one-hot embedding for the total number of Hs (explicit and implicit) on the atom GetTotalDegree(): one-hot embedding for the degree of the atom in the molecule including Hs GetTotalValence(): one-hot embedding for the total valence (explicit + implicit) of the atom GetIsAromatic(): whether the atom is aromatic IsInRing(): whether the atom is in a ring """ return onehot(atom.GetSymbol(), atom_vocab, allow_unknown=True) + \ onehot(atom.GetTotalNumHs(), num_hs_vocab) + \ onehot(atom.GetTotalDegree(), degree_vocab, allow_unknown=True) + \ onehot(atom.GetTotalValence(), total_valence_vocab) + \ [atom.GetIsAromatic(), atom.IsInRing()]
[docs]@R.register("features.atom.synthon_completion") def atom_synthon_completion(atom): """Synthon completion atom feature. Features: GetSymbol(): one-hot embedding for the atomic symbol GetTotalNumHs(): one-hot embedding for the total number of Hs (explicit and implicit) on the atom GetTotalDegree(): one-hot embedding for the degree of the atom in the molecule including Hs IsInRing(): whether the atom is in a ring IsInRingSize(3, 4, 5, 6): whether the atom is in a ring of a particular size IsInRing() and not IsInRingSize(3, 4, 5, 6): whether the atom is in a ring and not in a ring of 3, 4, 5, 6 """ return onehot(atom.GetSymbol(), atom_vocab, allow_unknown=True) + \ onehot(atom.GetTotalNumHs(), num_hs_vocab) + \ onehot(atom.GetTotalDegree(), degree_vocab, allow_unknown=True) + \ [atom.IsInRing(), atom.IsInRingSize(3), atom.IsInRingSize(4), atom.IsInRingSize(5), atom.IsInRingSize(6), atom.IsInRing() and (not atom.IsInRingSize(3)) and (not atom.IsInRingSize(4)) \ and (not atom.IsInRingSize(5)) and (not atom.IsInRingSize(6))]
[docs]@R.register("features.atom.symbol") def atom_symbol(atom): """Symbol atom feature. Features: GetSymbol(): one-hot embedding for the atomic symbol """ return onehot(atom.GetSymbol(), atom_vocab, allow_unknown=True)
[docs]@R.register("features.atom.explicit_property_prediction") def atom_explicit_property_prediction(atom): """Explicit property prediction atom feature. Features: GetSymbol(): one-hot embedding for the atomic symbol GetDegree(): one-hot embedding for the degree of the atom in the molecule GetTotalValence(): one-hot embedding for the total valence (explicit + implicit) of the atom GetFormalCharge(): one-hot embedding for the number of formal charges in the molecule GetIsAromatic(): whether the atom is aromatic """ return onehot(atom.GetSymbol(), atom_vocab, allow_unknown=True) + \ onehot(atom.GetDegree(), degree_vocab, allow_unknown=True) + \ onehot(atom.GetTotalValence(), total_valence_vocab, allow_unknown=True) + \ onehot(atom.GetFormalCharge(), formal_charge_vocab) + \ [atom.GetIsAromatic()]
[docs]@R.register("features.atom.property_prediction") def atom_property_prediction(atom): """Property prediction atom feature. Features: GetSymbol(): one-hot embedding for the atomic symbol GetDegree(): one-hot embedding for the degree of the atom in the molecule GetTotalNumHs(): one-hot embedding for the total number of Hs (explicit and implicit) on the atom GetTotalValence(): one-hot embedding for the total valence (explicit + implicit) of the atom GetFormalCharge(): one-hot embedding for the number of formal charges in the molecule GetIsAromatic(): whether the atom is aromatic """ return onehot(atom.GetSymbol(), atom_vocab, allow_unknown=True) + \ onehot(atom.GetDegree(), degree_vocab, allow_unknown=True) + \ onehot(atom.GetTotalNumHs(), num_hs_vocab, allow_unknown=True) + \ onehot(atom.GetTotalValence(), total_valence_vocab, allow_unknown=True) + \ onehot(atom.GetFormalCharge(), formal_charge_vocab, allow_unknown=True) + \ [atom.GetIsAromatic()]
[docs]@R.register("features.atom.position") def atom_position(atom): """ Atom position in the molecular conformation. Return 3D position if available, otherwise 2D position is returned. Note it takes much time to compute the conformation for large molecules. """ mol = atom.GetOwningMol() if mol.GetNumConformers() == 0: mol.Compute2DCoords() conformer = mol.GetConformer() pos = conformer.GetAtomPosition(atom.GetIdx()) return [pos.x, pos.y, pos.z]
[docs]@R.register("features.atom.pretrain") def atom_pretrain(atom): """Atom feature for pretraining. Features: GetSymbol(): one-hot embedding for the atomic symbol GetChiralTag(): one-hot embedding for atomic chiral tag """ return onehot(atom.GetSymbol(), atom_vocab, allow_unknown=True) + \ onehot(atom.GetChiralTag(), chiral_tag_vocab)
[docs]@R.register("features.atom.residue_symbol") def atom_residue_symbol(atom): """Residue symbol as atom feature. Only support atoms in a protein. Features: GetSymbol(): one-hot embedding for the atomic symbol GetResidueName(): one-hot embedding for the residue symbol """ residue = atom.GetPDBResidueInfo() return onehot(atom.GetSymbol(), atom_vocab, allow_unknown=True) + \ onehot(residue.GetResidueName() if residue else -1, residue_vocab, allow_unknown=True)
[docs]@R.register("") def bond_default(bond): """Default bond feature. Features: GetBondType(): one-hot embedding for the type of the bond GetBondDir(): one-hot embedding for the direction of the bond GetStereo(): one-hot embedding for the stereo configuration of the bond GetIsConjugated(): whether the bond is considered to be conjugated """ return onehot(bond.GetBondType(), bond_type_vocab) + \ onehot(bond.GetBondDir(), bond_dir_vocab) + \ onehot(bond.GetStereo(), bond_stereo_vocab) + \ [int(bond.GetIsConjugated())]
[docs]@R.register("") def bond_length(bond): """ Bond length in the molecular conformation. Note it takes much time to compute the conformation for large molecules. """ mol = bond.GetOwningMol() if mol.GetNumConformers() == 0: mol.Compute2DCoords() conformer = mol.GetConformer() h = conformer.GetAtomPosition(bond.GetBeginAtomIdx()) t = conformer.GetAtomPosition(bond.GetEndAtomIdx()) return [h.Distance(t)]
[docs]@R.register("") def bond_property_prediction(bond): """Property prediction bond feature. Features: GetBondType(): one-hot embedding for the type of the bond GetIsConjugated(): whether the bond is considered to be conjugated IsInRing(): whether the bond is in a ring """ return onehot(bond.GetBondType(), bond_type_vocab) + \ [int(bond.GetIsConjugated()), bond.IsInRing()]
[docs]@R.register("") def bond_pretrain(bond): """Bond feature for pretraining. Features: GetBondType(): one-hot embedding for the type of the bond GetBondDir(): one-hot embedding for the direction of the bond """ return onehot(bond.GetBondType(), bond_type_vocab) + \ onehot(bond.GetBondDir(), bond_dir_vocab)
[docs]@R.register("features.residue.symbol") def residue_symbol(residue): """Symbol residue feature. Features: GetResidueName(): one-hot embedding for the residue symbol """ return onehot(residue.GetResidueName(), residue_vocab, allow_unknown=True)
[docs]@R.register("features.residue.default") def residue_default(residue): """Default residue feature. Features: GetResidueName(): one-hot embedding for the residue symbol """ return residue_symbol(residue)
[docs]@R.register("features.molecule.ecfp") def ExtendedConnectivityFingerprint(mol, radius=2, length=1024): """Extended Connectivity Fingerprint molecule feature. Features: GetMorganFingerprintAsBitVect(): a Morgan fingerprint for a molecule as a bit vector """ ecfp = AllChem.GetMorganFingerprintAsBitVect(mol, radius, length) return list(ecfp)
[docs]@R.register("features.molecule.default") def molecule_default(mol): """Default molecule feature.""" return ExtendedConnectivityFingerprint(mol)
ECFP = ExtendedConnectivityFingerprint __all__ = [ "atom_default", "atom_center_identification", "atom_synthon_completion", "atom_symbol", "atom_explicit_property_prediction", "atom_property_prediction", "atom_position", "atom_pretrain", "atom_residue_symbol", "bond_default", "bond_length", "bond_property_prediction", "bond_pretrain", "residue_symbol", "residue_default", "ExtendedConnectivityFingerprint", "molecule_default", "ECFP", ]