import os
import csv
import math
import lmdb
import pickle
import logging
import warnings
from collections import defaultdict
from collections.abc import Sequence
from tqdm import tqdm
import numpy as np
from rdkit import Chem
from rdkit.Chem.Scaffolds import MurckoScaffold
import torch
from torch.utils import data as torch_data
from torchdrug import core, data, utils
logger = logging.getLogger(__name__)
[docs]class MoleculeDataset(torch_data.Dataset, core.Configurable):
"""
Molecule dataset.
Each sample contains a molecule graph, and any number of prediction targets.
"""
[docs] @utils.copy_args(data.Molecule.from_molecule)
def load_smiles(self, smiles_list, targets, transform=None, lazy=False, verbose=0, **kwargs):
"""
Load the dataset from SMILES and targets.
Parameters:
smiles_list (list of str): SMILES strings
targets (dict of list): prediction targets
transform (Callable, optional): data transformation function
lazy (bool, optional): if lazy mode is used, the molecules are processed in the dataloader.
This may slow down the data loading process, but save a lot of CPU memory and dataset loading time.
verbose (int, optional): output verbose level
**kwargs
"""
num_sample = len(smiles_list)
if num_sample > 1000000:
warnings.warn("Preprocessing molecules of a large dataset consumes a lot of CPU memory and time. "
"Use load_smiles(lazy=True) to construct molecules in the dataloader instead.")
for field, target_list in targets.items():
if len(target_list) != num_sample:
raise ValueError("Number of target `%s` doesn't match with number of molecules. "
"Expect %d but found %d" % (field, num_sample, len(target_list)))
self.transform = transform
self.lazy = lazy
self.kwargs = kwargs
self.smiles_list = []
self.data = []
self.targets = defaultdict(list)
if verbose:
smiles_list = tqdm(smiles_list, "Constructing molecules from SMILES")
for i, smiles in enumerate(smiles_list):
if not self.lazy or len(self.data) == 0:
mol = Chem.MolFromSmiles(smiles)
if not mol:
logger.debug("Can't construct molecule from SMILES `%s`. Ignore this sample." % smiles)
continue
mol = data.Molecule.from_molecule(mol, **kwargs)
else:
mol = None
self.data.append(mol)
self.smiles_list.append(smiles)
for field in targets:
self.targets[field].append(targets[field][i])
[docs] @utils.copy_args(load_smiles)
def load_csv(self, csv_file, smiles_field="smiles", target_fields=None, verbose=0, **kwargs):
"""
Load the dataset from a csv file.
Parameters:
csv_file (str): file name
smiles_field (str, optional): name of the SMILES column in the table.
Use ``None`` if there is no SMILES column.
target_fields (list of str, optional): name of target columns in the table.
Default is all columns other than the SMILES column.
verbose (int, optional): output verbose level
**kwargs
"""
if target_fields is not None:
target_fields = set(target_fields)
with open(csv_file, "r") as fin:
reader = csv.reader(fin)
if verbose:
reader = iter(tqdm(reader, "Loading %s" % csv_file, utils.get_line_count(csv_file)))
fields = next(reader)
smiles = []
targets = defaultdict(list)
for values in reader:
if not any(values):
continue
if smiles_field is None:
smiles.append("")
for field, value in zip(fields, values):
if field == smiles_field:
smiles.append(value)
elif target_fields is None or field in target_fields:
value = utils.literal_eval(value)
if value == "":
value = math.nan
targets[field].append(value)
self.load_smiles(smiles, targets, verbose=verbose, **kwargs)
[docs] def load_pickle(self, pkl_file, verbose=0):
"""
Load the dataset from a pickle file.
Parameters:
pkl_file (str): file name
verbose (int, optional): output verbose level
"""
with utils.smart_open(pkl_file, "rb") as fin:
num_sample, tasks = pickle.load(fin)
self.smiles_list = []
self.data = []
self.targets = {task: [] for task in tasks}
indexes = range(num_sample)
if verbose:
indexes = tqdm(indexes, "Loading %s" % pkl_file)
for i in indexes:
smiles, mol, values = pickle.load(fin)
self.smiles_list.append(smiles)
self.data.append(mol)
for task, value in zip(tasks, values):
self.targets[task] = value
[docs] def save_pickle(self, pkl_file, verbose=0):
"""
Save the dataset to a pickle file.
Parameters:
pkl_file (str): file name
verbose (int, optional): output verbose level
"""
with utils.smart_open(pkl_file, "wb") as fout:
num_sample = len(self.data)
tasks = self.targets.keys()
pickle.dump((num_sample, tasks), fout)
indexes = range(num_sample)
if verbose:
indexes = tqdm(indexes, "Dumping to %s" % pkl_file)
for i in indexes:
values = [v[i] for v in self.targets.values()]
pickle.dump((self.smiles_list[i], self.data[i], values), fout)
def _standarize_index(self, index, count):
if isinstance(index, slice):
start = index.start or 0
if start < 0:
start += count
stop = index.stop or count
if stop < 0:
stop += count
step = index.step or 1
index = range(start, stop, step)
elif not isinstance(index, list):
raise ValueError("Unknown index `%s`" % index)
return index
def get_item(self, index):
if getattr(self, "lazy", False):
# TODO: what if the smiles is invalid here?
item = {"graph": data.Molecule.from_smiles(self.smiles_list[index], **self.kwargs)}
else:
item = {"graph": self.data[index]}
item.update({k: v[index] for k, v in self.targets.items()})
if self.transform:
item = self.transform(item)
return item
def __getitem__(self, index):
if isinstance(index, int):
return self.get_item(index)
index = self._standarize_index(index, len(self))
return [self.get_item(i) for i in index]
@property
def tasks(self):
"""List of tasks."""
return list(self.targets.keys())
@property
def node_feature_dim(self):
"""Dimension of node features."""
return self.data[0].node_feature.shape[-1]
@property
def edge_feature_dim(self):
"""Dimension of edge features."""
return self.data[0].edge_feature.shape[-1]
@property
def num_atom_type(self):
"""Number of different atom types."""
return len(self.atom_types)
@property
def num_bond_type(self):
"""Number of different bond types."""
return len(self.bond_types)
@utils.cached_property
def atom_types(self):
"""All atom types."""
atom_types = set()
if getattr(self, "lazy", False):
warnings.warn("Calling this function for dataset with lazy=True may take a large amount of time.")
for smiles in self.smiles_list:
graph = data.Molecule.from_smiles(smiles, **self.kwargs)
atom_types.update(graph.atom_type.tolist())
else:
for graph in self.data:
atom_types.update(graph.atom_type.tolist())
return sorted(atom_types)
@utils.cached_property
def bond_types(self):
"""All bond types."""
bond_types = set()
if getattr(self, "lazy", False):
warnings.warn("Calling this function for dataset with lazy=True may take a large amount of time.")
for smiles in self.smiles_list:
graph = data.Molecule.from_smiles(smiles, **self.kwargs)
bond_types.update(graph.edge_list[:, 2].tolist())
else:
for graph in self.data:
bond_types.update(graph.edge_list[:, 2].tolist())
return sorted(bond_types)
def __len__(self):
return len(self.data)
def __repr__(self):
lines = [
"#sample: %d" % len(self),
"#task: %d" % len(self.tasks),
]
return "%s(\n %s\n)" % (self.__class__.__name__, "\n ".join(lines))
[docs]class ReactionDataset(MoleculeDataset, core.Configurable):
"""
Chemical reaction dataset.
Each sample contains two molecule graphs, and any number of prediction targets.
"""
[docs] @utils.copy_args(data.Molecule.from_molecule)
def load_smiles(self, smiles_list, targets, transform=None, verbose=0, **kwargs):
"""
Load the dataset from SMILES and targets.
Parameters:
smiles_list (list of str): SMILES strings
targets (dict of list): prediction targets
transform (Callable, optional): data transformation function
verbose (int, optional): output verbose level
**kwargs
"""
num_sample = len(smiles_list)
for field, target_list in targets.items():
if len(target_list) != num_sample:
raise ValueError("Number of target `%s` doesn't match with number of molecules. "
"Expect %d but found %d" % (field, num_sample, len(target_list)))
self.smiles_list = []
self.data = []
self.targets = defaultdict(list)
if verbose:
smiles_list = tqdm(smiles_list, "Constructing molecules from SMILES")
for i, smiles in enumerate(smiles_list):
smiles_reactant, agent, smiles_product = smiles.split(">")
mols = []
for _smiles in [smiles_reactant, smiles_product]:
mol = Chem.MolFromSmiles(_smiles)
if not mol:
logger.debug("Can't construct molecule from SMILES `%s`. Ignore this sample." % _smiles)
break
mol = data.Molecule.from_molecule(mol, **kwargs)
mols.append(mol)
else:
self.data.append(mols)
self.smiles_list.append(smiles)
for field in targets:
self.targets[field].append(targets[field][i])
self.transform = transform
@property
def node_feature_dim(self):
"""Dimension of node features."""
return self.data[0][0].node_feature.shape[-1]
@property
def edge_feature_dim(self):
"""Dimension of edge features."""
return self.data[0][0].edge_feature.shape[-1]
@property
def num_atom_type(self):
"""Number of different atom types."""
return len(self.atom_types)
@property
def num_bond_type(self):
"""Number of different bond types."""
return len(self.bond_types)
@utils.cached_property
def atom_types(self):
"""All atom types."""
atom_types = set()
for graphs in self.data:
for graph in graphs:
atom_types.update(graph.atom_type.tolist())
return sorted(atom_types)
@utils.cached_property
def bond_types(self):
"""All bond types."""
bond_types = set()
for graphs in self.data:
for graph in graphs:
bond_types.update(graph.edge_list[:, 2].tolist())
return sorted(bond_types)
def __len__(self):
return len(self.data)
[docs]class NodeClassificationDataset(torch_data.Dataset, core.Configurable):
"""
Node classification dataset.
The whole dataset contains one graph, where each node has its own node feature and label.
"""
[docs] def load_tsv(self, node_file, edge_file, verbose=0):
"""
Load the edge list from a tsv file.
Parameters:
node_file (str): node feature and label file
edge_file (str): edge list file
verbose (int, optional): output verbose level
"""
inv_node_vocab = {}
inv_label_vocab = {}
node_feature = []
node_label = []
with open(node_file, "r") as fin:
reader = csv.reader(fin, delimiter="\t")
if verbose:
reader = tqdm(reader, "Loading %s" % node_file, utils.get_line_count(node_file))
for tokens in reader:
node_token = tokens[0]
feature_tokens = tokens[1: -1]
label_token = tokens[-1]
inv_node_vocab[node_token] = len(inv_node_vocab)
if label_token not in inv_label_vocab:
inv_label_vocab[label_token] = len(inv_label_vocab)
feature = [utils.literal_eval(f) for f in feature_tokens]
label = inv_label_vocab[label_token]
node_feature.append(feature)
node_label.append(label)
edge_list = []
with open(edge_file, "r") as fin:
reader = csv.reader(fin, delimiter="\t")
if verbose:
reader = tqdm(reader, "Loading %s" % edge_file, utils.get_line_count(edge_file))
for tokens in reader:
h_token, t_token = tokens
if h_token not in inv_node_vocab:
inv_node_vocab[h_token] = len(inv_node_vocab)
h = inv_node_vocab[h_token]
if t_token not in inv_node_vocab:
inv_node_vocab[t_token] = len(inv_node_vocab)
t = inv_node_vocab[t_token]
edge_list.append((h, t))
self.load_edge(edge_list, node_feature, node_label, inv_node_vocab=inv_node_vocab,
inv_label_vocab=inv_label_vocab)
def load_edge(self, edge_list, node_feature, node_label, node_vocab=None, inv_node_vocab=None, label_vocab=None,
inv_label_vocab=None):
node_vocab, inv_node_vocab = self._standarize_vocab(node_vocab, inv_node_vocab)
label_vocab, inv_label_vocab = self._standarize_vocab(label_vocab, inv_label_vocab)
self.num_labeled_node = len(node_feature)
if len(node_vocab) > len(node_feature):
logger.warning("Missing features & labels for %d / %d nodes" %
(len(node_vocab) - len(node_feature), len(node_vocab)))
dummy_label = 0
dummy_feature = [0] * len(node_feature[0])
node_label += [dummy_label] * (len(node_vocab) - len(node_feature))
node_feature += [dummy_feature] * (len(node_vocab) - len(node_feature))
self.graph = data.Graph(edge_list, num_node=len(node_vocab), node_feature=node_feature)
with self.graph.node():
self.graph.node_label = torch.as_tensor(node_label)
self.node_vocab = node_vocab
self.inv_node_vocab = inv_node_vocab
self.label_vocab = label_vocab
self.inv_node_vocab = inv_label_vocab
def _standarize_vocab(self, vocab, inverse_vocab):
if vocab is not None:
if isinstance(vocab, dict):
assert set(vocab.keys()) == set(range(len(vocab))), "Vocab keys should be consecutive numbers"
vocab = [vocab[k] for k in range(len(vocab))]
if inverse_vocab is None:
inverse_vocab = {v: i for i, v in enumerate(vocab)}
if inverse_vocab is not None:
assert set(inverse_vocab.values()) == set(range(len(inverse_vocab))), \
"Inverse vocab values should be consecutive numbers"
if vocab is None:
vocab = sorted(inverse_vocab, key=lambda k: inverse_vocab[k])
return vocab, inverse_vocab
@property
def num_node(self):
"""Number of nodes."""
return self.graph.num_node
@property
def num_edge(self):
"""Number of edges."""
return self.graph.num_edge
@property
def node_feature_dim(self):
"""Dimension of node features."""
return self.graph.node_feature.shape[-1]
def __getitem__(self, index):
return {
"node_index": index,
"label": self.graph.node_label[index]
}
def __len__(self):
return self.num_labeled_node
def __repr__(self):
lines = [
"#node: %d" % self.num_node,
"#edge: %d" % self.num_edge,
"#class: %d" % len(self.label_vocab),
]
return "%s(\n %s\n)" % (self.__class__.__name__, "\n ".join(lines))
[docs]class KnowledgeGraphDataset(torch_data.Dataset, core.Configurable):
"""
Knowledge graph dataset.
The whole dataset contains one knowledge graph.
"""
[docs] def load_triplet(self, triplets, entity_vocab=None, relation_vocab=None, inv_entity_vocab=None,
inv_relation_vocab=None):
"""
Load the dataset from triplets.
The mapping between indexes and tokens is specified through either vocabularies or inverse vocabularies.
Parameters:
triplets (array_like): triplets of shape :math:`(n, 3)`
entity_vocab (dict of str, optional): maps entity indexes to tokens
relation_vocab (dict of str, optional): maps relation indexes to tokens
inv_entity_vocab (dict of str, optional): maps tokens to entity indexes
inv_relation_vocab (dict of str, optional): maps tokens to relation indexes
"""
entity_vocab, inv_entity_vocab = self._standarize_vocab(entity_vocab, inv_entity_vocab)
relation_vocab, inv_relation_vocab = self._standarize_vocab(relation_vocab, inv_relation_vocab)
num_node = len(entity_vocab) if entity_vocab else None
num_relation = len(relation_vocab) if relation_vocab else None
self.graph = data.Graph(triplets, num_node=num_node, num_relation=num_relation)
self.entity_vocab = entity_vocab
self.relation_vocab = relation_vocab
self.inv_entity_vocab = inv_entity_vocab
self.inv_relation_vocab = inv_relation_vocab
[docs] def load_tsv(self, tsv_file, verbose=0):
"""
Load the dataset from a tsv file.
Parameters:
tsv_file (str): file name
verbose (int, optional): output verbose level
"""
inv_entity_vocab = {}
inv_relation_vocab = {}
triplets = []
with open(tsv_file, "r") as fin:
reader = csv.reader(fin, delimiter="\t")
if verbose:
reader = tqdm(reader, "Loading %s" % tsv_file)
for tokens in reader:
h_token, r_token, t_token = tokens
if h_token not in inv_entity_vocab:
inv_entity_vocab[h_token] = len(inv_entity_vocab)
h = inv_entity_vocab[h_token]
if r_token not in inv_relation_vocab:
inv_relation_vocab[r_token] = len(inv_relation_vocab)
r = inv_relation_vocab[r_token]
if t_token not in inv_entity_vocab:
inv_entity_vocab[t_token] = len(inv_entity_vocab)
t = inv_entity_vocab[t_token]
triplets.append((h, t, r))
self.load_triplet(triplets, inv_entity_vocab=inv_entity_vocab, inv_relation_vocab=inv_relation_vocab)
[docs] def load_tsvs(self, tsv_files, verbose=0):
"""
Load the dataset from multiple tsv files.
Parameters:
tsv_files (list of str): list of file names
verbose (int, optional): output verbose level
"""
inv_entity_vocab = {}
inv_relation_vocab = {}
triplets = []
num_samples = []
for tsv_file in tsv_files:
with open(tsv_file, "r") as fin:
reader = csv.reader(fin, delimiter="\t")
if verbose:
reader = tqdm(reader, "Loading %s" % tsv_file, utils.get_line_count(tsv_file))
num_sample = 0
for tokens in reader:
h_token, r_token, t_token = tokens
if h_token not in inv_entity_vocab:
inv_entity_vocab[h_token] = len(inv_entity_vocab)
h = inv_entity_vocab[h_token]
if r_token not in inv_relation_vocab:
inv_relation_vocab[r_token] = len(inv_relation_vocab)
r = inv_relation_vocab[r_token]
if t_token not in inv_entity_vocab:
inv_entity_vocab[t_token] = len(inv_entity_vocab)
t = inv_entity_vocab[t_token]
triplets.append((h, t, r))
num_sample += 1
num_samples.append(num_sample)
self.load_triplet(triplets, inv_entity_vocab=inv_entity_vocab, inv_relation_vocab=inv_relation_vocab)
self.num_samples = num_samples
def _standarize_vocab(self, vocab, inverse_vocab):
if vocab is not None:
if isinstance(vocab, dict):
assert set(vocab.keys()) == set(range(len(vocab))), "Vocab keys should be consecutive numbers"
vocab = [vocab[k] for k in range(len(vocab))]
if inverse_vocab is None:
inverse_vocab = {v: i for i, v in enumerate(vocab)}
if inverse_vocab is not None:
assert set(inverse_vocab.values()) == set(range(len(inverse_vocab))), \
"Inverse vocab values should be consecutive numbers"
if vocab is None:
vocab = sorted(inverse_vocab, key=lambda k: inverse_vocab[k])
return vocab, inverse_vocab
@property
def num_entity(self):
"""Number of entities."""
return self.graph.num_node
@property
def num_triplet(self):
"""Number of triplets."""
return self.graph.num_edge
@property
def num_relation(self):
"""Number of relations."""
return self.graph.num_relation
def __getitem__(self, index):
return self.graph.edge_list[index]
def __len__(self):
return self.graph.num_edge
def __repr__(self):
lines = [
"#entity: %d" % self.num_entity,
"#relation: %d" % self.num_relation,
"#triplet: %d" % self.num_triplet,
]
return "%s(\n %s\n)" % (self.__class__.__name__, "\n ".join(lines))
[docs]class ProteinDataset(MoleculeDataset, core.Configurable):
"""
Protein dataset.
Each sample contains a protein graph, and any number of prediction targets.
"""
[docs] @utils.copy_args(data.Protein.from_sequence)
def load_sequence(self, sequences, targets, attributes=None, transform=None, lazy=False, verbose=0, **kwargs):
"""
Load the dataset from protein sequences and targets.
Parameters:
sequences (list of str): protein sequence strings
targets (dict of list): prediction targets
attributes (dict of list): protein-level attributes
transform (Callable, optional): protein sequence transformation function
lazy (bool, optional): if lazy mode is used, the proteins are processed in the dataloader.
This may slow down the data loading process, but save a lot of CPU memory and dataset loading time.
verbose (int, optional): output verbose level
**kwargs
"""
num_sample = len(sequences)
if num_sample > 1000000:
warnings.warn("Preprocessing proteins of a large dataset consumes a lot of CPU memory and time. "
"Use load_sequence(lazy=True) to construct molecules in the dataloader instead.")
for field, target_list in targets.items():
if len(target_list) != num_sample:
raise ValueError("Number of target `%s` doesn't match with number of molecules. "
"Expect %d but found %d" % (field, num_sample, len(target_list)))
self.transform = transform
self.lazy = lazy
self.kwargs = kwargs
self.sequences = []
self.data = []
self.targets = defaultdict(list)
if verbose:
sequences = tqdm(sequences, "Constructing proteins from sequences")
for i, sequence in enumerate(sequences):
if not self.lazy or len(self.data) == 0:
protein = data.Protein.from_sequence(sequence, **kwargs)
else:
protein = None
if attributes is not None:
with protein.graph():
for field in attributes:
setattr(protein, field, attributes[field][i])
self.data.append(protein)
self.sequences.append(sequence)
for field in targets:
self.targets[field].append(targets[field][i])
[docs] @utils.copy_args(load_sequence)
def load_lmdbs(self, lmdb_files, sequence_field="primary", target_fields=None, number_field="num_examples",
transform=None, lazy=False, verbose=0, **kwargs):
"""
Load the dataset from lmdb files.
Parameters:
lmdb_files (list of str): list of lmdb files
sequence_field (str, optional): name of the field of protein sequence in lmdb files
target_fields (list of str, optional): name of target fields in lmdb files
number_field (str, optional): name of the field of sample count in lmdb files
transform (Callable, optional): protein sequence transformation function
lazy (bool, optional): if lazy mode is used, the proteins are processed in the dataloader.
This may slow down the data loading process, but save a lot of CPU memory and dataset loading time.
verbose (int, optional): output verbose level
**kwargs
"""
if target_fields is not None:
target_fields = set(target_fields)
sequences = []
num_samples = []
targets = defaultdict(list)
for lmdb_file in lmdb_files:
env = lmdb.open(lmdb_file, readonly=True, lock=False, readahead=False, meminit=False)
with env.begin(write=False) as txn:
num_sample = pickle.loads(txn.get(number_field.encode()))
for i in range(num_sample):
item = pickle.loads(txn.get(str(i).encode()))
sequences.append(item[sequence_field])
if target_fields:
for field in target_fields:
value = item[field]
if isinstance(value, np.ndarray) and value.size == 1:
value = value.item()
targets[field].append(value)
num_samples.append(num_sample)
self.load_sequence(sequences, targets, attributes=None, transform=transform,
lazy=lazy, verbose=verbose, **kwargs)
self.num_samples = num_samples
[docs] @utils.copy_args(data.Protein.from_molecule)
def load_pdbs(self, pdb_files, transform=None, lazy=False, verbose=0, **kwargs):
"""
Load the dataset from pdb files.
Parameters:
pdb_files (list of str): pdb file names
transform (Callable, optional): protein sequence transformation function
lazy (bool, optional): if lazy mode is used, the proteins are processed in the dataloader.
This may slow down the data loading process, but save a lot of CPU memory and dataset loading time.
verbose (int, optional): output verbose level
**kwargs
"""
num_sample = len(pdb_files)
if num_sample > 1000000:
warnings.warn("Preprocessing proteins of a large dataset consumes a lot of CPU memory and time. "
"Use load_pdbs(lazy=True) to construct molecules in the dataloader instead.")
self.transform = transform
self.lazy = lazy
self.kwargs = kwargs
self.data = []
self.pdb_files = []
self.sequences = []
if verbose:
pdb_files = tqdm(pdb_files, "Constructing proteins from pdbs")
for i, pdb_file in enumerate(pdb_files):
if not lazy or i == 0:
mol = Chem.MolFromPDBFile(pdb_file)
if not mol:
logger.debug("Can't construct molecule from pdb file `%s`. Ignore this sample." % pdb_file)
continue
protein = data.Protein.from_molecule(mol, **kwargs)
if not protein:
logger.debug("Can't construct protein from pdb file `%s`. Ignore this sample." % pdb_file)
continue
else:
protein = None
if hasattr(protein, "residue_feature"):
with protein.residue():
protein.residue_feature = protein.residue_feature.to_sparse()
self.data.append(protein)
self.pdb_files.append(pdb_file)
self.sequences.append(protein.to_sequence() if protein else None)
[docs] @utils.copy_args(load_sequence)
def load_fasta(self, fasta_file, verbose=0, **kwargs):
"""
Load the dataset from a fasta file.
Parameters:
fasta_file (str): file name
verbose (int, optional): output verbose level
**kwargs
"""
with open(fasta_file, "r") as fin:
if verbose:
fin = tqdm(fin, "Loading %s" % fasta_file, utils.get_line_count(fasta_file))
sequences = []
lines = []
for line in fin:
line = line.strip()
if line.startswith(">") and lines:
sequence = "".join(lines)
sequences.append(sequence)
lines = []
else:
lines.append(line)
if lines:
sequence = "".join(lines)
sequences.append(sequence)
return self.load_sequence(sequences, verbose=verbose, **kwargs)
[docs] @utils.copy_args(data.Protein.from_molecule)
def load_pickle(self, pkl_file, transform=None, lazy=False, verbose=0, **kwargs):
"""
Load the dataset from a pickle file.
Parameters:
pkl_file (str): file name
transform (Callable, optional): protein sequence transformation function
lazy (bool, optional): if lazy mode is used, the proteins are processed in the dataloader.
This may slow down the data loading process, but save a lot of CPU memory and dataset loading time.
verbose (int, optional): output verbose level
**kwargs
"""
with utils.smart_open(pkl_file, "rb") as fin:
num_sample = pickle.load(fin)
self.transform = transform
self.lazy = lazy
self.kwargs = kwargs
self.sequences = []
self.pdb_files = []
self.data = []
indexes = range(num_sample)
if verbose:
indexes = tqdm(indexes, "Loading %s" % pkl_file)
for i in indexes:
pdb_file, sequence, protein = pickle.load(fin)
self.sequences.append(sequence)
self.pdb_files.append(pdb_file)
self.data.append(protein)
[docs] def save_pickle(self, pkl_file, verbose=0):
with utils.smart_open(pkl_file, "wb") as fout:
num_sample = len(self.data)
pickle.dump(num_sample, fout)
indexes = range(num_sample)
if verbose:
indexes = tqdm(indexes, "Dumping to %s" % pkl_file)
for i in indexes:
pdb_dir, pdb_name = os.path.split(self.pdb_files[i])
split = os.path.basename(pdb_dir)
pdb_file = os.path.join(split, pdb_name)
pickle.dump((pdb_file, self.sequences[i], self.data[i]), fout)
@property
def residue_feature_dim(self):
"""Dimension of residue features."""
return self.data[0].residue_feature.shape[-1]
[docs]class ProteinPairDataset(ProteinDataset, core.Configurable):
"""
Protein pair dataset.
Each sample contains two protein graphs, and any number of prediction targets.
"""
[docs] @utils.copy_args(data.Protein.from_sequence)
def load_sequence(self, sequences, targets, attributes=None, transform=None, lazy=False, verbose=0, **kwargs):
"""
Load the dataset from protein sequences and targets.
Parameters:
sequences (list of list of str): protein sequence string pairs
targets (dict of list): prediction targets
attributes (dict of list): protein-level attributes
transform (Callable, optional): protein sequence transformation function
lazy (bool, optional): if lazy mode is used, the protein pairs are processed in the dataloader.
This may slow down the data loading process, but save a lot of CPU memory and dataset loading time.
verbose (int, optional): output verbose level
**kwargs
"""
num_sample = len(sequences)
if num_sample > 1000000:
warnings.warn("Preprocessing proteins of a large dataset consumes a lot of CPU memory and time. "
"Use load_sequence(lazy=True) to construct molecules in the dataloader instead.")
for field, target_list in targets.items():
if len(target_list) != num_sample:
raise ValueError("Number of target `%s` doesn't match with number of molecules. "
"Expect %d but found %d" % (field, num_sample, len(target_list)))
self.transform = transform
self.lazy = lazy
self.kwargs = kwargs
self.sequences = []
self.data = []
self.targets = defaultdict(list)
if verbose:
sequences = tqdm(sequences, "Constructing proteins from sequences")
for i, sequence in enumerate(sequences):
if not self.lazy or len(self.data) == 0:
proteins = [data.Protein.from_sequence(seq, **kwargs) for seq in sequence]
else:
proteins = None
if attributes is not None:
for protein in proteins:
with protein.graph():
for field in attributes:
setattr(protein, field, attributes[field][i])
self.data.append(proteins)
self.sequences.append(sequence)
for field in targets:
self.targets[field].append(targets[field][i])
[docs] @utils.copy_args(load_sequence)
def load_lmdbs(self, lmdb_files, sequence_field="primary", target_fields=None, number_field="num_examples",
transform=None, lazy=False, verbose=0, **kwargs):
"""
Load the dataset from lmdb files.
Parameters:
lmdb_files (list of str): file names
sequence_field (str or list of str, optional): names of the fields of protein sequence in lmdb files
target_fields (list of str, optional): name of target fields in lmdb files
number_field (str, optional): name of the field of sample count in lmdb files
transform (Callable, optional): protein sequence transformation function
lazy (bool, optional): if lazy mode is used, the protein pairs are processed in the dataloader.
This may slow down the data loading process, but save a lot of CPU memory and dataset loading time.
verbose (int, optional): output verbose level
**kwargs
"""
if target_fields is not None:
target_fields = set(target_fields)
else:
target_fields = set()
if not isinstance(sequence_field, Sequence):
sequence_field = [sequence_field]
sequences = []
num_samples = []
targets = defaultdict(list)
for lmdb_file in lmdb_files:
env = lmdb.open(lmdb_file, readonly=True, lock=False, readahead=False, meminit=False)
with env.begin(write=False) as txn:
num_sample = pickle.loads(txn.get(number_field.encode()))
for i in range(num_sample):
item = pickle.loads(txn.get(str(i).encode()))
sequences.append([item[field] for field in sequence_field])
for field in target_fields:
value = item[field]
if isinstance(value, np.ndarray) and value.size == 1:
value = value.item()
targets[field].append(value)
num_samples.append(num_sample)
self.load_sequence(sequences, targets, transform=transform, lazy=lazy, verbose=verbose, **kwargs)
self.num_samples = num_samples
@property
def node_feature_dim(self):
"""Dimension of node features."""
return self.data[0][0].node_feature.shape[-1]
@property
def residue_feature_dim(self):
"""Dimension of residue features."""
return self.data[0][0].residue_feature.shape[-1]
[docs]class ProteinLigandDataset(ProteinDataset, core.Configurable):
"""
Protein-ligand dataset.
Each sample contains a protein graph and a molecule graph, and any number of prediction targets.
"""
[docs] @utils.copy_args(data.Protein.from_sequence)
def load_sequence(self, sequences, smiles, targets, num_samples, attributes=None, transform=None,
lazy=False, verbose=0, **kwargs):
"""
Load the dataset from protein sequences, ligand SMILES strings and targets.
Parameters:
sequences (list of str): protein sequence strings
smiles (list of str): ligand SMILES strings
targets (dict of list): prediction targets
num_samples (list of int): numbers of protein-ligand pairs in all splits
attributes (dict of list): protein-level attributes
transform (Callable, optional): protein sequence transformation function
lazy (bool, optional): if lazy mode is used, the protein-ligand pairs are processed in the dataloader.
This may slow down the data loading process, but save a lot of CPU memory and dataset loading time.
verbose (int, optional): output verbose level
**kwargs
"""
num_sample = len(sequences)
if num_sample > 1000000:
warnings.warn("Preprocessing proteins of a large dataset consumes a lot of CPU memory and time. "
"Use load_sequence(lazy=True) to construct molecules in the dataloader instead.")
if len(smiles) != len(sequences):
raise ValueError("Number of smiles doesn't match with number of proteins. "
"Expect %d but found %d" % (num_sample, len(smiles)))
for field, target_list in targets.items():
if len(target_list) != num_sample:
raise ValueError("Number of target `%s` doesn't match with number of molecules. "
"Expect %d but found %d" % (field, num_sample, len(target_list)))
self.transform = transform
self.lazy = lazy
self.kwargs = kwargs
self.sequences = []
self.smiles = []
self.data = []
self.targets = defaultdict(list)
cum_num_samples = [num_samples[0]]
for num in num_samples[1:]:
cum_num_samples.append(cum_num_samples[-1] + num)
_cur_split = 0
if verbose:
sequences = tqdm(sequences, "Constructing proteins from sequences")
for i, (sequence, smile) in enumerate(zip(sequences, smiles)):
if i >= cum_num_samples[_cur_split]:
_cur_split += 1
if not self.lazy or len(self.data) == 0:
protein = data.Protein.from_sequence(sequence, **kwargs)
mol = Chem.MolFromSmiles(smile)
if not mol:
logger.debug("Can't construct molecule from SMILES `%s`. Ignore this sample." % smile)
num_samples[_cur_split] -= 1
continue
mol = data.Molecule.from_molecule(mol)
else:
protein = None
mol = None
if attributes is not None:
with protein.graph():
for field in attributes:
setattr(protein, field, attributes[field][i])
self.data.append([protein, mol])
self.sequences.append(sequence)
self.smiles.append(smile)
for field in targets:
self.targets[field].append(targets[field][i])
return num_samples
[docs] @utils.copy_args(load_sequence)
def load_lmdbs(self, lmdb_files, sequence_field="target", smiles_field="drug", target_fields=None,
number_field="num_examples", transform=None, lazy=False, verbose=0, **kwargs):
"""
Load the dataset from lmdb files.
Parameters:
lmdb_files (list of str): file names
sequence_field (str, optional): name of the field of protein sequence in lmdb files
smiles_field (str, optional): name of the field of ligand SMILES string in lmdb files
target_fields (list of str, optional): name of target fields in lmdb files
number_field (str, optional): name of the field of sample count in lmdb files
transform (Callable, optional): protein sequence transformation function
lazy (bool, optional): if lazy mode is used, the protein-ligand pairs are processed in the dataloader.
This may slow down the data loading process, but save a lot of CPU memory and dataset loading time.
verbose (int, optional): output verbose level
**kwargs
"""
if target_fields is not None:
target_fields = set(target_fields)
sequences = []
smiles = []
num_samples = []
targets = defaultdict(list)
for lmdb_file in lmdb_files:
env = lmdb.open(lmdb_file, readonly=True, lock=False, readahead=False, meminit=False)
with env.begin(write=False) as txn:
num_sample = pickle.loads(txn.get(number_field.encode()))
for i in range(num_sample):
item = pickle.loads(txn.get(str(i).encode()))
sequences.append(item[sequence_field])
smiles.append(item[smiles_field])
if target_fields:
for field in target_fields:
value = item[field]
if isinstance(value, np.ndarray) and value.size == 1:
value = value.item()
targets[field].append(value)
num_samples.append(num_sample)
num_samples = self.load_sequence(sequences, smiles, targets, num_samples, transform=transform,
lazy=lazy, verbose=verbose, **kwargs)
self.num_samples = num_samples
@property
def ligand_node_feature_dim(self):
"""Dimension of node features for ligands."""
return self.data[0][1].node_feature.shape[-1]
@property
def protein_node_feature_dim(self):
"""Dimension of node features for proteins."""
return self.data[0][0].node_feature.shape[-1]
@property
def residue_feature_dim(self):
"""Dimension of residue features for proteins."""
return self.data[0][0].residue_feature.shape[-1]
[docs]class SemiSupervised(torch_data.Dataset, core.Configurable):
"""
Semi-supervised dataset.
Parameters:
dataset (Dataset): supervised dataset
indices (list of int): sample indices to keep supervision
"""
def __init__(self, dataset, indices):
self.dataset = dataset
self.indices = set(indices)
def __getitem__(self, idx):
item = self.dataset[idx]
item["labeled"] = (idx in self.indices)
return item
def __len__(self):
return len(self.dataset)
[docs]def semisupervised(dataset, length):
"""
Randomly construct a semi-supervised dataset based on the given length.
Parameters:
dataset (Dataset): supervised dataset
length (int): length of supervised data to keep
"""
if length > len(dataset):
raise ValueError("Length of labeled data exceeds the length of the dataset")
indexes = torch.randperm(length)[:length].tolist()
return SemiSupervised(dataset, indexes)
[docs]def key_split(dataset, keys, lengths=None, key_lengths=None):
def round_to_boundary(i):
for j in range(min(i, len(dataset) - i)):
if keys[indexes[i - j]] != keys[indexes[i - j - 1]]:
return i - j
if keys[indexes[i + j]] != keys[indexes[i + j - 1]]:
return i + j
if i < len(dataset) - i:
return 0
else:
return len(dataset)
keys = torch.as_tensor(keys)
key_set, keys = torch.unique(keys, return_inverse=True)
perm = torch.randperm(len(key_set))
keys = perm[keys]
indexes = keys.argsort().tolist()
if key_lengths is not None:
assert lengths is None
key2count = keys.bincount()
key_offset = 0
lengths = []
for key_length in key_lengths:
lengths.append(key2count[key_offset: key_offset + key_length].sum().item())
key_offset += key_length
offset = 0
offsets = [offset]
for length in lengths:
offset = round_to_boundary(offset + length)
offsets.append(offset)
offsets[-1] = len(dataset)
return [torch_data.Subset(dataset, indexes[offsets[i]: offsets[i + 1]]) for i in range(len(lengths))]
[docs]def scaffold_split(dataset, lengths):
"""
Randomly split a dataset into new datasets with non-overlapping scaffolds.
Parameters:
dataset (Dataset): dataset to split
lengths (list of int): expected length for each split.
Note the results may be different in length due to rounding.
"""
scaffold2id = {}
keys = []
for sample in dataset:
scaffold = sample["graph"].to_scaffold()
if scaffold not in scaffold2id:
id = len(scaffold2id)
scaffold2id[scaffold] = id
else:
id = scaffold2id[scaffold]
keys.append(id)
return key_split(dataset, keys, lengths)
[docs]def ordered_scaffold_split(dataset, lengths, chirality=True):
"""
Split a dataset into new datasets with non-overlapping scaffolds and sorted w.r.t. number of each scaffold.
Parameters:
dataset (Dataset): dataset to split
lengths (list of int): expected length for each split.
Note the results may be different in length due to rounding.
"""
frac_train, frac_valid, frac_test = 0.8, 0.1, 0.1
scaffold2id = defaultdict(list)
for idx, smiles in enumerate(dataset.smiles_list):
scaffold = MurckoScaffold.MurckoScaffoldSmiles(smiles=smiles, includeChirality=chirality)
scaffold2id[scaffold].append(idx)
scaffold2id = {key: sorted(value) for key, value in scaffold2id.items()}
scaffold_sets = [
scaffold_set for (scaffold, scaffold_set) in sorted(
scaffold2id.items(), key=lambda x: (len(x[1]), x[1][0]), reverse=True)
]
train_cutoff = frac_train * len(dataset)
valid_cutoff = (frac_train + frac_valid) * len(dataset)
train_idx, valid_idx, test_idx = [], [], []
for scaffold_set in scaffold_sets:
if len(train_idx) + len(scaffold_set) > train_cutoff:
if len(train_idx) + len(valid_idx) + len(scaffold_set) > valid_cutoff:
test_idx.extend(scaffold_set)
else:
valid_idx.extend(scaffold_set)
else:
train_idx.extend(scaffold_set)
return torch_data.Subset(dataset, train_idx), torch_data.Subset(dataset, valid_idx), torch_data.Subset(dataset, test_idx)