Source code for torchdrug.metrics.metric

import torch
from torch.nn import functional as F
from torch_scatter import scatter_add, scatter_mean, scatter_max
import networkx as nx
from rdkit import Chem
from rdkit.Chem import Descriptors

from torchdrug import utils
from torchdrug.layers import functional
from torchdrug.core import Registry as R
from torchdrug.metrics.rdkit import sascorer


[docs]@R.register("metrics.auroc") def area_under_roc(pred, target): """ Area under receiver operating characteristic curve (ROC). Parameters: pred (Tensor): predictions of shape :math:`(n,)` target (Tensor): binary targets of shape :math:`(n,)` """ if target.dtype != torch.long: raise TypeError("Expect `target` to be torch.long, but found %s" % target.dtype) order = pred.argsort(descending=True) target = target[order] hit = target.cumsum(0) all = (target == 0).sum() * (target == 1).sum() auroc = hit[target == 0].sum() / (all + 1e-10) return auroc
[docs]@R.register("metrics.auprc") def area_under_prc(pred, target): """ Area under precision-recall curve (PRC). Parameters: pred (Tensor): predictions of shape :math:`(n,)` target (Tensor): binary targets of shape :math:`(n,)` """ if target.dtype != torch.long: raise TypeError("Expect `target` to be torch.long, but found %s" % target.dtype) order = pred.argsort(descending=True) target = target[order] precision = target.cumsum(0) / torch.arange(1, len(target) + 1, device=target.device) auprc = precision[target == 1].sum() / ((target == 1).sum() + 1e-10) return auprc
[docs]@R.register("metrics.r2") def r2(pred, target): """ :math:`R^2` regression score. Parameters: pred (Tensor): predictions of shape :math:`(n,)` target (Tensor): targets of shape :math:`(n,)` """ total = torch.var(target, unbiased=False) residual = F.mse_loss(pred, target) return 1 - residual / total
[docs]@R.register("metrics.logp") def logP(pred): """ Logarithm of partition coefficient between octanol and water for a compound. Parameters: pred (PackedMolecule): molecules to evaluate """ logp = [] for mol in pred: mol = mol.to_molecule() try: with utils.no_rdkit_log(): mol.UpdatePropertyCache() score = Descriptors.MolLogP(mol) except Chem.AtomValenceException: score = 0 logp.append(score) return torch.tensor(logp, dtype=torch.float, device=pred.device)
[docs]@R.register("metrics.plogp") def penalized_logP(pred): """ Logarithm of partition coefficient, penalized by cycle length and synthetic accessibility. Parameters: pred (PackedMolecule): molecules to evaluate """ # statistics from ZINC250k logp_mean = 2.4570953396190123 logp_std = 1.434324401111988 sa_mean = 3.0525811293166134 sa_std = 0.8335207024513095 cycle_mean = 0.0485696876403053 cycle_std = 0.2860212110245455 plogp = [] for mol in pred: cycles = nx.cycle_basis(nx.Graph(mol.edge_list[:, :2].tolist())) if cycles: max_cycle = max([len(cycle) for cycle in cycles]) cycle = max(0, max_cycle - 6) else: cycle = 0 mol = mol.to_molecule() try: with utils.no_rdkit_log(): mol.UpdatePropertyCache() Chem.GetSymmSSSR(mol) logp = Descriptors.MolLogP(mol) sa = sascorer.calculateScore(mol) logp = (logp - logp_mean) / logp_std sa = (sa - sa_mean) / sa_std cycle = (cycle - cycle_mean) / cycle_std score = logp - sa - cycle except Chem.AtomValenceException: score = -30 plogp.append(score) return torch.tensor(plogp, dtype=torch.float, device=pred.device)
[docs]@R.register("metrics.SA") def SA(pred): """ Synthetic accesibility score. Parameters: pred (PackedMolecule): molecules to evaluate """ sa = [] for mol in pred: with utils.no_rdkit_log(): score = sascorer.calculateScore(mol.to_molecule()) sa.append(score) return torch.tensor(sa, dtype=torch.float, device=pred.device)
[docs]@R.register("metrics.qed") def QED(pred): """ Quantitative estimation of drug-likeness. Parameters: pred (PackedMolecule): molecules to evaluate """ qed = [] for mol in pred: try: with utils.no_rdkit_log(): score = Descriptors.qed(mol.to_molecule()) except Chem.AtomValenceException: score = -1 qed.append(score) return torch.tensor(qed, dtype=torch.float, device=pred.device)
[docs]@R.register("metrics.validity") def chemical_validity(pred): """ Chemical validity of molecules. Parameters: pred (PackedMolecule): molecules to evaluate """ validity = [] for i, mol in enumerate(pred): with utils.no_rdkit_log(): smiles = mol.to_smiles() mol = Chem.MolFromSmiles(smiles) validity.append(1 if mol else 0) return torch.tensor(validity, dtype=torch.float, device=pred.device)
@R.register("metrics.variadic_auroc") def variadic_area_under_roc(pred, target, size): index2graph = functional._size_to_index(size) _, order = functional.variadic_sort(pred, size, descending=True) cum_size = (size.cumsum(0) - size)[index2graph] target = target[order + cum_size] total_hit = functional.variadic_sum(target, size) total_hit = total_hit.cumsum(0) - total_hit hit = target.cumsum(0) - total_hit[index2graph] hit = torch.where(target == 0, hit, torch.zeros_like(hit)) all = functional.variadic_sum((target == 0).float(), size) * \ functional.variadic_sum((target == 1).float(), size) auroc = functional.variadic_sum(hit, size) / (all + 1e-10) return auroc @R.register("metrics.variadic_auprc") def variadic_area_under_prc(pred, target, size): index2graph = functional._size_to_index(size) _, order = functional.variadic_sort(pred, size, descending=True) cum_size = (size.cumsum(0) - size)[index2graph] target = target[order + cum_size] total_hit = functional.variadic_sum(target, size) total_hit = total_hit.cumsum(0) - total_hit hit = target.cumsum(0) - total_hit[index2graph] total = torch.ones_like(target).cumsum(0) - (size.cumsum(0) - size)[index2graph] precision = hit / total precision = torch.where(target == 1, precision, torch.zeros_like(precision)) auprc = functional.variadic_sum(precision, size) / \ (functional.variadic_sum((target == 1).float(), size) + 1e-10) return auprc @R.register("metrics.accuracy") def accuracy(pred, target): return (pred.argmax(dim=-1) == target).float().mean()
[docs]@R.register("metrics.variadic_accuracy") def variadic_accuracy(input, target, size): """ Compute classification accuracy over variadic sizes of categories. Suppose there are :math:`N` samples, and the number of categories in all samples is summed to :math`B`. Parameters: input (Tensor): prediction of shape :math:`(B,)` target (Tensor): target of shape :math:`(N,)`. Each target is a relative index in a sample. size (Tensor): number of categories of shape :math:`(N,)` """ index2graph = functional._size_to_index(size) input_class = scatter_max(input, index2graph)[1] target_index = target + size.cumsum(0) - size accuracy = (input_class == target_index).float() return accuracy
@R.register("metrics.variadic_top_precision") def variadic_top_precision(pred, target, size, k): index = functional.variadic_topk(pred, size, k, largest=True)[1] index = index + (size.cumsum(0) - size).repeat_interleave(k) precision = functional.variadic_sum(target[index], k) / k precision[size < k] = 0 return precision @R.register("metrics.mcc") def matthews_corrcoef(pred, target, eps=1e-6): """ Matthews correlation coefficient between target and prediction. Definition follows matthews_corrcoef for K classes in sklearn. For details, see: 'https://scikit-learn.org/stable/modules/model_evaluation.html#matthews-corrcoef' Parameters: pred (Tensor): prediction of shape :math: `(N,)` target (Tensor): target of shape :math: `(N,)` """ num_class = pred.size(-1) pred = pred.argmax(-1) ones = torch.ones(len(target), device=pred.device) confusion_matrix = scatter_add(ones, target * num_class + pred, dim=0, dim_size=num_class ** 2) confusion_matrix = confusion_matrix.view(num_class, num_class) t = confusion_matrix.sum(dim=1) p = confusion_matrix.sum(dim=0) c = confusion_matrix.trace() s = confusion_matrix.sum() return (c * s - t @ p) / ((s * s - p @ p) * (s * s - t @ t) + eps).sqrt() @R.register("metrics.pearsonr") def pearsonr(pred, target): """ Pearson correlation between target and prediction. Mimics `scipy.stats.pearsonr`. Parameters: pred (Tensor): prediction of shape :math: `(N,)` target (Tensor): target of shape :math: `(N,)` """ pred_mean = pred.float().mean() target_mean = target.float().mean() pred_centered = pred - pred_mean target_centered = target - target_mean pred_normalized = pred_centered / pred_centered.norm(2) target_normalized = target_centered / target_centered.norm(2) pearsonr = pred_normalized @ target_normalized return pearsonr @R.register("metrics.spearmanr") def spearmanr(pred, target, eps=1e-6): """ Spearman correlation between target and prediction. Implement in PyTorch, but non-diffierentiable. (validation metric only) Parameters: pred (Tensor): prediction of shape :math: `(N,)` target (Tensor): target of shape :math: `(N,)` """ def get_ranking(input): input_set, input_inverse = input.unique(return_inverse=True) order = input_inverse.argsort() ranking = torch.zeros(len(input_inverse), device=input.device) ranking[order] = torch.arange(1, len(input) + 1, dtype=torch.float, device=input.device) # for elements that have the same value, replace their rankings with the mean of their rankings mean_ranking = scatter_mean(ranking, input_inverse, dim=0, dim_size=len(input_set)) ranking = mean_ranking[input_inverse] return ranking pred = get_ranking(pred) target = get_ranking(target) covariance = (pred * target).mean() - pred.mean() * target.mean() pred_std = pred.std(unbiased=False) target_std = target.std(unbiased=False) spearmanr = covariance / (pred_std * target_std + eps) return spearmanr