# Source code for torchdrug.layers.geometry.function

```import torch
from torch import nn

from torchdrug import core, data
from torchdrug.layers import functional
from torchdrug.core import Registry as R

[docs]@R.register("layers.geometry.BondEdge")
class BondEdge(nn.Module, core.Configurable):
"""
Construct all bond edges.
"""

[docs]    def forward(self, graph):
"""
Return bond edges from the input graph. Edge types are inherited from the input graph.

Parameters:
graph (Graph): :math:`n` graph(s)

Returns:
(Tensor, int): edge list of shape :math:`(|E|, 3)`, number of relations
"""
return graph.edge_list, graph.num_relation

[docs]@R.register("layers.geometry.KNNEdge")
class KNNEdge(nn.Module, core.Configurable):
"""
Construct edges between each node and its nearest neighbors.

Parameters:
k (int, optional): number of neighbors
min_distance (int, optional): minimum distance between the residues of two nodes
"""

eps = 1e-10

def __init__(self, k=10, min_distance=5, max_distance=None):
super(KNNEdge, self).__init__()
self.k = k
self.min_distance = min_distance
self.max_distance = max_distance

[docs]    def forward(self, graph):
"""
Return KNN edges constructed from the input graph.

Parameters:
graph (Graph): :math:`n` graph(s)

Returns:
(Tensor, int): edge list of shape :math:`(|E|, 3)`, number of relations
"""
edge_list = knn_graph(graph.node_position, k=self.k, batch=graph.node2graph).t()
relation = torch.zeros(len(edge_list), 1, dtype=torch.long, device=graph.device)
edge_list = torch.cat([edge_list, relation], dim=-1)

if self.min_distance > 0:
node_in, node_out = edge_list.t()[:2]
mask = (graph.atom2residue[node_in] - graph.atom2residue[node_out]).abs() < self.min_distance

if self.max_distance:
node_in, node_out = edge_list.t()[:2]
mask = (graph.atom2residue[node_in] - graph.atom2residue[node_out]).abs() > self.max_distance

node_in, node_out = edge_list.t()[:2]
mask = (graph.node_position[node_in] - graph.node_position[node_out]).norm(dim=-1) < self.eps

return edge_list, 1

[docs]@R.register("layers.geometry.SpatialEdge")
class SpatialEdge(nn.Module, core.Configurable):
"""
Construct edges between nodes within a specified radius.

Parameters:
min_distance (int, optional): minimum distance between the residues of two nodes
"""

eps = 1e-10

def __init__(self, radius=5, min_distance=5, max_distance=None, max_num_neighbors=32):
super(SpatialEdge, self).__init__()
self.min_distance = min_distance
self.max_distance = max_distance
self.max_num_neighbors = max_num_neighbors

[docs]    def forward(self, graph):
"""
Return spatial radius edges constructed based on the input graph.

Parameters:
graph (Graph): :math:`n` graph(s)

Returns:
(Tensor, int): edge list of shape :math:`(|E|, 3)`, number of relations
"""
relation = torch.zeros(len(edge_list), 1, dtype=torch.long, device=graph.device)
edge_list = torch.cat([edge_list, relation], dim=-1)

if self.min_distance > 0:
node_in, node_out = edge_list.t()[:2]
mask = (graph.atom2residue[node_in] - graph.atom2residue[node_out]).abs() < self.min_distance

if self.max_distance:
node_in, node_out = edge_list.t()[:2]
mask = (graph.atom2residue[node_in] - graph.atom2residue[node_out]).abs() > self.max_distance

node_in, node_out = edge_list.t()[:2]
mask = (graph.node_position[node_in] - graph.node_position[node_out]).norm(dim=-1) < self.eps

return edge_list, 1

[docs]@R.register("layers.geometry.SequentialEdge")
class SequentialEdge(nn.Module, core.Configurable):
"""
Construct edges between atoms within close residues.

Parameters:
max_distance (int, optional): maximum distance between two residues in the sequence
"""

def __init__(self, max_distance=2, only_backbone=False):
super(SequentialEdge, self).__init__()
self.max_distance = max_distance
self.only_backbone = only_backbone

[docs]    def forward(self, graph):
"""
Return sequential edges constructed based on the input graph.
Edge types are defined by the relative distance between two residues in the sequence

Parameters:
graph (Graph): :math:`n` graph(s)

Returns:
(Tensor, int): edge list of shape :math:`(|E|, 3)`, number of relations
"""
if self.only_backbone:
is_backbone = (graph.atom_name == graph.atom_name2id["CA"]) \
| (graph.atom_name == graph.atom_name2id["C"]) \
| (graph.atom_name == graph.atom_name2id["N"])
atom2residue = graph.atom2residue[is_backbone]
else:
atom2residue = graph.atom2residue
residue2num_atom = atom2residue.bincount(minlength=graph.num_residue)
edge_list = []
for i in range(-self.max_distance, self.max_distance + 1):
node_index = torch.arange(graph.num_node, device=graph.device)
residue_index = torch.arange(graph.num_residue, device=graph.device)
if i > 0:
is_node_in = graph.atom2residue < graph.num_cum_residues[graph.atom2graph] - i
is_node_out = graph.atom2residue >= (graph.num_cum_residues - graph.num_residues)[graph.atom2graph] + i
is_residue_in = residue_index < graph.num_cum_residues[graph.residue2graph] - i
is_residue_out = residue_index >= (graph.num_cum_residues - graph.num_residues)[graph.residue2graph] + i
else:
is_node_in = graph.atom2residue >= (graph.num_cum_residues - graph.num_residues)[graph.atom2graph] - i
is_node_out = graph.atom2residue < graph.num_cum_residues[graph.atom2graph] + i
is_residue_in = residue_index >= (graph.num_cum_residues - graph.num_residues)[graph.residue2graph] - i
is_residue_out = residue_index < graph.num_cum_residues[graph.residue2graph] + i
if self.only_backbone:
is_node_in = is_node_in & is_backbone
is_node_out = is_node_out & is_backbone
node_in = node_index[is_node_in]
node_out = node_index[is_node_out]
# group atoms by residue ids
node_in = node_in[graph.atom2residue[node_in].argsort()]
node_out = node_out[graph.atom2residue[node_out].argsort()]
num_node_in = residue2num_atom[is_residue_in]
num_node_out = residue2num_atom[is_residue_out]
node_in, node_out = functional.variadic_meshgrid(node_in, num_node_in, node_out, num_node_out)
# exclude cross-chain edges
is_same_chain = (graph.chain_id[graph.atom2residue[node_in]] == graph.chain_id[graph.atom2residue[node_out]])
node_in = node_in[is_same_chain]
node_out = node_out[is_same_chain]
relation = torch.ones(len(node_in), dtype=torch.long, device=graph.device) * (i + self.max_distance)
edges = torch.stack([node_in, node_out, relation], dim=-1)
edge_list.append(edges)

edge_list = torch.cat(edge_list)

return edge_list, 2 * self.max_distance + 1

[docs]@R.register("layers.geometry.AlphaCarbonNode")
class AlphaCarbonNode(nn.Module, core.Configurable):
"""
Construct only alpha carbon atoms.
"""

[docs]    def forward(self, graph):
"""
Return a subgraph that only consists of alpha carbon nodes.

Parameters:
graph (Graph): :math:`n` graph(s)
"""
mask = (graph.atom_name == data.Protein.atom_name2id["CA"]) & (graph.atom2residue != -1)
assert (graph.num_node == graph.num_residue).all()

return graph

[docs]@R.register("layers.geometry.IdentityNode")
class IdentityNode(nn.Module, core.Configurable):
"""
Construct all nodes as the input.
"""

[docs]    def forward(self, graph):
"""
Return the input graph as is.

Parameters:
graph (Graph): :math:`n` graph(s)
"""
return graph

"""
Construct nodes by random edge masking.

Parameters:
"""

[docs]    def forward(self, graph):
"""
Return a graph with some edges masked out.

Parameters:
graph (Graph): :math:`n` graph(s)
"""
num_sample = num_samples.sum()
sample2graph = torch.repeat_interleave(num_samples)
edge_index = (torch.rand(num_sample, device=graph.device) * graph.num_edges[sample2graph]).long()
edge_index = edge_index + (graph.num_cum_edges - graph.num_edges)[sample2graph]

[docs]@R.register("layers.geometry.SubsequenceNode")
class SubsequenceNode(nn.Module, core.Configurable):
"""
Construct nodes by taking a random subsequence of the original graph.

Parameters:
max_length (int, optional): maximal length of the sequence after cropping
"""

def __init__(self, max_length=100):
super(SubsequenceNode, self).__init__()
self.max_length = max_length

[docs]    def forward(self, graph):
"""
Randomly take a subsequence of the specified length.
Return the full sequence if the sequence is shorter than the specified length.

Parameters:
graph (Graph): :math:`n` graph(s)
"""
starts = (torch.rand(graph.batch_size, device=graph.device) *
(graph.num_residues - self.max_length).clamp(min=0)).long()
ends = torch.min(starts + self.max_length, graph.num_residues)
starts = starts + graph.num_cum_residues - graph.num_residues
ends = ends + graph.num_cum_residues - graph.num_residues

return graph

[docs]@R.register("layers.geometry.SubspaceNode")
class SubspaceNode(nn.Module, core.Configurable):
"""
Construct nodes by taking a spatial ball of the original graph.

Parameters:
entity_level (str, optional): level to perform cropping.
Available options are ``node``, ``atom`` and ``residue``.
min_neighbor (int, optional): minimum number of nodes in the spatial ball
"""

super(SubspaceNode, self).__init__()
self.entity_level = entity_level
self.min_neighbor = min_neighbor

[docs]    def forward(self, graph):
"""
Randomly pick a node as the center, and crop a spatial ball
that is at least `radius` large and contain at least `k` nodes.

Parameters:
graph (Graph): :math:`n` graph(s)
"""
node_in = torch.arange(graph.num_node, device=graph.device)
node_in = node_in.repeat_interleave(graph.num_nodes)
node_out = torch.arange(graph.num_node, device=graph.device)
dist = (graph.node_position[node_in] - graph.node_position[node_out]).norm(dim=-1)
topk_dist = functional.variadic_topk(dist, graph.num_nodes, self.min_neighbor, largest=False)[0]