import math
from collections import defaultdict
import torch
from torch import nn
from torch.nn import functional as F
from torchdrug import core, layers, tasks, metrics, utils
from torchdrug.core import Registry as R
from torchdrug.layers import functional
[docs]@R.register("tasks.PropertyPrediction")
class PropertyPrediction(tasks.Task, core.Configurable):
"""
Graph / molecule / protein property prediction task.
This class is also compatible with semi-supervised learning.
Parameters:
model (nn.Module): graph representation model
task (str, list or dict, optional): training task(s).
For dict, the keys are tasks and the values are the corresponding weights.
criterion (str, list or dict, optional): training criterion(s). For dict, the keys are criterions and the values
are the corresponding weights. Available criterions are ``mse``, ``bce`` and ``ce``.
metric (str or list of str, optional): metric(s).
Available metrics are ``mae``, ``rmse``, ``auprc`` and ``auroc``.
num_mlp_layer (int, optional): number of layers in mlp prediction head
normalization (bool, optional): whether to normalize the target
num_class (int, optional): number of classes
mlp_batch_norm (bool, optional): apply batch normalization in mlp or not
mlp_dropout (float, optional): dropout in mlp
graph_construction_model (nn.Module, optional): graph construction model
verbose (int, optional): output verbose level
"""
eps = 1e-10
_option_members = {"task", "criterion", "metric"}
def __init__(self, model, task=(), criterion="mse", metric=("mae", "rmse"), num_mlp_layer=1,
normalization=True, num_class=None, mlp_batch_norm=False, mlp_dropout=0,
graph_construction_model=None, verbose=0):
super(PropertyPrediction, self).__init__()
self.model = model
self.task = task
self.criterion = criterion
self.metric = metric
self.num_mlp_layer = num_mlp_layer
# For classification tasks, we disable normalization tricks.
self.normalization = normalization and ("ce" not in criterion) and ("bce" not in criterion)
self.num_class = (num_class,) if isinstance(num_class, int) else num_class
self.mlp_batch_norm = mlp_batch_norm
self.mlp_dropout = mlp_dropout
self.graph_construction_model = graph_construction_model
self.verbose = verbose
[docs] def preprocess(self, train_set, valid_set, test_set):
"""
Compute the mean and derivation for each task on the training set.
"""
values = defaultdict(list)
for sample in train_set:
if not sample.get("labeled", True):
continue
for task in self.task:
if not math.isnan(sample[task]):
values[task].append(sample[task])
mean = []
std = []
weight = []
num_class = []
for task, w in self.task.items():
value = torch.tensor(values[task])
mean.append(value.float().mean())
std.append(value.float().std())
weight.append(w)
if value.ndim > 1:
num_class.append(value.shape[1])
elif value.dtype == torch.long:
task_class = value.max().item()
if task_class == 1 and "bce" in self.criterion:
num_class.append(1)
else:
num_class.append(task_class + 1)
else:
num_class.append(1)
self.register_buffer("mean", torch.as_tensor(mean, dtype=torch.float))
self.register_buffer("std", torch.as_tensor(std, dtype=torch.float))
self.register_buffer("weight", torch.as_tensor(weight, dtype=torch.float))
self.num_class = self.num_class or num_class
hidden_dims = [self.model.output_dim] * (self.num_mlp_layer - 1)
self.mlp = layers.MLP(self.model.output_dim, hidden_dims + [sum(self.num_class)],
batch_norm=self.mlp_batch_norm, dropout=self.mlp_dropout)
def forward(self, batch):
""""""
all_loss = torch.tensor(0, dtype=torch.float32, device=self.device)
metric = {}
pred = self.predict(batch, all_loss, metric)
if all([t not in batch for t in self.task]):
# unlabeled data
return all_loss, metric
target = self.target(batch)
labeled = ~torch.isnan(target)
target[~labeled] = 0
for criterion, weight in self.criterion.items():
if criterion == "mse":
if self.normalization:
loss = F.mse_loss((pred - self.mean) / self.std, (target - self.mean) / self.std, reduction="none")
else:
loss = F.mse_loss(pred, target, reduction="none")
elif criterion == "bce":
loss = F.binary_cross_entropy_with_logits(pred, target, reduction="none")
elif criterion == "ce":
loss = F.cross_entropy(pred, target.long().squeeze(-1), reduction="none").unsqueeze(-1)
else:
raise ValueError("Unknown criterion `%s`" % criterion)
loss = functional.masked_mean(loss, labeled, dim=0)
name = tasks._get_criterion_name(criterion)
if self.verbose > 0:
for t, l in zip(self.task, loss):
metric["%s [%s]" % (name, t)] = l
loss = (loss * self.weight).sum() / self.weight.sum()
metric[name] = loss
all_loss += loss * weight
return all_loss, metric
def predict(self, batch, all_loss=None, metric=None):
graph = batch["graph"]
if self.graph_construction_model:
graph = self.graph_construction_model(graph)
output = self.model(graph, graph.node_feature.float(), all_loss=all_loss, metric=metric)
pred = self.mlp(output["graph_feature"])
if self.normalization:
pred = pred * self.std + self.mean
return pred
def target(self, batch):
target = torch.stack([batch[t].float() for t in self.task], dim=-1)
labeled = batch.get("labeled", torch.ones(len(target), dtype=torch.bool, device=target.device))
target[~labeled] = math.nan
return target
def evaluate(self, pred, target):
labeled = ~torch.isnan(target)
metric = {}
for _metric in self.metric:
if _metric == "mae":
score = F.l1_loss(pred, target, reduction="none")
score = functional.masked_mean(score, labeled, dim=0)
elif _metric == "rmse":
score = F.mse_loss(pred, target, reduction="none")
score = functional.masked_mean(score, labeled, dim=0).sqrt()
elif _metric == "acc":
score = []
num_class = 0
for i, cur_num_class in enumerate(self.num_class):
_pred = pred[:, num_class:num_class + cur_num_class]
_target = target[:, i]
_labeled = labeled[:, i]
_score = metrics.accuracy(_pred[_labeled], _target[_labeled].long())
score.append(_score)
num_class += cur_num_class
score = torch.stack(score)
elif _metric == "mcc":
score = []
num_class = 0
for i, cur_num_class in enumerate(self.num_class):
_pred = pred[:, num_class:num_class + cur_num_class]
_target = target[:, i]
_labeled = labeled[:, i]
_score = metrics.matthews_corrcoef(_pred[_labeled], _target[_labeled].long())
score.append(_score)
num_class += cur_num_class
score = torch.stack(score)
elif _metric == "auroc":
score = []
for _pred, _target, _labeled in zip(pred.t(), target.long().t(), labeled.t()):
_score = metrics.area_under_roc(_pred[_labeled], _target[_labeled])
score.append(_score)
score = torch.stack(score)
elif _metric == "auprc":
score = []
for _pred, _target, _labeled in zip(pred.t(), target.long().t(), labeled.t()):
_score = metrics.area_under_prc(_pred[_labeled], _target[_labeled])
score.append(_score)
score = torch.stack(score)
elif _metric == "r2":
score = []
for _pred, _target, _labeled in zip(pred.t(), target.t(), labeled.t()):
_score = metrics.r2(_pred[_labeled], _target[_labeled])
score.append(_score)
score = torch.stack(score)
elif _metric == "spearmanr":
score = []
for _pred, _target, _labeled in zip(pred.t(), target.t(), labeled.t()):
_score = metrics.spearmanr(_pred[_labeled], _target[_labeled])
score.append(_score)
score = torch.stack(score)
elif _metric == "pearsonr":
score = []
for _pred, _target, _labeled in zip(pred.t(), target.t(), labeled.t()):
_score = metrics.pearsonr(_pred[_labeled], _target[_labeled])
score.append(_score)
score = torch.stack(score)
else:
raise ValueError("Unknown metric `%s`" % _metric)
name = tasks._get_metric_name(_metric)
for t, s in zip(self.task, score):
metric["%s [%s]" % (name, t)] = s
return metric
[docs]@R.register("tasks.MultipleBinaryClassification")
class MultipleBinaryClassification(tasks.Task, core.Configurable):
"""
Multiple binary classification task for graphs / molecules / proteins.
Parameters:
model (nn.Module): graph representation model
task (list of int, optional): training task id(s).
criterion (list or dict, optional): training criterion(s). For dict, the keys are criterions and the values
are the corresponding weights. Available criterions are ``bce``.
metric (str or list of str, optional): metric(s).
Available metrics are ``auroc@macro``, ``auprc@macro``, ``auroc@micro``, ``auprc@micro`` and ``f1_max``.
num_mlp_layer (int, optional): number of layers in the MLP prediction head
normalization (bool, optional): whether to normalize the target
reweight (bool, optional): whether to re-weight tasks according to the number of positive samples
graph_construction_model (nn.Module, optional): graph construction model
verbose (int, optional): output verbose level
"""
eps = 1e-10
_option_members = {"criterion", "metric"}
def __init__(self, model, task=(), criterion="bce", metric=("auprc@micro", "f1_max"), num_mlp_layer=1,
normalization=True, reweight=False, graph_construction_model=None, verbose=0):
super(MultipleBinaryClassification, self).__init__()
self.model = model
self.task = task
self.register_buffer("task_indices", torch.LongTensor(task))
self.criterion = criterion
self.metric = metric
self.num_mlp_layer = num_mlp_layer
self.normalization = normalization
self.reweight = reweight
self.graph_construction_model = graph_construction_model
self.verbose = verbose
hidden_dims = [self.model.output_dim] * (self.num_mlp_layer - 1)
self.mlp = layers.MLP(self.model.output_dim, hidden_dims + [len(task)])
[docs] def preprocess(self, train_set, valid_set, test_set):
"""
Compute the weight for each task on the training set.
"""
values = []
for data in train_set:
values.append(data["targets"][self.task_indices])
values = torch.stack(values, dim=0)
if self.reweight:
num_positive = values.sum(dim=0)
weight = (num_positive.mean() / num_positive).clamp(1, 10)
else:
weight = torch.ones(len(self.task), dtype=torch.float)
self.register_buffer("weight", torch.as_tensor(weight, dtype=torch.float))
[docs] def forward(self, batch):
all_loss = torch.tensor(0, dtype=torch.float32, device=self.device)
metric = {}
pred = self.predict(batch, all_loss, metric)
target = self.target(batch)
for criterion, weight in self.criterion.items():
if criterion == "bce":
loss = F.binary_cross_entropy_with_logits(pred, target, reduction="none")
else:
raise ValueError("Unknown criterion `%s`" % criterion)
loss = loss.mean(dim=0)
loss = (loss * self.weight).sum() / self.weight.sum()
name = tasks._get_criterion_name(criterion)
metric[name] = loss
all_loss += loss * weight
return all_loss, metric
def predict(self, batch, all_loss=None, metric=None):
graph = batch["graph"]
if self.graph_construction_model:
graph = self.graph_construction_model(graph)
output = self.model(graph, graph.node_feature.float(), all_loss=all_loss, metric=metric)
pred = self.mlp(output["graph_feature"])
return pred
def target(self, batch):
target = batch["targets"][:, self.task_indices]
return target
def evaluate(self, pred, target):
metric = {}
for _metric in self.metric:
if _metric == "auroc@micro":
score = metrics.area_under_roc(pred.flatten(), target.long().flatten())
elif _metric == "auroc@macro":
score = metrics.variadic_area_under_roc(pred, target.long(), dim=0).mean()
elif _metric == "auprc@micro":
score = metrics.area_under_prc(pred.flatten(), target.long().flatten())
elif _metric == "auprc@macro":
score = metrics.variadic_area_under_prc(pred, target.long(), dim=0).mean()
elif _metric == "f1_max":
score = metrics.f1_max(pred, target)
else:
raise ValueError("Unknown criterion `%s`" % _metric)
name = tasks._get_metric_name(_metric)
metric[name] = score
return metric
[docs]@R.register("tasks.NodePropertyPrediction")
class NodePropertyPrediction(tasks.Task, core.Configurable):
"""
Node / atom / residue property prediction task.
Parameters:
model (nn.Module): graph representation model
criterion (str, list or dict, optional): training criterion(s). For dict, the keys are criterions and the values
are the corresponding weights. Available criterions are ``mse``, ``bce`` and ``ce``.
metric (str or list of str, optional): metric(s).
Available metrics are ``mae``, ``rmse``, ``auprc`` and ``auroc``.
num_mlp_layer (int, optional): number of layers in mlp prediction head
normalization (bool, optional): whether to normalize the target
Available entities are ``node``, ``atom`` and ``residue``.
num_class (int, optional): number of classes
verbose (int, optional): output verbose level
"""
_option_members = {"criterion", "metric"}
def __init__(self, model, criterion="bce", metric=("macro_auprc", "macro_auroc"), num_mlp_layer=1,
normalization=True, num_class=None, verbose=0):
super(NodePropertyPrediction, self).__init__()
self.model = model
self.criterion = criterion
self.metric = metric
# For classification tasks, we disable normalization tricks.
self.normalization = normalization and ("ce" not in criterion) and ("bce" not in criterion)
self.num_mlp_layer = num_mlp_layer
self.num_class = num_class
self.verbose = verbose
[docs] def preprocess(self, train_set, valid_set, test_set):
"""
Compute the mean and derivation on the training set.
"""
self.view = getattr(train_set[0]["graph"], "view", "atom")
values = torch.cat([data["graph"].target for data in train_set])
mean = values.float().mean()
std = values.float().std()
if values.dtype == torch.long:
num_class = values.max().item()
if num_class > 1 or "bce" not in self.criterion:
num_class += 1
else:
num_class = 1
self.register_buffer("mean", torch.as_tensor(mean, dtype=torch.float))
self.register_buffer("std", torch.as_tensor(std, dtype=torch.float))
self.num_class = self.num_class or num_class
if hasattr(self.model, "node_output_dim"):
model_output_dim = self.model.node_output_dim
else:
model_output_dim = self.model.output_dim
hidden_dims = [model_output_dim] * (self.num_mlp_layer - 1)
self.mlp = layers.MLP(model_output_dim, hidden_dims + [self.num_class])
def predict(self, batch, all_loss=None, metric=None):
graph = batch["graph"]
output = self.model(graph, graph.node_feature.float(), all_loss=all_loss, metric=metric)
if self.view in ["node", "atom"]:
output_feature = output["node_feature"]
else:
output_feature = output.get("residue_feature", output.get("node_feature"))
pred = self.mlp(output_feature)
if self.normalization:
pred = pred * self.std + self.mean
return pred
def target(self, batch):
size = batch["graph"].num_nodes if self.view in ["node", "atom"] else batch["graph"].num_residues
return {
"label": batch["graph"].target,
"mask": batch["graph"].mask,
"size": size
}
def forward(self, batch):
""""""
all_loss = torch.tensor(0, dtype=torch.float32, device=self.device)
metric = {}
pred, target = self.predict_and_target(batch, all_loss, metric)
labeled = ~torch.isnan(target["label"]) & target["mask"]
for criterion, weight in self.criterion.items():
if criterion == "mse":
if self.normalization:
loss = F.mse_loss((pred - self.mean) / self.std, (target - self.mean) / self.std, reduction="none")
else:
loss = F.mse_loss(pred, target, reduction="none")
elif criterion == "bce":
loss = F.binary_cross_entropy_with_logits(pred, target["label"].float(), reduction="none")
elif criterion == "ce":
loss = F.cross_entropy(pred, target["label"], reduction="none")
else:
raise ValueError("Unknown criterion `%s`" % criterion)
loss = functional.masked_mean(loss, labeled, dim=0)
name = tasks._get_criterion_name(criterion)
metric[name] = loss
all_loss += loss * weight
all_loss += loss
return all_loss, metric
def evaluate(self, pred, target):
metric = {}
_target = target["label"]
_labeled = ~torch.isnan(_target) & target["mask"]
_size = functional.variadic_sum(_labeled.long(), target["size"])
for _metric in self.metric:
if _metric == "micro_acc":
score = metrics.accuracy(pred[_labeled], _target[_labeled].long())
elif metric == "micro_auroc":
score = metrics.area_under_roc(pred[_labeled], _target[_labeled])
elif metric == "micro_auprc":
score = metrics.area_under_prc(pred[_labeled], _target[_labeled])
elif _metric == "macro_auroc":
score = metrics.variadic_area_under_roc(pred[_labeled], _target[_labeled], _size).mean()
elif _metric == "macro_auprc":
score = metrics.variadic_area_under_prc(pred[_labeled], _target[_labeled], _size).mean()
elif _metric == "macro_acc":
score = pred[_labeled].argmax(-1) == _target[_labeled]
score = functional.variadic_mean(score.float(), _size).mean()
else:
raise ValueError("Unknown criterion `%s`" % _metric)
name = tasks._get_metric_name(_metric)
metric[name] = score
return metric
[docs]@R.register("tasks.InteractionPrediction")
@utils.copy_args(PropertyPrediction, ignore=("graph_construction_model",))
class InteractionPrediction(PropertyPrediction):
"""
Predict the interaction property of graph pairs.
Parameters:
model (nn.Module): graph representation model
model2 (nn.Module, optional): graph representation model for the second item. If ``None``, use tied-weight
model for the second item.
**kwargs
"""
def __init__(self, model, model2=None, **kwargs):
super(InteractionPrediction, self).__init__(model, **kwargs)
self.model2 = model2 or model
[docs] def preprocess(self, train_set, valid_set, test_set):
"""
Compute the mean and derivation for each task on the training set.
"""
values = defaultdict(list)
for sample in train_set:
if not sample.get("labeled", True):
continue
for task in self.task:
if not math.isnan(sample[task]):
values[task].append(sample[task])
mean = []
std = []
weight = []
num_class = []
for task, w in self.task.items():
value = torch.tensor(values[task])
mean.append(value.float().mean())
std.append(value.float().std())
weight.append(w)
if value.ndim > 1:
num_class.append(value.shape[1])
elif value.dtype == torch.long:
task_class = value.max().item()
if task_class == 1 and "bce" in self.criterion:
num_class.append(1)
else:
num_class.append(task_class + 1)
else:
num_class.append(1)
self.register_buffer("mean", torch.as_tensor(mean, dtype=torch.float))
self.register_buffer("std", torch.as_tensor(std, dtype=torch.float))
self.register_buffer("weight", torch.as_tensor(weight, dtype=torch.float))
self.num_class = self.num_class or num_class
hidden_dims = [self.model.output_dim] * (self.num_mlp_layer - 1)
self.mlp = layers.MLP(self.model.output_dim + self.model2.output_dim, hidden_dims + [sum(self.num_class)])
def predict(self, batch, all_loss=None, metric=None):
graph1 = batch["graph1"]
output1 = self.model(graph1, graph1.node_feature.float(), all_loss=all_loss, metric=metric)
graph2 = batch["graph2"]
output2 = self.model2(graph2, graph2.node_feature.float(), all_loss=all_loss, metric=metric)
pred = self.mlp(torch.cat([output1["graph_feature"], output2["graph_feature"]], dim=-1))
if self.normalization:
pred = pred * self.std + self.mean
return pred
[docs]@R.register("tasks.Unsupervised")
class Unsupervised(nn.Module, core.Configurable):
"""
Wrapper task for unsupervised learning.
The unsupervised loss should be computed by the model.
Parameters:
model (nn.Module): any model
"""
def __init__(self, model, graph_construction_model=None):
super(Unsupervised, self).__init__()
self.model = model
self.graph_construction_model = graph_construction_model
def forward(self, batch):
""""""
all_loss = torch.tensor(0, dtype=torch.float32, device=self.device)
metric = {}
pred = self.predict(batch, all_loss, metric)
return all_loss, metric
def predict(self, batch, all_loss=None, metric=None):
graph = batch["graph"]
if self.graph_construction_model:
graph = self.graph_construction_model(graph)
pred = self.model(graph, graph.node_feature.float(), all_loss=all_loss, metric=metric)
return pred