Source code for torchdrug.models.statistic

import torch
from torch import nn

from torch_scatter import scatter_add

from torchdrug import core, layers, data
from torchdrug.core import Registry as R


[docs]@R.register("models.Statistic") class Statistic(nn.Module, core.Configurable): """ The statistic feature engineering for protein sequence proposed in `Harnessing Computational Biology for Exact Linear B-cell Epitope Prediction`_. .. _Harnessing Computational Biology for Exact Linear B-cell Epitope Prediction: https://www.liebertpub.com/doi/abs/10.1089/omi.2015.0095 Parameters: type (str, optional): statistic feature. Available feature is ``DDE``. hidden_dims (list of int, optional): hidden dimensions """ num_residue_type = len(data.Protein.id2residue_symbol) input_dim = num_residue_type ** 2 _codons = {"A": 4, "C": 2, "D": 2, "E": 2, "F": 2, "G": 4, "H": 2, "I": 3, "K": 2, "L": 6, "M": 1, "N": 2, "P": 4, "Q": 2, "R": 6, "S": 6, "T": 4, "V": 4, "W": 1, "Y": 2} def __init__(self, type="DDE", hidden_dims=(512,)): super(Statistic, self).__init__() self.type = type self.output_dim = hidden_dims[-1] codons = self.calculate_codons() self.register_buffer("codons", codons) self.mlp = layers.Sequential( layers.MultiLayerPerceptron(self.input_dim, hidden_dims), nn.ReLU() ) def calculate_codons(self): codons = [0] * self.num_residue_type for i, token in data.Protein.id2residue_symbol.items(): codons[i] = self._codons[token] codons = torch.tensor(codons) return codons
[docs] def forward(self, graph, input, all_loss=None, metric=None): """ Compute the residue representations and the graph representation(s). Parameters: graph (Protein): :math:`n` protein(s) input (Tensor): input node representations all_loss (Tensor, optional): if specified, add loss to this tensor metric (dict, optional): if specified, output metrics to this dict Returns: dict with ``graph_feature`` field: graph representations of shape :math:`(n, d)` """ input = graph.residue_type index = input[:-1] * self.num_residue_type + input[1:] index = graph.residue2graph[:-1] * self.input_dim + index value = torch.ones(graph.num_residue - 1, dtype=torch.float, device=graph.device) mask = graph.residue2graph[:-1] == graph.residue2graph[1:] feature = scatter_add(value * mask.float(), index, dim=0, dim_size=graph.batch_size * self.input_dim) feature = feature.view(graph.batch_size, self.input_dim) feature = feature / (feature.sum(dim=-1, keepdim=True) + 1e-10) if self.type == "DDE": TM = self.codons.unsqueeze(0) * self.codons.unsqueeze(1) / 61 ** 2 TM = TM.flatten() TV = (TM * (1 - TM)).unsqueeze(0) / (graph.num_residues - 1 + 1e-10).unsqueeze(1) feature = (feature - TM.unsqueeze(0)) / (TV.sqrt() + 1e-10) else: raise ValueError("Unknown statistic feature type `%s`" % self.type) graph_feature = self.mlp(feature) return { "graph_feature": graph_feature, }