Source code for torchdrug.datasets.gene_ontology

import os
import csv
import glob

import torch
from torch.utils import data as torch_data

from torchdrug import data, utils
from torchdrug.core import Registry as R


[docs]@R.register("datasets.GeneOntology") @utils.copy_args(data.ProteinDataset.load_pdbs) class GeneOntology(data.ProteinDataset): """ A set of proteins with their 3D structures and GO terms. These terms classify proteins into hierarchically related functional classes organized into three ontologies: molecular function (MF), biological process (BP) and cellular component (CC). Statistics (test_cutoff=0.95): - #Train: 27,496 - #Valid: 3,053 - #Test: 2,991 Parameters: path (str): the path to store the dataset branch (str, optional): the GO branch test_cutoff (float, optional): the test cutoff used to split the dataset verbose (int, optional): output verbose level **kwargs """ url = "https://zenodo.org/record/6622158/files/GeneOntology.zip" md5 = "376be1f088cd1fe720e1eaafb701b5cb" branches = ["MF", "BP", "CC"] processed_file = "gene_ontology.pkl.gz" test_cutoffs = [0.3, 0.4, 0.5, 0.7, 0.95] def __init__(self, path, branch="MF", test_cutoff=0.95, verbose=1, **kwargs): path = os.path.expanduser(path) if not os.path.exists(path): os.makedirs(path) self.path = path if branch not in self.branches: raise ValueError("Unknown branch `%s` for GeneOntology dataset" % branch) self.branch = branch if test_cutoff not in self.test_cutoffs: raise ValueError("Unknown test cutoff `%.2f` for GeneOntology dataset" % test_cutoff) self.test_cutoff = test_cutoff zip_file = utils.download(self.url, path, md5=self.md5) path = os.path.join(utils.extract(zip_file), "GeneOntology") pkl_file = os.path.join(path, self.processed_file) csv_file = os.path.join(path, "nrPDB-GO_test.csv") pdb_ids = [] with open(csv_file, "r") as fin: reader = csv.reader(fin, delimiter=",") idx = self.test_cutoffs.index(test_cutoff) + 1 _ = next(reader) for line in reader: if line[idx] == "0": pdb_ids.append(line[0]) if os.path.exists(pkl_file): self.load_pickle(pkl_file, verbose=verbose, **kwargs) else: pdb_files = [] for split in ["train", "valid", "test"]: split_path = utils.extract(os.path.join(path, "%s.zip" % split)) pdb_files += sorted(glob.glob(os.path.join(split_path, split, "*.pdb"))) self.load_pdbs(pdb_files, verbose=verbose, **kwargs) self.save_pickle(pkl_file, verbose=verbose) if len(pdb_ids) > 0: self.filter_pdb(pdb_ids) tsv_file = os.path.join(path, "nrPDB-GO_annot.tsv") pdb_ids = [os.path.basename(pdb_file).split("_")[0] for pdb_file in self.pdb_files] self.load_annotation(tsv_file, pdb_ids) splits = [os.path.basename(os.path.dirname(pdb_file)) for pdb_file in self.pdb_files] self.num_samples = [splits.count("train"), splits.count("valid"), splits.count("test")] def filter_pdb(self, pdb_ids): pdb_ids = set(pdb_ids) sequences = [] pdb_files = [] data = [] for sequence, pdb_file, protein in zip(self.sequences, self.pdb_files, self.data): if os.path.basename(pdb_file).split("_")[0] in pdb_ids: continue sequences.append(sequence) pdb_files.append(pdb_file) data.append(protein) self.sequences = sequences self.pdb_files = pdb_files self.data = data def load_annotation(self, tsv_file, pdb_ids): idx = self.branches.index(self.branch) with open(tsv_file, "r") as fin: reader = csv.reader(fin, delimiter="\t") for i in range(12): _ = next(reader) if i == idx * 4 + 1: tasks = _ task2id = {task: i for i, task in enumerate(tasks)} _ = next(reader) pos_targets = {} for line in reader: pdb_id, pos_target = line[0], line[idx + 1] if idx + 1 < len(line) else None pos_target = [task2id[t] for t in pos_target.split(",")] if pos_target else [] pos_target = torch.LongTensor(pos_target) pos_targets[pdb_id] = pos_target # fake targets to enable the property self.tasks self.targets = task2id self.pos_targets = [] for pdb_id in pdb_ids: self.pos_targets.append(pos_targets[pdb_id]) def split(self): offset = 0 splits = [] for num_sample in self.num_samples: split = torch_data.Subset(self, range(offset, offset + num_sample)) splits.append(split) offset += num_sample return splits def get_item(self, index): if getattr(self, "lazy", False): protein = data.Protein.from_pdb(self.pdb_files[index], self.kwargs) else: protein = self.data[index].clone() if hasattr(protein, "residue_feature"): with protein.residue(): protein.residue_feature = protein.residue_feature.to_dense() item = {"graph": protein} if self.transform: item = self.transform(item) indices = self.pos_targets[index].unsqueeze(0) values = torch.ones(len(self.pos_targets[index])) item["targets"] = utils.sparse_coo_tensor(indices, values, (len(self.tasks),)).to_dense() return item