import os
import string
import warnings
from collections import defaultdict
from rdkit import Chem
import torch
from torch_scatter import scatter_add, scatter_max, scatter_min
from torchdrug import utils
from torchdrug.data import Molecule, PackedMolecule, Dictionary, feature
from torchdrug.core import Registry as R
from torchdrug.utils import pretty
[docs]class Protein(Molecule):
"""
Proteins with predefined chemical features.
Support both residue-level and atom-level operations and ensure consistency between two views.
.. warning::
The order of residues must be the same as the protein sequence.
However, this class doesn't enforce any order on nodes or edges.
Nodes may have a different order with residues.
Parameters:
edge_list (array_like, optional): list of edges of shape :math:`(|E|, 3)`.
Each tuple is (node_in, node_out, bond_type).
atom_type (array_like, optional): atom types of shape :math:`(|V|,)`
bond_type (array_like, optional): bond types of shape :math:`(|E|,)`
residue_type (array_like, optional): residue types of shape :math:`(|V_{res}|,)`
view (str, optional): default view for this protein. Can be ``atom`` or ``residue``.
atom_name (array_like, optional): atom names in a residue of shape :math:`(|V|,)`
atom2residue (array_like, optional): atom id to residue id mapping of shape :math:`(|V|,)`
residue_feature (array_like, optional): residue features of shape :math:`(|V_{res}|, ...)`
is_hetero_atom (array_like, optional): hetero atom indicators of shape :math:`(|V|,)`
occupancy (array_like, optional): occupancy of shape :math:`(|V|,)`
b_factor (array_like, optional): temperature factors of shape :math:`(|V|,)`
residue_number (array_like, optional): residue numbers of shape :math:`(|V_{res}|,)`
insertion_code (array_like, optional): insertion codes of shape :math:`(|V_{res}|,)`
chain_id (array_like, optional): chain ids of shape :math:`(|V_{res}|,)`
"""
_meta_types = {"node", "edge", "residue", "graph",
"node reference", "edge reference", "residue reference", "graph reference"}
dummy_protein = Chem.MolFromSequence("G")
dummy_atom = dummy_protein.GetAtomWithIdx(0)
# TODO: rdkit isn't compatible with X in the sequence
residue2id = {"GLY": 0, "ALA": 1, "SER": 2, "PRO": 3, "VAL": 4, "THR": 5, "CYS": 6, "ILE": 7, "LEU": 8,
"ASN": 9, "ASP": 10, "GLN": 11, "LYS": 12, "GLU": 13, "MET": 14, "HIS": 15, "PHE": 16,
"ARG": 17, "TYR": 18, "TRP": 19}
residue_symbol2id = {"G": 0, "A": 1, "S": 2, "P": 3, "V": 4, "T": 5, "C": 6, "I": 7, "L": 8, "N": 9,
"D": 10, "Q": 11, "K": 12, "E": 13, "M": 14, "H": 15, "F": 16, "R": 17, "Y": 18, "W": 19}
atom_name2id = {"C": 0, "CA": 1, "CB": 2, "CD": 3, "CD1": 4, "CD2": 5, "CE": 6, "CE1": 7, "CE2": 8,
"CE3": 9, "CG": 10, "CG1": 11, "CG2": 12, "CH2": 13, "CZ": 14, "CZ2": 15, "CZ3": 16,
"N": 17, "ND1": 18, "ND2": 19, "NE": 20, "NE1": 21, "NE2": 22, "NH1": 23, "NH2": 24,
"NZ": 25, "O": 26, "OD1": 27, "OD2": 28, "OE1": 29, "OE2": 30, "OG": 31, "OG1": 32,
"OH": 33, "OXT": 34, "SD": 35, "SG": 36, "UNK": 37}
alphabet2id = {c: i for i, c in enumerate(" " + string.ascii_uppercase + string.ascii_lowercase + string.digits)}
id2residue = {v: k for k, v in residue2id.items()}
id2residue_symbol = {v: k for k, v in residue_symbol2id.items()}
id2atom_name = {v: k for k, v in atom_name2id.items()}
id2alphabet = {v: k for k, v in alphabet2id.items()}
def __init__(self, edge_list=None, atom_type=None, bond_type=None, residue_type=None, view=None,
atom_name=None, atom2residue=None, residue_feature=None, is_hetero_atom=None, occupancy=None,
b_factor=None, residue_number=None, insertion_code=None, chain_id=None, **kwargs):
super(Protein, self).__init__(edge_list, atom_type, bond_type, **kwargs)
residue_type, num_residue = self._standarize_num_residue(residue_type)
self.num_residue = num_residue
self.view = self._standarize_view(view)
atom_name = self._standarize_attribute(atom_name, self.num_node)
atom2residue = self._standarize_attribute(atom2residue, self.num_node)
is_hetero_atom = self._standarize_attribute(is_hetero_atom, self.num_node, dtype=torch.bool)
occupancy = self._standarize_attribute(occupancy, self.num_node, dtype=torch.float, default=1)
b_factor = self._standarize_attribute(b_factor, self.num_node, dtype=torch.float)
residue_number = self._standarize_attribute(residue_number, self.num_residue)
insertion_code = self._standarize_attribute(insertion_code, self.num_residue)
chain_id = self._standarize_attribute(chain_id, self.num_residue)
with self.atom():
self.atom_name = atom_name
with self.residue_reference():
self.atom2residue = atom2residue
self.is_hetero_atom = is_hetero_atom
self.occupancy = occupancy
self.b_factor = b_factor
with self.residue():
self.residue_type = residue_type
if residue_feature is not None:
self.residue_feature = torch.as_tensor(residue_feature, device=self.device)
self.residue_number = residue_number
self.insertion_code = insertion_code
self.chain_id = chain_id
[docs] def residue(self):
"""
Context manager for residue attributes.
"""
return self.context("residue")
[docs] def residue_reference(self):
"""
Context manager for residue references.
"""
return self.context("residue reference")
@property
def node_feature(self):
if getattr(self, "view", "atom") == "atom":
return self.atom_feature
else:
return self.residue_feature
@node_feature.setter
def node_feature(self, value):
self.atom_feature = value
@property
def num_node(self):
return self.num_atom
@num_node.setter
def num_node(self, value):
self.num_atom = value
def _check_attribute(self, key, value):
super(Protein, self)._check_attribute(key, value)
for type in self._meta_contexts:
if type == "residue":
if len(value) != self.num_residue:
raise ValueError("Expect residue attribute `%s` to have shape (%d, *), but found %s" %
(key, self.num_residue, value.shape))
elif type == "residue reference":
is_valid = (value >= -1) & (value < self.num_residue)
if not is_valid.all():
error_value = value[~is_valid]
raise ValueError("Expect residue reference in [-1, %d), but found %d" %
(self.num_residue, error_value[0]))
def _standarize_num_residue(self, residue_type):
if residue_type is None:
raise ValueError("`residue_type` should be provided")
residue_type = torch.as_tensor(residue_type, dtype=torch.long, device=self.device)
num_residue = torch.tensor(len(residue_type), device=self.device)
return residue_type, num_residue
def __setattr__(self, key, value):
if key == "view" and value not in ["atom", "residue"]:
raise ValueError("Expect `view` to be either `atom` or `residue`, but found `%s`" % value)
return super(Protein, self).__setattr__(key, value)
def _standarize_view(self, view):
if view is None:
if self.num_atom > 0:
view = "atom"
else:
view = "residue"
return view
[docs] @classmethod
@utils.deprecated_alias(node_feature="atom_feature", edge_feature="bond_feature", graph_feature="mol_feature")
def from_molecule(cls, mol, atom_feature="default", bond_feature="default", residue_feature="default",
mol_feature=None, kekulize=False):
"""
Create a protein from an RDKit object.
Parameters:
mol (rdchem.Mol): molecule
atom_feature (str or list of str, optional): atom features to extract
bond_feature (str or list of str, optional): bond features to extract
residue_feature (str, list of str, optional): residue features to extract
mol_feature (str or list of str, optional): molecule features to extract
kekulize (bool, optional): convert aromatic bonds to single/double bonds.
Note this only affects the relation in ``edge_list``.
For ``bond_type``, aromatic bonds are always stored explicitly.
By default, aromatic bonds are stored.
"""
protein = Molecule.from_molecule(mol, atom_feature=atom_feature, bond_feature=bond_feature,
mol_feature=mol_feature, with_hydrogen=False, kekulize=kekulize)
residue_feature = cls._standarize_option(residue_feature)
if kekulize:
Chem.Kekulize(mol)
residue_type = []
atom_name = []
is_hetero_atom = []
occupancy = []
b_factor = []
atom2residue = []
residue_number = []
insertion_code = []
chain_id = []
_residue_feature = []
last_residue = None
atoms = [mol.GetAtomWithIdx(i) for i in range(mol.GetNumAtoms())] + [cls.dummy_atom]
for atom in atoms:
pdbinfo = atom.GetPDBResidueInfo()
number = pdbinfo.GetResidueNumber()
code = pdbinfo.GetInsertionCode()
type = pdbinfo.GetResidueName().strip()
canonical_residue = (number, code, type)
if canonical_residue != last_residue:
last_residue = canonical_residue
if type not in cls.residue2id:
warnings.warn("Unknown residue `%s`. Treat as glycine" % type)
type = "GLY"
residue_type.append(cls.residue2id[type])
residue_number.append(number)
if pdbinfo.GetInsertionCode() not in cls.alphabet2id:
warnings.warn(f"Fail to create the protein. Unknown insertion code {pdbinfo.GetInsertionCode()}.")
return None
if pdbinfo.GetChainId() not in cls.alphabet2id:
warnings.warn(f"Fail to create the protein. Unknown chain id {pdbinfo.GetChainId()}.")
return None
insertion_code.append(cls.alphabet2id[pdbinfo.GetInsertionCode()])
chain_id.append(cls.alphabet2id[pdbinfo.GetChainId()])
feature = []
for name in residue_feature:
func = R.get("features.residue.%s" % name)
feature += func(pdbinfo)
_residue_feature.append(feature)
name = pdbinfo.GetName().strip()
if name not in cls.atom_name2id:
name = "UNK"
atom_name.append(cls.atom_name2id[name])
is_hetero_atom.append(pdbinfo.GetIsHeteroAtom())
occupancy.append(pdbinfo.GetOccupancy())
b_factor.append(pdbinfo.GetTempFactor())
atom2residue.append(len(residue_type) - 1)
residue_type = torch.tensor(residue_type)[:-1]
atom_name = torch.tensor(atom_name)[:-1]
is_hetero_atom = torch.tensor(is_hetero_atom)[:-1]
occupancy = torch.tensor(occupancy)[:-1]
b_factor = torch.tensor(b_factor)[:-1]
atom2residue = torch.tensor(atom2residue)[:-1]
residue_number = torch.tensor(residue_number)[:-1]
insertion_code = torch.tensor(insertion_code)[:-1]
chain_id = torch.tensor(chain_id)[:-1]
if len(residue_feature) > 0:
_residue_feature = torch.tensor(_residue_feature)[:-1]
else:
_residue_feature = None
return cls(protein.edge_list, num_node=protein.num_node, residue_type=residue_type,
atom_name=atom_name, atom2residue=atom2residue, residue_feature=_residue_feature,
is_hetero_atom=is_hetero_atom, occupancy=occupancy, b_factor=b_factor,
residue_number=residue_number, insertion_code=insertion_code, chain_id=chain_id,
meta_dict=protein.meta_dict, **protein.data_dict)
@classmethod
def _residue_from_sequence(cls, sequence):
residue_type = []
residue_feature = []
sequence = sequence + "G"
for residue in sequence:
if residue not in cls.residue_symbol2id:
warnings.warn("Unknown residue symbol `%s`. Treat as glycine" % residue)
residue = "G"
residue_type.append(cls.residue_symbol2id[residue])
residue_feature.append(feature.onehot(residue, cls.residue_symbol2id, allow_unknown=True))
residue_type = residue_type[:-1]
residue_feature = torch.tensor(residue_feature)[:-1]
return cls(edge_list=None, atom_type=[], bond_type=[], num_node=0, residue_type=residue_type,
residue_feature=residue_feature)
[docs] @classmethod
@utils.deprecated_alias(node_feature="atom_feature", edge_feature="bond_feature", graph_feature="mol_feature")
def from_sequence(cls, sequence, atom_feature="default", bond_feature="default", residue_feature="default",
mol_feature=None, kekulize=False):
"""
Create a protein from a sequence.
.. note::
It takes considerable time to construct proteins with a large number of atoms and bonds.
If you only need residue information, you may speed up the construction by setting
``atom_feature`` and ``bond_feature`` to ``None``.
Parameters:
sequence (str): protein sequence
atom_feature (str or list of str, optional): atom features to extract
bond_feature (str or list of str, optional): bond features to extract
residue_feature (str, list of str, optional): residue features to extract
mol_feature (str or list of str, optional): molecule features to extract
kekulize (bool, optional): convert aromatic bonds to single/double bonds.
Note this only affects the relation in ``edge_list``.
For ``bond_type``, aromatic bonds are always stored explicitly.
By default, aromatic bonds are stored.
"""
if atom_feature is None and bond_feature is None and residue_feature == "default":
return cls._residue_from_sequence(sequence)
mol = Chem.MolFromSequence(sequence)
if mol is None:
raise ValueError("Invalid sequence `%s`" % sequence)
return cls.from_molecule(mol, atom_feature, bond_feature, residue_feature, mol_feature, kekulize)
[docs] @classmethod
@utils.deprecated_alias(node_feature="atom_feature", edge_feature="bond_feature", graph_feature="mol_feature")
def from_pdb(cls, pdb_file, atom_feature="default", bond_feature="default", residue_feature="default",
mol_feature=None, kekulize=False):
"""
Create a protein from a PDB file.
Parameters:
pdb_file (str): file name
atom_feature (str or list of str, optional): atom features to extract
bond_feature (str or list of str, optional): bond features to extract
residue_feature (str, list of str, optional): residue features to extract
mol_feature (str or list of str, optional): molecule features to extract
kekulize (bool, optional): convert aromatic bonds to single/double bonds.
Note this only affects the relation in ``edge_list``.
For ``bond_type``, aromatic bonds are always stored explicitly.
By default, aromatic bonds are stored.
"""
if not os.path.exists(pdb_file):
raise FileNotFoundError("No such file `%s`" % pdb_file)
mol = Chem.MolFromPDBFile(pdb_file)
if mol is None:
raise ValueError("RDKit cannot read PDB file `%s`" % pdb_file)
return cls.from_molecule(mol, atom_feature, bond_feature, residue_feature, mol_feature, kekulize)
[docs] def to_molecule(self, ignore_error=False):
"""
Return an RDKit object of this protein.
Parameters:
ignore_error (bool, optional): if true, return ``None`` for illegal molecules.
Otherwise, raise an exception.
Returns:
rdchem.Mol
"""
mol = super(Protein, self).to_molecule(ignore_error)
if mol is None:
return mol
residue_type = self.residue_type.tolist()
atom_name = self.atom_name.tolist()
atom2residue = self.atom2residue.tolist()
is_hetero_atom = self.is_hetero_atom.tolist()
occupancy = self.occupancy.tolist()
b_factor = self.b_factor.tolist()
residue_number = self.residue_number.tolist()
chain_id = self.chain_id.tolist()
insertion_code = self.insertion_code.tolist()
for i, atom in enumerate(mol.GetAtoms()):
r = atom2residue[i]
residue = Chem.AtomPDBResidueInfo()
residue.SetResidueNumber(residue_number[r])
residue.SetChainId(self.id2alphabet[chain_id[r]])
residue.SetInsertionCode(self.id2alphabet[insertion_code[r]])
residue.SetName(" %-3s" % self.id2atom_name[atom_name[i]])
residue.SetResidueName(self.id2residue[residue_type[r]])
residue.SetIsHeteroAtom(is_hetero_atom[i])
residue.SetOccupancy(occupancy[i])
residue.SetTempFactor(b_factor[i])
atom.SetPDBResidueInfo(residue)
return mol
[docs] def to_sequence(self):
"""
Return a sequence of this protein.
Returns:
str
"""
residue_type = self.residue_type.tolist()
cc_id = self.connected_component_id.tolist()
sequence = []
for i in range(self.num_residue):
if i > 0 and cc_id[i] > cc_id[i - 1]:
sequence.append(".")
sequence.append(self.id2residue_symbol[residue_type[i]])
return "".join(sequence)
[docs] def to_pdb(self, pdb_file):
"""
Write this protein to a pdb file.
Parameters:
pdb_file (str): file name
"""
mol = self.to_molecule()
Chem.MolToPDBFile(mol, pdb_file, flavor=10)
[docs] def split(self, node2graph):
node2graph = torch.as_tensor(node2graph, dtype=torch.long, device=self.device)
# coalesce arbitrary graph IDs to [0, n)
_, node2graph = torch.unique(node2graph, return_inverse=True)
num_graph = node2graph.max() + 1
index = node2graph.argsort()
mapping = torch.zeros_like(index)
mapping[index] = torch.arange(len(index), device=self.device)
node_in, node_out = self.edge_list.t()[:2]
edge_mask = node2graph[node_in] == node2graph[node_out]
edge2graph = node2graph[node_in]
edge_index = edge2graph.argsort()
edge_index = edge_index[edge_mask[edge_index]]
prepend = torch.tensor([-1], device=self.device)
is_first_node = torch.diff(node2graph[index], prepend=prepend) > 0
graph_index = self.node2graph[index[is_first_node]]
# a residue can be split into multiple graphs
max_num_node = node2graph.bincount(minlength=num_graph).max()
key = node2graph[index] * max_num_node + self.atom2residue[index]
key_set, atom2residue = key.unique(return_inverse=True)
residue_index = key_set % max_num_node
edge_list = self.edge_list.clone()
edge_list[:, :2] = mapping[edge_list[:, :2]]
num_nodes = node2graph.bincount(minlength=num_graph)
num_edges = edge2graph[edge_index].bincount(minlength=num_graph)
num_cum_residues = scatter_max(atom2residue, node2graph[index], dim_size=num_graph)[0] + 1
prepend = torch.tensor([0], device=self.device)
num_residues = torch.diff(num_cum_residues, prepend=prepend)
num_cum_nodes = num_nodes.cumsum(0)
offsets = (num_cum_nodes - num_nodes)[edge2graph[edge_index]]
data_dict, meta_dict = self.data_mask(index, edge_index, residue_index, graph_index,
exclude=("residue reference", "graph reference"))
return self.packed_type(edge_list[edge_index], edge_weight=self.edge_weight[edge_index],
num_nodes=num_nodes, num_edges=num_edges, num_residues=num_residues, view=self.view,
offsets=offsets, atom2residue=atom2residue, meta_dict=meta_dict, **data_dict)
[docs] @classmethod
def pack(cls, graphs):
edge_list = []
edge_weight = []
num_nodes = []
num_edges = []
num_residues = []
num_cum_node = 0
num_cum_edge = 0
num_cum_residue = 0
num_graph = 0
data_dict = defaultdict(list)
meta_dict = graphs[0].meta_dict
view = graphs[0].view
for graph in graphs:
edge_list.append(graph.edge_list)
edge_weight.append(graph.edge_weight)
num_nodes.append(graph.num_node)
num_edges.append(graph.num_edge)
num_residues.append(graph.num_residue)
for k, v in graph.data_dict.items():
for type in meta_dict[k]:
if type == "graph":
v = v.unsqueeze(0)
elif type == "node reference":
v = torch.where(v != -1, v + num_cum_node, -1)
elif type == "edge reference":
v = torch.where(v != -1, v + num_cum_edge, -1)
elif type == "residue reference":
v = torch.where(v != -1, v + num_cum_residue, -1)
elif type == "graph reference":
v = torch.where(v != -1, v + num_graph, -1)
data_dict[k].append(v)
num_cum_node += graph.num_node
num_cum_edge += graph.num_edge
num_cum_residue += graph.num_residue
num_graph += 1
edge_list = torch.cat(edge_list)
edge_weight = torch.cat(edge_weight)
data_dict = {k: torch.cat(v) for k, v in data_dict.items()}
return cls.packed_type(edge_list, edge_weight=edge_weight, num_relation=graphs[0].num_relation,
num_nodes=num_nodes, num_edges=num_edges, num_residues=num_residues, view=view,
meta_dict=meta_dict, **data_dict)
[docs] def repeat(self, count):
edge_list = self.edge_list.repeat(count, 1)
edge_weight = self.edge_weight.repeat(count)
num_nodes = [self.num_node] * count
num_edges = [self.num_edge] * count
num_residues = [self.num_residue] * count
num_relation = self.num_relation
data_dict = {}
for k, v in self.data_dict.items():
if "graph" in self.meta_dict[k]:
v = v.unsqueeze(0)
shape = [1] * v.ndim
shape[0] = count
length = len(v)
v = v.repeat(shape)
for type in self.meta_dict[k]:
if type == "node reference":
offsets = torch.arange(count, device=self.device) * self.num_node
v = v + offsets.repeat_interleave(length)
elif type == "edge reference":
offsets = torch.arange(count, device=self.device) * self.num_edge
v = v + offsets.repeat_interleave(length)
elif type == "residue reference":
offsets = torch.arange(count, device=self.device) * self.num_residue
v = v + offsets.repeat_interleave(length)
elif type == "graph reference":
offsets = torch.arange(count, device=self.device)
v = v + offsets.repeat_interleave(length)
data_dict[k] = v
return self.packed_type(edge_list, edge_weight=edge_weight,
num_nodes=num_nodes, num_edges=num_edges, num_residues=num_residues, view=self.view,
num_relation=num_relation, meta_dict=self.meta_dict, **data_dict)
[docs] def residue2atom(self, residue_index):
"""Map residue ids to atom ids."""
residue_index = self._standarize_index(residue_index, self.num_residue)
if not hasattr(self, "node_inverted_index"):
self.node_inverted_index = self._build_node_inverted_index()
inverted_range, order = self.node_inverted_index
starts, ends = inverted_range[residue_index].t()
num_match = ends - starts
offsets = num_match.cumsum(0) - num_match
ranges = torch.arange(num_match.sum(), device=self.device)
ranges = ranges + (starts - offsets).repeat_interleave(num_match)
index = order[ranges]
return index
def _build_node_inverted_index(self):
keys = self.atom2residue
order = keys.argsort()
keys_set, num_keys = keys.unique(return_counts=True)
ends = num_keys.cumsum(0)
starts = ends - num_keys
ranges = torch.stack([starts, ends], dim=-1)
inverted_range = Dictionary(keys_set, ranges)
return inverted_range, order
def __getitem__(self, index):
# why do we check tuple?
# case 1: x[0, 1] is parsed as (0, 1)
# case 2: x[[0, 1]] is parsed as [0, 1]
if not isinstance(index, tuple):
index = (index,)
if len(index) > 1:
raise ValueError("Protein has only 1 axis, but %d axis is indexed" % len(index))
return self.residue_mask(index[0], compact=True)
def data_mask(self, node_index=None, edge_index=None, residue_index=None, graph_index=None, include=None,
exclude=None):
data_dict, meta_dict = super(Protein, self).data_mask(node_index, edge_index, graph_index=graph_index,
include=include, exclude=exclude)
residue_mapping = None
for k, v in data_dict.items():
for type in meta_dict[k]:
if type == "residue" and residue_index is not None:
if v.is_sparse:
v = v.to_dense()[residue_index].to_sparse()
else:
v = v[residue_index]
elif type == "residue reference" and residue_index is not None:
if residue_mapping is None:
residue_mapping = self._get_mapping(residue_index, self.num_residue)
v = residue_mapping[v]
data_dict[k] = v
return data_dict, meta_dict
[docs] def residue_mask(self, index, compact=False):
"""
Return a masked protein based on the specified residues.
Note the compact option is applied to both residue and atom ids.
Parameters:
index (array_like): residue index
compact (bool, optional): compact residue ids or not
Returns:
Protein
"""
index = self._standarize_index(index, self.num_residue)
if (torch.diff(index) <= 0).any():
warnings.warn("`residue_mask()` is called to re-order the residues. This will change the protein sequence. "
"If this is not desired, you might have passed a wrong index to this function.")
residue_mapping = -torch.ones(self.num_residue, dtype=torch.long, device=self.device)
residue_mapping[index] = torch.arange(len(index), device=self.device)
node_index = residue_mapping[self.atom2residue] >= 0
node_index = self._standarize_index(node_index, self.num_node)
mapping = -torch.ones(self.num_node, dtype=torch.long, device=self.device)
if compact:
mapping[node_index] = torch.arange(len(node_index), device=self.device)
num_node = len(node_index)
else:
mapping[node_index] = node_index
num_node = self.num_node
edge_list = self.edge_list.clone()
edge_list[:, :2] = mapping[edge_list[:, :2]]
edge_index = (edge_list[:, :2] >= 0).all(dim=-1)
edge_index = self._standarize_index(edge_index, self.num_edge)
if compact:
data_dict, meta_dict = self.data_mask(node_index, edge_index, residue_index=index)
else:
data_dict, meta_dict = self.data_mask(edge_index=edge_index)
return type(self)(edge_list[edge_index], edge_weight=self.edge_weight[edge_index], num_node=num_node,
view=self.view, num_relation=self.num_relation, meta_dict=meta_dict, **data_dict)
[docs] def subresidue(self, index):
"""
Return a subgraph based on the specified residues.
Equivalent to :meth:`residue_mask(index, compact=True) <residue_mask>`.
Parameters:
index (array_like): residue index
Returns:
Protein
See also:
:meth:`Protein.residue_mask`
"""
return self.residue_mask(index, compact=True)
@property
def residue2graph(self):
"""Residue id to graph id mapping."""
return torch.zeros(self.num_residue, dtype=torch.long, device=self.device)
@utils.cached_property
def connected_component_id(self):
"""Connected component id of each residue."""
node_in, node_out = self.edge_list.t()[:2]
residue_in, residue_out = self.atom2residue[node_in], self.atom2residue[node_out]
mask = residue_in != residue_out
residue_in, residue_out = residue_in[mask], residue_out[mask]
range = torch.arange(self.num_residue, device=self.device)
residue_in, residue_out = torch.cat([residue_in, residue_out, range]), \
torch.cat([residue_out, residue_in, range])
min_neighbor = torch.arange(self.num_residue, device=self.device)
last = torch.zeros_like(min_neighbor)
while not torch.equal(min_neighbor, last):
last = min_neighbor
min_neighbor = scatter_min(min_neighbor[residue_out], residue_in, dim_size=self.num_residue)[0]
cc_id = torch.unique(min_neighbor, return_inverse=True)[1]
return cc_id
def __repr__(self):
fields = ["num_atom=%d" % self.num_node, "num_bond=%d" % self.num_edge,
"num_residue=%d" % self.num_residue]
if self.device.type != "cpu":
fields.append("device='%s'" % self.device)
return "%s(%s)" % (self.__class__.__name__, ", ".join(fields))
[docs]class PackedProtein(PackedMolecule, Protein):
"""
Container for proteins with variadic sizes.
Support both residue-level and atom-level operations and ensure consistency between two views.
.. warning::
Edges of the same graph are guaranteed to be consecutive in the edge list.
The order of residues must be the same as the protein sequence.
However, this class doesn't enforce any order on nodes or edges.
Nodes may have a different order with residues.
Parameters:
edge_list (array_like, optional): list of edges of shape :math:`(|E|, 3)`.
Each tuple is (node_in, node_out, bond_type).
atom_type (array_like, optional): atom types of shape :math:`(|V|,)`
bond_type (array_like, optional): bond types of shape :math:`(|E|,)`
residue_type (array_like, optional): residue types of shape :math:`(|V_{res}|,)`
view (str, optional): default view for this protein. Can be ``atom`` or ``residue``.
num_nodes (array_like, optional): number of nodes in each graph
By default, it will be inferred from the largest id in `edge_list`
num_edges (array_like, optional): number of edges in each graph
num_residues (array_like, optional): number of residues in each graph
offsets (array_like, optional): node id offsets of shape :math:`(|E|,)`.
If not provided, nodes in `edge_list` should be relative index, i.e., the index in each graph.
If provided, nodes in `edge_list` should be absolute index, i.e., the index in the packed graph.
"""
unpacked_type = Protein
_check_attribute = Protein._check_attribute
def __init__(self, edge_list=None, atom_type=None, bond_type=None, residue_type=None, view=None, num_nodes=None,
num_edges=None, num_residues=None, offsets=None, **kwargs):
super(PackedProtein, self).__init__(edge_list=edge_list, num_nodes=num_nodes, num_edges=num_edges,
offsets=offsets, atom_type=atom_type, bond_type=bond_type,
residue_type=residue_type, view=view, **kwargs)
num_residues = torch.as_tensor(num_residues, device=self.device)
num_cum_residues = num_residues.cumsum(0)
self.num_residues = num_residues
self.num_cum_residues = num_cum_residues
@property
def num_nodes(self):
return self.num_atoms
@num_nodes.setter
def num_nodes(self, value):
self.num_atoms = value
def data_mask(self, node_index=None, edge_index=None, residue_index=None, graph_index=None, include=None,
exclude=None):
data_dict, meta_dict = super(PackedProtein, self).data_mask(node_index, edge_index, graph_index=graph_index,
include=include, exclude=exclude)
residue_mapping = None
for k, v in data_dict.items():
for type in meta_dict[k]:
if type == "residue" and residue_index is not None:
if v.is_sparse:
v = v.to_dense()[residue_index].to_sparse()
else:
v = v[residue_index]
elif type == "residue reference" and residue_index is not None:
if residue_mapping is None:
residue_mapping = self._get_mapping(residue_index, self.num_residue)
v = residue_mapping[v]
data_dict[k] = v
return data_dict, meta_dict
[docs] def node_mask(self, index, compact=True):
index = self._standarize_index(index, self.num_node)
mapping = -torch.ones(self.num_node, dtype=torch.long, device=self.device)
if compact:
mapping[index] = torch.arange(len(index), device=self.device)
num_nodes = self._get_num_xs(index, self.num_cum_nodes)
offsets = self._get_offsets(num_nodes, self.num_edges)
else:
mapping[index] = index
num_nodes = self.num_nodes
offsets = self._offsets
edge_list = self.edge_list.clone()
edge_list[:, :2] = mapping[edge_list[:, :2]]
edge_index = (edge_list[:, :2] >= 0).all(dim=-1)
num_edges = self._get_num_xs(edge_index, self.num_cum_edges)
if compact:
data_dict, meta_dict = self.data_mask(index, edge_index)
else:
data_dict, meta_dict = self.data_mask(edge_index=edge_index)
return type(self)(edge_list[edge_index], edge_weight=self.edge_weight[edge_index],
num_nodes=num_nodes, num_edges=num_edges, num_residues=self.num_residues,
view=self.view, num_relation=self.num_relation, offsets=offsets[edge_index],
meta_dict=meta_dict, **data_dict)
[docs] def edge_mask(self, index):
index = self._standarize_index(index, self.num_edge)
data_dict, meta_dict = self.data_mask(edge_index=index)
num_edges = self._get_num_xs(index, self.num_cum_edges)
return type(self)(self.edge_list[index], edge_weight=self.edge_weight[index],
num_nodes=self.num_nodes, num_edges=num_edges, num_residues=self.num_residues,
view=self.view, num_relation=self.num_relation, offsets=self._offsets[index],
meta_dict=meta_dict, **data_dict)
[docs] def residue_mask(self, index, compact=False):
"""
Return a masked packed protein based on the specified residues.
Note the compact option is applied to both residue and atom ids, but not graph ids.
Parameters:
index (array_like): residue index
compact (bool, optional): compact residue ids or not
Returns:
PackedProtein
"""
index = self._standarize_index(index, self.num_residue)
residue_mapping = -torch.ones(self.num_residue, dtype=torch.long, device=self.device)
residue_mapping[index] = torch.arange(len(index), device=self.device)
node_index = residue_mapping[self.atom2residue] >= 0
node_index = self._standarize_index(node_index, self.num_node)
mapping = -torch.ones(self.num_node, dtype=torch.long, device=self.device)
if compact:
mapping[node_index] = torch.arange(len(node_index), device=self.device)
num_nodes = self._get_num_xs(node_index, self.num_cum_nodes)
num_residues = self._get_num_xs(index, self.num_cum_residues)
else:
mapping[node_index] = node_index
num_nodes = self.num_nodes
num_residues = self.num_residues
edge_list = self.edge_list.clone()
edge_list[:, :2] = mapping[edge_list[:, :2]]
edge_index = (edge_list[:, :2] >= 0).all(dim=-1)
edge_index = self._standarize_index(edge_index, self.num_edge)
num_edges = self._get_num_xs(edge_index, self.num_cum_edges)
offsets = self._get_offsets(num_nodes, num_edges)
if compact:
data_dict, meta_dict = self.data_mask(node_index, edge_index, residue_index=index)
else:
data_dict, meta_dict = self.data_mask(edge_index=edge_index)
return type(self)(edge_list[edge_index], edge_weight=self.edge_weight[edge_index],
num_nodes=num_nodes, num_edges=num_edges, num_residues=num_residues,
view=self.view, num_relation=self.num_relation, offsets=offsets,
meta_dict=meta_dict, **data_dict)
[docs] def graph_mask(self, index, compact=False):
index = self._standarize_index(index, self.batch_size)
graph_mapping = -torch.ones(self.batch_size, dtype=torch.long, device=self.device)
graph_mapping[index] = torch.arange(len(index), device=self.device)
node_index = graph_mapping[self.node2graph] >= 0
node_index = self._standarize_index(node_index, self.num_node)
residue_index = graph_mapping[self.residue2graph] >= 0
residue_index = self._standarize_index(residue_index, self.num_residue)
mapping = -torch.ones(self.num_node, dtype=torch.long, device=self.device)
if compact:
key = graph_mapping[self.node2graph[node_index]] * self.num_node + node_index
order = key.argsort()
node_index = node_index[order]
key = graph_mapping[self.residue2graph[residue_index]] * self.num_residue + residue_index
order = key.argsort()
residue_index = residue_index[order]
mapping[node_index] = torch.arange(len(node_index), device=self.device)
num_nodes = self.num_nodes[index]
num_residues = self.num_residues[index]
else:
mapping[node_index] = node_index
num_nodes = torch.zeros_like(self.num_nodes)
num_nodes[index] = self.num_nodes[index]
num_residues = torch.zeros_like(self.num_residues)
num_residues[index] = self.num_residues[index]
edge_list = self.edge_list.clone()
edge_list[:, :2] = mapping[edge_list[:, :2]]
edge_index = (edge_list[:, :2] >= 0).all(dim=-1)
edge_index = self._standarize_index(edge_index, self.num_edge)
if compact:
key = graph_mapping[self.edge2graph[edge_index]] * self.num_edge + edge_index
order = key.argsort()
edge_index = edge_index[order]
num_edges = self.num_edges[index]
else:
num_edges = torch.zeros_like(self.num_edges)
num_edges[index] = self.num_edges[index]
offsets = self._get_offsets(num_nodes, num_edges)
if compact:
data_dict, meta_dict = self.data_mask(node_index, edge_index,
residue_index=residue_index, graph_index=index)
else:
data_dict, meta_dict = self.data_mask(edge_index=edge_index)
return type(self)(edge_list[edge_index], edge_weight=self.edge_weight[edge_index],
num_nodes=num_nodes, num_edges=num_edges, num_residues=num_residues,
view=self.view, num_relation=self.num_relation, offsets=offsets,
meta_dict=meta_dict, **data_dict)
[docs] def get_item(self, index):
node_index = torch.arange(self.num_cum_nodes[index] - self.num_nodes[index], self.num_cum_nodes[index],
device=self.device)
edge_index = torch.arange(self.num_cum_edges[index] - self.num_edges[index], self.num_cum_edges[index],
device=self.device)
residue_index = torch.arange(self.num_cum_residues[index] - self.num_residues[index],
self.num_cum_residues[index], device=self.device)
graph_index = index
edge_list = self.edge_list[edge_index].clone()
edge_list[:, :2] -= self._offsets[edge_index].unsqueeze(-1)
data_dict, meta_dict = self.data_mask(node_index, edge_index,
residue_index=residue_index, graph_index=graph_index)
return self.unpacked_type(edge_list, edge_weight=self.edge_weight[edge_index], num_node=self.num_nodes[index],
num_relation=self.num_relation, meta_dict=meta_dict, **data_dict)
[docs] @classmethod
@utils.deprecated_alias(node_feature="atom_feature", edge_feature="bond_feature", graph_feature="mol_feature")
def from_molecule(cls, mols, atom_feature="default", bond_feature="default", residue_feature="default",
mol_feature=None, kekulize=False):
"""
Create a packed protein from a list of RDKit objects.
Parameters:
mols (list of rdchem.Mol): molecules
atom_feature (str or list of str, optional): atom features to extract
bond_feature (str or list of str, optional): bond features to extract
residue_feature (str or list of str, optional): residue features to extract
mol_feature (str or list of str, optional): molecule features to extract
kekulize (bool, optional): convert aromatic bonds to single/double bonds.
Note this only affects the relation in ``edge_list``.
For ``bond_type``, aromatic bonds are always stored explicitly.
By default, aromatic bonds are stored.
"""
protein = PackedMolecule.from_molecule(mols, atom_feature=atom_feature, bond_feature=bond_feature,
mol_feature=mol_feature, with_hydrogen=False, kekulize=kekulize)
residue_feature = cls._standarize_option(residue_feature)
residue_type = []
atom_name = []
is_hetero_atom = []
occupancy = []
b_factor = []
atom2residue = []
residue_number = []
insertion_code = []
chain_id = []
_residue_feature = []
last_residue = None
num_residues = []
num_cum_residue = 0
mols = mols + [cls.dummy_protein]
for mol in mols:
if mol is None:
mol = cls.empty_mol
if kekulize:
Chem.Kekulize(mol)
for atom in mol.GetAtoms():
residue = atom.GetPDBResidueInfo()
number = residue.GetResidueNumber()
code = residue.GetInsertionCode()
type = residue.GetResidueName().strip()
canonical_residue = (number, code, type)
if canonical_residue != last_residue:
last_residue = canonical_residue
if type not in cls.residue2id:
warnings.warn("Unknown residue `%s`. Treat as glycine" % type)
type = "GLY"
residue_type.append(cls.residue2id[type])
residue_number.append(number)
insertion_code.append(cls.alphabet2id[residue.GetInsertionCode()])
chain_id.append(cls.alphabet2id[residue.GetChainId()])
feature = []
for name in residue_feature:
func = R.get("features.residue.%s" % name)
feature += func(residue)
_residue_feature.append(feature)
name = residue.GetName().strip()
if name not in cls.atom_name2id:
name = "UNK"
atom_name.append(cls.atom_name2id[name])
is_hetero_atom.append(residue.GetIsHeteroAtom())
occupancy.append(residue.GetOccupancy())
b_factor.append(residue.GetTempFactor())
atom2residue.append(len(residue_type) - 1)
num_residues.append(len(residue_type) - num_cum_residue)
num_cum_residue = len(residue_type)
residue_type = torch.tensor(residue_type)[:-1]
atom_name = torch.tensor(atom_name)[:-5]
is_hetero_atom = torch.tensor(is_hetero_atom)[:-5]
occupancy = torch.tensor(occupancy)[:-5]
b_factor = torch.tensor(b_factor)[:-5]
atom2residue = torch.tensor(atom2residue)[:-5]
residue_number = torch.tensor(residue_number)[:-1]
insertion_code = torch.tensor(insertion_code)[:-1]
chain_id = torch.tensor(chain_id)[:-1]
if len(residue_feature) > 0:
_residue_feature = torch.tensor(_residue_feature)[:-1]
else:
_residue_feature = None
num_residues = num_residues[:-1]
return cls(protein.edge_list, residue_type=residue_type,
num_nodes=protein.num_nodes, num_edges=protein.num_edges, num_residues=num_residues,
atom_name=atom_name, atom2residue=atom2residue, residue_feature=_residue_feature,
is_hetero_atom=is_hetero_atom, occupancy=occupancy, b_factor=b_factor,
residue_number=residue_number, insertion_code=insertion_code, chain_id=chain_id,
offsets=protein._offsets, meta_dict=protein.meta_dict, **protein.data_dict)
@classmethod
def _residue_from_sequence(cls, sequences):
num_residues = []
residue_type = []
residue_feature = []
sequences = sequences + ["G"]
for sequence in sequences:
for residue in sequence:
if residue not in cls.residue_symbol2id:
warnings.warn("Unknown residue symbol `%s`. Treat as glycine" % residue)
residue = "G"
residue_type.append(cls.residue_symbol2id[residue])
residue_feature.append(feature.onehot(residue, cls.residue_symbol2id, allow_unknown=True))
num_residues.append(len(sequence))
residue_type = residue_type[:-1]
residue_feature = torch.tensor(residue_feature)[:-1]
edge_list = torch.zeros(0, 3, dtype=torch.long)
num_nodes = [0] * (len(sequences) - 1)
num_edges = [0] * (len(sequences) - 1)
num_residues = num_residues[:-1]
return cls(edge_list=edge_list, atom_type=[], bond_type=[], residue_type=residue_type,
num_nodes=num_nodes, num_edges=num_edges, num_residues=num_residues,
residue_feature=residue_feature)
[docs] @classmethod
@utils.deprecated_alias(node_feature="atom_feature", edge_feature="bond_feature", graph_feature="mol_feature")
def from_sequence(cls, sequences, atom_feature="default", bond_feature="default", residue_feature="default",
mol_feature=None, kekulize=False):
"""
Create a packed protein from a list of sequences.
.. note::
It takes considerable time to construct proteins with a large number of atoms and bonds.
If you only need residue information, you may speed up the construction by setting
``atom_feature`` and ``bond_feature`` to ``None``.
Parameters:
sequences (str): list of protein sequences
atom_feature (str or list of str, optional): atom features to extract
bond_feature (str or list of str, optional): bond features to extract
residue_feature (str or list of str, optional): residue features to extract
mol_feature (str or list of str, optional): molecule features to extract
kekulize (bool, optional): convert aromatic bonds to single/double bonds.
Note this only affects the relation in ``edge_list``.
For ``bond_type``, aromatic bonds are always stored explicitly.
By default, aromatic bonds are stored.
"""
if atom_feature is None and bond_feature is None and residue_feature == "default":
return cls._residue_from_sequence(sequences)
mols = []
for sequence in sequences:
mol = Chem.MolFromSequence(sequence)
if mol is None:
raise ValueError("Invalid sequence `%s`" % sequence)
mols.append(mol)
return cls.from_molecule(mols, atom_feature, bond_feature, residue_feature, mol_feature, kekulize)
[docs] @classmethod
@utils.deprecated_alias(node_feature="atom_feature", edge_feature="bond_feature", graph_feature="mol_feature")
def from_pdb(cls, pdb_files, atom_feature="default", bond_feature="default", residue_feature="default",
mol_feature=None, kekulize=False):
"""
Create a protein from a list of PDB files.
Parameters:
pdb_files (str): list of file names
atom_feature (str or list of str, optional): atom features to extract
bond_feature (str or list of str, optional): bond features to extract
residue_feature (str, list of str, optional): residue features to extract
mol_feature (str or list of str, optional): molecule features to extract
kekulize (bool, optional): convert aromatic bonds to single/double bonds.
Note this only affects the relation in ``edge_list``.
For ``bond_type``, aromatic bonds are always stored explicitly.
By default, aromatic bonds are stored.
"""
mols = []
for pdb_file in pdb_files:
mol = Chem.MolFromPDBFile(pdb_file)
mols.append(mol)
return cls.from_molecule(mols, atom_feature, bond_feature, residue_feature, mol_feature, kekulize)
[docs] def to_molecule(self, ignore_error=False):
mols = super(PackedProtein, self).to_molecule(ignore_error)
residue_type = self.residue_type.tolist()
atom_name = self.atom_name.tolist()
atom2residue = self.atom2residue.tolist()
is_hetero_atom = self.is_hetero_atom.tolist()
occupancy = self.occupancy.tolist()
b_factor = self.b_factor.tolist()
residue_number = self.residue_number.tolist()
chain_id = self.chain_id.tolist()
insertion_code = self.insertion_code.tolist()
num_cum_nodes = [0] + self.num_cum_nodes.tolist()
for i, mol in enumerate(mols):
for j, atom in enumerate(mol.GetAtoms(), num_cum_nodes[i]):
r = atom2residue[j]
residue = Chem.AtomPDBResidueInfo()
residue.SetResidueNumber(residue_number[r])
residue.SetChainId(self.id2alphabet[chain_id[r]])
residue.SetInsertionCode(self.id2alphabet[insertion_code[r]])
residue.SetName(" %-3s" % self.id2atom_name[atom_name[j]])
residue.SetResidueName(self.id2residue[residue_type[r]])
residue.SetIsHeteroAtom(is_hetero_atom[j])
residue.SetOccupancy(occupancy[j])
residue.SetTempFactor(b_factor[j])
atom.SetPDBResidueInfo(residue)
return mols
[docs] def to_sequence(self):
"""
Return a list of sequences.
Returns:
list of str
"""
residue_type = self.residue_type.tolist()
cc_id = self.connected_component_id.tolist()
num_cum_residues = [0] + self.num_cum_residues.tolist()
sequences = []
for i in range(self.batch_size):
sequence = []
for j in range(num_cum_residues[i], num_cum_residues[i + 1]):
if j > num_cum_residues[i] and cc_id[j] > cc_id[j - 1]:
sequence.append(".")
sequence.append(self.id2residue_symbol[residue_type[j]])
sequence = "".join(sequence)
sequences.append(sequence)
return sequences
[docs] def to_pdb(self, pdb_files):
"""
Write this packed protein to several pdb files.
Parameters:
pdb_files (list of str): list of file names
"""
mols = self.to_molecule()
for mol, pdb_file in zip(mols, pdb_files):
Chem.MolToPDBFile(mol, pdb_file, flavor=10)
[docs] def merge(self, graph2graph):
graph2graph = torch.as_tensor(graph2graph, dtype=torch.long, device=self.device)
# coalesce arbitrary graph IDs to [0, n)
_, graph2graph = torch.unique(graph2graph, return_inverse=True)
graph_key = graph2graph * self.batch_size + torch.arange(self.batch_size, device=self.device)
graph_index = graph_key.argsort()
graph = self.subbatch(graph_index)
graph2graph = graph2graph[graph_index]
num_graph = graph2graph[-1] + 1
num_nodes = scatter_add(graph.num_nodes, graph2graph, dim_size=num_graph)
num_edges = scatter_add(graph.num_edges, graph2graph, dim_size=num_graph)
num_residues = scatter_add(graph.num_residues, graph2graph, dim_size=num_graph)
offsets = self._get_offsets(num_nodes, num_edges)
data_dict, meta_dict = graph.data_mask(exclude="graph")
return type(self)(graph.edge_list, edge_weight=graph.edge_weight, num_nodes=num_nodes,
num_edges=num_edges, num_residues=num_residues, view=self.view, offsets=offsets,
meta_dict=meta_dict, **data_dict)
[docs] def repeat(self, count):
num_nodes = self.num_nodes.repeat(count)
num_edges = self.num_edges.repeat(count)
num_residues = self.num_residues.repeat(count)
offsets = self._get_offsets(num_nodes, num_edges)
edge_list = self.edge_list.repeat(count, 1)
edge_list[:, :2] += (offsets - self._offsets.repeat(count)).unsqueeze(-1)
data_dict = {}
for k, v in self.data_dict.items():
shape = [1] * v.ndim
shape[0] = count
length = len(v)
v = v.repeat(shape)
for _type in self.meta_dict[k]:
if _type == "node reference":
pack_offsets = torch.arange(count, device=self.device) * self.num_node
v = v + pack_offsets.repeat_interleave(length)
elif _type == "edge reference":
pack_offsets = torch.arange(count, device=self.device) * self.num_edge
v = v + pack_offsets.repeat_interleave(length)
elif _type == "residue reference":
pack_offsets = torch.arange(count, device=self.device) * self.num_residue
v = v + pack_offsets.repeat_interleave(length)
elif _type == "graph reference":
pack_offsets = torch.arange(count, device=self.device) * self.batch_size
v = v + pack_offsets.repeat_interleave(length)
data_dict[k] = v
return type(self)(edge_list, edge_weight=self.edge_weight.repeat(count),
num_nodes=num_nodes, num_edges=num_edges, num_residues=num_residues, view=self.view,
num_relation=self.num_relation, offsets=offsets,
meta_dict=self.meta_dict, **data_dict)
[docs] def repeat_interleave(self, repeats):
repeats = torch.as_tensor(repeats, dtype=torch.long, device=self.device)
if repeats.numel() == 1:
repeats = repeats * torch.ones(self.batch_size, dtype=torch.long, device=self.device)
num_nodes = self.num_nodes.repeat_interleave(repeats)
num_edges = self.num_edges.repeat_interleave(repeats)
num_residues = self.num_residues.repeat_interleave(repeats)
num_cum_nodes = num_nodes.cumsum(0)
num_cum_edges = num_edges.cumsum(0)
num_cum_residues = num_residues.cumsum(0)
num_node = num_nodes.sum()
num_edge = num_edges.sum()
num_residue = num_residues.sum()
batch_size = repeats.sum()
num_graphs = torch.ones(batch_size, device=self.device)
# special case 1: graphs[i] may have no node or no edge
# special case 2: repeats[i] may be 0
cum_repeats_shifted = repeats.cumsum(0) - repeats
graph_mask = cum_repeats_shifted < batch_size
cum_repeats_shifted = cum_repeats_shifted[graph_mask]
index = num_cum_nodes - num_nodes
index = torch.cat([index, index[cum_repeats_shifted]])
value = torch.cat([-num_nodes, self.num_nodes[graph_mask]])
mask = index < num_node
node_index = scatter_add(value[mask], index[mask], dim_size=num_node)
node_index = (node_index + 1).cumsum(0) - 1
index = num_cum_edges - num_edges
index = torch.cat([index, index[cum_repeats_shifted]])
value = torch.cat([-num_edges, self.num_edges[graph_mask]])
mask = index < num_edge
edge_index = scatter_add(value[mask], index[mask], dim_size=num_edge)
edge_index = (edge_index + 1).cumsum(0) - 1
index = num_cum_residues - num_residues
index = torch.cat([index, index[cum_repeats_shifted]])
value = torch.cat([-num_residues, self.num_residues[graph_mask]])
mask = index < num_residue
residue_index = scatter_add(value[mask], index[mask], dim_size=num_residue)
residue_index = (residue_index + 1).cumsum(0) - 1
graph_index = torch.repeat_interleave(repeats)
offsets = self._get_offsets(num_nodes, num_edges)
edge_list = self.edge_list[edge_index]
edge_list[:, :2] += (offsets - self._offsets[edge_index]).unsqueeze(-1)
node_offsets = None
edge_offsets = None
residue_offsets = None
graph_offsets = None
data_dict = {}
for k, v in self.data_dict.items():
num_xs = None
pack_offsets = None
for _type in self.meta_dict[k]:
if _type == "node":
v = v[node_index]
num_xs = num_nodes
elif _type == "edge":
v = v[edge_index]
num_xs = num_edges
elif _type == "residue":
v = v[residue_index]
num_xs = num_residues
elif _type == "graph":
v = v[graph_index]
num_xs = num_graphs
elif _type == "node reference":
if node_offsets is None:
node_offsets = self._get_repeat_pack_offsets(self.num_nodes, repeats)
pack_offsets = node_offsets
elif _type == "edge reference":
if edge_offsets is None:
edge_offsets = self._get_repeat_pack_offsets(self.num_edges, repeats)
pack_offsets = edge_offsets
elif _type == "residue reference":
if residue_offsets is None:
residue_offsets = self._get_repeat_pack_offsets(self.num_residues, repeats)
pack_offsets = residue_offsets
elif _type == "graph reference":
if graph_offsets is None:
graph_offsets = self._get_repeat_pack_offsets(num_graphs, repeats)
pack_offsets = graph_offsets
# add offsets to make references point to indexes in their own graph
if num_xs is not None and pack_offsets is not None:
v = v + pack_offsets.repeat_interleave(num_xs)
data_dict[k] = v
return type(self)(edge_list, edge_weight=self.edge_weight[edge_index],
num_nodes=num_nodes, num_edges=num_edges, num_residues=num_residues, view=self.view,
num_relation=self.num_relation, offsets=offsets, meta_dict=self.meta_dict, **data_dict)
[docs] def undirected(self, add_inverse=True):
undirected = PackedMolecule.undirected(self, add_inverse=add_inverse)
return type(self)(undirected.edge_list, edge_weight=undirected.edge_weight,
num_nodes=undirected.num_nodes, num_edges=undirected.num_edges,
num_residues=self.num_residues, view=self.view, num_relation=undirected.num_relation,
offsets=undirected._offsets, meta_dict=undirected.meta_dict, **undirected.data_dict)
[docs] def detach(self):
return type(self)(self.edge_list.detach(), edge_weight=self.edge_weight.detach(),
num_nodes=self.num_nodes, num_edges=self.num_edges, num_residues=self.num_residues,
view=self.view, num_relation=self.num_relation, offsets=self._offsets,
meta_dict=self.meta_dict, **utils.detach(self.data_dict))
[docs] def clone(self):
return type(self)(self.edge_list.clone(), edge_weight=self.edge_weight.clone(),
num_nodes=self.num_nodes, num_edges=self.num_edges, num_residues=self.num_residues,
view=self.view, num_relation=self.num_relation, offsets=self._offsets,
meta_dict=self.meta_dict, **utils.clone(self.data_dict))
[docs] def cuda(self, *args, **kwargs):
edge_list = self.edge_list.cuda(*args, **kwargs)
if edge_list is self.edge_list:
return self
else:
return type(self)(edge_list, edge_weight=self.edge_weight,
num_nodes=self.num_nodes, num_edges=self.num_edges, num_residues=self.num_residues,
view=self.view, num_relation=self.num_relation, offsets=self._offsets,
meta_dict=self.meta_dict, **utils.cuda(self.data_dict, *args, **kwargs))
[docs] def cpu(self):
edge_list = self.edge_list.cpu()
if edge_list is self.edge_list:
return self
else:
return type(self)(edge_list, edge_weight=self.edge_weight,
num_nodes=self.num_nodes, num_edges=self.num_edges, num_residues=self.num_residues,
view=self.view, num_relation=self.num_relation, offsets=self._offsets,
meta_dict=self.meta_dict, **utils.cpu(self.data_dict))
@utils.cached_property
def residue2graph(self):
"""Residue id to graph id mapping."""
range = torch.arange(self.batch_size, device=self.device)
residue2graph = range.repeat_interleave(self.num_residues)
return residue2graph
@utils.cached_property
def connected_component_id(self):
cc_id = super(PackedProtein, self).connected_component_id
cc_id_offsets = scatter_min(cc_id, self.residue2graph, dim_size=self.num_residue)[0][self.residue2graph]
cc_id = cc_id - cc_id_offsets
return cc_id
def __repr__(self):
fields = ["batch_size=%d" % self.batch_size,
"num_atoms=%s" % pretty.long_array(self.num_nodes.tolist()),
"num_bonds=%s" % pretty.long_array(self.num_edges.tolist()),
"num_residues=%s" % pretty.long_array(self.num_residues.tolist())]
if self.device.type != "cpu":
fields.append("device='%s'" % self.device)
return "%s(%s)" % (self.__class__.__name__, ", ".join(fields))
Protein.packed_type = PackedProtein