Source code for torchdrug.tasks.contact_prediction

import torch
from torch import nn
from torch.nn import functional as F

from torchdrug import core, layers, tasks, metrics
from torchdrug.core import Registry as R
from torchdrug.layers import functional


[docs]@R.register("tasks.ContactPrediction") class ContactPrediction(tasks.Task, core.Configurable): """ Predict whether each amino acid pair contact or not in the folding structure. Parameters: model (nn.Module): protein sequence representation model max_length (int, optional): maximal length of sequence. Truncate the sequence if it exceeds this limit. random_truncate (bool, optional): truncate the sequence at a random position. If not, truncate the suffix of the sequence. threshold (float, optional): distance threshold for contact gap (int, optional): sequential distance cutoff for evaluation criterion (str or dict, optional): training criterion. For dict, the key is criterion and the value is the corresponding weight. Available criterion is ``bce``. metric (str or list of str, optional): metric(s). Available metrics are ``accuracy``, ``prec@Lk`` and ``prec@k``. num_mlp_layer (int, optional): number of layers in mlp prediction head verbose (int, optional): output verbose level """ eps = 1e-10 _option_members = {"task", "criterion", "metric"} def __init__(self, model, max_length=500, random_truncate=True, threshold=8.0, gap=6, criterion="bce", metric=("accuracy", "prec@L5"), num_mlp_layer=1, verbose=0): super(ContactPrediction, self).__init__() self.model = model self.max_length = max_length self.random_truncate = random_truncate self.threshold = threshold self.gap = gap self.criterion = criterion self.metric = metric self.num_mlp_layer = num_mlp_layer self.verbose = verbose if hasattr(self.model, "node_output_dim"): model_output_dim = self.model.node_output_dim else: model_output_dim = self.model.output_dim hidden_dims = [model_output_dim] * (self.num_mlp_layer - 1) self.mlp = layers.MLP(2 * model_output_dim, hidden_dims + [1]) def truncate(self, batch): graph = batch["graph"] size = graph.num_residues if (size > self.max_length).any(): if self.random_truncate: starts = (torch.rand(graph.batch_size, device=graph.device) * \ (graph.num_residues - self.max_length).clamp(min=0)).long() ends = torch.min(starts + self.max_length, graph.num_residues) starts = starts + (graph.num_cum_residues - graph.num_residues) ends = ends + (graph.num_cum_residues - graph.num_residues) mask = functional.multi_slice_mask(starts, ends, graph.num_residue) else: starts = size.cumsum(0) - size size = size.clamp(max=self.max_length) ends = starts + size mask = functional.multi_slice_mask(starts, ends, graph.num_residue) graph = graph.subresidue(mask) return { "graph": graph } def forward(self, batch): """""" all_loss = torch.tensor(0, dtype=torch.float32, device=self.device) metric = {} batch = self.truncate(batch) pred = self.predict(batch, all_loss, metric) target = self.target(batch) for criterion, weight in self.criterion.items(): if criterion == "bce": loss = F.binary_cross_entropy_with_logits(pred, target["label"], reduction="none") loss = functional.variadic_mean(loss * target["mask"].float(), size=target["size"]) else: raise ValueError("Unknown criterion `%s`" % criterion) loss = loss.mean() name = tasks._get_criterion_name(criterion) metric[name] = loss all_loss += loss * weight return all_loss, metric def predict(self, batch, all_loss=None, metric=None): graph = batch["graph"] output = self.model(graph, graph.residue_feature.float(), all_loss=all_loss, metric=metric) output = output["residue_feature"] range = torch.arange(graph.num_residue, device=self.device) node_in, node_out = functional.variadic_meshgrid(range, graph.num_residues, range, graph.num_residues) if all_loss is None and node_in.shape[0] > (self.max_length ** 2) * graph.batch_size: # test # split large input to reduce memory cost size = (self.max_length ** 2) * graph.batch_size node_in_splits = node_in.split(size, dim=0) node_out_splits = node_out.split(size, dim=0) pred = [] for _node_in, _node_out in zip(node_in_splits, node_out_splits): prod = output[_node_in] * output[_node_out] diff = (output[_node_in] - output[_node_out]).abs() pairwise_features = torch.cat((prod, diff), -1) _pred = self.mlp(pairwise_features) pred.append(_pred) pred = torch.cat(pred, dim=0) else: prod = output[node_in] * output[node_out] diff = (output[node_in] - output[node_out]).abs() pairwise_features = torch.cat((prod, diff), -1) pred = self.mlp(pairwise_features) return pred.squeeze(-1) def target(self, batch): graph = batch["graph"] valid_mask = graph.mask residue_position = graph.residue_position range = torch.arange(graph.num_residue, device=self.device) node_in, node_out = functional.variadic_meshgrid(range, graph.num_residues, range, graph.num_residues) dist = (residue_position[node_in] - residue_position[node_out]).norm(p=2, dim=-1) label = (dist < self.threshold).float() mask = valid_mask[node_in] & valid_mask[node_out] & ((node_in - node_out).abs() >= self.gap) return { "label": label, "mask": mask, "size": graph.num_residues ** 2 } def evaluate(self, pred, target): label = target["label"] mask = target["mask"] size = functional.variadic_sum(mask.long(), target["size"]) label = label[mask] pred = pred[mask] metric = {} for _metric in self.metric: if _metric == "accuracy": score = (pred > 0) == label score = functional.variadic_mean(score.float(), size).mean() elif _metric.startswith("prec@L"): l = target["size"].sqrt().long() k = int(_metric[7:]) if len(_metric) > 7 else 1 l = torch.div(l, k, rounding_mode="floor") score = metrics.variadic_top_precision(pred, label, size, l).mean() elif _metric.startswith("prec@"): k = int(_metric[5:]) k = torch.full_like(size, k) score = metrics.variadic_top_precision(pred, label, size, k).mean() else: raise ValueError("Unknown criterion `%s`" % _metric) name = tasks._get_metric_name(_metric) metric[name] = score return metric