Source code for torchdrug.layers.geometry.function

import torch
from torch import nn
from torch_cluster import knn_graph, radius_graph

from torchdrug import core, data
from torchdrug.layers import functional
from torchdrug.core import Registry as R


[docs]@R.register("layers.geometry.BondEdge") class BondEdge(nn.Module, core.Configurable): """ Construct all bond edges. """
[docs] def forward(self, graph): """ Return bond edges from the input graph. Edge types are inherited from the input graph. Parameters: graph (Graph): :math:`n` graph(s) Returns: (Tensor, int): edge list of shape :math:`(|E|, 3)`, number of relations """ return graph.edge_list, graph.num_relation
[docs]@R.register("layers.geometry.KNNEdge") class KNNEdge(nn.Module, core.Configurable): """ Construct edges between each node and its nearest neighbors. Parameters: k (int, optional): number of neighbors min_distance (int, optional): minimum distance between the residues of two nodes """ eps = 1e-10 def __init__(self, k=10, min_distance=5, max_distance=None): super(KNNEdge, self).__init__() self.k = k self.min_distance = min_distance self.max_distance = max_distance
[docs] def forward(self, graph): """ Return KNN edges constructed from the input graph. Parameters: graph (Graph): :math:`n` graph(s) Returns: (Tensor, int): edge list of shape :math:`(|E|, 3)`, number of relations """ edge_list = knn_graph(graph.node_position, k=self.k, batch=graph.node2graph).t() relation = torch.zeros(len(edge_list), 1, dtype=torch.long, device=graph.device) edge_list = torch.cat([edge_list, relation], dim=-1) if self.min_distance > 0: node_in, node_out = edge_list.t()[:2] mask = (graph.atom2residue[node_in] - graph.atom2residue[node_out]).abs() < self.min_distance edge_list = edge_list[~mask] if self.max_distance: node_in, node_out = edge_list.t()[:2] mask = (graph.atom2residue[node_in] - graph.atom2residue[node_out]).abs() > self.max_distance edge_list = edge_list[~mask] node_in, node_out = edge_list.t()[:2] mask = (graph.node_position[node_in] - graph.node_position[node_out]).norm(dim=-1) < self.eps edge_list = edge_list[~mask] return edge_list, 1
[docs]@R.register("layers.geometry.SpatialEdge") class SpatialEdge(nn.Module, core.Configurable): """ Construct edges between nodes within a specified radius. Parameters: radius (float, optional): spatial radius min_distance (int, optional): minimum distance between the residues of two nodes """ eps = 1e-10 def __init__(self, radius=5, min_distance=5, max_distance=None, max_num_neighbors=32): super(SpatialEdge, self).__init__() self.radius = radius self.min_distance = min_distance self.max_distance = max_distance self.max_num_neighbors = max_num_neighbors
[docs] def forward(self, graph): """ Return spatial radius edges constructed based on the input graph. Parameters: graph (Graph): :math:`n` graph(s) Returns: (Tensor, int): edge list of shape :math:`(|E|, 3)`, number of relations """ edge_list = radius_graph(graph.node_position, r=self.radius, batch=graph.node2graph, max_num_neighbors=self.max_num_neighbors).t() relation = torch.zeros(len(edge_list), 1, dtype=torch.long, device=graph.device) edge_list = torch.cat([edge_list, relation], dim=-1) if self.min_distance > 0: node_in, node_out = edge_list.t()[:2] mask = (graph.atom2residue[node_in] - graph.atom2residue[node_out]).abs() < self.min_distance edge_list = edge_list[~mask] if self.max_distance: node_in, node_out = edge_list.t()[:2] mask = (graph.atom2residue[node_in] - graph.atom2residue[node_out]).abs() > self.max_distance edge_list = edge_list[~mask] node_in, node_out = edge_list.t()[:2] mask = (graph.node_position[node_in] - graph.node_position[node_out]).norm(dim=-1) < self.eps edge_list = edge_list[~mask] return edge_list, 1
[docs]@R.register("layers.geometry.SequentialEdge") class SequentialEdge(nn.Module, core.Configurable): """ Construct edges between atoms within close residues. Parameters: max_distance (int, optional): maximum distance between two residues in the sequence """ def __init__(self, max_distance=2, only_backbone=False): super(SequentialEdge, self).__init__() self.max_distance = max_distance self.only_backbone = only_backbone
[docs] def forward(self, graph): """ Return sequential edges constructed based on the input graph. Edge types are defined by the relative distance between two residues in the sequence Parameters: graph (Graph): :math:`n` graph(s) Returns: (Tensor, int): edge list of shape :math:`(|E|, 3)`, number of relations """ if self.only_backbone: is_backbone = (graph.atom_name == graph.atom_name2id["CA"]) \ | (graph.atom_name == graph.atom_name2id["C"]) \ | (graph.atom_name == graph.atom_name2id["N"]) atom2residue = graph.atom2residue[is_backbone] else: atom2residue = graph.atom2residue residue2num_atom = atom2residue.bincount(minlength=graph.num_residue) edge_list = [] for i in range(-self.max_distance, self.max_distance + 1): node_index = torch.arange(graph.num_node, device=graph.device) residue_index = torch.arange(graph.num_residue, device=graph.device) if i > 0: is_node_in = graph.atom2residue < graph.num_cum_residues[graph.atom2graph] - i is_node_out = graph.atom2residue >= (graph.num_cum_residues - graph.num_residues)[graph.atom2graph] + i is_residue_in = residue_index < graph.num_cum_residues[graph.residue2graph] - i is_residue_out = residue_index >= (graph.num_cum_residues - graph.num_residues)[graph.residue2graph] + i else: is_node_in = graph.atom2residue >= (graph.num_cum_residues - graph.num_residues)[graph.atom2graph] - i is_node_out = graph.atom2residue < graph.num_cum_residues[graph.atom2graph] + i is_residue_in = residue_index >= (graph.num_cum_residues - graph.num_residues)[graph.residue2graph] - i is_residue_out = residue_index < graph.num_cum_residues[graph.residue2graph] + i if self.only_backbone: is_node_in = is_node_in & is_backbone is_node_out = is_node_out & is_backbone node_in = node_index[is_node_in] node_out = node_index[is_node_out] # group atoms by residue ids node_in = node_in[graph.atom2residue[node_in].argsort()] node_out = node_out[graph.atom2residue[node_out].argsort()] num_node_in = residue2num_atom[is_residue_in] num_node_out = residue2num_atom[is_residue_out] node_in, node_out = functional.variadic_meshgrid(node_in, num_node_in, node_out, num_node_out) # exclude cross-chain edges is_same_chain = (graph.chain_id[graph.atom2residue[node_in]] == graph.chain_id[graph.atom2residue[node_out]]) node_in = node_in[is_same_chain] node_out = node_out[is_same_chain] relation = torch.ones(len(node_in), dtype=torch.long, device=graph.device) * (i + self.max_distance) edges = torch.stack([node_in, node_out, relation], dim=-1) edge_list.append(edges) edge_list = torch.cat(edge_list) return edge_list, 2 * self.max_distance + 1
[docs]@R.register("layers.geometry.AlphaCarbonNode") class AlphaCarbonNode(nn.Module, core.Configurable): """ Construct only alpha carbon atoms. """
[docs] def forward(self, graph): """ Return a subgraph that only consists of alpha carbon nodes. Parameters: graph (Graph): :math:`n` graph(s) """ mask = (graph.atom_name == data.Protein.atom_name2id["CA"]) & (graph.atom2residue != -1) residue2num_atom = graph.atom2residue[mask].bincount(minlength=graph.num_residue) residue_mask = residue2num_atom > 0 mask = mask & residue_mask[graph.atom2residue] graph = graph.subgraph(mask).subresidue(residue_mask) assert (graph.num_node == graph.num_residue).all() return graph
[docs]@R.register("layers.geometry.IdentityNode") class IdentityNode(nn.Module, core.Configurable): """ Construct all nodes as the input. """
[docs] def forward(self, graph): """ Return the input graph as is. Parameters: graph (Graph): :math:`n` graph(s) """ return graph
[docs]@R.register("layers.geometry.RandomEdgeMask") class RandomEdgeMask(nn.Module, core.Configurable): """ Construct nodes by random edge masking. Parameters: mask_rate (float, optional): rate of masked edges """ def __init__(self, mask_rate=0.15): super(RandomEdgeMask, self).__init__() self.mask_rate = mask_rate
[docs] def forward(self, graph): """ Return a graph with some edges masked out. Parameters: graph (Graph): :math:`n` graph(s) """ num_samples = (graph.num_edges * self.mask_rate).long().clamp(min=1) num_sample = num_samples.sum() sample2graph = torch.repeat_interleave(num_samples) edge_index = (torch.rand(num_sample, device=graph.device) * graph.num_edges[sample2graph]).long() edge_index = edge_index + (graph.num_cum_edges - graph.num_edges)[sample2graph] edge_mask = ~functional.as_mask(edge_index, graph.num_edge) return graph.edge_mask(edge_mask)
[docs]@R.register("layers.geometry.SubsequenceNode") class SubsequenceNode(nn.Module, core.Configurable): """ Construct nodes by taking a random subsequence of the original graph. Parameters: max_length (int, optional): maximal length of the sequence after cropping """ def __init__(self, max_length=100): super(SubsequenceNode, self).__init__() self.max_length = max_length
[docs] def forward(self, graph): """ Randomly take a subsequence of the specified length. Return the full sequence if the sequence is shorter than the specified length. Parameters: graph (Graph): :math:`n` graph(s) """ starts = (torch.rand(graph.batch_size, device=graph.device) * (graph.num_residues - self.max_length).clamp(min=0)).long() ends = torch.min(starts + self.max_length, graph.num_residues) starts = starts + graph.num_cum_residues - graph.num_residues ends = ends + graph.num_cum_residues - graph.num_residues residue_mask = functional.multi_slice_mask(starts, ends, graph.num_residue) graph = graph.subresidue(residue_mask) return graph
[docs]@R.register("layers.geometry.SubspaceNode") class SubspaceNode(nn.Module, core.Configurable): """ Construct nodes by taking a spatial ball of the original graph. Parameters: entity_level (str, optional): level to perform cropping. Available options are ``node``, ``atom`` and ``residue``. min_radius (float, optional): minimum radius of the spatial ball min_neighbor (int, optional): minimum number of nodes in the spatial ball """ def __init__(self, entity_level="node", min_radius=15.0, min_neighbor=50): super(SubspaceNode, self).__init__() self.entity_level = entity_level self.min_radius = min_radius self.min_neighbor = min_neighbor
[docs] def forward(self, graph): """ Randomly pick a node as the center, and crop a spatial ball that is at least `radius` large and contain at least `k` nodes. Parameters: graph (Graph): :math:`n` graph(s) """ node_in = torch.arange(graph.num_node, device=graph.device) node_in = functional.variadic_sample(node_in, graph.num_nodes, 1).squeeze(-1) node_in = node_in.repeat_interleave(graph.num_nodes) node_out = torch.arange(graph.num_node, device=graph.device) dist = (graph.node_position[node_in] - graph.node_position[node_out]).norm(dim=-1) topk_dist = functional.variadic_topk(dist, graph.num_nodes, self.min_neighbor, largest=False)[0] radius = (topk_dist[:, -1] * 1.5).clamp(min=self.min_radius) radius = radius.repeat_interleave(graph.num_nodes) node_index = node_out[dist < radius] if self.entity_level in ["node", "atom"]: graph = graph.subgraph(node_index) else: residue_index = graph.atom2residue[node_index].unique() graph = graph.subresidue(residue_index) return graph