import os
import warnings
import torch
from torch import nn
import esm
from torchdrug import core, layers, utils, data
from torchdrug.layers import functional
from torchdrug.core import Registry as R
[docs]@R.register("models.ESM")
class EvolutionaryScaleModeling(nn.Module, core.Configurable):
"""
The protein language model, Evolutionary Scale Modeling (ESM) proposed in
`Biological Structure and Function Emerge from Scaling Unsupervised Learning to 250 Million Protein Sequences`_.
.. _Biological Structure and Function Emerge from Scaling Unsupervised Learning to 250 Million Protein Sequences:
https://www.biorxiv.org/content/10.1101/622803v1.full.pdf
Parameters:
path (str): path to store ESM model weights
model (str, optional): model name. Available model names are ``ESM-1b``, ``ESM-1v`` and ``ESM-1b-regression``.
readout (str, optional): readout function. Available functions are ``sum`` and ``mean``.
"""
url = {
"ESM-1b": "https://dl.fbaipublicfiles.com/fair-esm/models/esm1b_t33_650M_UR50S.pt",
"ESM-1v": "https://dl.fbaipublicfiles.com/fair-esm/models/esm1v_t33_650M_UR90S_1.pt",
"ESM-1b-regression":
"https://dl.fbaipublicfiles.com/fair-esm/regression/esm1b_t33_650M_UR50S-contact-regression.pt",
"ESM-2-8M": "https://dl.fbaipublicfiles.com/fair-esm/models/esm2_t6_8M_UR50D.pt",
"ESM-2-35M": "https://dl.fbaipublicfiles.com/fair-esm/models/esm2_t12_35M_UR50D.pt",
"ESM-2-150M": "https://dl.fbaipublicfiles.com/fair-esm/models/esm2_t30_150M_UR50D.pt",
"ESM-2-650M": "https://dl.fbaipublicfiles.com/fair-esm/models/esm2_t33_650M_UR50D.pt",
"ESM-2-3B": "https://dl.fbaipublicfiles.com/fair-esm/models/esm2_t36_3B_UR50D.pt",
"ESM-2-15B": "https://dl.fbaipublicfiles.com/fair-esm/models/esm2_t48_15B_UR50D.pt",
}
md5 = {
"ESM-1b": "ba8914bc3358cae2254ebc8874ee67f6",
"ESM-1v": "1f04c2d2636b02b544ecb5fbbef8fefd",
"ESM-1b-regression": "e7fe626dfd516fb6824bd1d30192bdb1",
"ESM-2-8M": "8039fc9cee7f71cd2633b13b5a38ff50",
"ESM-2-35M": "a894ddb31522e511e1273abb23b5f974",
"ESM-2-150M": "229fcf8f9f3d4d442215662ca001b906",
"ESM-2-650M": "ba6d997e29db07a2ad9dca20e024b102",
"ESM-2-3B": "d37a0d0dbe7431e48a72072b9180b16b",
"ESM-2-15B": "af61a9c0b792ae50e244cde443b7f4ac",
}
output_dim = {
"ESM-1b": 1280,
"ESM-1v": 1280,
"ESM-2-8M": 320,
"ESM-2-35M": 480,
"ESM-2-150M": 640,
"ESM-2-650M": 1280,
"ESM-2-3B": 2560,
"ESM-2-15B": 5120,
}
num_layer = {
"ESM-1b": 33,
"ESM-1v": 33,
"ESM-2-8M": 6,
"ESM-2-35M": 12,
"ESM-2-150M": 30,
"ESM-2-650M": 33,
"ESM-2-3B": 36,
"ESM-2-15B": 48,
}
max_input_length = 1024 - 2
def __init__(self, path, model="ESM-1b", readout="mean"):
super(EvolutionaryScaleModeling, self).__init__()
path = os.path.expanduser(path)
if not os.path.exists(path):
os.makedirs(path)
self.path = path
_model, alphabet = self.load_weight(path, model)
self.alphabet = alphabet
mapping = self.construct_mapping(alphabet)
self.output_dim = self.output_dim[model]
self.model = _model
self.alphabet = alphabet
self.repr_layer = self.num_layer[model]
self.register_buffer("mapping", mapping)
if readout == "sum":
self.readout = layers.SumReadout("residue")
elif readout == "mean":
self.readout = layers.MeanReadout("residue")
else:
raise ValueError("Unknown readout `%s`" % readout)
def load_weight(self, path, model):
if model not in self.url:
raise ValueError("Unknown model `%s`" % model)
model_file = utils.download(self.url[model], path, md5=self.md5[model])
model_data = torch.load(model_file, map_location="cpu")
if model != "ESM-1v" and not model.startswith("ESM-2"):
regression_model = "%s-regression" % model
regression_file = utils.download(self.url[regression_model], path, md5=self.md5[regression_model])
regression_data = torch.load(regression_file, map_location="cpu")
else:
regression_data = None
model_name = os.path.basename(self.url[model])
return esm.pretrained.load_model_and_alphabet_core(model_name, model_data, regression_data)
def construct_mapping(self, alphabet):
mapping = [-1] * max(len(data.Protein.id2residue_symbol), len(self.alphabet))
for i, token in data.Protein.id2residue_symbol.items():
mapping[i] = alphabet.get_idx(token)
mapping = torch.tensor(mapping)
return mapping
[docs] def forward(self, graph, input, all_loss=None, metric=None):
"""
Compute the residue representations and the graph representation(s).
Parameters:
graph (Protein): :math:`n` protein(s)
input (Tensor): input node representations
all_loss (Tensor, optional): if specified, add loss to this tensor
metric (dict, optional): if specified, output metrics to this dict
Returns:
dict with ``residue_feature`` and ``graph_feature`` fields:
residue representations of shape :math:`(|V_{res}|, d)`, graph representations of shape :math:`(n, d)`
"""
input = graph.residue_type
input = self.mapping[input]
input[input == -1] = graph.residue_type[input == -1]
size = graph.num_residues
if (size > self.max_input_length).any():
warnings.warn("ESM can only encode proteins within %d residues. Truncate the input to fit into ESM."
% self.max_input_length)
starts = size.cumsum(0) - size
size = size.clamp(max=self.max_input_length)
ends = starts + size
mask = functional.multi_slice_mask(starts, ends, graph.num_residue)
input = input[mask]
graph = graph.subresidue(mask)
size_ext = size
if self.alphabet.prepend_bos:
bos = torch.ones(graph.batch_size, dtype=torch.long, device=self.device) * self.alphabet.cls_idx
input, size_ext = functional._extend(bos, torch.ones_like(size_ext), input, size_ext)
if self.alphabet.append_eos:
eos = torch.ones(graph.batch_size, dtype=torch.long, device=self.device) * self.alphabet.eos_idx
input, size_ext = functional._extend(input, size_ext, eos, torch.ones_like(size_ext))
input = functional.variadic_to_padded(input, size_ext, value=self.alphabet.padding_idx)[0]
output = self.model(input, repr_layers=[self.repr_layer])
residue_feature = output["representations"][self.repr_layer]
residue_feature = functional.padded_to_variadic(residue_feature, size_ext)
starts = size_ext.cumsum(0) - size_ext
if self.alphabet.prepend_bos:
starts = starts + 1
ends = starts + size
mask = functional.multi_slice_mask(starts, ends, len(residue_feature))
residue_feature = residue_feature[mask]
graph_feature = self.readout(graph, residue_feature)
return {
"graph_feature": graph_feature,
"residue_feature": residue_feature
}