Source code for torchdrug.metrics.metric

import torch
from torch.nn import functional as F
from torch_scatter import scatter_add, scatter_mean, scatter_max
import networkx as nx
from rdkit import Chem
from rdkit.Chem import Descriptors

from torchdrug import utils
from torchdrug.layers import functional
from torchdrug.core import Registry as R
from torchdrug.metrics.rdkit import sascorer


[docs]@R.register("metrics.auroc") def area_under_roc(pred, target): """ Area under receiver operating characteristic curve (ROC). Parameters: pred (Tensor): predictions of shape :math:`(n,)` target (Tensor): binary targets of shape :math:`(n,)` """ order = pred.argsort(descending=True) target = target[order] hit = target.cumsum(0) all = (target == 0).sum() * (target == 1).sum() auroc = hit[target == 0].sum() / (all + 1e-10) return auroc
[docs]@R.register("metrics.auprc") def area_under_prc(pred, target): """ Area under precision-recall curve (PRC). Parameters: pred (Tensor): predictions of shape :math:`(n,)` target (Tensor): binary targets of shape :math:`(n,)` """ order = pred.argsort(descending=True) target = target[order] precision = target.cumsum(0) / torch.arange(1, len(target) + 1, device=target.device) auprc = precision[target == 1].sum() / ((target == 1).sum() + 1e-10) return auprc
[docs]@R.register("metrics.r2") def r2(pred, target): """ :math:`R^2` regression score. Parameters: pred (Tensor): predictions of shape :math:`(n,)` target (Tensor): targets of shape :math:`(n,)` """ total = torch.var(target, unbiased=False) residual = F.mse_loss(pred, target) return 1 - residual / total
[docs]@R.register("metrics.logp") def logP(pred): """ Logarithm of partition coefficient between octanol and water for a compound. Parameters: pred (PackedMolecule): molecules to evaluate """ logp = [] for mol in pred: mol = mol.to_molecule() try: with utils.no_rdkit_log(): mol.UpdatePropertyCache() score = Descriptors.MolLogP(mol) except Chem.AtomValenceException: score = 0 logp.append(score) return torch.tensor(logp, dtype=torch.float, device=pred.device)
[docs]@R.register("metrics.plogp") def penalized_logP(pred): """ Logarithm of partition coefficient, penalized by cycle length and synthetic accessibility. Parameters: pred (PackedMolecule): molecules to evaluate """ # statistics from ZINC250k logp_mean = 2.4570953396190123 logp_std = 1.434324401111988 sa_mean = 3.0525811293166134 sa_std = 0.8335207024513095 cycle_mean = 0.0485696876403053 cycle_std = 0.2860212110245455 plogp = [] for mol in pred: cycles = nx.cycle_basis(nx.Graph(mol.edge_list[:, :2].tolist())) if cycles: max_cycle = max([len(cycle) for cycle in cycles]) cycle = max(0, max_cycle - 6) else: cycle = 0 mol = mol.to_molecule() try: with utils.no_rdkit_log(): mol.UpdatePropertyCache() Chem.GetSymmSSSR(mol) logp = Descriptors.MolLogP(mol) sa = sascorer.calculateScore(mol) logp = (logp - logp_mean) / logp_std sa = (sa - sa_mean) / sa_std cycle = (cycle - cycle_mean) / cycle_std score = logp - sa - cycle except Chem.AtomValenceException: score = -30 plogp.append(score) return torch.tensor(plogp, dtype=torch.float, device=pred.device)
[docs]@R.register("metrics.SA") def SA(pred): """ Synthetic accesibility score. Parameters: pred (PackedMolecule): molecules to evaluate """ sa = [] for mol in pred: with utils.no_rdkit_log(): score = sascorer.calculateScore(mol.to_molecule()) sa.append(score) return torch.tensor(sa, dtype=torch.float, device=pred.device)
[docs]@R.register("metrics.qed") def QED(pred): """ Quantitative estimation of drug-likeness. Parameters: pred (PackedMolecule): molecules to evaluate """ qed = [] for mol in pred: try: with utils.no_rdkit_log(): score = Descriptors.qed(mol.to_molecule()) except Chem.AtomValenceException: score = -1 qed.append(score) return torch.tensor(qed, dtype=torch.float, device=pred.device)
[docs]@R.register("metrics.validity") def chemical_validity(pred): """ Chemical validity of molecules. Parameters: pred (PackedMolecule): molecules to evaluate """ validity = [] for i, mol in enumerate(pred): with utils.no_rdkit_log(): smiles = mol.to_smiles() mol = Chem.MolFromSmiles(smiles) validity.append(1 if mol else 0) return torch.tensor(validity, dtype=torch.float, device=pred.device)
[docs]@R.register("metrics.variadic_auroc") def variadic_area_under_roc(pred, target, size): """ Area under receiver operating characteristic curve (ROC) for sets with variadic sizes. Suppose there are :math:`N` sets, and the sizes of all sets are summed to :math:`B`. Parameters: pred (Tensor): prediction of shape :math:`(B,)` target (Tensor): target of shape :math:`(B,)`. size (Tensor): size of sets of shape :math:`(N,)` """ index2graph = torch.repeat_interleave(size) _, order = functional.variadic_sort(pred, size, descending=True) cum_size = (size.cumsum(0) - size)[index2graph] target = target[order + cum_size] total_hit = functional.variadic_sum(target, size) total_hit = total_hit.cumsum(0) - total_hit hit = target.cumsum(0) - total_hit[index2graph] hit = torch.where(target == 0, hit, torch.zeros_like(hit)) all = functional.variadic_sum((target == 0).float(), size) * \ functional.variadic_sum((target == 1).float(), size) auroc = functional.variadic_sum(hit, size) / (all + 1e-10) return auroc
[docs]@R.register("metrics.variadic_auprc") def variadic_area_under_prc(pred, target, size): """ Area under precision-recall curve (PRC) for sets with variadic sizes. Suppose there are :math:`N` sets, and the sizes of all sets are summed to :math:`B`. Parameters: pred (Tensor): prediction of shape :math:`(B,)` target (Tensor): target of shape :math:`(B,)`. size (Tensor): size of sets of shape :math:`(N,)` """ index2graph = torch.repeat_interleave(size) _, order = functional.variadic_sort(pred, size, descending=True) cum_size = (size.cumsum(0) - size)[index2graph] target = target[order + cum_size] total_hit = functional.variadic_sum(target, size) total_hit = total_hit.cumsum(0) - total_hit hit = target.cumsum(0) - total_hit[index2graph] total = torch.ones_like(target).cumsum(0) - (size.cumsum(0) - size)[index2graph] precision = hit / total precision = torch.where(target == 1, precision, torch.zeros_like(precision)) auprc = functional.variadic_sum(precision, size) / \ (functional.variadic_sum((target == 1).float(), size) + 1e-10) return auprc
[docs]@R.register("metrics.f1_max") def f1_max(pred, target): """ F1 score with the optimal threshold. This function first enumerates all possible thresholds for deciding positive and negative samples, and then pick the threshold with the maximal F1 score. Parameters: pred (Tensor): predictions of shape :math:`(B, N)` target (Tensor): binary targets of shape :math:`(B, N)` """ order = pred.argsort(descending=True, dim=1) target = target.gather(1, order) precision = target.cumsum(1) / torch.ones_like(target).cumsum(1) recall = target.cumsum(1) / (target.sum(1, keepdim=True) + 1e-10) is_start = torch.zeros_like(target).bool() is_start[:, 0] = 1 is_start = torch.scatter(is_start, 1, order, is_start) all_order = pred.flatten().argsort(descending=True) order = order + torch.arange(order.shape[0], device=order.device).unsqueeze(1) * order.shape[1] order = order.flatten() inv_order = torch.zeros_like(order) inv_order[order] = torch.arange(order.shape[0], device=order.device) is_start = is_start.flatten()[all_order] all_order = inv_order[all_order] precision = precision.flatten() recall = recall.flatten() all_precision = precision[all_order] - \ torch.where(is_start, torch.zeros_like(precision), precision[all_order - 1]) all_precision = all_precision.cumsum(0) / is_start.cumsum(0) all_recall = recall[all_order] - \ torch.where(is_start, torch.zeros_like(recall), recall[all_order - 1]) all_recall = all_recall.cumsum(0) / pred.shape[0] all_f1 = 2 * all_precision * all_recall / (all_precision + all_recall + 1e-10) return all_f1.max()
[docs]@R.register("metrics.accuracy") def accuracy(pred, target): """ Classification accuracy. Suppose there are :math:`N` sets and :math:`C` categories. Parameters: pred (Tensor): prediction of shape :math:`(N, C)` target (Tensor): target of shape :math:`(N,)` """ return (pred.argmax(dim=-1) == target).float().mean()
[docs]@R.register("metrics.variadic_accuracy") def variadic_accuracy(input, target, size): """ Classification accuracy for categories with variadic sizes. Suppose there are :math:`N` samples, and the number of categories in all samples is summed to :math:`B`. Parameters: input (Tensor): prediction of shape :math:`(B,)` target (Tensor): target of shape :math:`(N,)`. Each target is a relative index in a sample. size (Tensor): number of categories of shape :math:`(N,)` """ index2graph = torch.repeat_interleave(size) input_class = scatter_max(input, index2graph)[1] target_index = target + size.cumsum(0) - size accuracy = (input_class == target_index).float() return accuracy
[docs]@R.register("metrics.variadic_top_precision") def variadic_top_precision(pred, target, size, k): """ Top-k precision for sets with variadic sizes. Suppose there are :math:`N` sets, and the sizes of all sets are summed to :math:`B`. Parameters: pred (Tensor): prediction of shape :math:`(B,)` target (Tensor): target of shape :math:`(B,)` size (Tensor): size of sets of shape :math:`(N,)` k (LongTensor): the k in "top-k" for different sets of shape :math:`(N,)` """ index = functional.variadic_topk(pred, size, k, largest=True)[1] index = index + (size.cumsum(0) - size).repeat_interleave(k) precision = functional.variadic_sum(target[index], k) / k precision[size < k] = 0 return precision
[docs]@R.register("metrics.mcc") def matthews_corrcoef(pred, target): """ Matthews correlation coefficient between prediction and target. Definition follows matthews_corrcoef for K classes in sklearn. For details, see: `https://scikit-learn.org/stable/modules/model_evaluation.html#matthews-corrcoef` Parameters: pred (Tensor): prediction of shape :math: `(N, K)` target (Tensor): target of shape :math: `(N,)` """ num_class = pred.size(-1) pred = pred.argmax(-1) ones = torch.ones(len(target), device=pred.device) confusion_matrix = scatter_add(ones, target * num_class + pred, dim=0, dim_size=num_class ** 2) confusion_matrix = confusion_matrix.view(num_class, num_class) t = confusion_matrix.sum(dim=1) p = confusion_matrix.sum(dim=0) c = confusion_matrix.trace() s = confusion_matrix.sum() return (c * s - t @ p) / ((s * s - p @ p) * (s * s - t @ t) + 1e-10).sqrt()
[docs]@R.register("metrics.pearsonr") def pearsonr(pred, target): """ Pearson correlation between prediction and target. Parameters: pred (Tensor): prediction of shape :math: `(N,)` target (Tensor): target of shape :math: `(N,)` """ pred_mean = pred.float().mean() target_mean = target.float().mean() pred_centered = pred - pred_mean target_centered = target - target_mean pred_normalized = pred_centered / pred_centered.norm(2) target_normalized = target_centered / target_centered.norm(2) pearsonr = pred_normalized @ target_normalized return pearsonr
[docs]@R.register("metrics.spearmanr") def spearmanr(pred, target): """ Spearman correlation between prediction and target. Parameters: pred (Tensor): prediction of shape :math: `(N,)` target (Tensor): target of shape :math: `(N,)` """ def get_ranking(input): input_set, input_inverse = input.unique(return_inverse=True) order = input_inverse.argsort() ranking = torch.zeros(len(input_inverse), device=input.device) ranking[order] = torch.arange(1, len(input) + 1, dtype=torch.float, device=input.device) # for elements that have the same value, replace their rankings with the mean of their rankings mean_ranking = scatter_mean(ranking, input_inverse, dim=0, dim_size=len(input_set)) ranking = mean_ranking[input_inverse] return ranking pred = get_ranking(pred) target = get_ranking(target) covariance = (pred * target).mean() - pred.mean() * target.mean() pred_std = pred.std(unbiased=False) target_std = target.std(unbiased=False) spearmanr = covariance / (pred_std * target_std + 1e-10) return spearmanr