Source code for torchdrug.datasets.wn18

import os

from torch.utils import data as torch_data

from torchdrug import data, utils
from torchdrug.core import Registry as R


[docs]@R.register("datasets.WN18") class WN18(data.KnowledgeGraphDataset): """ WordNet knowledge base. Statistics: - #Entity: 40,943 - #Relation: 18 - #Triplet: 151,442 Parameters: path (str): path to store the dataset verbose (int, optional): output verbose level """ urls = [ "https://github.com/DeepGraphLearning/KnowledgeGraphEmbedding/raw/master/data/wn18/train.txt", "https://github.com/DeepGraphLearning/KnowledgeGraphEmbedding/raw/master/data/wn18/valid.txt", "https://github.com/DeepGraphLearning/KnowledgeGraphEmbedding/raw/master/data/wn18/test.txt", ] md5s = [ "7d68324d293837ac165c3441a6c8b0eb", "f4f66fec0ca83b5ebe7ad7003404e61d", "b035247a8916c7ec3443fa949e1ff02c" ] def __init__(self, path, verbose=1): path = os.path.expanduser(path) if not os.path.exists(path): os.makedirs(path) self.path = path txt_files = [] for url, md5 in zip(self.urls, self.md5s): save_file = "wn18_%s" % os.path.basename(url) txt_file = utils.download(url, self.path, save_file=save_file, md5=md5) txt_files.append(txt_file) self.load_tsvs(txt_files, verbose=verbose) def split(self): offset = 0 splits = [] for num_sample in self.num_samples: split = torch_data.Subset(self, range(offset, offset + num_sample)) splits.append(split) offset += num_sample return splits
[docs]@R.register("datasets.WN18RR") class WN18RR(data.KnowledgeGraphDataset): """ A filtered version of WN18 dataset without trivial cases. Statistics: - #Entity: 40,943 - #Relation: 11 - #Triplet: 93,003 Parameters: path (str): path to store the dataset verbose (int, optional): output verbose level """ urls = [ "https://github.com/DeepGraphLearning/KnowledgeGraphEmbedding/raw/master/data/wn18rr/train.txt", "https://github.com/DeepGraphLearning/KnowledgeGraphEmbedding/raw/master/data/wn18rr/valid.txt", "https://github.com/DeepGraphLearning/KnowledgeGraphEmbedding/raw/master/data/wn18rr/test.txt", ] md5s = [ "35e81af3ae233327c52a87f23b30ad3c", "74a2ee9eca9a8d31f1a7d4d95b5e0887", "2b45ba1ba436b9d4ff27f1d3511224c9" ] def __init__(self, path, verbose=1): path = os.path.expanduser(path) if not os.path.exists(path): os.makedirs(path) self.path = path txt_files = [] for url, md5 in zip(self.urls, self.md5s): save_file = "wn18rr_%s" % os.path.basename(url) txt_file = utils.download(url, self.path, save_file=save_file, md5=md5) txt_files.append(txt_file) self.load_tsvs(txt_files, verbose=verbose) def split(self): offset = 0 splits = [] for num_sample in self.num_samples: split = torch_data.Subset(self, range(offset, offset + num_sample)) splits.append(split) offset += num_sample return splits