import math
from collections import defaultdict
import torch
from torch import nn
from torch.nn import functional as F
from torch_scatter import scatter_mean
from torchdrug import core, data, layers, tasks, metrics
from torchdrug.core import Registry as R
from torchdrug.layers import functional
[docs]@R.register("tasks.PropertyPrediction")
class PropertyPrediction(tasks.Task, core.Configurable):
"""
Graph / molecule property prediction task.
This class is also compatible with semi-supervised learning.
Parameters:
model (nn.Module): graph representation model
task (str, list or dict, optional): training task(s).
For dict, the keys are tasks and the values are the corresponding weights.
criterion (str, list or dict, optional): training criterion(s). For dict, the keys are criterions and the values
are the corresponding weights. Available criterions are ``mse`` and ``bce``.
metric (str or list of str, optional): metric(s).
Available metrics are ``mae``, ``rmse``, ``auprc`` and ``auroc``.
verbose (int, optional): output verbose level
"""
eps = 1e-10
_option_members = {"task", "criterion", "metric"}
def __init__(self, model, task=(), criterion="mse", metric=("mae", "rmse"), num_mlp_layer=1,
normalization=True, entity_level="residue", num_class=None, verbose=0):
super(PropertyPrediction, self).__init__()
self.model = model
self.task = task
self.criterion = criterion
self.metric = metric
self.num_mlp_layer = num_mlp_layer
self.normalization = normalization
self.entity_level = entity_level
self.num_class = num_class
self.verbose = verbose
[docs] def preprocess(self, train_set, valid_set, test_set):
"""
Compute the mean and derivation for each task on the training set.
"""
values = defaultdict(list)
for sample in train_set:
if not sample.get("labeled", True):
continue
for task in self.task:
if not math.isnan(sample[task]):
values[task].append(sample[task])
mean = []
std = []
weight = []
num_class = []
for task, w in self.task.items():
value = torch.tensor(values[task])
mean.append(value.float().mean())
std.append(value.float().std())
weight.append(w)
if value.ndim > 1:
num_class.append(value.shape[1])
elif value.dtype == torch.long:
num_class.append(value.max().item() + 1)
else:
num_class.append(1)
self.register_buffer("mean", torch.as_tensor(mean, dtype=torch.float))
self.register_buffer("std", torch.as_tensor(std, dtype=torch.float))
self.register_buffer("weight", torch.as_tensor(weight, dtype=torch.float))
self.num_class = self.num_class or num_class
hidden_dims = [self.model.output_dim] * (self.num_mlp_layer - 1)
self.mlp = layers.MLP(self.model.output_dim, hidden_dims + [sum(self.num_class)])
def forward(self, batch):
""""""
all_loss = torch.tensor(0, dtype=torch.float32, device=self.device)
metric = {}
pred = self.predict(batch, all_loss, metric)
if all([t not in batch for t in self.task]):
# unlabeled data
return all_loss, metric
target = self.target(batch)
labeled = ~torch.isnan(target)
target[~labeled] = 0
for criterion, weight in self.criterion.items():
if criterion == "mse":
if self.normalization:
target = (target - self.mean) / self.std
loss = F.mse_loss(pred, target, reduction="none")
elif criterion == "bce":
loss = F.binary_cross_entropy_with_logits(pred, target, reduction="none")
elif criterion == "ce":
loss = F.cross_entropy(pred, target.long().squeeze(-1), reduction="none").unsqueeze(-1)
else:
raise ValueError("Unknown criterion `%s`" % criterion)
loss = functional.masked_mean(loss, labeled, dim=0)
name = tasks._get_criterion_name(criterion)
if self.verbose > 0:
for t, l in zip(self.task, loss):
metric["%s [%s]" % (name, t)] = l
loss = (loss * self.weight).sum() / self.weight.sum()
metric[name] = loss
all_loss += loss * weight
return all_loss, metric
def predict(self, batch, all_loss=None, metric=None):
graph = batch["graph"]
if self.entity_level in ["node", "atom"]:
input = graph.node_feature.float()
elif self.entity_level == "residue":
input = graph.residue_feature.float()
else:
raise ValueError("Unknown entity level `%s`" % self.entity_level)
output = self.model(graph, input, all_loss=all_loss, metric=metric)
pred = self.mlp(output["graph_feature"])
return pred
def target(self, batch):
target = torch.stack([batch[t].float() for t in self.task], dim=-1)
labeled = batch.get("labeled", torch.ones(len(target), dtype=torch.bool, device=target.device))
target[~labeled] = math.nan
return target
def evaluate(self, pred, target):
labeled = ~torch.isnan(target)
metric = {}
for _metric in self.metric:
if _metric == "mae":
if self.normalization:
pred = pred * self.std + self.mean
score = F.l1_loss(pred, target, reduction="none")
score = functional.masked_mean(score, labeled, dim=0)
elif _metric == "rmse":
if self.normalization:
pred = pred * self.std + self.mean
score = F.mse_loss(pred, target, reduction="none")
score = functional.masked_mean(score, labeled, dim=0).sqrt()
elif _metric == "acc":
score = []
num_class = 0
for i, cur_num_class in enumerate(self.num_class):
_pred = pred[:, num_class:num_class + cur_num_class]
_target = target[:, i]
_labeled = labeled[:, i]
_score = metrics.accuracy(_pred[_labeled], _target[_labeled].long())
score.append(_score)
num_class += cur_num_class
score = torch.stack(score)
elif _metric == "mcc":
score = []
num_class = 0
for i, cur_num_class in enumerate(self.num_class):
_pred = pred[:, num_class:num_class + cur_num_class]
_target = target[:, i]
_labeled = labeled[:, i]
_score = metrics.matthews_corrcoef(_pred[_labeled], _target[_labeled].long())
score.append(_score)
num_class += cur_num_class
score = torch.stack(score)
elif _metric == "auroc":
score = []
for _pred, _target, _labeled in zip(pred.t(), target.long().t(), labeled.t()):
_score = metrics.area_under_roc(_pred[_labeled], _target[_labeled])
score.append(_score)
score = torch.stack(score)
elif _metric == "auprc":
score = []
for _pred, _target, _labeled in zip(pred.t(), target.long().t(), labeled.t()):
_score = metrics.area_under_prc(_pred[_labeled], _target[_labeled])
score.append(_score)
score = torch.stack(score)
elif _metric == "r2":
score = []
new_pred = pred * self.std + self.mean
for _pred, _target, _labeled in zip(new_pred.t(), target.t(), labeled.t()):
_score = metrics.r2(_pred[_labeled], _target[_labeled])
score.append(_score)
score = torch.stack(score)
elif _metric == "spearmanr":
score = []
new_pred = pred * self.std + self.mean
for _pred, _target, _labeled in zip(new_pred.t(), target.t(), labeled.t()):
_score = metrics.spearmanr(_pred[_labeled], _target[_labeled])
score.append(_score)
score = torch.stack(score)
elif _metric == "pearsonr":
score = []
new_pred = pred * self.std + self.mean
for _pred, _target, _labeled in zip(new_pred.t(), target.t(), labeled.t()):
_score = metrics.pearsonr(_pred[_labeled], _target[_labeled])
score.append(_score)
score = torch.stack(score)
else:
raise ValueError("Unknown criterion `%s`" % _metric)
name = tasks._get_metric_name(_metric)
for t, s in zip(self.task, score):
metric["%s [%s]" % (name, t)] = s
return metric
@R.register("tasks.NodePropertyPrediction")
class NodePropertyPrediction(tasks.Task, core.Configurable):
_option_members = set(["criterion", "metric"])
def __init__(self, model, criterion="bce", metric=("macro_auprc", "macro_auroc"), num_mlp_layer=1, normalization=True,
entity_level="residue", num_class=None, use_node_dim=False, verbose=0):
super(NodePropertyPrediction, self).__init__()
self.model = model
self.criterion = criterion
self.metric = metric
self.normalization = normalization
self.num_mlp_layer = num_mlp_layer
self.entity_level = entity_level
self.num_class = num_class
self.use_node_dim = use_node_dim
self.verbose = verbose
def preprocess(self, train_set, valid_set, test_set):
"""
Compute the mean and derivation on the training set.
"""
values = torch.cat([data["graph"].target for data in train_set])
mean = values.float().mean()
std = values.float().std()
if values.dtype == torch.long:
num_class = values.max() + 1
else:
num_class = 1
self.register_buffer("mean", torch.as_tensor(mean, dtype=torch.float))
self.register_buffer("std", torch.as_tensor(std, dtype=torch.float))
self.num_class = self.num_class or num_class
model_output_dim = self.model.node_output_dim if self.use_node_dim else self.model.output_dim
hidden_dims = [model_output_dim] * (self.num_mlp_layer - 1)
self.mlp = layers.MLP(model_output_dim, hidden_dims + [sum(self.num_class)])
def predict(self, batch, all_loss=None, metric=None):
graph = batch["graph"]
if self.entity_level in ["node", "atom"]:
input = graph.node_feature.float()
elif self.entity_level == "residue":
input = graph.residue_feature.float()
else:
raise ValueError("Unknown entity level `%s`" % self.entity_level)
output = self.model(graph, input, all_loss=all_loss, metric=metric)
if self.entity_level in ["node", "atom"]:
pred = self.mlp(output["node_feature"])
elif self.entity_level == "residue":
pred = self.mlp(output["residue_feature"])
else:
raise ValueError("Unknown entity level `%s`" % self.entity_level)
return pred
def target(self, batch):
if self.entity_level == "residue":
size = batch["graph"].num_residues
else:
size = batch["graph"].num_nodes
return {
"label": batch["graph"].target,
"mask": batch["graph"].mask,
"size": size
}
def forward(self, batch):
all_loss = torch.tensor(0, dtype=torch.float32, device=self.device)
metric = {}
pred, target = self.predict_and_target(batch, all_loss, metric)
labeled = ~torch.isnan(target["label"]) & target["mask"]
for criterion, weight in self.criterion.items():
if criterion == "mse":
if self.normalization:
target = (target - self.mean) / self.std
loss = F.mse_loss(pred, target, reduction="none")
elif criterion == "bce":
loss = F.binary_cross_entropy_with_logits(pred, target['label'].float(), reduction="none")
elif criterion == "ce":
loss = F.cross_entropy(pred, target['label'], reduction="none")
else:
raise ValueError("Unknown criterion `%s`" % criterion)
loss = functional.masked_mean(loss, labeled, dim=0)
name = tasks._get_criterion_name(criterion)
metric[name] = loss
all_loss += loss * weight
all_loss += loss
return all_loss, metric
def evaluate(self, pred, target):
metric = {}
_target = target["label"]
_labeled = ~torch.isnan(_target) & target["mask"]
_size = functional.variadic_sum(_labeled.long(), target["size"])
for _metric in self.metric:
if _metric == "micro_acc":
score = metrics.accuracy(pred[_labeled], _target[_labeled].long())
elif metric == "micro_auroc":
score = metrics.area_under_roc(pred[_labeled], _target[_labeled])
elif metric == "micro_auprc":
score = metrics.area_under_prc(pred[_labeled], _target[_labeled])
elif _metric == "macro_auroc":
score = metrics.variadic_area_under_roc(pred[_labeled], _target[_labeled], _size).mean()
elif _metric == "macro_auprc":
score = metrics.variadic_area_under_prc(pred[_labeled], _target[_labeled], _size).mean()
elif _metric == "macro_acc":
score = pred[_labeled].argmax(-1) == _target[_labeled]
score = functional.variadic_mean(score.float(), _size).mean()
else:
raise ValueError("Unknown criterion `%s`" % _metric)
name = tasks._get_metric_name(_metric)
metric[name] = score
return metric
@R.register("tasks.InteractionPrediction")
class InteractionPrediction(PropertyPrediction):
"""
Predict the categorical interaction property of protein pairs.
Parameters:
seq_model (nn.Module): protein sequence representation model
cls_head (nn.Module): the classification head for interaction prediction
mean_emb (bool, optional): whether to use mean sequence embedding
"""
def __init__(self, model, model2=None, task=(), criterion="mse", metric=("mae", "rmse"),
num_mlp_layer=1, normalization=True, num_class=None, verbose=0):
super(InteractionPrediction, self).__init__(model, task, criterion, metric, num_mlp_layer,
normalization, num_class, verbose)
self.model2 = model2 or model
def preprocess(self, train_set, valid_set, test_set):
"""
Compute the mean and derivation for each task on the training set.
"""
values = defaultdict(list)
for sample in train_set:
if not sample.get("labeled", True):
continue
for task in self.task:
if not math.isnan(sample[task]):
values[task].append(sample[task])
mean = []
std = []
weight = []
num_class = []
for task, w in self.task.items():
value = torch.tensor(values[task])
mean.append(value.float().mean())
std.append(value.float().std())
weight.append(w)
if value.ndim > 1:
num_class.append(value.shape[1])
elif value.dtype == torch.long:
num_class.append(value.max().item() + 1)
else:
num_class.append(1)
self.register_buffer("mean", torch.as_tensor(mean, dtype=torch.float))
self.register_buffer("std", torch.as_tensor(std, dtype=torch.float))
self.register_buffer("weight", torch.as_tensor(weight, dtype=torch.float))
self.num_class = self.num_class or num_class
hidden_dims = [self.model.output_dim] * (self.num_mlp_layer - 1)
self.mlp = layers.MLP(self.model.output_dim + self.model2.output_dim, hidden_dims + [sum(self.num_class)])
def predict(self, batch, all_loss=None, metric=None):
graph1 = batch["graph1"]
output1 = self.model(graph1, graph1.node_feature.float(), all_loss=all_loss, metric=metric)
graph2 = batch["graph2"]
output2 = self.model2(graph2, graph2.node_feature.float(), all_loss=all_loss, metric=metric)
pred = self.mlp(torch.cat([output1["graph_feature"], output2["graph_feature"]], dim=-1))
return pred
@R.register("tasks.AntigenAntibodyInteraction")
class AntigenAntibodyInteraction(PropertyPrediction, core.Configurable):
def __init__(self, model, task=(), criterion="mse", metric=("acc", "auroc"), num_transformer_layer=4,
num_mlp_layer=2, normalization=True, verbose=0):
super(AntigenAntibodyInteraction, self).__init__(model, task, criterion, metric, num_mlp_layer, normalization,
verbose)
encoder_layer = nn.TransformerEncoderLayer(model.output_dim, 8, activation="gelu")
self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=num_transformer_layer)
def preprocess(self, train_set, valid_set, test_set):
values = defaultdict(list)
for sample in train_set:
if not sample.get("labeled", True):
continue
for task in self.task:
if not math.isnan(sample[task]):
values[task].append(sample[task])
mean = []
std = []
weight = []
num_class = []
for task, w in self.task.items():
value = torch.tensor(values[task], dtype=torch.float)
mean.append(value.mean())
std.append(value.std())
weight.append(w)
if value.ndim > 1:
num_class.append(value.shape[1])
elif value.dtype == torch.long:
num_class.append(value.max() + 1)
else:
num_class.append(1)
self.register_buffer("mean", torch.as_tensor(mean, dtype=torch.float))
self.register_buffer("std", torch.as_tensor(std, dtype=torch.float))
self.register_buffer("weight", torch.as_tensor(weight, dtype=torch.float))
self.num_class = num_class
hidden_dims = [self.model.output_dim] * (self.num_mlp_layer - 1)
self.mlp = layers.MLP(self.model.output_dim * 2, hidden_dims + [sum(num_class)])
def forward(self, batch):
all_loss = torch.tensor(0, dtype=torch.float32, device=self.device)
metric = {}
pred = self.predict(batch, all_loss, metric)
if all([t not in batch for t in self.task]):
# unlabeled data
return all_loss, metric
target = self.target(batch)
labeled = ~torch.isnan(target)
target[~labeled] = 0
for criterion, weight in self.criterion.items():
if isinstance(weight, str):
task, criterion = criterion, weight
weight = 1
# TODO: ensure the order of dict?
index = list(self.task.keys()).index(task)
mask = slice(sum(self.num_class[:index]), sum(self.num_class[:index + 1]))
_pred = pred[:, mask]
_target = target[:, mask]
_labeled = labeled[:, mask]
mean = self.mean[mask]
std = self.std[mask]
else:
_pred = pred
_target = target
_labeled = labeled
mean = self.mean
std = self.std
if criterion == "mse":
if self.normalization:
_target = (_target - mean) / std
loss = F.mse_loss(_pred, _target, reduction="none")
elif criterion == "bce":
loss = F.binary_cross_entropy_with_logits(_pred, _target, reduction="none")
elif criterion == "ce":
loss = F.cross_entropy(_pred, _target.long().squeeze(-1), reduction="none").unsqueeze(-1)
else:
raise ValueError("Unknown criterion `%s`" % criterion)
loss = functional.masked_mean(loss, _labeled, dim=0)
name = tasks._get_criterion_name(criterion)
if self.verbose > 0:
for t, l in zip(self.task, loss):
metric["%s [%s]" % (name, t)] = l
loss = (loss * self.weight).sum() / self.weight.sum()
metric[name] = loss
all_loss += loss * weight
return all_loss, metric
def predict(self, batch, all_loss=None, metric=None):
def get_chain_feature(chain):
output = self.model(chain, chain.node_feature.float(), all_loss=all_loss, metric=metric)
residue_feature = output["residue_feature"]
interval = torch.cat([chain.cdr1_interval, chain.cdr2_interval, chain.cdr3_interval], dim=-1)
starts, ends = interval.view(-1, 2).t()
lengths = ends - starts
mask = functional.multi_slice_mask(starts, ends, length=chain.num_residue)
feature_type = torch.arange(3, device=self.device).repeat(len(chain)).repeat_interleave(lengths)
cdr_feature = residue_feature[mask]
cdr_size = lengths.view(-1, 3).sum(-1)
antigen_feature = antigen.truncated_residue_feature.mean(dim=1)
antigen_size = torch.ones(len(chain), dtype=torch.long, device=self.device)
feature = functional._extend(cdr_feature, cdr_size, antigen_feature, antigen_size)[0]
antigen_feature_type = torch.ones(len(antigen_feature), dtype=torch.long, device=self.device) * 3
feature_type, size = functional._extend(feature_type, cdr_size, antigen_feature_type, antigen_size)
# readout CDR + antigen features into a single representation
feature, mask = functional.variadic_to_padded(feature, size)
feature = feature.transpose(0, 1) # convert to sequence first
feature = self.transformer(feature, src_key_padding_mask=~mask)
feature = feature.transpose(0, 1) # convert to batch first
cdr_feature = functional.padded_to_variadic(feature, cdr_size)
index2chain = functional._size_to_index(cdr_size)
feature = scatter_mean(cdr_feature, index2chain, dim=0)
return feature
heavy, light, antigen = batch["graph"]
heavy_feature = get_chain_feature(heavy)
light_feature = get_chain_feature(light)
antibody_feature = torch.cat([heavy_feature, light_feature], dim=-1)
pred = self.mlp(antibody_feature)
return pred
def evaluate(self, pred, target):
labeled = ~torch.isnan(target)
metric = {}
for _metric, weight in self.metric.items():
if isinstance(weight, str):
task, _metric = _metric, weight
# TODO: ensure the order of dict?
index = list(self.task.keys()).index(task)
task = [task]
mask = slice(sum(self.num_class[:index]), sum(self.num_class[:index + 1]))
_pred = pred[:, mask]
_target = target[:, mask]
_labeled = labeled[:, mask]
mean = self.mean[mask]
std = self.std[mask]
else:
task = self.task
_pred = pred
_target = target
_labeled = labeled
mean = self.mean
std = self.std
if _metric == "mae":
if self.normalization:
_pred = _pred * std + mean
score = F.l1_loss(_pred, _target, reduction="none")
score = functional.masked_mean(score, _labeled, dim=0)
elif _metric == "rmse":
if self.normalization:
_pred = _pred * std + mean
score = F.mse_loss(_pred, _target, reduction="none")
score = functional.masked_mean(score, _labeled, dim=0).sqrt()
elif _metric == "acc":
assert not isinstance(weight, str)
score = []
num_class = 0
for i, cur_num_class in enumerate(self.num_class):
_pred = pred[:, num_class:num_class + cur_num_class]
_target = target[:, i]
_labeled = labeled[:, i]
_score = metrics.accuracy(_pred[_labeled], _target[_labeled].long())
score.append(_score)
num_class += cur_num_class
score = torch.stack(score)
elif _metric == "mcc":
assert not isinstance(weight, str)
score = []
num_class = 0
for i, cur_num_class in enumerate(self.num_class):
_pred = pred[:, num_class:num_class + cur_num_class]
_target = target[:, i]
_labeled = labeled[:, i]
_score = metrics.matthews_corrcoef(_pred[_labeled], _target[_labeled].long())
score.append(_score)
num_class += cur_num_class
score = torch.stack(score)
elif _metric == "auroc":
score = []
for __pred, __target, __labeled in zip(_pred.t(), _target.long().t(), _labeled.t()):
_score = metrics.area_under_roc(__pred[__labeled], __target[__labeled])
score.append(_score)
score = torch.stack(score)
elif _metric == "auprc":
score = []
for __pred, __target, __labeled in zip(_pred.t(), _target.long().t(), _labeled.t()):
_score = metrics.area_under_prc(__pred[__labeled], __target[__labeled])
score.append(_score)
score = torch.stack(score)
elif _metric == "r2":
score = []
if self.normalization:
_pred = _pred * std + mean
for __pred, __target, __labeled in zip(_pred.t(), _target.t(), _labeled.t()):
_score = metrics.r2(__pred[__labeled], __target[__labeled])
score.append(_score)
score = torch.stack(score)
elif _metric == "spearmanr":
score = []
if self.normalization:
_pred = _pred * std + mean
for __pred, __target, __labeled in zip(_pred.t(), _target.t(), _labeled.t()):
_score = metrics.spearmanr(__pred[__labeled], __target[__labeled])
score.append(_score)
score = torch.stack(score)
elif _metric == "pearsonr":
score = []
if self.normalization:
_pred = _pred * std + mean
for __pred, __target, __labeled in zip(_pred.t(), _target.t(), _labeled.t()):
_score = metrics.pearsonr(__pred[__labeled], __target[__labeled])
score.append(_score)
score = torch.stack(score)
else:
raise ValueError("Unknown criterion `%s`" % _metric)
name = tasks._get_metric_name(_metric)
for t, s in zip(task, score):
metric["%s [%s]" % (name, t)] = s
return metric
[docs]class Unsupervised(nn.Module):
"""
Wrapper task for unsupervised learning.
The unsupervised loss should be computed by the model.
Parameters:
model (nn.Module): any model
"""
def __init__(self, model):
super(Unsupervised, self).__init__()
self.model = model
def forward(self, batch):
""""""
all_loss = torch.tensor(0, dtype=torch.float32, device=self.device)
metric = {}
pred = self.predict(batch, all_loss, metric)
return all_loss, metric
def predict(self, batch, all_loss=None, metric=None):
graph = batch["graph"]
pred = self.model(graph, graph.node_feature.float(), all_loss=all_loss, metric=metric)
return pred