Source code for torchdrug.datasets.hetionet

import os

from torch.utils import data as torch_data

from torchdrug import data, utils
from torchdrug.core import Registry as R


[docs]@R.register("datasets.Hetionet") class Hetionet(data.KnowledgeGraphDataset): """ Hetionet for knowledge graph reasoning. Statistics: - #Entity: 45,158 - #Relation: 24 - #Triplet: 2,025,177 Parameters: path (str): path to store the dataset verbose (int, optional): output verbose level """ urls = [ "https://www.dropbox.com/s/y47bt9oq57h6l5k/train.txt?dl=1", "https://www.dropbox.com/s/a0pbrx9tz3dgsff/valid.txt?dl=1", "https://www.dropbox.com/s/4dhrvg3fyq5tnu4/test.txt?dl=1", ] md5s = [ "6e58915d70ce6d9389c6e4785245e0b3", "77f15fac4f8170b836392a5b1d315afa", "e8877aafe89d0c9b9c1efb9027cb7226" ] def __init__(self, path, verbose=1): path = os.path.expanduser(path) if not os.path.exists(path): os.makedirs(path) self.path = path txt_files = [] for url, md5 in zip(self.urls, self.md5s): save_file = "hetionet_%s.txt" % os.path.splitext(os.path.basename(url))[0] txt_file = utils.download(url, self.path, save_file=save_file, md5=md5) txt_files.append(txt_file) self.load_tsvs(txt_files, verbose=verbose) def split(self): offset = 0 splits = [] for num_sample in self.num_samples: split = torch_data.Subset(self, range(offset, offset + num_sample)) splits.append(split) offset += num_sample return splits