import math
import torch
from torch_scatter import scatter_max
from torchdrug import utils
class PerfectHash(object):
"""
Perfect hash function.
The function can be applied to either scalar keys or vector keys.
It takes :math:`O(n\log n)` time and :math:`O(n)` space to construct the hash table.
It maps queries to their indexes in the original key set in :math:`O(1)` time.
If the query is not present in the key set, it returns -1.
The algorithm is adapted from `Storing a Sparse Table with O(1) Worst Case Access Time`_.
.. _Storing a Sparse Table with O(1) Worst Case Access Time:
http://citeseerx.ist.psu.edu/viewdoc/download?doi=10.1.1.91.346&rep=rep1&type=pdf
Parameters:
keys (LongTensor): keys of shape :math:`(N,)` or :math:`(N, D)`
weight (LongTensor, optional): weight of the level-1 hash
bias (LongTensor, optional): bias of the level-1 hash
sub_weights (LongTensor, optional): weight of level-2 hashes
sub_biases (LongTensor, optional): bias of level-2 hashes
"""
prime = 1000000000039
max_attempt = 10
max_input_dim = (torch.iinfo(torch.int64).max - prime) / prime
def __init__(self, keys, weight=None, bias=None, sub_weights=None, sub_biases=None):
if keys.ndim == 1:
keys = keys.unsqueeze(-1)
num_input, input_dim = keys.shape
if weight is None:
weight = torch.randint(0, self.prime, (1, input_dim), device=keys.device)
if bias is None:
bias = torch.randint(0, self.prime, (1,), device=keys.device)
if sub_weights is None:
sub_weights = torch.randint(0, self.prime, (num_input, input_dim), device=keys.device)
if sub_biases is None:
sub_biases = torch.randint(0, self.prime, (num_input,), device=keys.device)
self.keys = keys
self.weight = weight
self.bias = bias
self.sub_weights = sub_weights
self.sub_biases = sub_biases
self.num_input = num_input
self.num_output = num_input
self.input_dim = input_dim
self._construct_hash_table()
def _construct_hash_table(self):
index = self.hash(self.keys)
count = index.bincount(minlength=self.num_output)
for i in range(self.max_attempt):
if (count ** 2).sum() < 4 * self.num_output:
break
self._reset_hash()
index = self.hash(self.keys)
count = index.bincount(minlength=self.num_output)
else:
raise RuntimeError("Fail to generate a level-1 hash after %d attempts. "
"Are you sure the keys are unique?" % self.max_attempt)
self.num_sub_outputs = (count ** 2).clamp(min=1)
self.num_sub_output = self.num_sub_outputs.sum()
self.offsets = self.num_sub_outputs.cumsum(0) - self.num_sub_outputs
sub_index = self.sub_hash(self.keys, index)
count = sub_index.bincount(minlength=self.num_sub_output)
has_collision = scatter_max(count, self.second2first, dim_size=self.num_output)[0] > 1
max_attempt = int(self.max_attempt * math.log(self.num_input) / math.log(2))
for i in range(max_attempt):
if not has_collision.any():
break
self._reset_sub_hash(has_collision)
sub_index = self.sub_hash(self.keys, index)
count = sub_index.bincount(minlength=self.num_sub_output)
has_collision = scatter_max(count, self.second2first, dim_size=self.num_output)[0] > 1
else:
raise RuntimeError("Fail to generate level-2 hashes after %d attempts. "
"Are you sure the keys are unique?" % max_attempt)
self.table = -torch.ones(self.num_sub_output, dtype=torch.long, device=self.device)
self.table[sub_index] = torch.arange(self.num_input, device=self.device)
def __call__(self, keys):
"""
Get the indexes of keys in the original key set.
Return -1 for keys that are not present in the key set.
"""
keys = torch.as_tensor(keys, dtype=torch.long, device=self.device)
if self.input_dim == 1 and keys.shape[-1] != 1:
keys = keys.unsqueeze(-1)
index = self.hash(keys)
sub_index = self.sub_hash(keys, index)
final_index = self.table[sub_index]
found = final_index != -1
found_index = final_index[found]
equal = (keys[found] == self.keys[final_index[found]]).all(dim=-1)
final_index[found] = torch.where(equal, found_index, -torch.ones_like(found_index))
return final_index
def _reset_hash(self):
self.weight = torch.randint_like(self.weight, 0, self.prime)
self.bias = torch.randint_like(self.bias, 0, self.prime)
def _reset_sub_hash(self, mask=None):
if mask is None:
self.sub_weights = torch.randint_like(self.sub_weights, 0, self.prime)
self.sub_biases = torch.randint_like(self.sub_biases, 0, self.prime)
else:
self.sub_weights[mask] = torch.randint_like(self.sub_weights[mask], 0, self.prime)
self.sub_biases[mask] = torch.randint_like(self.sub_biases[mask], 0, self.prime)
def hash(self, keys):
"""Apply the level-1 hash function to the keys."""
keys = keys % self.prime
hash = (keys * self.weight % self.prime).sum(dim=-1) + self.bias
return hash % self.prime % self.num_output
def sub_hash(self, keys, index):
"""
Apply level-2 hash functions to the keys.
Parameters:
keys (LongTensor): query keys
index (LongTensor): output of the level-1 hash function
"""
keys = keys % self.prime
weight = self.sub_weights[index]
bias = self.sub_biases[index]
num_outputs = self.num_sub_outputs[index]
offsets = self.offsets[index]
hash = (keys * weight % self.prime).sum(dim=-1) + bias
return hash % self.prime % num_outputs + offsets
@utils.cached_property
def second2first(self):
"""Level-2 hash values to level-1 hash values mapping."""
range = torch.arange(self.num_output, device=self.device)
second2first = range.repeat_interleave(self.num_sub_outputs)
return second2first
@property
def device(self):
"""Device."""
return self.keys.device
def cpu(self):
"""
Return a copy of this hash function in CPU memory.
This is a non-op if the hash function is already in CPU memory.
"""
keys = self.keys.cpu()
if keys is self.keys:
return self
else:
return type(self)(keys, weight=self.weight.cpu(), bias=self.bias.cpu(),
sub_weights=self.sub_weights.cpu(), sub_biases=self.sub_biases.cpu())
def cuda(self, *args, **kwargs):
"""
Return a copy of this hash function in CUDA memory.
This is a non-op if the hash function is already on the correct device.
"""
keys = self.keys.cuda(*args, **kwargs)
if keys is self.keys:
return self
else:
return type(self)(keys, weight=self.weight.cuda(*args, **kwargs),
bias=self.bias.cuda(*args, **kwargs),
sub_weights=self.sub_weights.cuda(*args, **kwargs),
sub_biases=self.sub_biases.cuda(*args, **kwargs))
[docs]class Dictionary(object):
"""
Dictionary for mapping keys to values.
This class has the same behavior as the built-in dict, except it operates on tensors and support batching.
Example::
>>> keys = torch.tensor([[0, 0], [1, 1], [2, 2]])
>>> values = torch.tensor([[0, 1], [1, 2], [2, 3]])
>>> d = data.Dictionary(keys, values)
>>> assert (d[[[0, 0], [2, 2]]] == values[[0, 2]]).all()
>>> assert (d.has_key([[0, 1], [1, 2]]) == torch.tensor([False, False])).all()
Parameters:
keys (LongTensor): keys of shape :math:`(N,)` or :math:`(N, D)`
values (Tensor): values of shape :math:`(N, ...)`
hash (PerfectHash, optional): hash function for keys
"""
def __init__(self, keys, values, hash=None):
self.keys = keys
self.values = values
self.hash = hash or PerfectHash(keys)
def __getitem__(self, keys):
"""
Return the value for each key. Raise key error if any key is not in the dictionary.
"""
keys = torch.as_tensor(keys, dtype=torch.long, device=self.device)
index = self.hash(keys)
not_found = index == -1
if not_found.any():
raise KeyError(keys[not_found].tolist())
return self.values[index]
[docs] def get(self, keys, default=None):
"""
Return the value for each key if the key is in the dictionary, otherwise the default value is returned.
Parameters:
keys (LongTensor): keys of arbitrary shape
default (int or Tensor, optional): default return value. By default, 0 is used.
"""
keys = torch.as_tensor(keys, dtype=torch.long, device=self.device)
if default is None:
default = 0
default = torch.as_tensor(default, dtype=self.values.dtype, device=self.device)
index = self.hash(keys)
shape = list(index.shape) + list(self.values.shape[1:])
values = torch.ones(shape, dtype=self.values.dtype, device=self.device) * default
found = index != -1
values[found] = self.values[index[found]]
return values
[docs] def has_key(self, keys):
"""Check whether each key exists in the dictionary."""
index = self.hash(keys)
return index != -1
[docs] def to_dict(self):
"""
Return a built-in dict object of this dictionary.
"""
keys = self.keys.tolist()
values = self.values.tolist()
dict = {tuple(k): tuple(v) for k, v in zip(keys, values)}
return dict
@property
def device(self):
"""Device."""
return self.keys.device
[docs] def cpu(self):
"""
Return a copy of this dictionary in CPU memory.
This is a non-op if the dictionary is already in CPU memory.
"""
keys = self.keys.cpu()
if keys is self.keys:
return self
else:
return type(self)(keys, self.values.cpu(), hash=self.hash.cpu())
[docs] def cuda(self, *args, **kwargs):
"""
Return a copy of this dictionary in CUDA memory.
This is a non-op if the dictionary is already in CUDA memory.
"""
keys = self.keys.cuda(*args, **kwargs)
if keys is self.keys:
return self
else:
return type(self)(keys, self.values.cuda(*args, **kwargs), hash=self.hash.cuda(*args, **kwargs))