Source code for torchdrug.data.dataloader

from collections import deque
from collections.abc import Mapping, Sequence

import torch

from torchdrug import data


[docs]def graph_collate(batch): """ Convert any list of same nested container into a container of tensors. For instances of :class:`data.Graph <torchdrug.data.Graph>`, they are collated by :meth:`data.Graph.pack <torchdrug.data.Graph.pack>`. Parameters: batch (list): list of samples with the same nested container """ elem = batch[0] if isinstance(elem, torch.Tensor): out = None if torch.utils.data.get_worker_info() is not None: numel = sum([x.numel() for x in batch]) storage = elem.storage()._new_shared(numel) out = elem.new(storage) return torch.stack(batch, 0, out=out) elif isinstance(elem, float): return torch.tensor(batch, dtype=torch.float) elif isinstance(elem, int): return torch.tensor(batch) elif isinstance(elem, (str, bytes)): return batch elif isinstance(elem, data.Graph): return elem.pack(batch) elif isinstance(elem, Mapping): return {key: graph_collate([d[key] for d in batch]) for key in elem} elif isinstance(elem, Sequence): it = iter(batch) elem_size = len(next(it)) if not all(len(elem) == elem_size for elem in it): raise RuntimeError('Each element in list of batch should be of equal size') return [graph_collate(samples) for samples in zip(*batch)] raise TypeError("Can't collate data with type `%s`" % type(elem))
[docs]class DataLoader(torch.utils.data.DataLoader): """ Extended data loader for batching graph structured data. See `torch.utils.data.DataLoader`_ for more details. .. _torch.utils.data.DataLoader: https://pytorch.org/docs/stable/data.html#torch.utils.data.DataLoader Parameters: dataset (Dataset): dataset from which to load the data batch_size (int, optional): how many samples per batch to load shuffle (bool, optional): set to ``True`` to have the data reshuffled at every epoch sampler (Sampler, optional): sampler that draws single sample from the dataset batch_sampler (Sampler, optional): sampler that draws a mini-batch of data from the dataset num_workers (int, optional): how many subprocesses to use for data loading collate_fn (callable, optional): merge a list of samples into a mini-batch kwargs: keyword arguments for `torch.utils.data.DataLoader`_ """ def __init__(self, dataset, batch_size=1, shuffle=False, sampler=None, batch_sampler=None, num_workers=0, collate_fn=graph_collate, **kwargs): super(DataLoader, self).__init__(dataset, batch_size, shuffle, sampler, batch_sampler, num_workers, collate_fn, **kwargs)
class DataQueue(torch.utils.data.Dataset): def __init__(self): self.queue = deque() def append(self, item): self.queue.append(item) def pop(self): self.queue.popleft() def __getitem__(self, index): return self.queue[index] def __len__(self): return len(self.deque) class ExperienceReplay(torch.utils.data.DataLoader): def __init__(self, cache_size, batch_size=1, shuffle=True, **kwargs): super(ExperienceReplay, self).__init__(DataQueue(), batch_size, shuffle, **kwargs) self.cache_size = cache_size def update(self, items): for item in items: self.dataset.append(item) while len(self.dataset) > self.cache_size: self.dataset.pop() @property def cold(self): return len(self.dataset) < self.cache_size