Source code for torchdrug.tasks.reasoning

import torch
from torch import nn
from torch.nn import functional as F
from torch.utils import data as torch_data

from torchdrug import core, tasks
from torchdrug.layers import functional
from torchdrug.core import Registry as R


[docs]@R.register("tasks.KnowledgeGraphCompletion") class KnowledgeGraphCompletion(tasks.Task, core.Configurable): """ Knowledge graph completion task. This class provides routines for the family of knowledge graph embedding models. Parameters: model (nn.Module): knowledge graph completion model criterion (str, list or dict, optional): training criterion(s). For dict, the keys are criterions and the values are the corresponding weights. Available criterions are ``bce``, ``ce`` and ``ranking``. metric (str or list of str, optional): metric(s). Available metrics are ``mr``, ``mrr`` and ``hits@K``. num_negative (int, optional): number of negative samples per positive sample margin (float, optional): margin in ranking criterion adversarial_temperature (float, optional): temperature for self-adversarial negative sampling. Set ``0`` to disable self-adversarial negative sampling. strict_negative (bool, optional): use strict negative sampling or not fact_ratio (float, optional): split the training set into facts and labels. Set ``None`` to use the whole training set as both facts and labels. sample_weight (bool, optional): whether to down-weight triplets from entities of large degrees filtered_ranking (bool, optional): use filtered or unfiltered ranking for evaluation full_batch_eval (bool, optional): whether to feed test negative samples by full batch or mini batch. Full batch speeds up evaluation significantly, but may cause OOM problems for some models and datasets. """ _option_members = {"criterion", "metric"} def __init__(self, model, criterion="bce", metric=("mr", "mrr", "hits@1", "hits@3", "hits@10"), num_negative=128, margin=6, adversarial_temperature=0, strict_negative=True, fact_ratio=None, sample_weight=True, filtered_ranking=True, full_batch_eval=False): super(KnowledgeGraphCompletion, self).__init__() self.model = model self.criterion = criterion self.metric = metric self.num_negative = num_negative self.margin = margin self.adversarial_temperature = adversarial_temperature self.strict_negative = strict_negative self.fact_ratio = fact_ratio self.sample_weight = sample_weight self.filtered_ranking = filtered_ranking self.full_batch_eval = full_batch_eval def preprocess(self, train_set, valid_set, test_set): if isinstance(train_set, torch_data.Subset): dataset = train_set.dataset else: dataset = train_set self.num_entity = dataset.num_entity self.num_relation = dataset.num_relation self.register_buffer("graph", dataset.graph) fact_mask = torch.ones(len(dataset), dtype=torch.bool) fact_mask[valid_set.indices] = 0 fact_mask[test_set.indices] = 0 if self.fact_ratio: length = int(len(train_set) * self.fact_ratio) index = torch.randperm(len(train_set))[length:] train_indices = torch.tensor(train_set.indices) fact_mask[train_indices[index]] = 0 train_set = torch_data.Subset(train_set, index) self.register_buffer("fact_graph", dataset.graph.edge_mask(fact_mask)) if self.sample_weight: degree_hr = torch.zeros(self.num_entity, self.num_relation, dtype=torch.long) degree_tr = torch.zeros(self.num_entity, self.num_relation, dtype=torch.long) for h, t, r in train_set: degree_hr[h, r] += 1 degree_tr[t, r] += 1 self.register_buffer("degree_hr", degree_hr) self.register_buffer("degree_tr", degree_tr) return train_set, valid_set, test_set def forward(self, batch, all_loss=None, metric=None): """""" all_loss = torch.tensor(0, dtype=torch.float32, device=self.device) metric = {} pred = self.predict(batch, all_loss, metric) pos_h_index, pos_t_index, pos_r_index = batch.t() for criterion, weight in self.criterion.items(): if criterion == "bce": target = torch.zeros_like(pred) target[:, 0] = 1 loss = F.binary_cross_entropy_with_logits(pred, target, reduction="none") neg_weight = torch.ones_like(pred) if self.adversarial_temperature > 0: with torch.no_grad(): neg_weight[:, 1:] = F.softmax(pred[:, 1:] / self.adversarial_temperature, dim=-1) else: neg_weight[:, 1:] = 1 / self.num_negative loss = (loss * neg_weight).sum(dim=-1) / neg_weight.sum(dim=-1) elif criterion == "ce": target = torch.zeros(len(pred), dtype=torch.long, device=self.device) loss = F.cross_entropy(pred, target, reduction="none") elif criterion == "ranking": positive = pred[:, :1] negative = pred[:, 1:] target = torch.ones_like(negative) loss = F.margin_ranking_loss(positive, negative, target, margin=self.margin) else: raise ValueError("Unknown criterion `%s`" % criterion) if self.sample_weight: sample_weight = self.degree_hr[pos_h_index, pos_r_index] * self.degree_tr[pos_t_index, pos_r_index] sample_weight = 1 / sample_weight.float().sqrt() loss = (loss * sample_weight).sum() / sample_weight.sum() else: loss = loss.mean() name = tasks._get_criterion_name(criterion) metric[name] = loss all_loss += loss * weight return all_loss, metric def predict(self, batch, all_loss=None, metric=None): pos_h_index, pos_t_index, pos_r_index = batch.t() batch_size = len(batch) if all_loss is None: # test all_index = torch.arange(self.num_entity, device=self.device) t_preds = [] h_preds = [] num_negative = self.num_entity if self.full_batch_eval else self.num_negative for neg_index in all_index.split(num_negative): r_index = pos_r_index.unsqueeze(-1).expand(-1, len(neg_index)) h_index, t_index = torch.meshgrid(pos_h_index, neg_index) t_pred = self.model(self.fact_graph, h_index, t_index, r_index, all_loss=all_loss, metric=metric) t_preds.append(t_pred) t_pred = torch.cat(t_preds, dim=-1) for neg_index in all_index.split(num_negative): r_index = pos_r_index.unsqueeze(-1).expand(-1, len(neg_index)) t_index, h_index = torch.meshgrid(pos_t_index, neg_index) h_pred = self.model(self.fact_graph, h_index, t_index, r_index, all_loss=all_loss, metric=metric) h_preds.append(h_pred) h_pred = torch.cat(h_preds, dim=-1) pred = torch.stack([t_pred, h_pred], dim=1) # in case of GPU OOM pred = pred.cpu() else: # train if self.strict_negative: neg_index = self._strict_negative(pos_h_index, pos_t_index, pos_r_index) else: neg_index = torch.randint(self.num_entity, (batch_size, self.num_negative), device=self.device) h_index = pos_h_index.unsqueeze(-1).repeat(1, self.num_negative + 1) t_index = pos_t_index.unsqueeze(-1).repeat(1, self.num_negative + 1) r_index = pos_r_index.unsqueeze(-1).repeat(1, self.num_negative + 1) t_index[:batch_size // 2, 1:] = neg_index[:batch_size // 2] h_index[batch_size // 2:, 1:] = neg_index[batch_size // 2:] pred = self.model(self.fact_graph, h_index, t_index, r_index, all_loss=all_loss, metric=metric) return pred def target(self, batch): # test target batch_size = len(batch) pos_h_index, pos_t_index, pos_r_index = batch.t() any = -torch.ones_like(pos_h_index) pattern = torch.stack([pos_h_index, any, pos_r_index], dim=-1) edge_index, num_t_truth = self.graph.match(pattern) t_truth_index = self.graph.edge_list[edge_index, 1] pos_index = torch.repeat_interleave(num_t_truth) t_mask = torch.ones(batch_size, self.num_entity, dtype=torch.bool, device=self.device) t_mask[pos_index, t_truth_index] = 0 pattern = torch.stack([any, pos_t_index, pos_r_index], dim=-1) edge_index, num_h_truth = self.graph.match(pattern) h_truth_index = self.graph.edge_list[edge_index, 0] pos_index = torch.repeat_interleave(num_h_truth) h_mask = torch.ones(batch_size, self.num_entity, dtype=torch.bool, device=self.device) h_mask[pos_index, h_truth_index] = 0 mask = torch.stack([t_mask, h_mask], dim=1) target = torch.stack([pos_t_index, pos_h_index], dim=1) # in case of GPU OOM return mask.cpu(), target.cpu() def evaluate(self, pred, target): mask, target = target pos_pred = pred.gather(-1, target.unsqueeze(-1)) if self.filtered_ranking: ranking = torch.sum((pos_pred <= pred) & mask, dim=-1) + 1 else: ranking = torch.sum(pos_pred <= pred, dim=-1) + 1 metric = {} for _metric in self.metric: if _metric == "mr": score = ranking.float().mean() elif _metric == "mrr": score = (1 / ranking.float()).mean() elif _metric.startswith("hits@"): threshold = int(_metric[5:]) score = (ranking <= threshold).float().mean() else: raise ValueError("Unknown metric `%s`" % _metric) name = tasks._get_metric_name(_metric) metric[name] = score return metric def visualize(self, batch): h_index, t_index, r_index = batch.t() return self.model.visualize(self.fact_graph, h_index, t_index, r_index) @torch.no_grad() def _strict_negative(self, pos_h_index, pos_t_index, pos_r_index): batch_size = len(pos_h_index) any = -torch.ones_like(pos_h_index) pattern = torch.stack([pos_h_index, any, pos_r_index], dim=-1) pattern = pattern[:batch_size // 2] edge_index, num_t_truth = self.fact_graph.match(pattern) t_truth_index = self.fact_graph.edge_list[edge_index, 1] pos_index = torch.repeat_interleave(num_t_truth) t_mask = torch.ones(len(pattern), self.num_entity, dtype=torch.bool, device=self.device) t_mask[pos_index, t_truth_index] = 0 neg_t_candidate = t_mask.nonzero()[:, 1] num_t_candidate = t_mask.sum(dim=-1) neg_t_index = functional.variadic_sample(neg_t_candidate, num_t_candidate, self.num_negative) pattern = torch.stack([any, pos_t_index, pos_r_index], dim=-1) pattern = pattern[batch_size // 2:] edge_index, num_h_truth = self.fact_graph.match(pattern) h_truth_index = self.fact_graph.edge_list[edge_index, 0] pos_index = torch.repeat_interleave(num_h_truth) h_mask = torch.ones(len(pattern), self.num_entity, dtype=torch.bool, device=self.device) h_mask[pos_index, h_truth_index] = 0 neg_h_candidate = h_mask.nonzero()[:, 1] num_h_candidate = h_mask.sum(dim=-1) neg_h_index = functional.variadic_sample(neg_h_candidate, num_h_candidate, self.num_negative) neg_index = torch.cat([neg_t_index, neg_h_index]) return neg_index