Source code for torchdrug.tasks.property_prediction

import math
from collections import defaultdict

import torch
from torch import nn
from torch.nn import functional as F
from torch_scatter import scatter_mean

from torchdrug import core, data, layers, tasks, metrics
from torchdrug.core import Registry as R
from torchdrug.layers import functional


[docs]@R.register("tasks.PropertyPrediction") class PropertyPrediction(tasks.Task, core.Configurable): """ Graph / molecule property prediction task. This class is also compatible with semi-supervised learning. Parameters: model (nn.Module): graph representation model task (str, list or dict, optional): training task(s). For dict, the keys are tasks and the values are the corresponding weights. criterion (str, list or dict, optional): training criterion(s). For dict, the keys are criterions and the values are the corresponding weights. Available criterions are ``mse`` and ``bce``. metric (str or list of str, optional): metric(s). Available metrics are ``mae``, ``rmse``, ``auprc`` and ``auroc``. verbose (int, optional): output verbose level """ eps = 1e-10 _option_members = {"task", "criterion", "metric"} def __init__(self, model, task=(), criterion="mse", metric=("mae", "rmse"), num_mlp_layer=1, normalization=True, entity_level="residue", num_class=None, verbose=0): super(PropertyPrediction, self).__init__() self.model = model self.task = task self.criterion = criterion self.metric = metric self.num_mlp_layer = num_mlp_layer self.normalization = normalization self.entity_level = entity_level self.num_class = num_class self.verbose = verbose
[docs] def preprocess(self, train_set, valid_set, test_set): """ Compute the mean and derivation for each task on the training set. """ values = defaultdict(list) for sample in train_set: if not sample.get("labeled", True): continue for task in self.task: if not math.isnan(sample[task]): values[task].append(sample[task]) mean = [] std = [] weight = [] num_class = [] for task, w in self.task.items(): value = torch.tensor(values[task]) mean.append(value.float().mean()) std.append(value.float().std()) weight.append(w) if value.ndim > 1: num_class.append(value.shape[1]) elif value.dtype == torch.long: num_class.append(value.max().item() + 1) else: num_class.append(1) self.register_buffer("mean", torch.as_tensor(mean, dtype=torch.float)) self.register_buffer("std", torch.as_tensor(std, dtype=torch.float)) self.register_buffer("weight", torch.as_tensor(weight, dtype=torch.float)) self.num_class = self.num_class or num_class hidden_dims = [self.model.output_dim] * (self.num_mlp_layer - 1) self.mlp = layers.MLP(self.model.output_dim, hidden_dims + [sum(self.num_class)])
def forward(self, batch): """""" all_loss = torch.tensor(0, dtype=torch.float32, device=self.device) metric = {} pred = self.predict(batch, all_loss, metric) if all([t not in batch for t in self.task]): # unlabeled data return all_loss, metric target = self.target(batch) labeled = ~torch.isnan(target) target[~labeled] = 0 for criterion, weight in self.criterion.items(): if criterion == "mse": if self.normalization: target = (target - self.mean) / self.std loss = F.mse_loss(pred, target, reduction="none") elif criterion == "bce": loss = F.binary_cross_entropy_with_logits(pred, target, reduction="none") elif criterion == "ce": loss = F.cross_entropy(pred, target.long().squeeze(-1), reduction="none").unsqueeze(-1) else: raise ValueError("Unknown criterion `%s`" % criterion) loss = functional.masked_mean(loss, labeled, dim=0) name = tasks._get_criterion_name(criterion) if self.verbose > 0: for t, l in zip(self.task, loss): metric["%s [%s]" % (name, t)] = l loss = (loss * self.weight).sum() / self.weight.sum() metric[name] = loss all_loss += loss * weight return all_loss, metric def predict(self, batch, all_loss=None, metric=None): graph = batch["graph"] if self.entity_level in ["node", "atom"]: input = graph.node_feature.float() elif self.entity_level == "residue": input = graph.residue_feature.float() else: raise ValueError("Unknown entity level `%s`" % self.entity_level) output = self.model(graph, input, all_loss=all_loss, metric=metric) pred = self.mlp(output["graph_feature"]) return pred def target(self, batch): target = torch.stack([batch[t].float() for t in self.task], dim=-1) labeled = batch.get("labeled", torch.ones(len(target), dtype=torch.bool, device=target.device)) target[~labeled] = math.nan return target def evaluate(self, pred, target): labeled = ~torch.isnan(target) metric = {} for _metric in self.metric: if _metric == "mae": if self.normalization: pred = pred * self.std + self.mean score = F.l1_loss(pred, target, reduction="none") score = functional.masked_mean(score, labeled, dim=0) elif _metric == "rmse": if self.normalization: pred = pred * self.std + self.mean score = F.mse_loss(pred, target, reduction="none") score = functional.masked_mean(score, labeled, dim=0).sqrt() elif _metric == "acc": score = [] num_class = 0 for i, cur_num_class in enumerate(self.num_class): _pred = pred[:, num_class:num_class + cur_num_class] _target = target[:, i] _labeled = labeled[:, i] _score = metrics.accuracy(_pred[_labeled], _target[_labeled].long()) score.append(_score) num_class += cur_num_class score = torch.stack(score) elif _metric == "mcc": score = [] num_class = 0 for i, cur_num_class in enumerate(self.num_class): _pred = pred[:, num_class:num_class + cur_num_class] _target = target[:, i] _labeled = labeled[:, i] _score = metrics.matthews_corrcoef(_pred[_labeled], _target[_labeled].long()) score.append(_score) num_class += cur_num_class score = torch.stack(score) elif _metric == "auroc": score = [] for _pred, _target, _labeled in zip(pred.t(), target.long().t(), labeled.t()): _score = metrics.area_under_roc(_pred[_labeled], _target[_labeled]) score.append(_score) score = torch.stack(score) elif _metric == "auprc": score = [] for _pred, _target, _labeled in zip(pred.t(), target.long().t(), labeled.t()): _score = metrics.area_under_prc(_pred[_labeled], _target[_labeled]) score.append(_score) score = torch.stack(score) elif _metric == "r2": score = [] new_pred = pred * self.std + self.mean for _pred, _target, _labeled in zip(new_pred.t(), target.t(), labeled.t()): _score = metrics.r2(_pred[_labeled], _target[_labeled]) score.append(_score) score = torch.stack(score) elif _metric == "spearmanr": score = [] new_pred = pred * self.std + self.mean for _pred, _target, _labeled in zip(new_pred.t(), target.t(), labeled.t()): _score = metrics.spearmanr(_pred[_labeled], _target[_labeled]) score.append(_score) score = torch.stack(score) elif _metric == "pearsonr": score = [] new_pred = pred * self.std + self.mean for _pred, _target, _labeled in zip(new_pred.t(), target.t(), labeled.t()): _score = metrics.pearsonr(_pred[_labeled], _target[_labeled]) score.append(_score) score = torch.stack(score) else: raise ValueError("Unknown criterion `%s`" % _metric) name = tasks._get_metric_name(_metric) for t, s in zip(self.task, score): metric["%s [%s]" % (name, t)] = s return metric
@R.register("tasks.NodePropertyPrediction") class NodePropertyPrediction(tasks.Task, core.Configurable): _option_members = set(["criterion", "metric"]) def __init__(self, model, criterion="bce", metric=("macro_auprc", "macro_auroc"), num_mlp_layer=1, normalization=True, entity_level="residue", num_class=None, use_node_dim=False, verbose=0): super(NodePropertyPrediction, self).__init__() self.model = model self.criterion = criterion self.metric = metric self.normalization = normalization self.num_mlp_layer = num_mlp_layer self.entity_level = entity_level self.num_class = num_class self.use_node_dim = use_node_dim self.verbose = verbose def preprocess(self, train_set, valid_set, test_set): """ Compute the mean and derivation on the training set. """ values = torch.cat([data["graph"].target for data in train_set]) mean = values.float().mean() std = values.float().std() if values.dtype == torch.long: num_class = values.max() + 1 else: num_class = 1 self.register_buffer("mean", torch.as_tensor(mean, dtype=torch.float)) self.register_buffer("std", torch.as_tensor(std, dtype=torch.float)) self.num_class = self.num_class or num_class model_output_dim = self.model.node_output_dim if self.use_node_dim else self.model.output_dim hidden_dims = [model_output_dim] * (self.num_mlp_layer - 1) self.mlp = layers.MLP(model_output_dim, hidden_dims + [sum(self.num_class)]) def predict(self, batch, all_loss=None, metric=None): graph = batch["graph"] if self.entity_level in ["node", "atom"]: input = graph.node_feature.float() elif self.entity_level == "residue": input = graph.residue_feature.float() else: raise ValueError("Unknown entity level `%s`" % self.entity_level) output = self.model(graph, input, all_loss=all_loss, metric=metric) if self.entity_level in ["node", "atom"]: pred = self.mlp(output["node_feature"]) elif self.entity_level == "residue": pred = self.mlp(output["residue_feature"]) else: raise ValueError("Unknown entity level `%s`" % self.entity_level) return pred def target(self, batch): if self.entity_level == "residue": size = batch["graph"].num_residues else: size = batch["graph"].num_nodes return { "label": batch["graph"].target, "mask": batch["graph"].mask, "size": size } def forward(self, batch): all_loss = torch.tensor(0, dtype=torch.float32, device=self.device) metric = {} pred, target = self.predict_and_target(batch, all_loss, metric) labeled = ~torch.isnan(target["label"]) & target["mask"] for criterion, weight in self.criterion.items(): if criterion == "mse": if self.normalization: target = (target - self.mean) / self.std loss = F.mse_loss(pred, target, reduction="none") elif criterion == "bce": loss = F.binary_cross_entropy_with_logits(pred, target['label'].float(), reduction="none") elif criterion == "ce": loss = F.cross_entropy(pred, target['label'], reduction="none") else: raise ValueError("Unknown criterion `%s`" % criterion) loss = functional.masked_mean(loss, labeled, dim=0) name = tasks._get_criterion_name(criterion) metric[name] = loss all_loss += loss * weight all_loss += loss return all_loss, metric def evaluate(self, pred, target): metric = {} _target = target["label"] _labeled = ~torch.isnan(_target) & target["mask"] _size = functional.variadic_sum(_labeled.long(), target["size"]) for _metric in self.metric: if _metric == "micro_acc": score = metrics.accuracy(pred[_labeled], _target[_labeled].long()) elif metric == "micro_auroc": score = metrics.area_under_roc(pred[_labeled], _target[_labeled]) elif metric == "micro_auprc": score = metrics.area_under_prc(pred[_labeled], _target[_labeled]) elif _metric == "macro_auroc": score = metrics.variadic_area_under_roc(pred[_labeled], _target[_labeled], _size).mean() elif _metric == "macro_auprc": score = metrics.variadic_area_under_prc(pred[_labeled], _target[_labeled], _size).mean() elif _metric == "macro_acc": score = pred[_labeled].argmax(-1) == _target[_labeled] score = functional.variadic_mean(score.float(), _size).mean() else: raise ValueError("Unknown criterion `%s`" % _metric) name = tasks._get_metric_name(_metric) metric[name] = score return metric @R.register("tasks.InteractionPrediction") class InteractionPrediction(PropertyPrediction): """ Predict the categorical interaction property of protein pairs. Parameters: seq_model (nn.Module): protein sequence representation model cls_head (nn.Module): the classification head for interaction prediction mean_emb (bool, optional): whether to use mean sequence embedding """ def __init__(self, model, model2=None, task=(), criterion="mse", metric=("mae", "rmse"), num_mlp_layer=1, normalization=True, num_class=None, verbose=0): super(InteractionPrediction, self).__init__(model, task, criterion, metric, num_mlp_layer, normalization, num_class, verbose) self.model2 = model2 or model def preprocess(self, train_set, valid_set, test_set): """ Compute the mean and derivation for each task on the training set. """ values = defaultdict(list) for sample in train_set: if not sample.get("labeled", True): continue for task in self.task: if not math.isnan(sample[task]): values[task].append(sample[task]) mean = [] std = [] weight = [] num_class = [] for task, w in self.task.items(): value = torch.tensor(values[task]) mean.append(value.float().mean()) std.append(value.float().std()) weight.append(w) if value.ndim > 1: num_class.append(value.shape[1]) elif value.dtype == torch.long: num_class.append(value.max().item() + 1) else: num_class.append(1) self.register_buffer("mean", torch.as_tensor(mean, dtype=torch.float)) self.register_buffer("std", torch.as_tensor(std, dtype=torch.float)) self.register_buffer("weight", torch.as_tensor(weight, dtype=torch.float)) self.num_class = self.num_class or num_class hidden_dims = [self.model.output_dim] * (self.num_mlp_layer - 1) self.mlp = layers.MLP(self.model.output_dim + self.model2.output_dim, hidden_dims + [sum(self.num_class)]) def predict(self, batch, all_loss=None, metric=None): graph1 = batch["graph1"] output1 = self.model(graph1, graph1.node_feature.float(), all_loss=all_loss, metric=metric) graph2 = batch["graph2"] output2 = self.model2(graph2, graph2.node_feature.float(), all_loss=all_loss, metric=metric) pred = self.mlp(torch.cat([output1["graph_feature"], output2["graph_feature"]], dim=-1)) return pred @R.register("tasks.AntigenAntibodyInteraction") class AntigenAntibodyInteraction(PropertyPrediction, core.Configurable): def __init__(self, model, task=(), criterion="mse", metric=("acc", "auroc"), num_transformer_layer=4, num_mlp_layer=2, normalization=True, verbose=0): super(AntigenAntibodyInteraction, self).__init__(model, task, criterion, metric, num_mlp_layer, normalization, verbose) encoder_layer = nn.TransformerEncoderLayer(model.output_dim, 8, activation="gelu") self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=num_transformer_layer) def preprocess(self, train_set, valid_set, test_set): values = defaultdict(list) for sample in train_set: if not sample.get("labeled", True): continue for task in self.task: if not math.isnan(sample[task]): values[task].append(sample[task]) mean = [] std = [] weight = [] num_class = [] for task, w in self.task.items(): value = torch.tensor(values[task], dtype=torch.float) mean.append(value.mean()) std.append(value.std()) weight.append(w) if value.ndim > 1: num_class.append(value.shape[1]) elif value.dtype == torch.long: num_class.append(value.max() + 1) else: num_class.append(1) self.register_buffer("mean", torch.as_tensor(mean, dtype=torch.float)) self.register_buffer("std", torch.as_tensor(std, dtype=torch.float)) self.register_buffer("weight", torch.as_tensor(weight, dtype=torch.float)) self.num_class = num_class hidden_dims = [self.model.output_dim] * (self.num_mlp_layer - 1) self.mlp = layers.MLP(self.model.output_dim * 2, hidden_dims + [sum(num_class)]) def forward(self, batch): all_loss = torch.tensor(0, dtype=torch.float32, device=self.device) metric = {} pred = self.predict(batch, all_loss, metric) if all([t not in batch for t in self.task]): # unlabeled data return all_loss, metric target = self.target(batch) labeled = ~torch.isnan(target) target[~labeled] = 0 for criterion, weight in self.criterion.items(): if isinstance(weight, str): task, criterion = criterion, weight weight = 1 # TODO: ensure the order of dict? index = list(self.task.keys()).index(task) mask = slice(sum(self.num_class[:index]), sum(self.num_class[:index + 1])) _pred = pred[:, mask] _target = target[:, mask] _labeled = labeled[:, mask] mean = self.mean[mask] std = self.std[mask] else: _pred = pred _target = target _labeled = labeled mean = self.mean std = self.std if criterion == "mse": if self.normalization: _target = (_target - mean) / std loss = F.mse_loss(_pred, _target, reduction="none") elif criterion == "bce": loss = F.binary_cross_entropy_with_logits(_pred, _target, reduction="none") elif criterion == "ce": loss = F.cross_entropy(_pred, _target.long().squeeze(-1), reduction="none").unsqueeze(-1) else: raise ValueError("Unknown criterion `%s`" % criterion) loss = functional.masked_mean(loss, _labeled, dim=0) name = tasks._get_criterion_name(criterion) if self.verbose > 0: for t, l in zip(self.task, loss): metric["%s [%s]" % (name, t)] = l loss = (loss * self.weight).sum() / self.weight.sum() metric[name] = loss all_loss += loss * weight return all_loss, metric def predict(self, batch, all_loss=None, metric=None): def get_chain_feature(chain): output = self.model(chain, chain.node_feature.float(), all_loss=all_loss, metric=metric) residue_feature = output["residue_feature"] interval = torch.cat([chain.cdr1_interval, chain.cdr2_interval, chain.cdr3_interval], dim=-1) starts, ends = interval.view(-1, 2).t() lengths = ends - starts mask = functional.multi_slice_mask(starts, ends, length=chain.num_residue) feature_type = torch.arange(3, device=self.device).repeat(len(chain)).repeat_interleave(lengths) cdr_feature = residue_feature[mask] cdr_size = lengths.view(-1, 3).sum(-1) antigen_feature = antigen.truncated_residue_feature.mean(dim=1) antigen_size = torch.ones(len(chain), dtype=torch.long, device=self.device) feature = functional._extend(cdr_feature, cdr_size, antigen_feature, antigen_size)[0] antigen_feature_type = torch.ones(len(antigen_feature), dtype=torch.long, device=self.device) * 3 feature_type, size = functional._extend(feature_type, cdr_size, antigen_feature_type, antigen_size) # readout CDR + antigen features into a single representation feature, mask = functional.variadic_to_padded(feature, size) feature = feature.transpose(0, 1) # convert to sequence first feature = self.transformer(feature, src_key_padding_mask=~mask) feature = feature.transpose(0, 1) # convert to batch first cdr_feature = functional.padded_to_variadic(feature, cdr_size) index2chain = functional._size_to_index(cdr_size) feature = scatter_mean(cdr_feature, index2chain, dim=0) return feature heavy, light, antigen = batch["graph"] heavy_feature = get_chain_feature(heavy) light_feature = get_chain_feature(light) antibody_feature = torch.cat([heavy_feature, light_feature], dim=-1) pred = self.mlp(antibody_feature) return pred def evaluate(self, pred, target): labeled = ~torch.isnan(target) metric = {} for _metric, weight in self.metric.items(): if isinstance(weight, str): task, _metric = _metric, weight # TODO: ensure the order of dict? index = list(self.task.keys()).index(task) task = [task] mask = slice(sum(self.num_class[:index]), sum(self.num_class[:index + 1])) _pred = pred[:, mask] _target = target[:, mask] _labeled = labeled[:, mask] mean = self.mean[mask] std = self.std[mask] else: task = self.task _pred = pred _target = target _labeled = labeled mean = self.mean std = self.std if _metric == "mae": if self.normalization: _pred = _pred * std + mean score = F.l1_loss(_pred, _target, reduction="none") score = functional.masked_mean(score, _labeled, dim=0) elif _metric == "rmse": if self.normalization: _pred = _pred * std + mean score = F.mse_loss(_pred, _target, reduction="none") score = functional.masked_mean(score, _labeled, dim=0).sqrt() elif _metric == "acc": assert not isinstance(weight, str) score = [] num_class = 0 for i, cur_num_class in enumerate(self.num_class): _pred = pred[:, num_class:num_class + cur_num_class] _target = target[:, i] _labeled = labeled[:, i] _score = metrics.accuracy(_pred[_labeled], _target[_labeled].long()) score.append(_score) num_class += cur_num_class score = torch.stack(score) elif _metric == "mcc": assert not isinstance(weight, str) score = [] num_class = 0 for i, cur_num_class in enumerate(self.num_class): _pred = pred[:, num_class:num_class + cur_num_class] _target = target[:, i] _labeled = labeled[:, i] _score = metrics.matthews_corrcoef(_pred[_labeled], _target[_labeled].long()) score.append(_score) num_class += cur_num_class score = torch.stack(score) elif _metric == "auroc": score = [] for __pred, __target, __labeled in zip(_pred.t(), _target.long().t(), _labeled.t()): _score = metrics.area_under_roc(__pred[__labeled], __target[__labeled]) score.append(_score) score = torch.stack(score) elif _metric == "auprc": score = [] for __pred, __target, __labeled in zip(_pred.t(), _target.long().t(), _labeled.t()): _score = metrics.area_under_prc(__pred[__labeled], __target[__labeled]) score.append(_score) score = torch.stack(score) elif _metric == "r2": score = [] if self.normalization: _pred = _pred * std + mean for __pred, __target, __labeled in zip(_pred.t(), _target.t(), _labeled.t()): _score = metrics.r2(__pred[__labeled], __target[__labeled]) score.append(_score) score = torch.stack(score) elif _metric == "spearmanr": score = [] if self.normalization: _pred = _pred * std + mean for __pred, __target, __labeled in zip(_pred.t(), _target.t(), _labeled.t()): _score = metrics.spearmanr(__pred[__labeled], __target[__labeled]) score.append(_score) score = torch.stack(score) elif _metric == "pearsonr": score = [] if self.normalization: _pred = _pred * std + mean for __pred, __target, __labeled in zip(_pred.t(), _target.t(), _labeled.t()): _score = metrics.pearsonr(__pred[__labeled], __target[__labeled]) score.append(_score) score = torch.stack(score) else: raise ValueError("Unknown criterion `%s`" % _metric) name = tasks._get_metric_name(_metric) for t, s in zip(task, score): metric["%s [%s]" % (name, t)] = s return metric
[docs]class Unsupervised(nn.Module): """ Wrapper task for unsupervised learning. The unsupervised loss should be computed by the model. Parameters: model (nn.Module): any model """ def __init__(self, model): super(Unsupervised, self).__init__() self.model = model def forward(self, batch): """""" all_loss = torch.tensor(0, dtype=torch.float32, device=self.device) metric = {} pred = self.predict(batch, all_loss, metric) return all_loss, metric def predict(self, batch, all_loss=None, metric=None): graph = batch["graph"] pred = self.model(graph, graph.node_feature.float(), all_loss=all_loss, metric=metric) return pred