Source code for torchdrug.tasks.pretrain

import copy
import math

import torch
from torch import nn
from torch.nn import functional as F
from torch_scatter import scatter_min

from torchdrug import core, tasks, layers
from torchdrug.data import constant
from torchdrug.layers import functional
from torchdrug.core import Registry as R


[docs]@R.register("tasks.EdgePrediction") class EdgePrediction(tasks.Task, core.Configurable): """ Edge prediction task proposed in `Inductive Representation Learning on Large Graphs`_. .. _Inductive Representation Learning on Large Graphs: https://arxiv.org/abs/1706.02216 Parameters: model (nn.Module): node representation model """ def __init__(self, model): super(EdgePrediction, self).__init__() self.model = model def _get_directed(self, graph): mask = graph.edge_list[:, 0] < graph.edge_list[:, 1] graph = graph.edge_mask(mask) return graph def predict(self, batch, all_loss=None, metric=None): graph = batch["graph"] output = self.model(graph, graph.node_feature.float(), all_loss, metric) node_feature = output["node_feature"] graph = self._get_directed(graph) node_in, node_out = graph.edge_list.t()[:2] neg_index = (torch.rand(2, graph.num_edge, device=self.device) * graph.num_nodes[graph.edge2graph]).long() neg_index = neg_index + (graph.num_cum_nodes - graph.num_nodes)[graph.edge2graph] node_in = torch.cat([node_in, neg_index[0]]) node_out = torch.cat([node_out, neg_index[1]]) pred = torch.einsum("bd, bd -> b", node_feature[node_in], node_feature[node_out]) return pred def target(self, batch): graph = batch["graph"] target = torch.ones(graph.num_edge, device=self.device) target[len(target) // 2:] = 0 return target def evaluate(self, pred, target): metric = {} accuracy = ((pred > 0) == (target > 0.5)).float().mean() name = tasks._get_metric_name("acc") metric[name] = accuracy return metric def forward(self, batch): """""" all_loss = torch.tensor(0, dtype=torch.float32, device=self.device) metric = {} pred = self.predict(batch, all_loss, metric) target = self.target(batch) loss = F.binary_cross_entropy_with_logits(pred, target) name = tasks._get_criterion_name("bce") metric[name] = loss metric.update(self.evaluate(pred, target)) all_loss += loss return all_loss, metric
[docs]@R.register("tasks.AttributeMasking") class AttributeMasking(tasks.Task, core.Configurable): """ Attribute masking proposed in `Strategies for Pre-training Graph Neural Networks`_. .. _Strategies for Pre-training Graph Neural Networks: https://arxiv.org/abs/1905.12265 Parameters: model (nn.Module): node representation model mask_rate (float, optional): rate of masked nodes num_mlp_layer (int, optional): number of MLP layers """ def __init__(self, model, mask_rate=0.15, num_mlp_layer=2, graph_construction_model=None): super(AttributeMasking, self).__init__() self.model = model self.mask_rate = mask_rate self.num_mlp_layer = num_mlp_layer self.graph_construction_model = graph_construction_model def preprocess(self, train_set, valid_set, test_set): data = train_set[0] self.view = getattr(data["graph"], "view", "atom") if hasattr(self.model, "node_output_dim"): model_output_dim = self.model.node_output_dim else: model_output_dim = self.model.output_dim if self.view == "atom": num_label = constant.NUM_ATOM else: num_label = constant.NUM_AMINO_ACID self.mlp = layers.MLP(model_output_dim, [model_output_dim] * (self.num_mlp_layer - 1) + [num_label]) def predict_and_target(self, batch, all_loss=None, metric=None): graph = batch["graph"] if self.graph_construction_model: graph = self.graph_construction_model.apply_node_layer(graph) num_nodes = graph.num_nodes if self.view in ["atom", "node"] else graph.num_residues num_cum_nodes = num_nodes.cumsum(0) num_samples = (num_nodes * self.mask_rate).long().clamp(1) num_sample = num_samples.sum() sample2graph = torch.repeat_interleave(num_samples) node_index = (torch.rand(num_sample, device=self.device) * num_nodes[sample2graph]).long() node_index = node_index + (num_cum_nodes - num_nodes)[sample2graph] if self.view == "atom": target = graph.atom_type[node_index] input = graph.node_feature.float() input[node_index] = 0 else: target = graph.residue_type[node_index] with graph.residue(): graph.residue_feature[node_index] = 0 graph.residue_type[node_index] = 0 # Generate masked edge features. Any better implementation? if self.graph_construction_model: graph = self.graph_construction_model.apply_edge_layer(graph) input = graph.residue_feature.float() output = self.model(graph, input, all_loss, metric) if self.view in ["node", "atom"]: node_feature = output["node_feature"] else: node_feature = output.get("residue_feature", output.get("node_feature")) node_feature = node_feature[node_index] pred = self.mlp(node_feature) return pred, target def evaluate(self, pred, target): metric = {} accuracy = (pred.argmax(dim=-1) == target).float().mean() name = tasks._get_metric_name("acc") metric[name] = accuracy return metric def forward(self, batch): """""" all_loss = torch.tensor(0, dtype=torch.float32, device=self.device) metric = {} pred, target = self.predict_and_target(batch, all_loss, metric) metric.update(self.evaluate(pred, target)) loss = F.cross_entropy(pred, target) name = tasks._get_criterion_name("ce") metric[name] = loss all_loss += loss return all_loss, metric
[docs]@R.register("tasks.ContextPrediction") class ContextPrediction(tasks.Task, core.Configurable): """ Context prediction task proposed in `Strategies for Pre-training Graph Neural Networks`_. .. _Strategies for Pre-training Graph Neural Networks: https://arxiv.org/abs/1905.12265 For a given center node, the subgraph is defined as a k-hop neighborhood (inclusive) around the selected node. The context graph is defined as the surrounding graph structure between r1- (exclusive) and r2-hop (inclusive) from the center node. Nodes between k- and r1-hop are picked as anchor nodes for the context representation. Parameters: model (nn.Module): node representation model for subgraphs. context_model (nn.Module, optional): node representation model for context graphs. By default, use the same architecture as ``model`` without parameter sharing. k (int, optional): radius for subgraphs r1 (int, optional): inner radius for context graphs r2 (int, optional): outer radius for context graphs readout (nn.Module, optional): readout function over context anchor nodes num_negative (int, optional): number of negative samples per positive sample """ def __init__(self, model, context_model=None, k=5, r1=4, r2=7, readout="mean", num_negative=1): super(ContextPrediction, self).__init__() self.model = model self.k = k self.r1 = r1 self.r2 = r2 self.num_negative = num_negative assert r1 < k < r2 if context_model is None: self.context_model = copy.deepcopy(model) else: self.context_model = context_model if readout == "sum": self.readout = layers.SumReadout() elif readout == "mean": self.readout = layers.MeanReadout() else: raise ValueError("Unknown readout `%s`" % readout) def substruct_and_context(self, graph): center_index = (torch.rand(len(graph), device=self.device) * graph.num_nodes).long() center_index = center_index + graph.num_cum_nodes - graph.num_nodes dist = torch.full((graph.num_node,), self.r2 + 1, dtype=torch.long, device=self.device) dist[center_index] = 0 # single source shortest path node_in, node_out = graph.edge_list.t()[:2] for i in range(self.r2): new_dist = scatter_min(dist[node_in], node_out, dim_size=graph.num_node)[0] + 1 dist = torch.min(dist, new_dist) substruct_mask = dist <= self.k context_mask = (dist > self.r1) & (dist <= self.r2) is_center_node = functional.as_mask(center_index, graph.num_node) is_anchor_node = (dist > self.r1) & (dist <= self.k) substruct = graph.clone() context = graph.clone() with substruct.node(): substruct.is_center_node = is_center_node with context.node(): context.is_anchor_node = is_anchor_node substruct = substruct.subgraph(substruct_mask) context = context.subgraph(context_mask) valid = context.num_nodes > 0 substruct = substruct[valid] context = context[valid] return substruct, context def predict_and_target(self, batch, all_loss=None, metric=None): graph = batch["graph"] substruct, context = self.substruct_and_context(graph) anchor = context.subgraph(context.is_anchor_node) substruct_output = self.model(substruct, substruct.node_feature.float(), all_loss, metric) substruct_feature = substruct_output["node_feature"][substruct.is_center_node] context_output = self.context_model(context, context.node_feature.float(), all_loss, metric) anchor_feature = context_output["node_feature"][context.is_anchor_node] context_feature = self.readout(anchor, anchor_feature) shift = torch.arange(self.num_negative, device=self.device) + 1 neg_index = (torch.arange(len(context), device=self.device).unsqueeze(-1) + shift) % len(context) # (batch_size, num_negative) context_feature = torch.cat([context_feature.unsqueeze(1), context_feature[neg_index]], dim=1) substruct_feature = substruct_feature.unsqueeze(1).expand_as(context_feature) pred = torch.einsum("bnd, bnd -> bn", substruct_feature, context_feature) target = torch.zeros_like(pred) target[:, 0] = 1 return pred, target def evaluate(self, pred, target): metric = {} accuracy = ((pred > 0) == (target > 0.5)).float().mean() name = tasks._get_metric_name("acc") metric[name] = accuracy return metric def forward(self, batch): """""" all_loss = torch.tensor(0, dtype=torch.float32, device=self.device) metric = {} pred, target = self.predict_and_target(batch, all_loss, metric) metric.update(self.evaluate(pred, target)) loss = F.binary_cross_entropy_with_logits(pred, target) name = tasks._get_criterion_name("bce") metric[name] = loss all_loss += loss return all_loss, metric
[docs]@R.register("tasks.DistancePrediction") class DistancePrediction(tasks.Task, core.Configurable): """ Pairwise spatial distance prediction task proposed in `Protein Representation Learning by Geometric Structure Pretraining`_. .. _Protein Representation Learning by Geometric Structure Pretraining: https://arxiv.org/pdf/2203.06125.pdf Randomly select some edges and predict the lengths of the edges using the representations of two nodes. The selected edges are removed from the input graph to prevent trivial solutions. Parameters: model (nn.Module): node representation model num_sample (int, optional): number of edges selected from each graph num_mlp_layer (int, optional): number of MLP layers in distance predictor graph_construction_model (nn.Module, optional): graph construction model """ def __init__(self, model, num_sample=256, num_mlp_layer=2, graph_construction_model=None): super(DistancePrediction, self).__init__() self.model = model self.num_sample = num_sample self.num_mlp_layer = num_mlp_layer self.graph_construction_model = graph_construction_model self.mlp = layers.MLP(2 * model.output_dim, [model.output_dim] * (num_mlp_layer - 1) + [1]) def predict_and_target(self, batch, all_loss=None, metric=None): graph = batch["graph"] if self.graph_construction_model: graph = self.graph_construction_model(graph) node_in, node_out = graph.edge_list[:, :2].t() indices = torch.arange(graph.num_edge, device=self.device) indices = functional.variadic_sample(indices, graph.num_edges, self.num_sample).flatten(-2, -1) node_i = node_in[indices] node_j = node_out[indices] graph = graph.edge_mask(~functional.as_mask(indices, graph.num_edge)) # Calculate distance target = (graph.node_position[node_i] - graph.node_position[node_j]).norm(p=2, dim=-1) output = self.model(graph, graph.node_feature.float() , all_loss, metric)["node_feature"] node_feature = torch.cat([output[node_i], output[node_j]], dim=-1) pred = self.mlp(node_feature).squeeze(-1) return pred, target def evaluate(self, pred, target): metric = {} mse = F.mse_loss(pred, target) name = tasks._get_metric_name("mse") metric[name] = mse return metric
[docs] def forward(self, batch): all_loss = torch.tensor(0, dtype=torch.float32, device=self.device) metric = {} pred, target = self.predict_and_target(batch, all_loss, metric) metric.update(self.evaluate(pred, target)) loss = F.mse_loss(pred, target) name = tasks._get_criterion_name("mse") metric[name] = loss all_loss += loss return all_loss, metric
[docs]@R.register("tasks.AnglePrediction") class AnglePrediction(tasks.Task, core.Configurable): """ Angle prediction task proposed in `Protein Representation Learning by Geometric Structure Pretraining`_. .. _Protein Representation Learning by Geometric Structure Pretraining: https://arxiv.org/pdf/2203.06125.pdf Randomly select pairs of adjacent edges and predict the angles between them using the representations of three nodes. The selected edges are removed from the input graph to prevent trivial solutions. Parameters: model (nn.Module): node representation model num_sample (int, optional): number of edge pairs selected from each graph num_class (int, optional): number of classes to discretize the angles num_mlp_layer (int, optional): number of MLP layers in angle predictor graph_construction_model (nn.Module, optional): graph construction model """ def __init__(self, model, num_sample=256, num_class=8, num_mlp_layer=2, graph_construction_model=None): super(AnglePrediction, self).__init__() self.model = model self.num_sample = num_sample self.num_mlp_layer = num_mlp_layer self.graph_construction_model = graph_construction_model boundary = torch.arange(0, math.pi, math.pi / num_class) self.register_buffer("boundary", boundary) self.mlp = layers.MLP(3 * model.output_dim, [model.output_dim] * (num_mlp_layer - 1) + [num_class]) def predict_and_target(self, batch, all_loss=None, metric=None): graph = batch["graph"] if self.graph_construction_model: graph = self.graph_construction_model(graph) node_in, node_out = graph.edge_list[:, :2].t() line_graph = graph.line_graph() edge_in, edge_out = line_graph.edge_list[:, :2].t() is_self_loop1 = (edge_in == edge_out) is_self_loop2 = (node_in[edge_in] == node_out[edge_out]) is_remove = is_self_loop1 | is_self_loop2 line_graph = line_graph.edge_mask(~is_remove) edge_in, edge_out = line_graph.edge_list[:, :2].t() # (k->j) - (j->i) node_i = node_out[edge_out] node_j = node_in[edge_out] node_k = node_in[edge_in] indices = torch.arange(line_graph.num_edge, device=self.device) indices = functional.variadic_sample(indices, line_graph.num_edges, self.num_sample).flatten(-2, -1) node_i = node_i[indices] node_j = node_j[indices] node_k = node_k[indices] mask = torch.ones((graph.num_edge,), device=graph.device, dtype=torch.bool) mask[edge_out[indices]] = 0 mask[edge_in[indices]] = 0 graph = graph.edge_mask(mask) # Calculate angles vector1 = graph.node_position[node_i] - graph.node_position[node_j] vector2 = graph.node_position[node_k] - graph.node_position[node_j] x = (vector1 * vector2).sum(dim=-1) y = torch.cross(vector1, vector2).norm(dim=-1) angle = torch.atan2(y, x) target = torch.bucketize(angle, self.boundary, right=True) - 1 output = self.model(graph, graph.node_feature.float() , all_loss, metric)["node_feature"] node_feature = torch.cat([output[node_i], output[node_j], output[node_k]], dim=-1) pred = self.mlp(node_feature) return pred, target def evaluate(self, pred, target): metric = {} accuracy = (pred.argmax(dim=-1) == target).float().mean() name = tasks._get_metric_name("acc") metric[name] = accuracy return metric
[docs] def forward(self, batch): all_loss = torch.tensor(0, dtype=torch.float32, device=self.device) metric = {} pred, target = self.predict_and_target(batch, all_loss, metric) metric.update(self.evaluate(pred, target)) loss = F.cross_entropy(pred, target) name = tasks._get_criterion_name("ce") metric[name] = loss all_loss += loss return all_loss, metric
[docs]@R.register("tasks.DihedralPrediction") class DihedralPrediction(tasks.Task, core.Configurable): """ Dihedral prediction task proposed in `Protein Representation Learning by Geometric Structure Pretraining`_. .. _Protein Representation Learning by Geometric Structure Pretraining: https://arxiv.org/pdf/2203.06125.pdf Randomly select three consecutive edges and predict the dihedrals among them using the representations of four nodes. The selected edges are removed from the input graph to prevent trivial solutions. Parameters: model (nn.Module): node representation model num_sample (int, optional): number of edge triplets selected from each graph num_class (int, optional): number of classes for discretizing the dihedrals num_mlp_layer (int, optional): number of MLP layers in dihedral angle predictor graph_construction_model (nn.Module, optional): graph construction model """ def __init__(self, model, num_sample=256, num_class=8, num_mlp_layer=2, graph_construction_model=None): super(DihedralPrediction, self).__init__() self.model = model self.num_sample = num_sample self.num_mlp_layer = num_mlp_layer self.graph_construction_model = graph_construction_model boundary = torch.arange(0, math.pi, math.pi / num_class) self.register_buffer("boundary", boundary) self.mlp = layers.MLP(4 * model.output_dim, [model.output_dim] * (num_mlp_layer - 1) + [num_class]) def predict_and_target(self, batch, all_loss=None, metric=None): graph = batch["graph"] if self.graph_construction_model: graph = self.graph_construction_model(graph) node_in, node_out = graph.edge_list[:, :2].t() line_graph = graph.line_graph() edge_in, edge_out = line_graph.edge_list[:, :2].t() is_self_loop1 = (edge_in == edge_out) is_self_loop2 = (node_in[edge_in] == node_out[edge_out]) is_remove = is_self_loop1 | is_self_loop2 line_graph = line_graph.edge_mask(~is_remove) edge_in, edge_out = line_graph.edge_list[:, :2].t() line2_graph = line_graph.line_graph() edge2_in, edge2_out = line2_graph.edge_list.t()[:2] is_self_loop1 = (edge2_in == edge2_out) is_self_loop2 = (edge_in[edge2_in] == edge_out[edge2_out]) is_remove = is_self_loop1 | is_self_loop2 line2_graph = line2_graph.edge_mask(~is_remove) edge2_in, edge2_out = line2_graph.edge_list[:, :2].t() # (k->t->j) - (t->j->i) node_i = node_out[edge_out[edge2_out]] node_j = node_in[edge_out[edge2_out]] node_t = node_in[edge_out[edge2_in]] node_k = node_in[edge_in[edge2_in]] indices = torch.arange(line2_graph.num_edge, device=self.device) indices = functional.variadic_sample(indices, line2_graph.num_edges, self.num_sample).flatten(-2, -1) node_i = node_i[indices] node_j = node_j[indices] node_t = node_t[indices] node_k = node_k[indices] mask = torch.ones((graph.num_edge,), device=graph.device, dtype=torch.bool) mask[edge_out[edge2_out[indices]]] = 0 mask[edge_out[edge2_in[indices]]] = 0 mask[edge_in[edge2_in[indices]]] = 0 graph = graph.edge_mask(mask) v_ctr = graph.node_position[node_t] - graph.node_position[node_j] # (A, 3) v1 = graph.node_position[node_i] - graph.node_position[node_j] v2 = graph.node_position[node_k] - graph.node_position[node_t] n1 = torch.cross(v_ctr, v1, dim=-1) # Normal vectors of the two planes n2 = torch.cross(v_ctr, v2, dim=-1) a = (n1 * n2).sum(dim=-1) b = torch.cross(n1, n2).norm(dim=-1) dihedral = torch.atan2(b, a) target = torch.bucketize(dihedral, self.boundary, right=True) - 1 output = self.model(graph, graph.node_feature.float() , all_loss, metric)["node_feature"] node_feature = torch.cat([output[node_i], output[node_j], output[node_k], output[node_t]], dim=-1) pred = self.mlp(node_feature) return pred, target def evaluate(self, pred, target): metric = {} accuracy = (pred.argmax(dim=-1) == target).float().mean() name = tasks._get_metric_name("acc") metric[name] = accuracy return metric
[docs] def forward(self, batch): all_loss = torch.tensor(0, dtype=torch.float32, device=self.device) metric = {} pred, target = self.predict_and_target(batch, all_loss, metric) metric.update(self.evaluate(pred, target)) loss = F.cross_entropy(pred, target) name = tasks._get_criterion_name("ce") metric[name] = loss all_loss += loss return all_loss, metric