Source code for torchdrug.transforms.transform

import copy
import logging
from collections import deque

import torch

from torchdrug import core
from torchdrug.core import Registry as R


logger = logging.getLogger(__name__)


[docs]@R.register("transforms.NormalizeTarget") class NormalizeTarget(core.Configurable): """ Normalize the target values in a sample. Parameters: mean (dict of float): mean of targets std (dict of float): standard deviation of targets """ def __init__(self, mean, std): self.mean = mean self.std = std def __call__(self, item): item = item.copy() for k in self.mean: if k in item: item[k] = (item[k] - self.mean[k]) / self.std[k] else: raise ValueError("Can't find target `%s` in data item" % k) return item
[docs]@R.register("transforms.RemapAtomType") class RemapAtomType(core.Configurable): """ Map atom types to their index in a vocabulary. Atom types that don't present in the vocabulary are mapped to -1. Parameters: atom_types (array_like): vocabulary of atom types """ def __init__(self, atom_types): atom_types = torch.as_tensor(atom_types) self.id2atom = atom_types self.atom2id = - torch.ones(atom_types.max() + 1, dtype=torch.long, device=atom_types.device) self.atom2id[atom_types] = torch.arange(len(atom_types), device=atom_types.device) def __call__(self, item): graph = copy.copy(item["graph"]) graph.atom_type = self.atom2id[graph.atom_type] item = item.copy() item["graph"] = graph return item
[docs]@R.register("transforms.RandomBFSOrder") class RandomBFSOrder(core.Configurable): """ Order the nodes in a graph according to a random BFS order. """ def __call__(self, item): graph = item["graph"] edge_list = graph.edge_list[:, :2].tolist() neighbor = [[] for _ in range(graph.num_node)] for h, t in edge_list: neighbor[h].append(t) depth = [-1] * graph.num_node i = torch.randint(graph.num_node, (1,)).item() queue = deque([i]) depth[i] = 0 order = [] while queue: h = queue.popleft() order.append(h) for t in neighbor[h]: if depth[t] == -1: depth[t] = depth[h] + 1 queue.append(t) item = item.copy() item["graph"] = graph.subgraph(order) return item
[docs]@R.register("transforms.Shuffle") class Shuffle(core.Configurable): """ Shuffle the order of nodes and edges in a graph. Parameters: shuffle_node (bool, optional): shuffle node order or not shuffle_edge (bool, optional): shuffle edge order or not """ def __init__(self, shuffle_node=True, shuffle_edge=True): self.shuffle_node = shuffle_node self.shuffle_edge = shuffle_edge def __call__(self, item): graph = item["graph"] data = self.transform_data(graph.data_dict, graph.meta) item = item.copy() item["graph"] = type(graph)(**data) return item def transform_data(self, data, meta): edge_list = data["edge_list"] num_node = data["num_node"] num_edge = data["num_edge"] if self.shuffle_edge: node_perm = torch.randperm(num_node, device=edge_list.device) else: node_perm = torch.arange(num_node, device=edge_list.device) if self.shuffle_edge: edge_perm = torch.randperm(num_edge, device=edge_list.device) else: edge_perm = torch.randperm(num_edge, device=edge_list.device) new_data = {} for key in data: if meta[key] == "node": new_data[key] = data[key][node_perm] elif meta[key] == "edge": new_data[key] = node_perm[data[key][edge_perm]] else: new_data[key] = data[key] return new_data
[docs]@R.register("transforms.VirtualNode") class VirtualNode(core.Configurable): """ Add a virtual node and connect it with every node in the graph. Parameters: relation (int, optional): relation of virtual edges. By default, use the maximal relation in the graph plus 1. weight (int, optional): weight of virtual edges node_feature (array_like, optional): feature of the virtual node edge_feature (array_like, optional): feature of virtual edges kwargs: other attributes of the virtual node or virtual edges """ def __init__(self, relation=None, weight=1, node_feature=None, edge_feature=None, **kwargs): self.relation = relation self.weight = weight self.default = {k: torch.as_tensor(v) for k, v in kwargs.items()} if node_feature is not None: self.default["node_feature"] = torch.as_tensor(node_feature) if edge_feature is not None: self.default["edge_feature"] = torch.as_tensor(edge_feature) def __call__(self, item): graph = item["graph"] edge_list = graph.edge_list edge_weight = graph.edge_weight num_node = graph.num_node num_relation = graph.num_relation existing_node = torch.arange(num_node, device=edge_list.device) virtual_node = torch.ones(num_node, dtype=torch.long, device=edge_list.device) * num_node node_in = torch.cat([virtual_node, existing_node]) node_out = torch.cat([existing_node, virtual_node]) if edge_list.shape[1] == 2: new_edge = torch.stack([node_in, node_out], dim=-1) else: if self.relation is None: relation = num_relation num_relation = num_relation + 1 else: relation = self.relation relation = relation * torch.ones(num_node * 2, dtype=torch.long, device=edge_list.device) new_edge = torch.stack([node_in, node_out, relation], dim=-1) edge_list = torch.cat([edge_list, new_edge]) new_edge_weight = self.weight * torch.ones(num_node * 2, device=edge_weight.device) edge_weight = torch.cat([edge_weight, new_edge_weight]) # add default node/edge attributes data = graph.data_dict.copy() for key, value in graph.meta.items(): if value == "node": if key in self.default: new_data = self.default[key].unsqueeze(0) else: new_data = torch.zeros(1, *data[key].shape[1:], dtype=data[key].dtype, device=data[key].device) data[key] = torch.cat([data[key], new_data]) elif value == "edge": if key in self.default: repeat = [-1] * (data[key].ndim - 1) new_data = self.default[key].expand(num_node * 2, *repeat) else: new_data = torch.zeros(num_node * 2, *data[key].shape[1:], dtype=data[key].dtype, device=data[key].device) data[key] = torch.cat([data[key], new_data]) graph = type(graph)(edge_list, edge_weight=edge_weight, num_node=num_node + 1, num_relation=num_relation, meta=graph.meta, **data) item = item.copy() item["graph"] = graph return item
[docs]@R.register("transforms.VirtualAtom") class VirtualAtom(VirtualNode, core.Configurable): """ Add a virtual atom and connect it with every atom in the molecule. Parameters: atom_type (int, optional): type of the virtual atom bond_type (int, optional): type of the virtual bonds node_feature (array_like, optional): feature of the virtual atom edge_feature (array_like, optional): feature of virtual bonds kwargs: other attributes of the virtual atoms or virtual bonds """ def __init__(self, atom_type=None, bond_type=None, node_feature=None, edge_feature=None, **kwargs): super(VirtualAtom, self).__init__(relation=bond_type, weight=1, node_feature=node_feature, edge_feature=edge_feature, atom_type=atom_type, **kwargs)
[docs]@R.register("transforms.TruncateProtein") class TruncateProtein(core.Configurable): """ Truncate over long protein sequences into a fixed length. Parameters: max_length (int, optional): maximal length of the sequence. Truncate the sequence if it exceeds this limit. random (bool, optional): truncate the sequence at a random position. If not, truncate the suffix of the sequence. keys (str or list of str, optional): keys for the items that require truncation in a sample """ def __init__(self, max_length=None, random=False, keys="graph"): self.truncate_length = max_length self.random = random if isinstance(keys, str): keys = [keys] self.keys = keys def __call__(self, item): new_item = item.copy() for key in self.keys: graph = item[key] if graph.num_residue > self.truncate_length: if self.random: start = torch.randint(graph.num_residue - self.truncate_length, (1,)).item() else: start = 0 end = start + self.truncate_length mask = torch.zeros(graph.num_residue, dtype=torch.bool, device=graph.device) mask[start:end] = True graph = graph.subresidue(mask) new_item[key] = graph return new_item
[docs]@R.register("transforms.ProteinView") class ProteinView(core.Configurable): """ Convert proteins to a specific view. Parameters: view (str): protein view. Can be ``atom`` or ``residue``. keys (str or list of str, optional): keys for the items that require view change in a sample """ def __init__(self, view, keys="graph"): self.view = view if isinstance(keys, str): keys = [keys] self.keys = keys def __call__(self, item): item = item.copy() for key in self.keys: graph = copy.copy(item[key]) graph.view = self.view item[key] = graph return item
[docs]@R.register("transforms.Compose") class Compose(core.Configurable): """ Compose a list of transforms into one. Parameters: transforms (list of callable): list of transforms """ def __init__(self, transforms): # flatten recursive composition new_transforms = [] for transform in transforms: if isinstance(transform, Compose): new_transforms += transform.transforms elif transform is not None: new_transforms.append(transform) self.transforms = new_transforms def __call__(self, item): for transform in self.transforms: item = transform(item) return item